aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Duarte <licorne@google.com>2021-12-22 13:03:59 +0000
committerDavid Duarte <licorne@google.com>2022-01-10 12:59:06 +0000
commit2fed870203ce3567f58e3506d769d2edc80184c6 (patch)
tree220547fe0be73abab673e6beae38a10c7efafe5a
parentb18f8bfb854667b3c11a2b4a3f771b78fa01119a (diff)
downloadmmi2grpc-2fed870203ce3567f58e3506d769d2edc80184c6.tar.gz
protoc-gen-custom_grpc: Support stream input mode
Change-Id: I2b215bf9f14c2b25fa75cb1272b836122629c0eb
-rwxr-xr-xprotoc-gen-custom_grpc27
1 files changed, 17 insertions, 10 deletions
diff --git a/protoc-gen-custom_grpc b/protoc-gen-custom_grpc
index 3a0c8be..eb5ec11 100755
--- a/protoc-gen-custom_grpc
+++ b/protoc-gen-custom_grpc
@@ -34,16 +34,23 @@ def generate_method(imports, file, service, method):
output_type = import_type(imports, method.output_type)
if input_mode == 'stream':
- raise Error("TODO: stream as input type")
-
- return (
- f'def {method.name}(self, wait_for_ready=None, **kwargs):\n'
- f' return self.channel.{input_mode}_{output_mode}(\n'
- f" '/{file.package}.{service.name}/{method.name}',\n"
- f' request_serializer={input_type}.SerializeToString,\n'
- f' response_deserializer={output_type}.FromString\n'
- f' )({input_type}(**kwargs), wait_for_ready=wait_for_ready)'
- ).split('\n')
+ return (
+ f'def {method.name}(self, iterator, **kwargs):\n'
+ f' return self.channel.{input_mode}_{output_mode}(\n'
+ f" '/{file.package}.{service.name}/{method.name}',\n"
+ f' request_serializer={input_type}.SerializeToString,\n'
+ f' response_deserializer={output_type}.FromString\n'
+ f' )(iterator, **kwargs)'
+ ).split('\n')
+ else:
+ return (
+ f'def {method.name}(self, wait_for_ready=None, **kwargs):\n'
+ f' return self.channel.{input_mode}_{output_mode}(\n'
+ f" '/{file.package}.{service.name}/{method.name}',\n"
+ f' request_serializer={input_type}.SerializeToString,\n'
+ f' response_deserializer={output_type}.FromString\n'
+ f' )({input_type}(**kwargs), wait_for_ready=wait_for_ready)'
+ ).split('\n')
def generate_service(imports, file, service):