diff options
Diffstat (limited to 'pw_transfer/py/tests/transfer_test.py')
-rw-r--r-- | pw_transfer/py/tests/transfer_test.py | 69 |
1 files changed, 63 insertions, 6 deletions
diff --git a/pw_transfer/py/tests/transfer_test.py b/pw_transfer/py/tests/transfer_test.py index 8b29293f3..5b4057d98 100644 --- a/pw_transfer/py/tests/transfer_test.py +++ b/pw_transfer/py/tests/transfer_test.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2022 The Pigweed Authors +# Copyright 2023 The Pigweed Authors # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of @@ -61,7 +61,7 @@ class TransferManagerTest(unittest.TestCase): self._service = self._client.channel(1).rpcs.pw.transfer.Transfer self._sent_chunks: List[transfer_pb2.Chunk] = [] - self._packets_to_send: List[List[bytes]] = [] + self._packets_to_send: List[List[packet_pb2.RpcPacket]] = [] def _enqueue_server_responses( self, method: _Method, responses: Iterable[Iterable[transfer_pb2.Chunk]] @@ -77,7 +77,7 @@ class TransferManagerTest(unittest.TestCase): method_id=method.value, status=Status.OK.value, payload=response.SerializeToString(), - ).SerializeToString() + ) ) self._packets_to_send.append(serialized_group) @@ -90,7 +90,7 @@ class TransferManagerTest(unittest.TestCase): service_id=_TRANSFER_SERVICE_ID, method_id=method.value, status=error.value, - ).SerializeToString() + ) ] ) @@ -106,7 +106,8 @@ class TransferManagerTest(unittest.TestCase): if self._packets_to_send: responses = self._packets_to_send.pop(0) for response in responses: - self._client.process_packet(response) + response.call_id = packet.call_id + self._client.process_packet(response.SerializeToString()) def _received_data(self) -> bytearray: data = bytearray() @@ -401,6 +402,61 @@ class TransferManagerTest(unittest.TestCase): self.assertEqual(exception.resource_id, 31) self.assertEqual(exception.status, Status.NOT_FOUND) + def test_read_transfer_reopen(self) -> None: + manager = pw_transfer.Manager( + self._service, + initial_response_timeout_s=DEFAULT_TIMEOUT_S, + default_response_timeout_s=DEFAULT_TIMEOUT_S, + ) + + # A FAILED_PRECONDITION error should attempt a stream reopen. + self._enqueue_server_error(_Method.READ, Status.FAILED_PRECONDITION) + self._enqueue_server_responses( + _Method.READ, + ( + ( + transfer_pb2.Chunk( + transfer_id=3, + offset=0, + data=b'xyz', + remaining_bytes=0, + ), + ), + ), + ) + + # The transfer should complete following reopen, with the first chunk + # being retried. + data = manager.read(3) + self.assertEqual(data, b'xyz') + self.assertEqual(len(self._sent_chunks), 3) + self.assertEqual(self._sent_chunks[0], self._sent_chunks[1]) + self.assertTrue(self._sent_chunks[-1].HasField('status')) + self.assertEqual(self._sent_chunks[-1].status, 0) + + def test_read_transfer_reopen_max_attempts(self) -> None: + manager = pw_transfer.Manager( + self._service, + initial_response_timeout_s=DEFAULT_TIMEOUT_S, + default_response_timeout_s=DEFAULT_TIMEOUT_S, + ) + + # A FAILED_PRECONDITION error should attempt a stream reopen; enqueue + # several. + self._enqueue_server_error(_Method.READ, Status.FAILED_PRECONDITION) + self._enqueue_server_error(_Method.READ, Status.FAILED_PRECONDITION) + self._enqueue_server_error(_Method.READ, Status.FAILED_PRECONDITION) + self._enqueue_server_error(_Method.READ, Status.FAILED_PRECONDITION) + self._enqueue_server_error(_Method.READ, Status.FAILED_PRECONDITION) + + with self.assertRaises(pw_transfer.Error) as context: + manager.read(81) + + exception = context.exception + self.assertEqual(len(self._sent_chunks), 4) + self.assertEqual(exception.resource_id, 81) + self.assertEqual(exception.status, Status.INTERNAL) + def test_read_transfer_server_error(self) -> None: manager = pw_transfer.Manager( self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S @@ -417,7 +473,8 @@ class TransferManagerTest(unittest.TestCase): def test_write_transfer_basic(self) -> None: manager = pw_transfer.Manager( - self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S + self._service, + default_response_timeout_s=DEFAULT_TIMEOUT_S, ) self._enqueue_server_responses( |