diff options
Diffstat (limited to 'pw_hdlc/py/pw_hdlc/rpc.py')
-rw-r--r-- | pw_hdlc/py/pw_hdlc/rpc.py | 94 |
1 files changed, 40 insertions, 54 deletions
diff --git a/pw_hdlc/py/pw_hdlc/rpc.py b/pw_hdlc/py/pw_hdlc/rpc.py index 831861b46..b41275331 100644 --- a/pw_hdlc/py/pw_hdlc/rpc.py +++ b/pw_hdlc/py/pw_hdlc/rpc.py @@ -139,7 +139,8 @@ class HdlcRpcClient: output: Callable[[bytes], Any] = write_to_file, client_impl: pw_rpc.client.ClientImpl = None, *, - _incoming_packet_filter_for_testing: '_PacketFilter' = None): + _incoming_packet_filter_for_testing: pw_rpc. + ChannelManipulator = None): """Creates an RPC client configured to communicate using HDLC. Args: @@ -159,10 +160,13 @@ class HdlcRpcClient: self.client = pw_rpc.Client.from_modules(client_impl, channels, self.protos.modules()) - self._test_filter = _incoming_packet_filter_for_testing + rpc_output: Callable[[bytes], Any] = self._handle_rpc_packet + if _incoming_packet_filter_for_testing is not None: + _incoming_packet_filter_for_testing.send_packet = rpc_output + rpc_output = _incoming_packet_filter_for_testing frame_handlers: FrameHandlers = { - DEFAULT_ADDRESS: self._handle_rpc_packet, + DEFAULT_ADDRESS: lambda frame: rpc_output(frame.data), STDOUT_ADDRESS: lambda frame: output(frame.data), } @@ -184,15 +188,12 @@ class HdlcRpcClient: return self.client.channel(channel_id).rpcs - def _handle_rpc_packet(self, frame: Frame) -> None: - if self._test_filter and not self._test_filter.keep_packet(frame.data): - return + def _handle_rpc_packet(self, packet: bytes) -> None: + if not self.client.process_packet(packet): + _LOG.error('Packet not handled by RPC client: %s', packet) - if not self.client.process_packet(frame.data): - _LOG.error('Packet not handled by RPC client: %s', frame.data) - -def _try_connect(sock: socket.socket, port: int, attempts: int = 10) -> None: +def _try_connect(port: int, attempts: int = 10) -> socket.socket: """Tries to connect to the specified port up to the given number of times. This is helpful when connecting to a process that was started by this @@ -205,9 +206,11 @@ def _try_connect(sock: socket.socket, port: int, attempts: int = 10) -> None: time.sleep(0.001) try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', port)) - return + return sock except ConnectionRefusedError: + sock.close() if attempts <= 0: raise @@ -218,21 +221,13 @@ class SocketSubprocess: self._server_process = subprocess.Popen(command, stdin=subprocess.PIPE) self.stdin = self._server_process.stdin - sock = None - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - _try_connect(sock, port) + self.socket: socket.socket = _try_connect(port) # 🧦 except: - if sock: - sock.close() - self._server_process.terminate() self._server_process.communicate() raise - self.socket: socket.socket = sock # 🧦 - def close(self) -> None: try: self.socket.close() @@ -247,16 +242,21 @@ class SocketSubprocess: self.close() -class _PacketFilter: +class PacketFilter(pw_rpc.ChannelManipulator): """Determines if a packet should be kept or dropped for testing purposes.""" _Action = Callable[[int], Tuple[bool, bool]] _KEEP = lambda _: (True, False) _DROP = lambda _: (False, False) def __init__(self, name: str) -> None: + super().__init__() self.name = name self.packet_count = 0 - self._actions: Deque[_PacketFilter._Action] = collections.deque() + self._actions: Deque[PacketFilter._Action] = collections.deque() + + def process_and_send(self, packet: bytes): + if self.keep_packet(packet): + self.send_packet(packet) def reset(self) -> None: self.packet_count = 0 @@ -264,11 +264,11 @@ class _PacketFilter: def keep(self, count: int) -> None: """Keeps the next count packets.""" - self._actions.extend(_PacketFilter._KEEP for _ in range(count)) + self._actions.extend(PacketFilter._KEEP for _ in range(count)) def drop(self, count: int) -> None: """Drops the next count packets.""" - self._actions.extend(_PacketFilter._DROP for _ in range(count)) + self._actions.extend(PacketFilter._DROP for _ in range(count)) def drop_every(self, every: int) -> None: """Drops every Nth packet forever.""" @@ -296,33 +296,21 @@ class _PacketFilter: return keep -class _TestChannelOutput: - def __init__(self, send: Callable[[bytes], Any]) -> None: - self._send = send - self.packets = _PacketFilter('outgoing RPC') - - def __call__(self, data: bytes) -> None: - if self.packets.keep_packet(data): - self._send(data) - - class HdlcRpcLocalServerAndClient: """Runs an RPC server in a subprocess and connects to it over a socket. This can be used to run a local RPC server in an integration test. """ - def __init__(self, - server_command: Sequence, - port: int, - protos: PathsModulesOrProtoLibrary, - *, - for_testing: bool = False) -> None: - """Creates a new HdlcRpcLocalServerAndClient. - - If for_testing=True, the HdlcRpcLocalServerAndClient will have - outgoing_packets and incoming_packets _PacketFilter members that can be - used to program packet loss for testing purposes. - """ + def __init__( + self, + server_command: Sequence, + port: int, + protos: PathsModulesOrProtoLibrary, + *, + incoming_processor: Optional[pw_rpc.ChannelManipulator] = None, + outgoing_processor: Optional[pw_rpc.ChannelManipulator] = None + ) -> None: + """Creates a new HdlcRpcLocalServerAndClient.""" self.server = SocketSubprocess(server_command, port) @@ -333,20 +321,18 @@ class HdlcRpcLocalServerAndClient: self.output = io.BytesIO() self.channel_output: Any = self.server.socket.sendall - if for_testing: - self.channel_output = _TestChannelOutput(self.channel_output) - self.outgoing_packets = self.channel_output.packets - self.incoming_packets = _PacketFilter('incoming RPC') - incoming_filter: Optional[_PacketFilter] = self.incoming_packets - else: - incoming_filter = None + + self._incoming_processor = incoming_processor + if outgoing_processor is not None: + outgoing_processor.send_packet = self.channel_output + self.channel_output = outgoing_processor self.client = HdlcRpcClient( self._bytes_queue.get, protos, default_channels(self.channel_output), self.output.write, - _incoming_packet_filter_for_testing=incoming_filter).client + _incoming_packet_filter_for_testing=incoming_processor).client def _read_from_socket(self): while True: |