diff options
author | Minghao Li <lmh463896910@gmail.com> | 2022-03-22 17:50:20 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-03-22 02:50:20 -0700 |
commit | 7d9bb55a51084bed22b511eda488843158c5309e (patch) | |
tree | 893279c473c36aeee6eef447551ef9f72d30f76d | |
parent | 59464bdabe0b6f0ea33ec52ceaf8c7c4a7f73bdb (diff) | |
download | mobly-7d9bb55a51084bed22b511eda488843158c5309e.tar.gz |
Add a new base class for snippet client (#795)
This is the first PR to kick off the standardization of Mobly snippet client, which would enable us to scale the snippet mechanism to a broader set of platforms.
Once the new snippet client code is proven, we will enable it by default and remove the old snippet client under `android_device_lib`.
-rw-r--r-- | mobly/snippet/__init__.py | 13 | ||||
-rw-r--r-- | mobly/snippet/client_base.py | 466 | ||||
-rw-r--r-- | mobly/snippet/errors.py | 52 | ||||
-rw-r--r-- | tests/mobly/snippet/__init__.py | 13 | ||||
-rwxr-xr-x | tests/mobly/snippet/client_base_test.py | 442 |
5 files changed, 986 insertions, 0 deletions
diff --git a/mobly/snippet/__init__.py b/mobly/snippet/__init__.py new file mode 100644 index 0000000..ac3f9e6 --- /dev/null +++ b/mobly/snippet/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mobly/snippet/client_base.py b/mobly/snippet/client_base.py new file mode 100644 index 0000000..8fdd6c9 --- /dev/null +++ b/mobly/snippet/client_base.py @@ -0,0 +1,466 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The JSON RPC client base for communicating with snippet servers. + +The JSON RPC protocol expected by this module is: + +.. code-block:: json + + Request: + { + 'id': <Required. Monotonically increasing integer containing the ID of this + request.>, + 'method': <Required. String containing the name of the method to execute.>, + 'params': <Required. JSON array containing the arguments to the method, + `null` if no positional arguments for the RPC method.>, + 'kwargs': <Optional. JSON dict containing the keyword arguments for the + method, `null` if no positional arguments for the RPC method.>, + } + + Response: + { + 'error': <Required. String containing the error thrown by executing the + method, `null` if no error occurred.>, + 'id': <Required. Int id of request that this response maps to.>, + 'result': <Required. Arbitrary JSON object containing the result of + executing the method, `null` if the method could not be executed + or returned void.>, + 'callback': <Required. String that represents a callback ID used to + identify events associated with a particular CallbackHandler + object, `null` if this is not a async RPC.>, + } +""" + +import abc +import contextlib +import enum +import json +import threading +import time + +from mobly.snippet import errors + +# Maximum logging length of RPC response in DEBUG level when verbose logging is +# off. +_MAX_RPC_RESP_LOGGING_LENGTH = 1024 + +# The required field names of RPC response. +RPC_RESPONSE_REQUIRED_FIELDS = ('id', 'error', 'result', 'callback') + + +class StartServerStages(enum.Enum): + """The stages for the starting server process.""" + BEFORE_STARTING_SERVER = 1 + DO_START_SERVER = 2 + BUILD_CONNECTION = 3 + AFTER_STARTING_SERVER = 4 + + +class ClientBase(abc.ABC): + """Base class for JSON RPC clients that connect to snippet servers. + + Connects to a remote device running a JSON RPC compatible server. Users call + the function `start_server` to start the server on the remote device before + sending any RPC. After sending all RPCs, users call the function `stop_server` + to stop all the running instances. + + Attributes: + package: str, the user-visible name of the snippet library being + communicated with. + host_port: int, the host port of this RPC client. + device_port: int, the device port of this RPC client. + log: Logger, the logger of the corresponding device controller. + verbose_logging: bool, if True, prints more detailed log + information. Default is True. + """ + + def __init__(self, package, device): + """Initializes the instance of ClientBase. + + Args: + package: str, the user-visible name of the snippet library being + communicated with. + device: DeviceController, the device object associated with a client. + """ + + self.package = package + self.host_port = None + self.device_port = None + self.log = device.log + self.verbose_logging = True + self._device = device + self._counter = None + self._lock = threading.Lock() + self._event_client = None + + def __del__(self): + self.close_connection() + + def start_server(self): + """Starts the server on the remote device and connects to it. + + This process contains four stages: + - before_starting_server: prepares for starting the server. + - do_start_server: starts the server on the remote device. + - build_connection: builds a connection with the server. + - after_starting_server: does the things after the server is available. + + After this, the self.host_port and self.device_port attributes must be + set. + + Raises: + errors.ProtocolError: something went wrong when exchanging data with the + server. + errors.ServerStartPreCheckError: when prechecks for starting the server + failed. + errors.ServerStartError: when failed to start the snippet server. + """ + + @contextlib.contextmanager + def _execute_one_stage(stage): + """Context manager for executing one stage. + + Args: + stage: StartServerStages, the stage which is running under this + context manager. + + Yields: + None. + """ + self.log.debug('[START_SERVER] Running the stage %s.', stage.name) + yield + self.log.debug('[START_SERVER] Finished the stage %s.', stage.name) + + self.log.debug('Starting the server.') + start_time = time.perf_counter() + + with _execute_one_stage(StartServerStages.BEFORE_STARTING_SERVER): + self.before_starting_server() + + try: + with _execute_one_stage(StartServerStages.DO_START_SERVER): + self.do_start_server() + + with _execute_one_stage(StartServerStages.BUILD_CONNECTION): + self._build_connection() + + with _execute_one_stage(StartServerStages.AFTER_STARTING_SERVER): + self.after_starting_server() + + except Exception: + self.log.error('[START SERVER] Error occurs when starting the server.') + try: + self.stop_server() + except Exception: # pylint: disable=broad-except + # Only prints this exception and re-raises the original exception + self.log.exception('[START_SERVER] Failed to stop server because of ' + 'new exception.') + + raise + + self.log.debug('Snippet %s started after %.1fs on host port %d.', + self.package, + time.perf_counter() - start_time, self.host_port) + + @abc.abstractmethod + def before_starting_server(self): + """Prepares for starting the server. + + For example, subclass can check or modify the device settings at this + stage. + + Raises: + errors.ServerStartPreCheckError: when prechecks for starting the server + failed. + """ + + @abc.abstractmethod + def do_start_server(self): + """Starts the server on the remote device. + + The client has completed the preparations, so the client calls this + function to start the server. + """ + + def _build_connection(self): + """Proxy function of build_connection. + + This function resets the RPC id counter before calling `build_connection`. + """ + self._counter = self._id_counter() + self.build_connection() + + @abc.abstractmethod + def build_connection(self): + """Builds a connection with the server on the remote device. + + The command to start the server has been already sent before calling this + function. So the client builds a connection to it and sends a handshake + to ensure the server is available for upcoming RPCs. + + This function uses self.host_port for communicating with the server. If + self.host_port is 0 or None, this function finds an available host port to + build connection and set self.host_port to the found port. + + Raises: + errors.ProtocolError: something went wrong when exchanging data with the + server. + """ + + @abc.abstractmethod + def after_starting_server(self): + """Does the things after the server is available. + + For example, subclass can get device information from the server. + """ + + def __getattr__(self, name): + """Wrapper for python magic to turn method calls into RPCs.""" + + def rpc_call(*args, **kwargs): + return self._rpc(name, *args, **kwargs) + + return rpc_call + + def _id_counter(self): + """Returns an id generator.""" + i = 0 + while True: + yield i + i += 1 + + def set_snippet_client_verbose_logging(self, verbose): + """Switches verbose logging. True for logging full RPC responses. + + By default it will write full messages returned from RPCs. Turning off the + verbose logging will result in writing no more than + _MAX_RPC_RESP_LOGGING_LENGTH characters per RPC returned string. + + _MAX_RPC_RESP_LOGGING_LENGTH will be set to 1024 by default. The length + contains the full RPC response in JSON format, not just the RPC result + field. + + Args: + verbose: bool, if True, turns on verbose logging, otherwise turns off. + """ + self.log.info('Sets verbose logging to %s.', verbose) + self.verbose_logging = verbose + + @abc.abstractmethod + def restore_server_connection(self, port=None): + """Reconnects to the server after the device was disconnected. + + Instead of creating a new instance of the client: + - Uses the given port (or finds a new available host_port if 0 or None is + given). + - Tries to connect to the remote server with the selected port. + + Args: + port: int, if given, this is the host port from which to connect to the + remote device port. Otherwise, finds a new available port as host + port. + + Raises: + errors.ServerRestoreConnectionError: when failed to restore the connection + with the snippet server. + """ + + def _rpc(self, rpc_func_name, *args, **kwargs): + """Sends a RPC to the server. + + Args: + rpc_func_name: str, the name of the snippet function to execute on the + server. + *args: any, the positional arguments of the RPC request. + **kwargs: any, the keyword arguments of the RPC request. + + Returns: + The result of the RPC. + + Raises: + errors.ProtocolError: something went wrong when exchanging data with the + server. + errors.ApiError: the RPC went through, however executed with errors. + """ + try: + self.check_server_proc_running() + except Exception: + self.log.error( + 'Server process running check failed, skip sending RPC method(%s).', + rpc_func_name) + raise + + with self._lock: + rpc_id = next(self._counter) + request = self._gen_rpc_request(rpc_id, rpc_func_name, *args, **kwargs) + + self.log.debug('Sending RPC request %s.', request) + response = self.send_rpc_request(request) + self.log.debug('RPC request sent.') + + if self.verbose_logging or _MAX_RPC_RESP_LOGGING_LENGTH >= len(response): + self.log.debug('Snippet received: %s', response) + else: + self.log.debug('Snippet received: %s... %d chars are truncated', + response[:_MAX_RPC_RESP_LOGGING_LENGTH], + len(response) - _MAX_RPC_RESP_LOGGING_LENGTH) + + response_decoded = self._decode_response_string_and_validate_format( + rpc_id, response) + return self._handle_rpc_response(rpc_func_name, response_decoded) + + @abc.abstractmethod + def check_server_proc_running(self): + """Checks whether the server is still running. + + If the server is not running, it throws an error. As this function is called + each time the client tries to send a RPC, this should be a quick check + without affecting performance. Otherwise it is fine to not check anything. + + Raises: + errors.ServerDiedError: if the server died. + """ + + def _gen_rpc_request(self, rpc_id, rpc_func_name, *args, **kwargs): + """Generates the JSON RPC request. + + In the generated JSON string, the fields are sorted by keys in ascending + order. + + Args: + rpc_id: int, the id of this RPC. + rpc_func_name: str, the name of the snippet function to execute + on the server. + *args: any, the positional arguments of the RPC. + **kwargs: any, the keyword arguments of the RPC. + + Returns: + A string of the JSON RPC request. + """ + data = {'id': rpc_id, 'method': rpc_func_name, 'params': args} + if kwargs: + data['kwargs'] = kwargs + return json.dumps(data, sort_keys=True) + + @abc.abstractmethod + def send_rpc_request(self, request): + """Sends the JSON RPC request to the server and gets a response. + + Note that the request and response are both in string format. So if the + connection with server provides interfaces in bytes format, please + transform them to string in the implementation of this function. + + Args: + request: str, a string of the RPC request. + + Returns: + A string of the RPC response. + + Raises: + errors.ProtocolError: something went wrong when exchanging data with the + server. + """ + + def _decode_response_string_and_validate_format(self, rpc_id, response): + """Decodes response JSON string to python dict and validates its format. + + Args: + rpc_id: int, the actual id of this RPC. It should be the same with the id + in the response, otherwise throws an error. + response: str, the JSON string of the RPC response. + + Returns: + A dict decoded from the response JSON string. + + Raises: + errors.ProtocolError: if the response format is invalid. + """ + if not response: + raise errors.ProtocolError(self._device, + errors.ProtocolError.NO_RESPONSE_FROM_SERVER) + + result = json.loads(response) + for field_name in RPC_RESPONSE_REQUIRED_FIELDS: + if field_name not in result: + raise errors.ProtocolError( + self._device, + errors.ProtocolError.RESPONSE_MISSING_FIELD % field_name) + + if result['id'] != rpc_id: + raise errors.ProtocolError(self._device, + errors.ProtocolError.MISMATCHED_API_ID) + + return result + + def _handle_rpc_response(self, rpc_func_name, response): + """Handles the content of RPC response. + + If the RPC response contains error information, it throws an error. If the + RPC is asynchronous, it creates and returns a callback handler + object. Otherwise, it returns the result field of the response. + + Args: + rpc_func_name: str, the name of the snippet function that this RPC + triggered on the snippet server. + response: dict, the object decoded from the response JSON string. + + Returns: + The result of the RPC. If synchronous RPC, it is the result field of the + response. If asynchronous RPC, it is the callback handler object. + + Raises: + errors.ApiError: if the snippet function executed with errors. + """ + + if response['error']: + raise errors.ApiError(self._device, response['error']) + if response['callback'] is not None: + return self.handle_callback(response['callback'], response['result'], + rpc_func_name) + return response['result'] + + @abc.abstractmethod + def handle_callback(self, callback_id, ret_value, rpc_func_name): + """Creates a callback handler for the asynchronous RPC. + + Args: + callback_id: str, the callback ID for creating a callback handler object. + ret_value: any, the result field of the RPC response. + rpc_func_name: str, the name of the snippet function executed on the + server. + + Returns: + The callback handler object. + """ + + def stop_server(self): + """Proxy function of do_stop_server.""" + self.log.debug('Stopping snippet %s.', self.package) + self.do_stop_server() + self.log.debug('Snippet %s stopped.', self.package) + + @abc.abstractmethod + def do_stop_server(self): + """Kills any running instance of the server.""" + + @abc.abstractmethod + def close_connection(self): + """Closes the connection to the snippet server on the device. + + This is a unilateral closing from the client side, without tearing down + the snippet server running on the device. + + The connection to the snippet server can be re-established by calling + `restore_server_connection`. + """ diff --git a/mobly/snippet/errors.py b/mobly/snippet/errors.py new file mode 100644 index 0000000..d22e9fc --- /dev/null +++ b/mobly/snippet/errors.py @@ -0,0 +1,52 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for errors thrown from snippet client objects.""" +# TODO(mhaoli): Package `mobly.snippet` should not import errors from +# android_device_lib. However, android_device_lib.DeviceError is the base error +# for the errors thrown from Android snippet clients and device controllers. +# We should resolve this legacy problem. +from mobly.controllers.android_device_lib import errors + + +class Error(errors.DeviceError): + """Root error type for snippet clients.""" + + +class ServerRestoreConnectionError(Error): + """Raised when failed to restore the connection with the snippet server.""" + + +class ServerStartError(Error): + """Raised when failed to start the snippet server.""" + + +class ServerStartPreCheckError(Error): + """Raised when prechecks for starting the snippet server failed.""" + + +class ApiError(Error): + """Raised when remote API reported an error.""" + + +class ProtocolError(Error): + """Raised when there was an error in exchanging data with server.""" + NO_RESPONSE_FROM_HANDSHAKE = 'No response from handshake.' + NO_RESPONSE_FROM_SERVER = ('No response from server. ' + 'Check the device logcat for crashes.') + MISMATCHED_API_ID = 'RPC request-response ID mismatch.' + RESPONSE_MISSING_FIELD = 'Missing required field in the RPC response: %s.' + + +class ServerDiedError(Error): + """Raised if the snippet server died before all tests finish.""" diff --git a/tests/mobly/snippet/__init__.py b/tests/mobly/snippet/__init__.py new file mode 100644 index 0000000..ac3f9e6 --- /dev/null +++ b/tests/mobly/snippet/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/mobly/snippet/client_base_test.py b/tests/mobly/snippet/client_base_test.py new file mode 100755 index 0000000..4da74d6 --- /dev/null +++ b/tests/mobly/snippet/client_base_test.py @@ -0,0 +1,442 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for mobly.snippet.client_base.""" + +import logging +import random +import string +import unittest +from unittest import mock + +from mobly.snippet import client_base +from mobly.snippet import errors + + +def _generate_fix_length_rpc_response( + response_length, + template='{"id": 0, "result": "%s", "error": null, "callback": null}'): + """Generates a RPC response string with specified length. + + This function generates a random string and formats the template with the + generated random string to get the response string. This function formats + the template with printf style string formatting. + + Args: + response_length: int, the length of the response string to generate. + template: str, the template used for generating the response string. + + Returns: + The generated response string. + + Raises: + ValueError: if the specified length is too small to generate a response. + """ + # We need to -2 here because the string formatting will delete the substring + # '%s' in the template, of which the length is 2. + result_length = response_length - (len(template) - 2) + if result_length < 0: + raise ValueError(f'The response_length should be no smaller than ' + f'template_length + 2. Got response_length ' + f'{response_length}, template_length {len(template)}.') + chars = string.ascii_letters + string.digits + return template % ''.join(random.choice(chars) for _ in range(result_length)) + + +class FakeClient(client_base.ClientBase): + """Fake client class for unit tests.""" + + def __init__(self): + """Initializes the instance by mocking a device controller.""" + mock_device = mock.Mock() + mock_device.log = logging + super().__init__(package='FakeClient', device=mock_device) + + # Override abstract methods to enable initialization + def before_starting_server(self): + pass + + def do_start_server(self): + pass + + def build_connection(self): + pass + + def after_starting_server(self): + pass + + def restore_server_connection(self, port=None): + pass + + def check_server_proc_running(self): + pass + + def send_rpc_request(self, request): + pass + + def handle_callback(self, callback_id, ret_value, rpc_func_name): + pass + + def do_stop_server(self): + pass + + def close_connection(self): + pass + + +class ClientBaseTest(unittest.TestCase): + """Unit tests for mobly.snippet.client_base.ClientBase.""" + + def setUp(self): + super().setUp() + self.client = FakeClient() + self.client.host_port = 12345 + + @mock.patch.object(FakeClient, 'before_starting_server') + @mock.patch.object(FakeClient, 'do_start_server') + @mock.patch.object(FakeClient, '_build_connection') + @mock.patch.object(FakeClient, 'after_starting_server') + def test_start_server_stage_order(self, mock_after_func, mock_build_conn_func, + mock_do_start_func, mock_before_func): + """Test that starting server runs its stages in expected order.""" + order_manager = mock.Mock() + order_manager.attach_mock(mock_before_func, 'mock_before_func') + order_manager.attach_mock(mock_do_start_func, 'mock_do_start_func') + order_manager.attach_mock(mock_build_conn_func, 'mock_build_conn_func') + order_manager.attach_mock(mock_after_func, 'mock_after_func') + + self.client.start_server() + + expected_call_order = [ + mock.call.mock_before_func(), + mock.call.mock_do_start_func(), + mock.call.mock_build_conn_func(), + mock.call.mock_after_func(), + ] + self.assertListEqual(order_manager.mock_calls, expected_call_order) + + @mock.patch.object(FakeClient, 'stop_server') + @mock.patch.object(FakeClient, 'before_starting_server') + def test_start_server_before_starting_server_fail(self, mock_before_func, + mock_stop_server): + """Test starting server's stage before_starting_server fails.""" + mock_before_func.side_effect = Exception('ha') + + with self.assertRaisesRegex(Exception, 'ha'): + self.client.start_server() + mock_stop_server.assert_not_called() + + @mock.patch.object(FakeClient, 'stop_server') + @mock.patch.object(FakeClient, 'do_start_server') + def test_start_server_do_start_server_fail(self, mock_do_start_func, + mock_stop_server): + """Test starting server's stage do_start_server fails.""" + mock_do_start_func.side_effect = Exception('ha') + + with self.assertRaisesRegex(Exception, 'ha'): + self.client.start_server() + mock_stop_server.assert_called() + + @mock.patch.object(FakeClient, 'stop_server') + @mock.patch.object(FakeClient, '_build_connection') + def test_start_server_build_connection_fail(self, mock_build_conn_func, + mock_stop_server): + """Test starting server's stage _build_connection fails.""" + mock_build_conn_func.side_effect = Exception('ha') + + with self.assertRaisesRegex(Exception, 'ha'): + self.client.start_server() + mock_stop_server.assert_called() + + @mock.patch.object(FakeClient, 'stop_server') + @mock.patch.object(FakeClient, 'after_starting_server') + def test_start_server_after_starting_server_fail(self, mock_after_func, + mock_stop_server): + """Test starting server's stage after_starting_server fails.""" + mock_after_func.side_effect = Exception('ha') + + with self.assertRaisesRegex(Exception, 'ha'): + self.client.start_server() + mock_stop_server.assert_called() + + @mock.patch.object(FakeClient, 'check_server_proc_running') + @mock.patch.object(FakeClient, '_gen_rpc_request') + @mock.patch.object(FakeClient, 'send_rpc_request') + @mock.patch.object(FakeClient, '_decode_response_string_and_validate_format') + @mock.patch.object(FakeClient, '_handle_rpc_response') + def test_rpc_stage_dependencies(self, mock_handle_resp, mock_decode_resp_str, + mock_send_request, mock_gen_request, + mock_precheck): + """Test the internal dependencies when sending a RPC. + + When sending a RPC, it calls multiple functions in specific order, and + each function uses the output of the previously called function. This test + case checks above dependencies. + + Args: + mock_handle_resp: the mock function of FakeClient._handle_rpc_response. + mock_decode_resp_str: the mock function of + FakeClient._decode_response_string_and_validate_format. + mock_send_request: the mock function of FakeClient.send_rpc_request. + mock_gen_request: the mock function of FakeClient._gen_rpc_request. + mock_precheck: the mock function of FakeClient.check_server_proc_running. + """ + self.client.start_server() + + expected_response_str = ('{"id": 0, "result": 123, "error": null, ' + '"callback": null}') + expected_response_dict = { + 'id': 0, + 'result': 123, + 'error': None, + 'callback': None, + } + expected_request = ('{"id": 10, "method": "some_rpc", "params": [1, 2],' + '"kwargs": {"test_key": 3}') + expected_result = 123 + + mock_gen_request.return_value = expected_request + mock_send_request.return_value = expected_response_str + mock_decode_resp_str.return_value = expected_response_dict + mock_handle_resp.return_value = expected_result + rpc_result = self.client.some_rpc(1, 2, test_key=3) + + mock_precheck.assert_called() + mock_gen_request.assert_called_with(0, 'some_rpc', 1, 2, test_key=3) + mock_send_request.assert_called_with(expected_request) + mock_decode_resp_str.assert_called_with(0, expected_response_str) + mock_handle_resp.assert_called_with('some_rpc', expected_response_dict) + self.assertEqual(rpc_result, expected_result) + + @mock.patch.object(FakeClient, 'check_server_proc_running') + @mock.patch.object(FakeClient, '_gen_rpc_request') + @mock.patch.object(FakeClient, 'send_rpc_request') + @mock.patch.object(FakeClient, '_decode_response_string_and_validate_format') + @mock.patch.object(FakeClient, '_handle_rpc_response') + def test_rpc_precheck_fail(self, mock_handle_resp, mock_decode_resp_str, + mock_send_request, mock_gen_request, + mock_precheck): + """Test when RPC precheck fails it will skip sending RPC.""" + self.client.start_server() + mock_precheck.side_effect = Exception('server_died') + + with self.assertRaisesRegex(Exception, 'server_died'): + self.client.some_rpc(1, 2) + + mock_gen_request.assert_not_called() + mock_send_request.assert_not_called() + mock_handle_resp.assert_not_called() + mock_decode_resp_str.assert_not_called() + + def test_gen_request(self): + """Test generating a RPC request. + + Test that _gen_rpc_request returns a string represents a JSON dict + with all required fields. + """ + request = self.client._gen_rpc_request(0, 'test_rpc', 1, 2, test_key=3) + expected_result = ('{"id": 0, "kwargs": {"test_key": 3}, ' + '"method": "test_rpc", "params": [1, 2]}') + self.assertEqual(request, expected_result) + + def test_gen_request_without_kwargs(self): + """Test no keyword arguments. + + Test that _gen_rpc_request ignores the kwargs field when no + keyword arguments. + """ + request = self.client._gen_rpc_request(0, 'test_rpc', 1, 2) + expected_result = '{"id": 0, "method": "test_rpc", "params": [1, 2]}' + self.assertEqual(request, expected_result) + + def test_rpc_no_response(self): + """Test parsing an empty RPC response.""" + with self.assertRaisesRegex(errors.ProtocolError, + errors.ProtocolError.NO_RESPONSE_FROM_SERVER): + self.client._decode_response_string_and_validate_format(0, '') + + with self.assertRaisesRegex(errors.ProtocolError, + errors.ProtocolError.NO_RESPONSE_FROM_SERVER): + self.client._decode_response_string_and_validate_format(0, None) + + def test_rpc_response_missing_fields(self): + """Test parsing a RPC response that misses some required fields.""" + mock_resp_without_id = '{"result": 123, "error": null, "callback": null}' + with self.assertRaisesRegex( + errors.ProtocolError, + errors.ProtocolError.RESPONSE_MISSING_FIELD % 'id'): + self.client._decode_response_string_and_validate_format( + 10, mock_resp_without_id) + + mock_resp_without_result = '{"id": 10, "error": null, "callback": null}' + with self.assertRaisesRegex( + errors.ProtocolError, + errors.ProtocolError.RESPONSE_MISSING_FIELD % 'result'): + self.client._decode_response_string_and_validate_format( + 10, mock_resp_without_result) + + mock_resp_without_error = '{"id": 10, "result": 123, "callback": null}' + with self.assertRaisesRegex( + errors.ProtocolError, + errors.ProtocolError.RESPONSE_MISSING_FIELD % 'error'): + self.client._decode_response_string_and_validate_format( + 10, mock_resp_without_error) + + mock_resp_without_callback = '{"id": 10, "result": 123, "error": null}' + with self.assertRaisesRegex( + errors.ProtocolError, + errors.ProtocolError.RESPONSE_MISSING_FIELD % 'callback'): + self.client._decode_response_string_and_validate_format( + 10, mock_resp_without_callback) + + def test_rpc_response_error(self): + """Test parsing a RPC response with a non-empty error field.""" + mock_resp_with_error = { + 'id': 10, + 'result': 123, + 'error': 'some_error', + 'callback': None, + } + with self.assertRaisesRegex(errors.ApiError, 'some_error'): + self.client._handle_rpc_response('some_rpc', mock_resp_with_error) + + def test_rpc_response_callback(self): + """Test parsing response function handles the callback field well.""" + # Call handle_callback function if the "callback" field is not None + mock_resp_with_callback = { + 'id': 10, + 'result': 123, + 'error': None, + 'callback': '1-0' + } + with mock.patch.object(self.client, + 'handle_callback') as mock_handle_callback: + expected_callback = mock.Mock() + mock_handle_callback.return_value = expected_callback + + rpc_result = self.client._handle_rpc_response('some_rpc', + mock_resp_with_callback) + mock_handle_callback.assert_called_with('1-0', 123, 'some_rpc') + # Ensure the RPC function returns what handle_callback returned + self.assertIs(expected_callback, rpc_result) + + # Do not call handle_callback function if the "callback" field is None + mock_resp_without_callback = { + 'id': 10, + 'result': 123, + 'error': None, + 'callback': None + } + with mock.patch.object(self.client, + 'handle_callback') as mock_handle_callback: + self.client._handle_rpc_response('some_rpc', mock_resp_without_callback) + mock_handle_callback.assert_not_called() + + def test_rpc_response_id_mismatch(self): + """Test parsing a RPC response with wrong id.""" + right_id = 5 + wrong_id = 20 + resp = f'{{"id": {right_id}, "result": 1, "error": null, "callback": null}}' + + with self.assertRaisesRegex(errors.ProtocolError, + errors.ProtocolError.MISMATCHED_API_ID): + self.client._decode_response_string_and_validate_format(wrong_id, resp) + + @mock.patch.object(FakeClient, 'send_rpc_request') + def test_rpc_verbose_logging_with_long_string(self, mock_send_request): + """Test RPC response isn't truncated when verbose logging is on.""" + mock_log = mock.Mock() + self.client.log = mock_log + self.client.set_snippet_client_verbose_logging(True) + self.client.start_server() + + resp = _generate_fix_length_rpc_response( + client_base._MAX_RPC_RESP_LOGGING_LENGTH * 2) + mock_send_request.return_value = resp + self.client.some_rpc(1, 2) + mock_log.debug.assert_called_with('Snippet received: %s', resp) + + @mock.patch.object(FakeClient, 'send_rpc_request') + def test_rpc_truncated_logging_short_response(self, mock_send_request): + """Test RPC response isn't truncated with small length.""" + mock_log = mock.Mock() + self.client.log = mock_log + self.client.set_snippet_client_verbose_logging(False) + self.client.start_server() + + resp = _generate_fix_length_rpc_response( + int(client_base._MAX_RPC_RESP_LOGGING_LENGTH // 2)) + mock_send_request.return_value = resp + self.client.some_rpc(1, 2) + mock_log.debug.assert_called_with('Snippet received: %s', resp) + + @mock.patch.object(FakeClient, 'send_rpc_request') + def test_rpc_truncated_logging_fit_size_response(self, mock_send_request): + """Test RPC response isn't truncated with length equal to the threshold.""" + mock_log = mock.Mock() + self.client.log = mock_log + self.client.set_snippet_client_verbose_logging(False) + self.client.start_server() + + resp = _generate_fix_length_rpc_response( + client_base._MAX_RPC_RESP_LOGGING_LENGTH) + mock_send_request.return_value = resp + self.client.some_rpc(1, 2) + mock_log.debug.assert_called_with('Snippet received: %s', resp) + + @mock.patch.object(FakeClient, 'send_rpc_request') + def test_rpc_truncated_logging_long_response(self, mock_send_request): + """Test RPC response is truncated with length larger than the threshold.""" + mock_log = mock.Mock() + self.client.log = mock_log + self.client.set_snippet_client_verbose_logging(False) + self.client.start_server() + + max_len = client_base._MAX_RPC_RESP_LOGGING_LENGTH + resp = _generate_fix_length_rpc_response(max_len * 40) + mock_send_request.return_value = resp + self.client.some_rpc(1, 2) + mock_log.debug.assert_called_with( + 'Snippet received: %s... %d chars are truncated', + resp[:client_base._MAX_RPC_RESP_LOGGING_LENGTH], + len(resp) - max_len) + + @mock.patch.object(FakeClient, 'send_rpc_request') + def test_rpc_call_increment_counter(self, mock_send_request): + """Test that with each RPC call the counter is incremented by 1.""" + self.client.start_server() + resp = '{"id": %d, "result": 123, "error": null, "callback": null}' + mock_send_request.side_effect = (resp % (i,) for i in range(10)) + + for _ in range(0, 10): + self.client.some_rpc() + + self.assertEqual(next(self.client._counter), 10) + + @mock.patch.object(FakeClient, 'send_rpc_request') + def test_build_connection_reset_counter(self, mock_send_request): + """Test that _build_connection resets the counter to zero.""" + self.client.start_server() + resp = '{"id": %d, "result": 123, "error": null, "callback": null}' + mock_send_request.side_effect = (resp % (i,) for i in range(10)) + + for _ in range(0, 10): + self.client.some_rpc() + + self.assertEqual(next(self.client._counter), 10) + self.client._build_connection() + self.assertEqual(next(self.client._counter), 0) + + +if __name__ == '__main__': + unittest.main() |