From 2e8ffbc9411e08f50515421ca69acf5c1a4089b6 Mon Sep 17 00:00:00 2001 From: uael Date: Thu, 16 Feb 2023 01:34:35 +0000 Subject: gen: completely rework the generator Change-Id: I4722fbf0515f240c8c425728e7abc1e12ecf1aa3 --- python/_build/protoc-gen-custom_grpc | 374 ++++++++++++++++++----------------- python/pyproject.toml | 4 +- 2 files changed, 198 insertions(+), 180 deletions(-) diff --git a/python/_build/protoc-gen-custom_grpc b/python/_build/protoc-gen-custom_grpc index 1b306cd..37b31b7 100755 --- a/python/_build/protoc-gen-custom_grpc +++ b/python/_build/protoc-gen-custom_grpc @@ -54,33 +54,30 @@ def find_type(package: str, type_name: str) -> Tuple[FileDescriptorProto, Union[ raise Exception(f'Type {package}.{type_name} not found') -def import_type(imports: List[str], file: FileDescriptorProto, type: str, use_stubs: bool=True) -> Tuple[str, Union[DescriptorProto, EnumDescriptorProto]]: +def add_import(imports: List[str], import_str: str) -> None: + if not import_str in imports: + imports.append(import_str) + + +def import_type(imports: List[str], type: str, local: Optional[FileDescriptorProto]) -> Tuple[str, Union[DescriptorProto, EnumDescriptorProto], str]: package = type[1:type.rindex('.')] type_name = type[type.rindex('.')+1:] - type_file, desc = find_type(package, type_name) - if use_stubs and type_file == file: - return f'{type_name}', desc - suffix = '_pb2' - if use_stubs and next((True for x in _REQUEST.file_to_generate if x == type_file.name), False): - suffix = '_grpc' - python_path = type_file.name.replace('.proto', '').replace('/', '.') + file, desc = find_type(package, type_name) + if file == local: + return f'{type_name}', desc, '' + python_path = file.name.replace('.proto', '').replace('/', '.') module_path = python_path[:python_path.rindex('.')] - module_name = python_path[python_path.rindex('.')+1:] + suffix - if not f'from {module_path} import {module_name}' in imports: - imports.append(f'from {module_path} import {module_name}') - return f'{module_name}.{type_name}', desc - - -def generate_enum(imports: List[str], file: FileDescriptorProto, enum: EnumDescriptorProto) -> List[str]: - return [ - f'class {enum.name}(enum.IntEnum):', - *[f' {value.name} = {value.number}' for value in enum.value], - '' - ] + module_name = python_path[python_path.rindex('.')+1:] + '_pb2' + add_import(imports, f'from {module_path} import {module_name}') + dft_import = '' + if isinstance(desc, EnumDescriptorProto): + dft_import = f'from {module_path}.{module_name} import {desc.value[0].name}' + return f'{module_name}.{type_name}', desc, dft_import -def generate_type(imports: List[str], file: FileDescriptorProto, parent: DescriptorProto, field: FieldDescriptorProto) -> Tuple[str, str]: +def collect_type(imports: List[str], parent: DescriptorProto, field: FieldDescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[str, str, str]: dft: str + dft_import: str = '' if field.type == FieldDescriptor.TYPE_BYTES: type = 'bytes' dft = 'b\'\'' @@ -118,78 +115,94 @@ def generate_type(imports: List[str], file: FileDescriptorProto, parent: Descrip if nested_type.name == type: assert nested_type.options.map_entry assert field.label == FieldDescriptor.LABEL_REPEATED - key_type, _ = generate_type(imports, file, nested_type, nested_type.field[0]) - val_type, _ = generate_type(imports, file, nested_type, nested_type.field[1]) - return f'Dict[{key_type}, {val_type}]', '{}' - type, desc = import_type(imports, file, field.type_name) + key_type, _, _ = collect_type(imports, nested_type, nested_type.field[0], local) + val_type, _, _ = collect_type(imports, nested_type, nested_type.field[1], local) + add_import(imports, 'from typing import Dict') + return f'Dict[{key_type}, {val_type}]', '{}', '' + type, desc, enum_dft = import_type(imports, field.type_name, local) if isinstance(desc, EnumDescriptorProto): - dft = f'{type}.{desc.value[0].name}' + dft_import = enum_dft + dft = desc.value[0].name else: dft = f'{type}()' - type = f"'{type}'" else: raise Exception(f'TODO: {field}') if field.label == FieldDescriptor.LABEL_REPEATED: + add_import(imports, 'from typing import List') type = f'List[{type}]' dft = '[]' - return type, dft + return type, dft, dft_import -def generate_field(imports: List[str], file: FileDescriptorProto, message: DescriptorProto, field: FieldDescriptorProto) -> Tuple[Optional[int], str, str, str]: - type, dft = generate_type(imports, file, message, field) +def collect_field(imports: List[str], message: DescriptorProto, field: FieldDescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[Optional[int], str, str, str, str]: + type, dft, dft_import = collect_type(imports, message, field, local) oneof_index = field.oneof_index if 'oneof_index' in f'{field}' else None - return oneof_index, field.name, type, dft - + return oneof_index, field.name, type, dft, dft_import -# TODO(uael): refactor the use of this global. -_MESSAGES: Dict[str, Tuple[List[str], List[str]]] = {} - -# TODO(uael): refactor the next function to be readable. -def generate_message(imports: List[str], overrides: list[str], file: FileDescriptorProto, message: DescriptorProto) -> List[str]: +def collect_message(imports: List[str], message: DescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[ + List[Tuple[str, str, str]], + Dict[str, list[Tuple[str, str]]], +]: + fields: List[Tuple[str, str, str]] = [] oneof: Dict[str, list[Tuple[str, str]]] = {} - pb2 = os.path.basename(file.name).replace('.proto', '_pb2') - base = f'{pb2}.{message.name}' - message_lines: List[str] = [ - f'@dataclass', - f'class {message.name}(Message):', - ] - - parameters: List[str] = [] - parameters_name: List[str] = [] for field in message.field: - idx, name, type, dft = generate_field(imports, file, message, field) - parameters_name.append(f'{name}={name}') + idx, name, type, dft, dft_import = collect_field(imports, message, field, local) if idx is not None: oneof_name = message.oneof_decl[idx].name oneof.setdefault(oneof_name, []) oneof[oneof_name].append((name, type)) else: - parameters.append(f'{name}: {type} = {dft}') - dft = dft.replace('[]', 'field(default_factory=list)').replace('{}', 'field(default_factory=dict)') - message_lines.append(f' {name}: {type} = {dft}') + add_import(imports, dft_import) + fields.append((name, type, dft)) for oneof_name, oneof_fields in oneof.items(): - if len(message_lines) > 2: message_lines.append('') - message_lines.append(f' # Oneof `{oneof_name}` variants.') for name, type in oneof_fields: - parameters.append(f'{name}: Optional[{type}] = None') - message_lines.append(f' {name}: Optional[{type}] = None') + add_import(imports, 'from typing import Optional') + fields.append((name, f'Optional[{type}]', 'None')) - for oneof_name, oneof_fields in oneof.items(): - literals: str = ', '.join((f'Literal[\'{name}\']' for name, _ in oneof_fields)) + return fields, oneof + + +def generate_enum(imports: List[str], file: FileDescriptorProto, enum: EnumDescriptorProto, res: List[CodeGeneratorResponse.File]) -> List[str]: + res.append(CodeGeneratorResponse.File( + name=file.name.replace('.proto', '_pb2.py'), + insertion_point=f'module_scope', + content=f'class {enum.name}: ...\n\n' + )) + add_import(imports, 'from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper') + return [ + f'class {enum.name}(int, EnumTypeWrapper):', + f' pass', + f'', + *[f'{value.name}: {enum.name}' for value in enum.value], + '' + ] - overrides.extend([ - f'def _{message.name}_{oneof_name}_variant(self: {message.name}) -> Optional[str]:', - f' return self.WhichOneof(\'{oneof_name}\') # type: ignore', - '', - f'setattr({base}, \'{oneof_name}_variant\', _{message.name}_{oneof_name}_variant)', - '' - ]) +def generate_message(imports: List[str], file: FileDescriptorProto, message: DescriptorProto, res: List[CodeGeneratorResponse.File]) -> List[str]: + nested_message_lines: List[str] = [] + message_lines: List[str] = [f'class {message.name}(Message):'] + + add_import(imports, 'from google.protobuf.message import Message') + fields, oneof = collect_message(imports, message, file) + + for (name, type, _) in fields: + message_lines.append(f' {name}: {type}') + + args = ', '.join([f'{name}: {type} = {dft}' for name, type, dft in fields]) + if args: args = ', ' + args + message_lines.extend([ + f'', + f' def __init__(self{args}) -> None: ...', + f'' + ]) + + for oneof_name, oneof_fields in oneof.items(): + literals: str = ', '.join((f'Literal[\'{name}\']' for name, _ in oneof_fields)) types: Set[str] = set((type for _, type in oneof_fields)) if len(types) == 1: type = 'Optional[' + types.pop() + ']' @@ -197,85 +210,88 @@ def generate_message(imports: List[str], overrides: list[str], file: FileDescrip types.add('None') type = 'Union[' + ', '.join(types) + ']' - message_lines.extend([ - '', - ' @property', - f' def {oneof_name}(self) -> {type}: ...' - ]) - - overrides.extend([ - f'def _{message.name}_{oneof_name}(self: {message.name}) -> {type}:', - f' variant: Optional[str] = self.{oneof_name}_variant()', - f' if variant is None: return None', - '\n'.join([f' if variant == \'{name}\': return unwrap(self.{name})' for name, _ in oneof_fields]), - f' raise Exception(\'Field `{oneof_name}` not found.\')', - '', - f'setattr({base}, \'{oneof_name}\', property(_{message.name}_{oneof_name}))', - '' + nested_message_lines.extend([ + f'class {message.name}_{oneof_name}_dict(TypedDict, total=False):', + '\n'.join([f' {name}: {type}' for name, type in oneof_fields]), + f'', ]) + add_import(imports, 'from typing import Union') + add_import(imports, 'from typing_extensions import TypedDict') + add_import(imports, 'from typing_extensions import Literal') message_lines.extend([ + f' @property', + f' def {oneof_name}(self) -> {type}: ...' f'', - f' class {oneof_name}_dict(TypedDict, total=False):', - '\n'.join([f' {name}: {type}' for name, type in oneof_fields]), + f' def {oneof_name}_variant(self) -> Union[{literals}, None]: ...' f'', - f' def {oneof_name}_variant(self) -> Union[{literals}, None]: ...' + f' def {oneof_name}_asdict(self) -> {message.name}_{oneof_name}_dict: ...', f'', - f' def {oneof_name}_asdict(self) -> \'{message.name}.{oneof_name}_dict\': ...' ]) - overrides.extend([ - f'def _{message.name}_{oneof_name}_asdict(self: {message.name}) -> \'{message.name}.{oneof_name}_dict\':', - f' variant: Optional[str] = self.{oneof_name}_variant()', - f' if variant is None: return {{}}', - '\n'.join([f' if variant == \'{name}\': return {{\'{name}\': unwrap(self.{name})}}' for name, _ in oneof_fields]), - f' raise Exception(\'Field `{oneof_name}` not found.\')', - '', - f'setattr({base}, \'{oneof_name}_asdict\', _{message.name}_{oneof_name}_asdict)', - '' - ]) + return_variant = '\n '.join([f'if variant == \'{name}\': return unwrap(self.{name})' for name, _ in oneof_fields]) + return_asdict = '\n '.join([f'if variant == \'{name}\': return {{\'{name}\': unwrap(self.{name})}} # type: ignore' for name, _ in oneof_fields]) + if return_variant: return_variant += '\n ' + if return_asdict: return_asdict += '\n ' - _MESSAGES[message.name] = parameters, parameters_name + res.append(CodeGeneratorResponse.File( + name=file.name.replace('.proto', '_pb2.py'), + insertion_point=f'module_scope', + content=f""" +def _{message.name}_{oneof_name}(self: {message.name}): + variant = self.{oneof_name}_variant() + if variant is None: return None + {return_variant}raise Exception('Field `{oneof_name}` not found.') - if len(message_lines) == 2: - message_lines.append(' pass') - message_lines.append('') +def _{message.name}_{oneof_name}_variant(self: {message.name}): + return self.WhichOneof('{oneof_name}') # type: ignore - message_lines.extend([ - f'setattr({message.name}, \'__new__\', lambda _, *args, **kwargs: {base}(*args, **kwargs)) # type: ignore', - '' - ]) +def _{message.name}_{oneof_name}_asdict(self: {message.name}): + variant = self.{oneof_name}_variant() + if variant is None: return {{}} + {return_asdict}raise Exception('Field `{oneof_name}` not found.') + +setattr({message.name}, '{oneof_name}', property(_{message.name}_{oneof_name})) +setattr({message.name}, '{oneof_name}_variant', _{message.name}_{oneof_name}_variant) +setattr({message.name}, '{oneof_name}_asdict', _{message.name}_{oneof_name}_asdict) +""")) - return message_lines + return message_lines + nested_message_lines def generate_service_method(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, method: MethodDescriptorProto, sync: bool = True) -> List[str]: input_mode = 'stream' if method.client_streaming else 'unary' output_mode = 'stream' if method.server_streaming else 'unary' - input_type, _ = import_type(imports, file, method.input_type) - output_type, _ = import_type(imports, file, method.output_type) + input_type, input_msg, _ = import_type(imports, method.input_type, None) + output_type, _, _ = import_type(imports, method.output_type, None) - input_type_pb2, _ = import_type(imports, file, method.input_type, False) - output_type_pb2, _ = import_type(imports, file, method.output_type, False) - - iterator_type = 'Iterator' if sync else 'AsyncIterator' - - parameters, parameters_name = _MESSAGES.get(input_type, ([], [])) - args = ', '.join(parameters) - args_name = ', '.join(parameters_name) - - if args: args = ', ' + args + input_type_pb2, _, _ = import_type(imports, method.input_type, None) + output_type_pb2, _, _ = import_type(imports, method.output_type, None) if output_mode == 'stream': if input_mode == 'stream': output_type_hint = f'StreamStream[{input_type}, {output_type}]' + if sync: + add_import(imports, f'from {file.package}._utils import Sender') + add_import(imports, f'from {file.package}._utils import Stream') + add_import(imports, f'from {file.package}._utils import StreamStream') + else: + add_import(imports, f'from {file.package}._utils import AioSender as Sender') + add_import(imports, f'from {file.package}._utils import AioStream as Stream') + add_import(imports, f'from {file.package}._utils import AioStreamStream as StreamStream') else: output_type_hint = f'Stream[{output_type}]' + if sync: + add_import(imports, f'from {file.package}._utils import Stream') + else: + add_import(imports, f'from {file.package}._utils import AioStream as Stream') else: output_type_hint = output_type if sync else f'Awaitable[{output_type}]' + if not sync: add_import(imports, f'from typing import Awaitable') if input_mode == 'stream' and output_mode == 'stream': + add_import(imports, f'from typing import Optional') return ( f'def {method.name}(self, timeout: Optional[float] = None) -> {output_type_hint}:\n' f' tx: Sender[{input_type}] = Sender()\n' @@ -287,6 +303,9 @@ def generate_service_method(imports: List[str], file: FileDescriptorProto, servi f' return StreamStream(tx, rx)' ).split('\n') if input_mode == 'stream': + iterator_type = 'Iterator' if sync else 'AsyncIterator' + add_import(imports, f'from typing import {iterator_type}') + add_import(imports, f'from typing import Optional') return ( f'def {method.name}(self, iterator: {iterator_type}[{input_type}], timeout: Optional[float] = None) -> {output_type_hint}:\n' f' return self.channel.{input_mode}_{output_mode}( # type: ignore\n' @@ -296,6 +315,12 @@ def generate_service_method(imports: List[str], file: FileDescriptorProto, servi f' )(iterator)' ).split('\n') else: + add_import(imports, f'from typing import Optional') + assert isinstance(input_msg, DescriptorProto) + input_fields, _ = collect_message(imports, input_msg, None) + args = ', '.join([f'{name}: {type} = {dft}' for name, type, dft in input_fields]) + args_name = ', '.join([f'{name}={name}' for name, _, _ in input_fields]) + if args: args = ', ' + args return ( f'def {method.name}(self{args}, wait_for_ready: Optional[bool] = None, timeout: Optional[float] = None) -> {output_type_hint}:\n' f' return self.channel.{input_mode}_{output_mode}( # type: ignore\n' @@ -317,7 +342,7 @@ def generate_service(imports: List[str], file: FileDescriptorProto, service: Ser f'class {service.name}:\n' f' channel: {channel_type}\n' f'\n' - f' def __init__(self, channel: {channel_type}):\n' + f' def __init__(self, channel: {channel_type}) -> None:\n' f' self.channel = channel\n' f'\n' f' {methods}\n' @@ -328,17 +353,25 @@ def generate_servicer_method(imports: List[str], method: MethodDescriptorProto, input_mode = 'stream' if method.client_streaming else 'unary' output_mode = 'stream' if method.server_streaming else 'unary' - input_type, _ = import_type(imports, file, method.input_type) - output_type, _ = import_type(imports, file, method.output_type) + input_type, _, _ = import_type(imports, method.input_type, None) + output_type, _, _ = import_type(imports, method.output_type, None) output_type_hint = output_type if output_mode == 'stream': - output_type_hint = f'Generator[{output_type}, None, None]' if sync else f'AsyncGenerator[{output_type}, None]' + if sync: + output_type_hint = f'Generator[{output_type}, None, None]' + add_import(imports, f'from typing import Generator') + else: + output_type_hint = f'AsyncGenerator[{output_type}, None]' + add_import(imports, f'from typing import AsyncGenerator') + + iterator_type = 'Iterator' if sync else 'AsyncIterator' if input_mode == 'stream': - input_stream_type = ('Iterator' if sync else 'AsyncIterator') + f'[{input_type}]' + iterator_type = 'Iterator' if sync else 'AsyncIterator' + add_import(imports, f'from typing import {iterator_type}') lines = (('' if sync else 'async ') + ( - f'def {method.name}(self, request: {input_stream_type}, context: grpc.ServicerContext) -> {output_type_hint}:\n' + f'def {method.name}(self, request: {iterator_type}[{input_type}], context: grpc.ServicerContext) -> {output_type_hint}:\n' f' context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore\n' f' context.set_details("Method not implemented!") # type: ignore\n' f' raise NotImplementedError("Method not implemented!")' @@ -373,8 +406,8 @@ def generate_rpc_method_handler(imports: List[str], method: MethodDescriptorProt input_mode = 'stream' if method.client_streaming else 'unary' output_mode = 'stream' if method.server_streaming else 'unary' - input_type, _ = import_type(imports, file, method.input_type, False) - output_type, _ = import_type(imports, file, method.output_type, False) + input_type, _, _ = import_type(imports, method.input_type, None) + output_type, _, _ = import_type(imports, method.output_type, None) return ( f"'{method.name}': grpc.{input_mode}_{output_mode}_rpc_method_handler( # type: ignore\n" @@ -418,50 +451,22 @@ _HEADER = '''# Copyright 2022 Google LLC # limitations under the License. """Generated python gRPC interfaces.""" - -__version__ = "0.0.1" - -import enum -import grpc - -from dataclasses import dataclass, field -from typing import Dict, Generator, Optional, List, Literal, Union, Iterator, AsyncGenerator, AsyncIterator, Awaitable, TypeVar, TypedDict - -from google.protobuf.message import Message ''' -_UTILS_PY = '''# Copyright 2022 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Generated python gRPC helpers.""" +_UTILS_PY = f'''{_HEADER} import asyncio import queue import grpc +import sys -from typing import Any, AsyncIterable, AsyncIterator, Generic, Iterator, Optional, TypeVar +from typing import Any, AsyncIterable, AsyncIterator, Generic, Iterator, TypeVar _T_co = TypeVar('_T_co', covariant=True) _T = TypeVar('_T') -def unwrap(optional: Optional[_T]) -> _T: - assert optional - return optional - - class Stream(Iterator[_T_co], grpc.RpcContext): ... @@ -469,7 +474,10 @@ class AioStream(AsyncIterable[_T_co], grpc.RpcContext): ... class Sender(Iterator[_T]): - _inner: queue.Queue[_T] + if sys.version_info >= (3, 8): + _inner: queue.Queue[_T] + else: + _inner: queue.Queue def __init__(self) -> None: self._inner = queue.Queue() @@ -485,7 +493,10 @@ class Sender(Iterator[_T]): class AioSender(AsyncIterator[_T]): - _inner: asyncio.Queue[_T] + if sys.version_info >= (3, 8): + _inner: asyncio.Queue[_T] + else: + _inner: asyncio.Queue def __init__(self) -> None: self._inner = asyncio.Queue() @@ -574,27 +585,35 @@ _UTILS_FILES: Set[str] = set() for file_name in _REQUEST.file_to_generate: file: FileDescriptorProto = next(filter(lambda x: x.name == file_name, _REQUEST.proto_file)) - overrides: List[str] = [] - imports: List[str] = [] + _FILES.append(CodeGeneratorResponse.File( + name=file.name.replace('.proto', '_pb2.py'), + insertion_point=f'module_scope', + content='def unwrap(x):\n assert x\n return x\n' + )) + + pyi_imports: List[str] = [] + grpc_imports: List[str] = ['import grpc'] + grpc_aio_imports: List[str] = ['import grpc', 'import grpc.aio'] + + enums = '\n'.join(sum([generate_enum(pyi_imports, file, enum, _FILES) for enum in file.enum_type], [])) + messages = '\n'.join(sum([generate_message(pyi_imports, file, message, _FILES) for message in file.message_type], [])) - enums = '\n'.join(sum([generate_enum(imports, file, enum) for enum in file.enum_type], [])) - messages = '\n'.join(sum([generate_message(imports, overrides, file, message) for message in file.message_type], [])) - services = '\n'.join(sum([generate_service(imports, file, service) for service in file.service], [])) - aio_services = '\n'.join(sum([generate_service(imports, file, service, False) for service in file.service], [])) - servicers = '\n'.join(sum([generate_servicer(imports, file, service) for service in file.service], [])) - aio_servicers = '\n'.join(sum([generate_servicer(imports, file, service, False) for service in file.service], [])) - add_servicer_methods = '\n'.join(sum([generate_add_servicer_to_server_method(imports, file, service) for service in file.service], [])) - aio_add_servicer_methods = '\n'.join(sum([generate_add_servicer_to_server_method(imports, file, service, False) for service in file.service], [])) + services = '\n'.join(sum([generate_service(grpc_imports, file, service) for service in file.service], [])) + aio_services = '\n'.join(sum([generate_service(grpc_aio_imports, file, service, False) for service in file.service], [])) - imports.sort() + servicers = '\n'.join(sum([generate_servicer(grpc_imports, file, service) for service in file.service], [])) + aio_servicers = '\n'.join(sum([generate_servicer(grpc_aio_imports, file, service, False) for service in file.service], [])) - imports_str: str = '\n'.join(imports) - overrides_str: str = '\n'.join(overrides) + add_servicer_methods = '\n'.join(sum([generate_add_servicer_to_server_method(grpc_imports, file, service) for service in file.service], [])) + aio_add_servicer_methods = '\n'.join(sum([generate_add_servicer_to_server_method(grpc_aio_imports, file, service, False) for service in file.service], [])) - package = file_name.replace('.proto', '_grpc').replace('/', '.') + pyi_imports.sort() + grpc_imports.sort() + grpc_aio_imports.sort() - enum_import_str = f'\nfrom {package} import ' + ', '.join([e.name for e in file.enum_type]) if len(file.enum_type) else '' - message_import_str = f'\nfrom {package} import ' + ', '.join([m.name for m in file.message_type]) if len(file.message_type) else '' + pyi_imports_str: str = '\n'.join(pyi_imports) + grpc_imports_str: str = '\n'.join(grpc_imports) + grpc_aio_imports_str: str = '\n'.join(grpc_aio_imports) utils_filename = file_name.replace(os.path.basename(file_name), '_utils.py') if utils_filename not in _UTILS_FILES: @@ -606,19 +625,18 @@ for file_name in _REQUEST.file_to_generate: ) ]) - super_package = package[:package.rindex('.')] - - extras = f'from {super_package}._utils import unwrap, Sender, Stream, StreamStream' - aio_extras = f'from {super_package}._utils import unwrap, AioSender as Sender, AioStream as Stream, AioStreamStream as StreamStream' - _FILES.extend([ + CodeGeneratorResponse.File( + name=file.name.replace('.proto', '_pb2.pyi'), + content=f'{_HEADER}\n\n{pyi_imports_str}\n\n{enums}\n\n{messages}\n' + ), CodeGeneratorResponse.File( name=file_name.replace('.proto', '_grpc.py'), - content=f'{_HEADER}\n\n{imports_str}\n\n{extras}\n\n\n\n{enums}\n\n{messages}\n\n{services}\n\n{servicers}\n\n{add_servicer_methods}\n\n{overrides_str}' + content=f'{_HEADER}\n\n{grpc_imports_str}\n\n{services}\n\n{servicers}\n\n{add_servicer_methods}' ), CodeGeneratorResponse.File( name=file_name.replace('.proto', '_grpc_aio.py'), - content=f'{_HEADER}\n\n{imports_str}{enum_import_str}{message_import_str}\n\n{aio_extras}\n\n{aio_services}\n\n{aio_servicers}\n\n{aio_add_servicer_methods}' + content=f'{_HEADER}\n\n{grpc_aio_imports_str}\n\n{aio_services}\n\n{aio_servicers}\n\n{aio_add_servicer_methods}' ) ]) diff --git a/python/pyproject.toml b/python/pyproject.toml index cb5b72a..0e42956 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -3,7 +3,7 @@ name = "bt-test-interfaces" readme = "../README.md" authors = [{name = "Pandora", email = "pandora-core@google.com"}] dynamic = ["version", "description"] -dependencies = ["protobuf==3.20.1"] +dependencies = ["protobuf==4.22.0"] [tool.flit.module] name = "pandora" @@ -12,6 +12,6 @@ name = "pandora" include = ["_build"] [build-system] -requires = ["flit_core==3.7.1", "grpcio-tools>=1.41"] +requires = ["flit_core==3.7.1", "grpcio-tools==1.51.1"] build-backend = "_build.backend" backend-path = ["."] -- cgit v1.2.3 From 05d6efa62af0bc06c68826d88f32c76d510291cb Mon Sep 17 00:00:00 2001 From: uael Date: Thu, 23 Mar 2023 00:43:57 +0000 Subject: gen: do not try to guess the right package Test: None Change-Id: I1990b70a540c53173333712ac1351a120054dfa6 --- python/_build/protoc-gen-custom_grpc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/_build/protoc-gen-custom_grpc b/python/_build/protoc-gen-custom_grpc index 37b31b7..7f6a381 100755 --- a/python/_build/protoc-gen-custom_grpc +++ b/python/_build/protoc-gen-custom_grpc @@ -273,19 +273,19 @@ def generate_service_method(imports: List[str], file: FileDescriptorProto, servi if input_mode == 'stream': output_type_hint = f'StreamStream[{input_type}, {output_type}]' if sync: - add_import(imports, f'from {file.package}._utils import Sender') - add_import(imports, f'from {file.package}._utils import Stream') - add_import(imports, f'from {file.package}._utils import StreamStream') + add_import(imports, f'from ._utils import Sender') + add_import(imports, f'from ._utils import Stream') + add_import(imports, f'from ._utils import StreamStream') else: - add_import(imports, f'from {file.package}._utils import AioSender as Sender') - add_import(imports, f'from {file.package}._utils import AioStream as Stream') - add_import(imports, f'from {file.package}._utils import AioStreamStream as StreamStream') + add_import(imports, f'from ._utils import AioSender as Sender') + add_import(imports, f'from ._utils import AioStream as Stream') + add_import(imports, f'from ._utils import AioStreamStream as StreamStream') else: output_type_hint = f'Stream[{output_type}]' if sync: - add_import(imports, f'from {file.package}._utils import Stream') + add_import(imports, f'from ._utils import Stream') else: - add_import(imports, f'from {file.package}._utils import AioStream as Stream') + add_import(imports, f'from ._utils import AioStream as Stream') else: output_type_hint = output_type if sync else f'Awaitable[{output_type}]' if not sync: add_import(imports, f'from typing import Awaitable') -- cgit v1.2.3 From 5813d74571a0902a23fde00fa126e452ed29a037 Mon Sep 17 00:00:00 2001 From: Yuyang Huang Date: Thu, 29 Dec 2022 16:45:50 -0800 Subject: Add ASHA music streaming rpc Change-Id: Ie5bc5fa6b1e2df1d32e2b688ce7417a3cf77e0da --- pandora/asha.proto | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 56 insertions(+), 3 deletions(-) diff --git a/pandora/asha.proto b/pandora/asha.proto index 4fafc5b..6b2ab64 100644 --- a/pandora/asha.proto +++ b/pandora/asha.proto @@ -14,17 +14,26 @@ syntax = "proto3"; -option java_outer_classname = "ASHAProto"; +option java_outer_classname = "AshaProto"; -package pandora; +package pandora.asha; import "google/protobuf/empty.proto"; +import "pandora/host.proto"; // Service to trigger Audio Streaming for Hearing Aid (ASHA) procedures. // ASHA uses connection-oriented L2CAP channels (CoC) and GATT. -service ASHA { +service Asha { // Register ASHA Service. rpc Register(RegisterRequest) returns (google.protobuf.Empty); + // Capture Audio. + rpc CaptureAudio(CaptureAudioRequest) returns (stream CaptureAudioResponse); + // Start a suspended stream. + rpc Start(StartRequest) returns (StartResponse); + // Playback audio + rpc PlaybackAudio(stream PlaybackAudioRequest) returns (PlaybackAudioResponse); + // Stop a started stream. + rpc Stop(stream StopRequest) returns (StopResponse); } // Request of the `Register` method. @@ -32,3 +41,47 @@ message RegisterRequest { uint32 capability = 1; // left or right device, monaural or binaural device. repeated uint32 hisyncid = 2; // id identifying two hearing aids as one pair. } + +// Request of the `CaptureAudio` method. +message CaptureAudioRequest { + // Low Energy connection. + Connection connection = 1; +} + +// Response of the `CaptureAudio` method. +message CaptureAudioResponse { + // Audio data received on peripheral side. + // `data` is decoded by G722 decoder. + bytes data = 1; +} + +// Request of the `Start` method. +message StartRequest { + // Low Energy connection. + Connection connection = 1; +} + +// Response of the `Start` method. +message StartResponse {} + +// Request of the `PlaybackAudio` method. +message PlaybackAudioRequest { + // Low Energy connection. + Connection connection = 1; + // Audio data to playback. + // `data` should be interleaved stereo frames with 16-bit signed little-endian + // linear PCM samples at 44100Hz sample rate + bytes data = 2; +} + +// Response of the `PlaybackAudio` method. +message PlaybackAudioResponse {} + +// Request of the `Stop` method. +message StopRequest { + // Low Energy connection. + Connection connection = 1; +} + +// Response of the `Stop` method. +message StopResponse {} -- cgit v1.2.3 From 7232fcbdd4d121ffccfebda221bd0dc3b100be44 Mon Sep 17 00:00:00 2001 From: Yuyang Huang Date: Fri, 10 Mar 2023 12:59:19 -0800 Subject: [Pandora] Remove a2dp.proto a2dp.proto is no longer used under Pandora, and it has message naming conflict with asha.proto Change-Id: I33950e29e6adf4a19734dc47a5ca3cd466920605 --- pandora/a2dp.proto | 296 ----------------------------------------------------- 1 file changed, 296 deletions(-) delete mode 100644 pandora/a2dp.proto diff --git a/pandora/a2dp.proto b/pandora/a2dp.proto deleted file mode 100644 index d262d1b..0000000 --- a/pandora/a2dp.proto +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package pandora; - -import "pandora/host.proto"; -import "google/protobuf/empty.proto"; -import "google/protobuf/wrappers.proto"; - -option java_outer_classname = "A2DPProto"; - -// Service to trigger A2DP (Advanced Audio Distribution Profile) procedures. -// -// Requirements for the implementer: -// - Streams must not be automatically opened, even if discovered. -// - The `Host` service must be implemented -// -// References: -// - [A2DP] Bluetooth SIG, Specification of the Bluetooth System, -// Advanced Audio Distribution, Version 1.3 or Later -// - [AVDTP] Bluetooth SIG, Specification of the Bluetooth System, -// Audio/Video Distribution Transport Protocol, Version 1.3 or Later -service A2DP { - // Open a stream from a local **Source** endpoint to a remote **Sink** - // endpoint. - // - // The returned source should be in the AVDTP_OPEN state (see [AVDTP] 9.1). - // The rpc must block until the stream has reached this state. - // - // A cancellation of this call must result in aborting the current - // AVDTP procedure (see [AVDTP] 9.9). - rpc OpenSource(OpenSourceRequest) returns (OpenSourceResponse); - // Open a stream from a local **Sink** endpoint to a remote **Source** - // endpoint. - // - // The returned sink must be in the AVDTP_OPEN state (see [AVDTP] 9.1). - // The rpc must block until the stream has reached this state. - // - // A cancellation of this call must result in aborting the current - // AVDTP procedure (see [AVDTP] 9.9). - rpc OpenSink(OpenSinkRequest) returns (OpenSinkResponse); - // Wait for a stream from a local **Source** endpoint to - // a remote **Sink** endpoint to open. - // - // The returned source should be in the AVDTP_OPEN state (see [AVDTP] 9.1). - // The rpc must block until the stream has reached this state. - // - // If the peer has opened a source prior to this call, the server will - // return it. The server must return the same source only once. - rpc WaitSource(WaitSourceRequest) returns (WaitSourceResponse); - // Wait for a stream from a local **Sink** endpoint to - // a remote **Source** endpoint to open. - // - // The returned sink should be in the AVDTP_OPEN state (see [AVDTP] 9.1). - // The rpc must block until the stream has reached this state. - // - // If the peer has opened a sink prior to this call, the server will - // return it. The server must return the same sink only once. - rpc WaitSink(WaitSinkRequest) returns (WaitSinkResponse); - // Get if the stream is suspended - rpc IsSuspended(IsSuspendedRequest) returns (google.protobuf.BoolValue); - // Start an opened stream. - // - // The rpc must block until the stream has reached the - // AVDTP_STREAMING state (see [AVDTP] 9.1). - rpc Start(StartRequest) returns (StartResponse); - // Suspend a streaming stream. - // - // The rpc must block until the stream has reached the AVDTP_OPEN - // state (see [AVDTP] 9.1). - rpc Suspend(SuspendRequest) returns (SuspendResponse); - // Close a stream, the source or sink tokens must not be reused afterwards. - rpc Close(CloseRequest) returns (CloseResponse); - // Get the `AudioEncoding` value of a stream - rpc GetAudioEncoding(GetAudioEncodingRequest) - returns (GetAudioEncodingResponse); - // Playback audio by a `Source` - rpc PlaybackAudio(stream PlaybackAudioRequest) - returns (PlaybackAudioResponse); - // Capture audio from a `Sink` - rpc CaptureAudio(CaptureAudioRequest) returns (stream CaptureAudioResponse); -} - -// Audio encoding formats. -enum AudioEncoding { - // Interleaved stereo frames with 16-bit signed little-endian linear PCM - // samples at 44100Hz sample rate - PCM_S16_LE_44K1_STEREO = 0; - // Interleaved stereo frames with 16-bit signed little-endian linear PCM - // samples at 48000Hz sample rate - PCM_S16_LE_48K_STEREO = 1; -} - -// A Token representing a Source stream (see [A2DP] 2.2). -// It's acquired via an OpenSource on the A2DP service. -message Source { - // Opaque value filled by the GRPC server, must not - // be modified nor crafted. - bytes cookie = 1; -} - -// A Token representing a Sink stream (see [A2DP] 2.2). -// It's acquired via an OpenSink on the A2DP service. -message Sink { - // Opaque value filled by the GRPC server, must not - // be modified nor crafted. - bytes cookie = 1; -} - -// Request for the `OpenSource` method. -message OpenSourceRequest { - // The connection that will open the stream. - Connection connection = 1; -} - -// Response for the `OpenSource` method. -message OpenSourceResponse { - // Result of the `OpenSource` call. - oneof result { - // Opened stream. - Source source = 1; - // The Connection disconnected. - google.protobuf.Empty disconnected = 2; - } -} - -// Request for the `OpenSink` method. -message OpenSinkRequest { - // The connection that will open the stream. - Connection connection = 1; -} - -// Response for the `OpenSink` method. -message OpenSinkResponse { - // Result of the `OpenSink` call. - oneof result { - // Opened stream. - Sink sink = 1; - // The Connection disconnected. - google.protobuf.Empty disconnected = 2; - } -} - -// Request for the `WaitSource` method. -message WaitSourceRequest { - // The connection that is awaiting the stream. - Connection connection = 1; -} - -// Response for the `WaitSource` method. -message WaitSourceResponse { - // Result of the `WaitSource` call. - oneof result { - // Awaited stream. - Source source = 1; - // The Connection disconnected. - google.protobuf.Empty disconnected = 2; - } -} - -// Request for the `WaitSink` method. -message WaitSinkRequest { - // The connection that is awaiting the stream. - Connection connection = 1; -} - -// Response for the `WaitSink` method. -message WaitSinkResponse { - // Result of the `WaitSink` call. - oneof result { - // Awaited stream. - Sink sink = 1; - // The Connection disconnected. - google.protobuf.Empty disconnected = 2; - } -} - -// Request for the `IsSuspended` method. -message IsSuspendedRequest { - // The stream on which the function will check if it's suspended - oneof target { - Sink sink = 1; - Source source = 2; - } -} - -// Request for the `Start` method. -message StartRequest { - // Target of the start, either a Sink or a Source. - oneof target { - Sink sink = 1; - Source source = 2; - } -} - -// Response for the `Start` method. -message StartResponse { - // Result of the `Start` call. - oneof result { - // Stream successfully started. - google.protobuf.Empty started = 1; - // Stream is already in AVDTP_STREAMING state. - google.protobuf.Empty already_started = 2; - // The Connection disconnected. - google.protobuf.Empty disconnected = 3; - } -} - -// Request for the `Suspend` method. -message SuspendRequest { - // Target of the suspend, either a Sink or a Source. - oneof target { - Sink sink = 1; - Source source = 2; - } -} - -// Response for the `Suspend` method. -message SuspendResponse { - // Result of the `Suspend` call. - oneof result { - // Stream successfully suspended. - google.protobuf.Empty suspended = 1; - // Stream is already in AVDTP_OPEN state. - google.protobuf.Empty already_suspended = 2; - // The Connection disconnected. - google.protobuf.Empty disconnected = 3; - } -} - -// Request for the `Close` method. -message CloseRequest { - // Target of the close, either a Sink or a Source. - oneof target { - Sink sink = 1; - Source source = 2; - } -} - -// Response for the `Close` method. -message CloseResponse {} - -// Request for the `GetAudioEncoding` method. -message GetAudioEncodingRequest { - // The stream on which the function will read the `AudioEncoding`. - oneof target { - Sink sink = 1; - Source source = 2; - } -} - -// Response for the `GetAudioEncoding` method. -message GetAudioEncodingResponse { - // Audio encoding of the stream. - AudioEncoding encoding = 1; -} - -// Request for the `PlaybackAudio` method. -message PlaybackAudioRequest { - // Source that will playback audio. - Source source = 1; - // Audio data to playback. - // The audio data must be encoded in the specified `AudioEncoding` value - // obtained in response of a `GetAudioEncoding` method call. - bytes data = 2; -} - -// Response for the `PlaybackAudio` method. -message PlaybackAudioResponse {} - -// Request for the `CaptureAudio` method. -message CaptureAudioRequest { - // Sink that will capture audio - Sink sink = 1; -} - -// Response for the `CaptureAudio` method. -message CaptureAudioResponse { - // Captured audio data. - // The audio data is encoded in the specified `AudioEncoding` value - // obtained in response of a `GetAudioEncoding` method call. - bytes data = 1; -} -- cgit v1.2.3