diff options
Diffstat (limited to 'pw_rpc/py/pw_rpc/callback_client.py')
-rw-r--r-- | pw_rpc/py/pw_rpc/callback_client.py | 446 |
1 files changed, 373 insertions, 73 deletions
diff --git a/pw_rpc/py/pw_rpc/callback_client.py b/pw_rpc/py/pw_rpc/callback_client.py index e61490a5b..ea7fad8d5 100644 --- a/pw_rpc/py/pw_rpc/callback_client.py +++ b/pw_rpc/py/pw_rpc/callback_client.py @@ -1,4 +1,4 @@ -# Copyright 2020 The Pigweed Authors +# Copyright 2021 The Pigweed Authors # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of @@ -11,7 +11,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. -"""Defines a callback-based RPC ClientImpl to use with pw_rpc.client.Client. +"""Defines a callback-based RPC ClientImpl to use with pw_rpc.Client. callback_client.Impl supports invoking RPCs synchronously or asynchronously. Asynchronous invocations use a callback. @@ -40,26 +40,63 @@ When invoking a method, requests may be provided as a message object or as kwargs for the message fields (but not both). """ +import enum +import inspect import logging import queue -from typing import Any, Callable, Optional, Tuple +import textwrap +import threading +from typing import Any, Callable, Iterator, NamedTuple, Union, Optional -from pw_rpc import client -from pw_rpc.descriptors import Channel, Method, Service +from pw_protobuf_compiler.python_protos import proto_repr from pw_status import Status +from pw_rpc import client, descriptors +from pw_rpc.client import PendingRpc, PendingRpcs +from pw_rpc.descriptors import Channel, Method, Service + _LOG = logging.getLogger(__name__) -Callback = Callable[[client.PendingRpc, Optional[Status], Any], Any] + +class UseDefault(enum.Enum): + """Marker for args that should use a default value, when None is valid.""" + VALUE = 0 + + +OptionalTimeout = Union[UseDefault, float, None] + +ResponseCallback = Callable[[PendingRpc, Any], Any] +CompletionCallback = Callable[[PendingRpc, Status], Any] +ErrorCallback = Callable[[PendingRpc, Status], Any] + + +class _Callbacks(NamedTuple): + response: ResponseCallback + completion: CompletionCallback + error: ErrorCallback + + +def _default_response(rpc: PendingRpc, response: Any) -> None: + _LOG.info('%s response: %s', rpc, response) + + +def _default_completion(rpc: PendingRpc, status: Status) -> None: + _LOG.info('%s finished: %s', rpc, status) + + +def _default_error(rpc: PendingRpc, status: Status) -> None: + _LOG.error('%s error: %s', rpc, status) class _MethodClient: """A method that can be invoked for a particular channel.""" - def __init__(self, client_impl: 'Impl', rpcs: client.PendingRpcs, - channel: Channel, method: Method): + def __init__(self, client_impl: 'Impl', rpcs: PendingRpcs, + channel: Channel, method: Method, + default_timeout_s: Optional[float]): self._impl = client_impl self._rpcs = rpcs - self._rpc = client.PendingRpc(channel, method.service, method) + self._rpc = PendingRpc(channel, method.service, method) + self.default_timeout_s: Optional[float] = default_timeout_s @property def channel(self) -> Channel: @@ -73,26 +110,65 @@ class _MethodClient: def service(self) -> Service: return self._rpc.service - def invoke(self, callback: Callback, _request=None, **request_fields): - """Invokes an RPC with a callback.""" + def invoke(self, + request: Any, + response: ResponseCallback = _default_response, + completion: CompletionCallback = _default_completion, + error: ErrorCallback = _default_error, + *, + override_pending: bool = True, + keep_open: bool = False) -> '_AsyncCall': + """Invokes an RPC with callbacks.""" self._rpcs.send_request(self._rpc, - self.method.get_request( - _request, request_fields), - callback, - override_pending=False) - return _AsyncCall(self._rpcs, self._rpc) - - def reinvoke(self, callback: Callback, _request=None, **request_fields): - """Invokes an RPC with a callback, overriding any pending requests.""" - self._rpcs.send_request(self._rpc, - self.method.get_request( - _request, request_fields), - callback, - override_pending=True) + request, + _Callbacks(response, completion, error), + override_pending=override_pending, + keep_open=keep_open) return _AsyncCall(self._rpcs, self._rpc) def __repr__(self) -> str: - return repr(self.method) + return self.help() + + def __call__(self): + raise NotImplementedError('Implemented by derived classes') + + def help(self) -> str: + """Returns a help message about this RPC.""" + function_call = self.method.full_name + '(' + + docstring = inspect.getdoc(self.__call__) + assert docstring is not None + + annotation = inspect.Signature.from_callable(self).return_annotation + if isinstance(annotation, type): + annotation = annotation.__name__ + + arg_sep = f',\n{" " * len(function_call)}' + return ( + f'{function_call}' + f'{arg_sep.join(descriptors.field_help(self.method.request_type))})' + f'\n\n{textwrap.indent(docstring, " ")}\n\n' + f' Returns {annotation}.') + + +class RpcTimeout(Exception): + def __init__(self, rpc: PendingRpc, timeout: Optional[float]): + super().__init__( + f'No response received for {rpc.method} after {timeout} s') + self.rpc = rpc + self.timeout = timeout + + +class RpcError(Exception): + def __init__(self, rpc: PendingRpc, status: Status): + if status is Status.NOT_FOUND: + msg = ': the RPC server does not support this RPC' + else: + msg = '' + + super().__init__(f'{rpc.method} failed with error {status}{msg}') + self.rpc = rpc + self.status = status class _AsyncCall: @@ -100,12 +176,12 @@ class _AsyncCall: # TODO(hepler): Consider alternatives (futures) and/or expand functionality. - def __init__(self, rpcs: client.PendingRpcs, rpc: client.PendingRpc): - self.rpc = rpc + def __init__(self, rpcs: PendingRpcs, rpc: PendingRpc): + self._rpc = rpc self._rpcs = rpcs def cancel(self) -> bool: - return self._rpcs.send_cancel(self.rpc) + return self._rpcs.send_cancel(self._rpc) def __enter__(self) -> '_AsyncCall': return self @@ -114,47 +190,213 @@ class _AsyncCall: self.cancel() -class _StreamingResponses: +class StreamingResponses: """Used to iterate over a queue.SimpleQueue.""" - def __init__(self, responses: queue.SimpleQueue): + def __init__(self, method_client: _MethodClient, + responses: queue.SimpleQueue, + default_timeout_s: OptionalTimeout): + self._method_client = method_client self._queue = responses self.status: Optional[Status] = None - def get(self, block: bool = True, timeout_s: float = None): - while True: - self.status, response = self._queue.get(block, timeout_s) - if self.status is not None: - return + if default_timeout_s is UseDefault.VALUE: + self.default_timeout_s = self._method_client.default_timeout_s + else: + self.default_timeout_s = default_timeout_s - yield response + @property + def method(self) -> Method: + return self._method_client.method + + def cancel(self) -> None: + self._method_client._rpcs.send_cancel(self._method_client._rpc) # pylint: disable=protected-access + + def responses(self, + *, + block: bool = True, + timeout_s: OptionalTimeout = UseDefault.VALUE) -> Iterator: + """Returns an iterator of stream responses. + + Args: + timeout_s: timeout in seconds; None blocks indefinitely + """ + if timeout_s is UseDefault.VALUE: + timeout_s = self.default_timeout_s + + try: + while True: + response = self._queue.get(block, timeout_s) + + if isinstance(response, Exception): + raise response + + if isinstance(response, Status): + self.status = response + return + + yield response + except queue.Empty: + self.cancel() + raise RpcTimeout(self._method_client._rpc, timeout_s) # pylint: disable=protected-access + except: + self.cancel() + raise def __iter__(self): - return self.get() + return self.responses() + + def __repr__(self) -> str: + return f'{type(self).__name__}({self.method})' -class UnaryMethodClient(_MethodClient): - def __call__(self, _request=None, **request_fields) -> Tuple[Status, Any]: - responses: queue.SimpleQueue = queue.SimpleQueue() - self.reinvoke( - lambda _, status, payload: responses.put((status, payload)), - _request, **request_fields) - return responses.get() +def _method_client_docstring(method: Method) -> str: + return f'''\ +Class that invokes the {method.full_name} {method.type.sentence_name()} RPC. + +Calling this directly invokes the RPC synchronously. The RPC can be invoked +asynchronously using the invoke method. +''' + + +def _function_docstring(method: Method) -> str: + return f'''\ +Invokes the {method.full_name} {method.type.sentence_name()} RPC. + +This function accepts either the request protobuf fields as keyword arguments or +a request protobuf as a positional argument. +''' + + +def _update_function_signature(method: Method, function: Callable) -> None: + """Updates the name, docstring, and parameters to match a method.""" + function.__name__ = method.full_name + function.__doc__ = _function_docstring(method) + + # In order to have good tab completion and help messages, update the + # function signature to accept only keyword arguments for the proto message + # fields. This doesn't actually change the function signature -- it just + # updates how it appears when inspected. + sig = inspect.signature(function) + + params = [next(iter(sig.parameters.values()))] # Get the "self" parameter + params += method.request_parameters() + params.append( + inspect.Parameter('pw_rpc_timeout_s', inspect.Parameter.KEYWORD_ONLY)) + function.__signature__ = sig.replace( # type: ignore[attr-defined] + parameters=params) + +class UnaryResponse(NamedTuple): + """Result of invoking a unary RPC: status and response.""" + status: Status + response: Any -class ServerStreamingMethodClient(_MethodClient): - def __call__(self, _request=None, **request_fields) -> _StreamingResponses: + def __repr__(self) -> str: + return f'({self.status}, {proto_repr(self.response)})' + + +class _UnaryResponseHandler: + """Tracks the state of an ongoing synchronous unary RPC call.""" + def __init__(self, rpc: PendingRpc): + self._rpc = rpc + self._response: Any = None + self._status: Optional[Status] = None + self._error: Optional[RpcError] = None + self._event = threading.Event() + + def on_response(self, _: PendingRpc, response: Any) -> None: + self._response = response + + def on_completion(self, _: PendingRpc, status: Status) -> None: + self._status = status + self._event.set() + + def on_error(self, _: PendingRpc, status: Status) -> None: + self._error = RpcError(self._rpc, status) + self._event.set() + + def wait(self, timeout_s: Optional[float]) -> UnaryResponse: + if not self._event.wait(timeout_s): + raise RpcTimeout(self._rpc, timeout_s) + + if self._error is not None: + raise self._error + + assert self._status is not None + return UnaryResponse(self._status, self._response) + + +def _unary_method_client(client_impl: 'Impl', rpcs: PendingRpcs, + channel: Channel, method: Method, + default_timeout: Optional[float]) -> _MethodClient: + """Creates an object used to call a unary method.""" + def call(self: _MethodClient, + _rpc_request_proto=None, + *, + pw_rpc_timeout_s=UseDefault.VALUE, + **request_fields) -> UnaryResponse: + + handler = _UnaryResponseHandler(self._rpc) # pylint: disable=protected-access + self.invoke( + self.method.get_request(_rpc_request_proto, request_fields), + handler.on_response, handler.on_completion, handler.on_error) + + if pw_rpc_timeout_s is UseDefault.VALUE: + pw_rpc_timeout_s = self.default_timeout_s + + return handler.wait(pw_rpc_timeout_s) + + _update_function_signature(method, call) + + # The MethodClient class is created dynamically so that the __call__ method + # can be configured differently for each method. + method_client_type = type( + f'{method.name}_UnaryMethodClient', (_MethodClient, ), + dict(__call__=call, __doc__=_method_client_docstring(method))) + return method_client_type(client_impl, rpcs, channel, method, + default_timeout) + + +def _server_streaming_method_client(client_impl: 'Impl', rpcs: PendingRpcs, + channel: Channel, method: Method, + default_timeout: Optional[float]): + """Creates an object used to call a server streaming method.""" + def call(self: _MethodClient, + _rpc_request_proto=None, + *, + pw_rpc_timeout_s=UseDefault.VALUE, + **request_fields) -> StreamingResponses: responses: queue.SimpleQueue = queue.SimpleQueue() - self.reinvoke( - lambda _, status, payload: responses.put((status, payload)), - _request, **request_fields) - return _StreamingResponses(responses) + self.invoke( + self.method.get_request(_rpc_request_proto, request_fields), + lambda _, response: responses.put(response), + lambda _, status: responses.put(status), + lambda rpc, status: responses.put(RpcError(rpc, status))) + return StreamingResponses(self, responses, pw_rpc_timeout_s) + + _update_function_signature(method, call) + + # The MethodClient class is created dynamically so that the __call__ method + # can be configured differently for each method type. + method_client_type = type( + f'{method.name}_ServerStreamingMethodClient', (_MethodClient, ), + dict(__call__=call, __doc__=_method_client_docstring(method))) + return method_client_type(client_impl, rpcs, channel, method, + default_timeout) class ClientStreamingMethodClient(_MethodClient): def __call__(self): raise NotImplementedError - def invoke(self, callback: Callback, _request=None, **request_fields): + def invoke(self, + request: Any, + response: ResponseCallback = _default_response, + completion: CompletionCallback = _default_completion, + error: ErrorCallback = _default_error, + *, + override_pending: bool = True, + keep_open: bool = False) -> _AsyncCall: raise NotImplementedError @@ -162,40 +404,65 @@ class BidirectionalStreamingMethodClient(_MethodClient): def __call__(self): raise NotImplementedError - def invoke(self, callback: Callback, _request=None, **request_fields): + def invoke(self, + request: Any, + response: ResponseCallback = _default_response, + completion: CompletionCallback = _default_completion, + error: ErrorCallback = _default_error, + *, + override_pending: bool = True, + keep_open: bool = False) -> _AsyncCall: raise NotImplementedError class Impl(client.ClientImpl): - """Callback-based client.ClientImpl.""" - def method_client(self, rpcs: client.PendingRpcs, channel: Channel, - method: Method) -> _MethodClient: + """Callback-based ClientImpl.""" + def __init__(self, + default_unary_timeout_s: Optional[float] = 1.0, + default_stream_timeout_s: Optional[float] = 1.0): + super().__init__() + self._default_unary_timeout_s = default_unary_timeout_s + self._default_stream_timeout_s = default_stream_timeout_s + + @property + def default_unary_timeout_s(self) -> Optional[float]: + return self._default_unary_timeout_s + + @property + def default_stream_timeout_s(self) -> Optional[float]: + return self._default_stream_timeout_s + + def method_client(self, channel: Channel, method: Method) -> _MethodClient: """Returns an object that invokes a method using the given chanel.""" if method.type is Method.Type.UNARY: - return UnaryMethodClient(self, rpcs, channel, method) + return _unary_method_client(self, self.rpcs, channel, method, + self.default_unary_timeout_s) if method.type is Method.Type.SERVER_STREAMING: - return ServerStreamingMethodClient(self, rpcs, channel, method) + return _server_streaming_method_client( + self, self.rpcs, channel, method, + self.default_stream_timeout_s) if method.type is Method.Type.CLIENT_STREAMING: - return ClientStreamingMethodClient(self, rpcs, channel, method) + return ClientStreamingMethodClient(self, self.rpcs, channel, + method, + self.default_unary_timeout_s) - if method.type is Method.Type.BIDI_STREAMING: - return BidirectionalStreamingMethodClient(self, rpcs, channel, - method) + if method.type is Method.Type.BIDIRECTIONAL_STREAMING: + return BidirectionalStreamingMethodClient( + self, self.rpcs, channel, method, + self.default_stream_timeout_s) raise AssertionError(f'Unknown method type {method.type}') - def process_response(self, - rpcs: client.PendingRpcs, - rpc: client.PendingRpc, - context, - status: Optional[Status], - payload, - *, - args: tuple = (), - kwargs: dict = None) -> None: + def handle_response(self, + rpc: PendingRpc, + context, + payload, + *, + args: tuple = (), + kwargs: dict = None) -> None: """Invokes the callback associated with this RPC. Any additional positional and keyword args passed through @@ -205,7 +472,40 @@ class Impl(client.ClientImpl): kwargs = {} try: - context(rpc, status, payload, *args, **kwargs) + context.response(rpc, payload, *args, **kwargs) + except: # pylint: disable=bare-except + self.rpcs.send_cancel(rpc) + _LOG.exception('Response callback %s for %s raised exception', + context.response, rpc) + + def handle_completion(self, + rpc: PendingRpc, + context, + status: Status, + *, + args: tuple = (), + kwargs: dict = None): + if kwargs is None: + kwargs = {} + + try: + context.completion(rpc, status, *args, **kwargs) + except: # pylint: disable=bare-except + _LOG.exception('Completion callback %s for %s raised exception', + context.completion, rpc) + + def handle_error(self, + rpc: PendingRpc, + context, + status: Status, + *, + args: tuple = (), + kwargs: dict = None) -> None: + if kwargs is None: + kwargs = {} + + try: + context.error(rpc, status, *args, **kwargs) except: # pylint: disable=bare-except - rpcs.send_cancel(rpc) - _LOG.exception('Callback %s for %s raised exception', context, rpc) + _LOG.exception('Error callback %s for %s raised exception', + context.error, rpc) |