aboutsummaryrefslogtreecommitdiff
path: root/pw_hdlc/py/pw_hdlc/rpc.py
diff options
context:
space:
mode:
Diffstat (limited to 'pw_hdlc/py/pw_hdlc/rpc.py')
-rw-r--r--pw_hdlc/py/pw_hdlc/rpc.py94
1 files changed, 40 insertions, 54 deletions
diff --git a/pw_hdlc/py/pw_hdlc/rpc.py b/pw_hdlc/py/pw_hdlc/rpc.py
index 831861b46..b41275331 100644
--- a/pw_hdlc/py/pw_hdlc/rpc.py
+++ b/pw_hdlc/py/pw_hdlc/rpc.py
@@ -139,7 +139,8 @@ class HdlcRpcClient:
output: Callable[[bytes], Any] = write_to_file,
client_impl: pw_rpc.client.ClientImpl = None,
*,
- _incoming_packet_filter_for_testing: '_PacketFilter' = None):
+ _incoming_packet_filter_for_testing: pw_rpc.
+ ChannelManipulator = None):
"""Creates an RPC client configured to communicate using HDLC.
Args:
@@ -159,10 +160,13 @@ class HdlcRpcClient:
self.client = pw_rpc.Client.from_modules(client_impl, channels,
self.protos.modules())
- self._test_filter = _incoming_packet_filter_for_testing
+ rpc_output: Callable[[bytes], Any] = self._handle_rpc_packet
+ if _incoming_packet_filter_for_testing is not None:
+ _incoming_packet_filter_for_testing.send_packet = rpc_output
+ rpc_output = _incoming_packet_filter_for_testing
frame_handlers: FrameHandlers = {
- DEFAULT_ADDRESS: self._handle_rpc_packet,
+ DEFAULT_ADDRESS: lambda frame: rpc_output(frame.data),
STDOUT_ADDRESS: lambda frame: output(frame.data),
}
@@ -184,15 +188,12 @@ class HdlcRpcClient:
return self.client.channel(channel_id).rpcs
- def _handle_rpc_packet(self, frame: Frame) -> None:
- if self._test_filter and not self._test_filter.keep_packet(frame.data):
- return
+ def _handle_rpc_packet(self, packet: bytes) -> None:
+ if not self.client.process_packet(packet):
+ _LOG.error('Packet not handled by RPC client: %s', packet)
- if not self.client.process_packet(frame.data):
- _LOG.error('Packet not handled by RPC client: %s', frame.data)
-
-def _try_connect(sock: socket.socket, port: int, attempts: int = 10) -> None:
+def _try_connect(port: int, attempts: int = 10) -> socket.socket:
"""Tries to connect to the specified port up to the given number of times.
This is helpful when connecting to a process that was started by this
@@ -205,9 +206,11 @@ def _try_connect(sock: socket.socket, port: int, attempts: int = 10) -> None:
time.sleep(0.001)
try:
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', port))
- return
+ return sock
except ConnectionRefusedError:
+ sock.close()
if attempts <= 0:
raise
@@ -218,21 +221,13 @@ class SocketSubprocess:
self._server_process = subprocess.Popen(command, stdin=subprocess.PIPE)
self.stdin = self._server_process.stdin
- sock = None
-
try:
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- _try_connect(sock, port)
+ self.socket: socket.socket = _try_connect(port) # 🧦
except:
- if sock:
- sock.close()
-
self._server_process.terminate()
self._server_process.communicate()
raise
- self.socket: socket.socket = sock # 🧦
-
def close(self) -> None:
try:
self.socket.close()
@@ -247,16 +242,21 @@ class SocketSubprocess:
self.close()
-class _PacketFilter:
+class PacketFilter(pw_rpc.ChannelManipulator):
"""Determines if a packet should be kept or dropped for testing purposes."""
_Action = Callable[[int], Tuple[bool, bool]]
_KEEP = lambda _: (True, False)
_DROP = lambda _: (False, False)
def __init__(self, name: str) -> None:
+ super().__init__()
self.name = name
self.packet_count = 0
- self._actions: Deque[_PacketFilter._Action] = collections.deque()
+ self._actions: Deque[PacketFilter._Action] = collections.deque()
+
+ def process_and_send(self, packet: bytes):
+ if self.keep_packet(packet):
+ self.send_packet(packet)
def reset(self) -> None:
self.packet_count = 0
@@ -264,11 +264,11 @@ class _PacketFilter:
def keep(self, count: int) -> None:
"""Keeps the next count packets."""
- self._actions.extend(_PacketFilter._KEEP for _ in range(count))
+ self._actions.extend(PacketFilter._KEEP for _ in range(count))
def drop(self, count: int) -> None:
"""Drops the next count packets."""
- self._actions.extend(_PacketFilter._DROP for _ in range(count))
+ self._actions.extend(PacketFilter._DROP for _ in range(count))
def drop_every(self, every: int) -> None:
"""Drops every Nth packet forever."""
@@ -296,33 +296,21 @@ class _PacketFilter:
return keep
-class _TestChannelOutput:
- def __init__(self, send: Callable[[bytes], Any]) -> None:
- self._send = send
- self.packets = _PacketFilter('outgoing RPC')
-
- def __call__(self, data: bytes) -> None:
- if self.packets.keep_packet(data):
- self._send(data)
-
-
class HdlcRpcLocalServerAndClient:
"""Runs an RPC server in a subprocess and connects to it over a socket.
This can be used to run a local RPC server in an integration test.
"""
- def __init__(self,
- server_command: Sequence,
- port: int,
- protos: PathsModulesOrProtoLibrary,
- *,
- for_testing: bool = False) -> None:
- """Creates a new HdlcRpcLocalServerAndClient.
-
- If for_testing=True, the HdlcRpcLocalServerAndClient will have
- outgoing_packets and incoming_packets _PacketFilter members that can be
- used to program packet loss for testing purposes.
- """
+ def __init__(
+ self,
+ server_command: Sequence,
+ port: int,
+ protos: PathsModulesOrProtoLibrary,
+ *,
+ incoming_processor: Optional[pw_rpc.ChannelManipulator] = None,
+ outgoing_processor: Optional[pw_rpc.ChannelManipulator] = None
+ ) -> None:
+ """Creates a new HdlcRpcLocalServerAndClient."""
self.server = SocketSubprocess(server_command, port)
@@ -333,20 +321,18 @@ class HdlcRpcLocalServerAndClient:
self.output = io.BytesIO()
self.channel_output: Any = self.server.socket.sendall
- if for_testing:
- self.channel_output = _TestChannelOutput(self.channel_output)
- self.outgoing_packets = self.channel_output.packets
- self.incoming_packets = _PacketFilter('incoming RPC')
- incoming_filter: Optional[_PacketFilter] = self.incoming_packets
- else:
- incoming_filter = None
+
+ self._incoming_processor = incoming_processor
+ if outgoing_processor is not None:
+ outgoing_processor.send_packet = self.channel_output
+ self.channel_output = outgoing_processor
self.client = HdlcRpcClient(
self._bytes_queue.get,
protos,
default_channels(self.channel_output),
self.output.write,
- _incoming_packet_filter_for_testing=incoming_filter).client
+ _incoming_packet_filter_for_testing=incoming_processor).client
def _read_from_socket(self):
while True: