aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/mobly/controllers/android_device_lib/snippet_client_v2_test.py394
-rwxr-xr-xtests/mobly/controllers/android_device_test.py44
-rw-r--r--tests/mobly/snippet/__init__.py13
-rwxr-xr-xtests/mobly/snippet/client_base_test.py424
4 files changed, 857 insertions, 18 deletions
diff --git a/tests/mobly/controllers/android_device_lib/snippet_client_v2_test.py b/tests/mobly/controllers/android_device_lib/snippet_client_v2_test.py
new file mode 100644
index 0000000..b97774b
--- /dev/null
+++ b/tests/mobly/controllers/android_device_lib/snippet_client_v2_test.py
@@ -0,0 +1,394 @@
+# Copyright 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Unit tests for mobly.controllers.android_device_lib.snippet_client_v2."""
+
+import unittest
+from unittest import mock
+
+from mobly.controllers.android_device_lib import adb
+from mobly.controllers.android_device_lib import errors as android_device_lib_errors
+from mobly.controllers.android_device_lib import snippet_client_v2
+from mobly.snippet import errors
+from tests.lib import mock_android_device
+
+MOCK_PACKAGE_NAME = 'some.package.name'
+MOCK_SERVER_PATH = f'{MOCK_PACKAGE_NAME}/{snippet_client_v2._INSTRUMENTATION_RUNNER_PACKAGE}'
+MOCK_USER_ID = 0
+
+
+class SnippetClientV2Test(unittest.TestCase):
+ """Unit tests for SnippetClientV2."""
+
+ def _make_client(self, adb_proxy=None, mock_properties=None):
+ adb_proxy = adb_proxy or mock_android_device.MockAdbProxy(
+ instrumented_packages=[
+ (MOCK_PACKAGE_NAME,
+ snippet_client_v2._INSTRUMENTATION_RUNNER_PACKAGE,
+ MOCK_PACKAGE_NAME)
+ ],
+ mock_properties=mock_properties)
+
+ device = mock.Mock()
+ device.adb = adb_proxy
+ device.adb.current_user_id = MOCK_USER_ID
+ device.build_info = {
+ 'build_version_codename':
+ adb_proxy.getprop('ro.build.version.codename'),
+ 'build_version_sdk':
+ adb_proxy.getprop('ro.build.version.sdk'),
+ }
+
+ self.client = snippet_client_v2.SnippetClientV2(MOCK_PACKAGE_NAME, device)
+
+ def _make_client_with_extra_adb_properties(self, extra_properties):
+ mock_properties = mock_android_device.DEFAULT_MOCK_PROPERTIES.copy()
+ mock_properties.update(extra_properties)
+ self._make_client(mock_properties=mock_properties)
+
+ def _mock_server_process_starting_response(self, mock_start_subprocess,
+ resp_lines):
+ mock_proc = mock_start_subprocess.return_value
+ mock_proc.stdout.readline.side_effect = resp_lines
+
+ def test_check_app_installed_normally(self):
+ """Tests that app checker runs normally when app installed correctly."""
+ self._make_client()
+ self.client._validate_snippet_app_on_device()
+
+ def test_check_app_installed_fail_app_not_installed(self):
+ """Tests that app checker fails without installing app."""
+ self._make_client(mock_android_device.MockAdbProxy())
+ expected_msg = f'.* {MOCK_PACKAGE_NAME} is not installed.'
+ with self.assertRaisesRegex(errors.ServerStartPreCheckError, expected_msg):
+ self.client._validate_snippet_app_on_device()
+
+ def test_check_app_installed_fail_not_instrumented(self):
+ """Tests that app checker fails without instrumenting app."""
+ self._make_client(
+ mock_android_device.MockAdbProxy(
+ installed_packages=[MOCK_PACKAGE_NAME]))
+ expected_msg = (
+ f'.* {MOCK_PACKAGE_NAME} is installed, but it is not instrumented.')
+ with self.assertRaisesRegex(errors.ServerStartPreCheckError, expected_msg):
+ self.client._validate_snippet_app_on_device()
+
+ def test_check_app_installed_fail_instrumentation_not_installed(self):
+ """Tests that app checker fails without installing instrumentation."""
+ self._make_client(
+ mock_android_device.MockAdbProxy(instrumented_packages=[(
+ MOCK_PACKAGE_NAME,
+ snippet_client_v2._INSTRUMENTATION_RUNNER_PACKAGE,
+ 'not.installed')]))
+ expected_msg = ('.* Instrumentation target not.installed is not installed.')
+ with self.assertRaisesRegex(errors.ServerStartPreCheckError, expected_msg):
+ self.client._validate_snippet_app_on_device()
+
+ @mock.patch.object(mock_android_device.MockAdbProxy, 'shell')
+ def test_disable_hidden_api_normally(self, mock_shell_func):
+ """Tests the disabling hidden api process works normally."""
+ self._make_client_with_extra_adb_properties({
+ 'ro.build.version.codename': 'S',
+ 'ro.build.version.sdk': '31',
+ })
+ self.client._device.is_rootable = True
+ self.client._disable_hidden_api_blocklist()
+ mock_shell_func.assert_called_with(
+ 'settings put global hidden_api_blacklist_exemptions "*"')
+
+ @mock.patch.object(mock_android_device.MockAdbProxy, 'shell')
+ def test_disable_hidden_api_low_sdk(self, mock_shell_func):
+ """Tests it doesn't disable hidden api with low SDK."""
+ self._make_client_with_extra_adb_properties({
+ 'ro.build.version.codename': 'O',
+ 'ro.build.version.sdk': '26',
+ })
+ self.client._device.is_rootable = True
+ self.client._disable_hidden_api_blocklist()
+ mock_shell_func.assert_not_called()
+
+ @mock.patch.object(mock_android_device.MockAdbProxy, 'shell')
+ def test_disable_hidden_api_non_rootable(self, mock_shell_func):
+ """Tests it doesn't disable hidden api with non-rootable device."""
+ self._make_client_with_extra_adb_properties({
+ 'ro.build.version.codename': 'S',
+ 'ro.build.version.sdk': '31',
+ })
+ self.client._device.is_rootable = False
+ self.client._disable_hidden_api_blocklist()
+ mock_shell_func.assert_not_called()
+
+ @mock.patch('mobly.controllers.android_device_lib.snippet_client_v2.'
+ 'utils.start_standing_subprocess')
+ @mock.patch.object(mock_android_device.MockAdbProxy,
+ 'shell',
+ return_value=b'setsid')
+ def test_start_server_with_user_id(self, mock_adb, mock_start_subprocess):
+ """Tests that `--user` is added to starting command with SDK >= 24."""
+ self._make_client_with_extra_adb_properties({'ro.build.version.sdk': '30'})
+ self._mock_server_process_starting_response(
+ mock_start_subprocess,
+ resp_lines=[
+ b'SNIPPET START, PROTOCOL 1 234', b'SNIPPET SERVING, PORT 1234'
+ ])
+
+ self.client.start_server()
+ start_cmd_list = [
+ 'adb', 'shell',
+ (f'setsid am instrument --user {MOCK_USER_ID} -w -e action start '
+ f'{MOCK_SERVER_PATH}')
+ ]
+ self.assertListEqual(mock_start_subprocess.call_args_list,
+ [mock.call(start_cmd_list, shell=False)])
+ self.assertEqual(self.client.device_port, 1234)
+ mock_adb.assert_called_with(['which', 'setsid'])
+
+ @mock.patch('mobly.controllers.android_device_lib.snippet_client_v2.'
+ 'utils.start_standing_subprocess')
+ @mock.patch.object(mock_android_device.MockAdbProxy,
+ 'shell',
+ return_value=b'setsid')
+ def test_start_server_without_user_id(self, mock_adb, mock_start_subprocess):
+ """Tests that `--user` is not added to starting command on SDK < 24."""
+ self._make_client_with_extra_adb_properties({'ro.build.version.sdk': '21'})
+ self._mock_server_process_starting_response(
+ mock_start_subprocess,
+ resp_lines=[
+ b'SNIPPET START, PROTOCOL 1 234', b'SNIPPET SERVING, PORT 1234'
+ ])
+
+ self.client.start_server()
+ start_cmd_list = [
+ 'adb', 'shell',
+ f'setsid am instrument -w -e action start {MOCK_SERVER_PATH}'
+ ]
+ self.assertListEqual(mock_start_subprocess.call_args_list,
+ [mock.call(start_cmd_list, shell=False)])
+ mock_adb.assert_called_with(['which', 'setsid'])
+ self.assertEqual(self.client.device_port, 1234)
+
+ @mock.patch('mobly.controllers.android_device_lib.snippet_client_v2.'
+ 'utils.start_standing_subprocess')
+ @mock.patch.object(mock_android_device.MockAdbProxy,
+ 'shell',
+ side_effect=adb.AdbError('cmd', 'stdout', 'stderr',
+ 'ret_code'))
+ def test_start_server_without_persisting_commands(self, mock_adb,
+ mock_start_subprocess):
+ """Checks the starting server command without persisting commands."""
+ self._make_client()
+ self._mock_server_process_starting_response(
+ mock_start_subprocess,
+ resp_lines=[
+ b'SNIPPET START, PROTOCOL 1 234', b'SNIPPET SERVING, PORT 1234'
+ ])
+
+ self.client.start_server()
+ start_cmd_list = [
+ 'adb', 'shell',
+ (f' am instrument --user {MOCK_USER_ID} -w -e action start '
+ f'{MOCK_SERVER_PATH}')
+ ]
+ self.assertListEqual(mock_start_subprocess.call_args_list,
+ [mock.call(start_cmd_list, shell=False)])
+ mock_adb.assert_has_calls(
+ [mock.call(['which', 'setsid']),
+ mock.call(['which', 'nohup'])])
+ self.assertEqual(self.client.device_port, 1234)
+
+ @mock.patch('mobly.controllers.android_device_lib.snippet_client_v2.'
+ 'utils.start_standing_subprocess')
+ def test_start_server_with_nohup(self, mock_start_subprocess):
+ """Checks the starting server command with nohup."""
+ self._make_client()
+ self._mock_server_process_starting_response(
+ mock_start_subprocess,
+ resp_lines=[
+ b'SNIPPET START, PROTOCOL 1 234', b'SNIPPET SERVING, PORT 1234'
+ ])
+
+ def _mocked_shell(arg):
+ if 'nohup' in arg:
+ return b'nohup'
+ raise adb.AdbError('cmd', 'stdout', 'stderr', 'ret_code')
+
+ self.client._adb.shell = _mocked_shell
+
+ self.client.start_server()
+ start_cmd_list = [
+ 'adb', 'shell',
+ (f'nohup am instrument --user {MOCK_USER_ID} -w -e action start '
+ f'{MOCK_SERVER_PATH}')
+ ]
+ self.assertListEqual(mock_start_subprocess.call_args_list,
+ [mock.call(start_cmd_list, shell=False)])
+ self.assertEqual(self.client.device_port, 1234)
+
+ @mock.patch('mobly.controllers.android_device_lib.snippet_client_v2.'
+ 'utils.start_standing_subprocess')
+ def test_start_server_with_setsid(self, mock_start_subprocess):
+ """Checks the starting server command with setsid."""
+ self._make_client()
+ self._mock_server_process_starting_response(
+ mock_start_subprocess,
+ resp_lines=[
+ b'SNIPPET START, PROTOCOL 1 234', b'SNIPPET SERVING, PORT 1234'
+ ])
+
+ def _mocked_shell(arg):
+ if 'setsid' in arg:
+ return b'setsid'
+ raise adb.AdbError('cmd', 'stdout', 'stderr', 'ret_code')
+
+ self.client._adb.shell = _mocked_shell
+ self.client.start_server()
+ start_cmd_list = [
+ 'adb', 'shell',
+ (f'setsid am instrument --user {MOCK_USER_ID} -w -e action start '
+ f'{MOCK_SERVER_PATH}')
+ ]
+ self.assertListEqual(mock_start_subprocess.call_args_list,
+ [mock.call(start_cmd_list, shell=False)])
+ self.assertEqual(self.client.device_port, 1234)
+
+ @mock.patch('mobly.controllers.android_device_lib.snippet_client_v2.'
+ 'utils.start_standing_subprocess')
+ def test_start_server_server_crash(self, mock_start_standing_subprocess):
+ """Tests that starting server process crashes."""
+ self._make_client()
+ self._mock_server_process_starting_response(
+ mock_start_standing_subprocess,
+ resp_lines=[b'INSTRUMENTATION_RESULT: shortMsg=Process crashed.\n'])
+ with self.assertRaisesRegex(
+ errors.ServerStartProtocolError,
+ 'INSTRUMENTATION_RESULT: shortMsg=Process crashed.'):
+ self.client.start_server()
+
+ @mock.patch('mobly.controllers.android_device_lib.snippet_client_v2.'
+ 'utils.start_standing_subprocess')
+ def test_start_server_unknown_protocol_version(
+ self, mock_start_standing_subprocess):
+ """Tests that starting server process reports unknown protocol version."""
+ self._make_client()
+ self._mock_server_process_starting_response(
+ mock_start_standing_subprocess,
+ resp_lines=[b'SNIPPET START, PROTOCOL 99 0\n'])
+ with self.assertRaisesRegex(errors.ServerStartProtocolError,
+ 'SNIPPET START, PROTOCOL 99 0'):
+ self.client.start_server()
+
+ @mock.patch('mobly.controllers.android_device_lib.snippet_client_v2.'
+ 'utils.start_standing_subprocess')
+ def test_start_server_invalid_device_port(self,
+ mock_start_standing_subprocess):
+ """Tests that starting server process reports invalid device port."""
+ self._make_client()
+ self._mock_server_process_starting_response(
+ mock_start_standing_subprocess,
+ resp_lines=[
+ b'SNIPPET START, PROTOCOL 1 0\n', b'SNIPPET SERVING, PORT ABC\n'
+ ])
+ with self.assertRaisesRegex(errors.ServerStartProtocolError,
+ 'SNIPPET SERVING, PORT ABC'):
+ self.client.start_server()
+
+ @mock.patch('mobly.controllers.android_device_lib.snippet_client_v2.'
+ 'utils.start_standing_subprocess')
+ def test_start_server_with_junk(self, mock_start_standing_subprocess):
+ """Tests that starting server process reports known protocol with junk."""
+ self._make_client()
+ self._mock_server_process_starting_response(
+ mock_start_standing_subprocess,
+ resp_lines=[
+ b'This is some header junk\n',
+ b'Some phones print arbitrary output\n',
+ b'SNIPPET START, PROTOCOL 1 0\n',
+ b'Maybe in the middle too\n',
+ b'SNIPPET SERVING, PORT 123\n',
+ ])
+ self.client.start_server()
+ self.assertEqual(123, self.client.device_port)
+
+ @mock.patch('mobly.controllers.android_device_lib.snippet_client_v2.'
+ 'utils.start_standing_subprocess')
+ def test_start_server_no_valid_line(self, mock_start_standing_subprocess):
+ """Tests that starting server process reports unknown protocol message."""
+ self._make_client()
+ self._mock_server_process_starting_response(
+ mock_start_standing_subprocess,
+ resp_lines=[
+ b'This is some header junk\n',
+ b'Some phones print arbitrary output\n',
+ b'', # readline uses '' to mark EOF
+ ])
+ with self.assertRaisesRegex(
+ errors.ServerStartError,
+ 'Unexpected EOF when waiting for server to start.'):
+ self.client.start_server()
+
+ @mock.patch('mobly.utils.stop_standing_subprocess')
+ @mock.patch.object(mock_android_device.MockAdbProxy,
+ 'shell',
+ return_value=b'OK (0 tests)')
+ def test_stop_server_normally(self, mock_android_device_shell,
+ mock_stop_standing_subprocess):
+ """Tests that stopping server process works normally."""
+ self._make_client()
+ mock_proc = mock.Mock()
+ self.client._proc = mock_proc
+ self.client.stop()
+ self.assertIs(self.client._proc, None)
+ mock_android_device_shell.assert_called_once_with(
+ f'am instrument --user {MOCK_USER_ID} -w -e action stop '
+ f'{MOCK_SERVER_PATH}')
+ mock_stop_standing_subprocess.assert_called_once_with(mock_proc)
+
+ @mock.patch('mobly.utils.stop_standing_subprocess')
+ @mock.patch.object(mock_android_device.MockAdbProxy,
+ 'shell',
+ return_value=b'OK (0 tests)')
+ def test_stop_server_server_already_cleaned(self, mock_android_device_shell,
+ mock_stop_standing_subprocess):
+ """Tests stopping server process when subprocess is already cleaned."""
+ self._make_client()
+ self.client._proc = None
+ self.client.stop()
+ self.assertIs(self.client._proc, None)
+ mock_stop_standing_subprocess.assert_not_called()
+ mock_android_device_shell.assert_called_once_with(
+ f'am instrument --user {MOCK_USER_ID} -w -e action stop '
+ f'{MOCK_SERVER_PATH}')
+
+ @mock.patch('mobly.utils.stop_standing_subprocess')
+ @mock.patch.object(mock_android_device.MockAdbProxy,
+ 'shell',
+ return_value=b'Closed with error.')
+ def test_stop_server_stop_with_error(self, mock_android_device_shell,
+ mock_stop_standing_subprocess):
+ """Tests all resources are cleaned even if stopping server has error."""
+ self._make_client()
+ mock_proc = mock.Mock()
+ self.client._proc = mock_proc
+ with self.assertRaisesRegex(android_device_lib_errors.DeviceError,
+ 'Closed with error'):
+ self.client.stop()
+
+ self.assertIs(self.client._proc, None)
+ mock_stop_standing_subprocess.assert_called_once_with(mock_proc)
+ mock_android_device_shell.assert_called_once_with(
+ f'am instrument --user {MOCK_USER_ID} -w -e action stop '
+ f'{MOCK_SERVER_PATH}')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/mobly/controllers/android_device_test.py b/tests/mobly/controllers/android_device_test.py
index a98ecf3..6adb8f5 100755
--- a/tests/mobly/controllers/android_device_test.py
+++ b/tests/mobly/controllers/android_device_test.py
@@ -132,20 +132,26 @@ class AndroidDeviceTest(unittest.TestCase):
with self.assertRaisesRegex(android_device.Error, expected_msg):
android_device.create([1])
+ @mock.patch('mobly.controllers.android_device.list_fastboot_devices')
@mock.patch('mobly.controllers.android_device.list_adb_devices')
@mock.patch('mobly.controllers.android_device.list_adb_devices_by_usb_id')
@mock.patch('mobly.controllers.android_device.AndroidDevice')
- def test_get_instances(self, mock_ad_class, mock_list_adb_usb, mock_list_adb):
+ def test_get_instances(self, mock_ad_class, mock_list_adb_usb, mock_list_adb,
+ mock_list_fastboot):
+ mock_list_fastboot.return_value = ['0']
mock_list_adb.return_value = ['1']
mock_list_adb_usb.return_value = []
- android_device.get_instances(['1'])
- mock_ad_class.assert_called_with('1')
+ android_device.get_instances(['0', '1'])
+ mock_ad_class.assert_any_call('0')
+ mock_ad_class.assert_any_call('1')
+ @mock.patch('mobly.controllers.android_device.list_fastboot_devices')
@mock.patch('mobly.controllers.android_device.list_adb_devices')
@mock.patch('mobly.controllers.android_device.list_adb_devices_by_usb_id')
@mock.patch('mobly.controllers.android_device.AndroidDevice')
def test_get_instances_do_not_exist(self, mock_ad_class, mock_list_adb_usb,
- mock_list_adb):
+ mock_list_adb, mock_list_fastboot):
+ mock_list_fastboot.return_value = []
mock_list_adb.return_value = []
mock_list_adb_usb.return_value = []
with self.assertRaisesRegex(
@@ -154,12 +160,14 @@ class AndroidDeviceTest(unittest.TestCase):
):
android_device.get_instances(['1'])
+ @mock.patch('mobly.controllers.android_device.list_fastboot_devices')
@mock.patch('mobly.controllers.android_device.list_adb_devices')
@mock.patch('mobly.controllers.android_device.list_adb_devices_by_usb_id')
@mock.patch('mobly.controllers.android_device.AndroidDevice')
def test_get_instances_with_configs(self, mock_ad_class, mock_list_adb_usb,
- mock_list_adb):
- mock_list_adb.return_value = ['1', '2']
+ mock_list_adb, mock_list_fastboot):
+ mock_list_fastboot.return_value = ['1']
+ mock_list_adb.return_value = ['2']
mock_list_adb_usb.return_value = []
configs = [{'serial': '1'}, {'serial': '2'}]
android_device.get_instances_with_configs(configs)
@@ -173,12 +181,15 @@ class AndroidDeviceTest(unittest.TestCase):
f'Required value "serial" is missing in AndroidDevice config {config}'):
android_device.get_instances_with_configs([config])
+ @mock.patch('mobly.controllers.android_device.list_fastboot_devices')
@mock.patch('mobly.controllers.android_device.list_adb_devices')
@mock.patch('mobly.controllers.android_device.list_adb_devices_by_usb_id')
@mock.patch('mobly.controllers.android_device.AndroidDevice')
def test_get_instances_with_configsdo_not_exist(self, mock_ad_class,
mock_list_adb_usb,
- mock_list_adb):
+ mock_list_adb,
+ mock_list_fastboot):
+ mock_list_fastboot.return_value = []
mock_list_adb.return_value = []
mock_list_adb_usb.return_value = []
config = {'serial': '1'}
@@ -859,8 +870,8 @@ class AndroidDeviceTest(unittest.TestCase):
@mock.patch('mobly.utils.create_dir')
@mock.patch('mobly.logger.get_log_file_timestamp')
def test_AndroidDevice_take_screenshot_with_prefix(
- self, get_log_file_timestamp_mock, create_dir_mock,
- FastbootProxy, MockAdbProxy):
+ self, get_log_file_timestamp_mock, create_dir_mock, FastbootProxy,
+ MockAdbProxy):
get_log_file_timestamp_mock.return_value = '07-22-2019_17-53-34-450'
mock_serial = '1'
ad = android_device.AndroidDevice(serial=mock_serial)
@@ -1141,22 +1152,19 @@ class AndroidDeviceTest(unittest.TestCase):
mock_serial = '1'
ad = android_device.AndroidDevice(serial=mock_serial)
self.assertEqual(ad.debug_tag, '1')
- with self.assertRaisesRegex(
- android_device.DeviceError,
- r'<AndroidDevice\|1> Something'):
+ with self.assertRaisesRegex(android_device.DeviceError,
+ r'<AndroidDevice\|1> Something'):
raise android_device.DeviceError(ad, 'Something')
# Verify that debug tag's setter updates the debug prefix correctly.
ad.debug_tag = 'Mememe'
- with self.assertRaisesRegex(
- android_device.DeviceError,
- r'<AndroidDevice\|Mememe> Something'):
+ with self.assertRaisesRegex(android_device.DeviceError,
+ r'<AndroidDevice\|Mememe> Something'):
raise android_device.DeviceError(ad, 'Something')
# Verify that repr is changed correctly.
- with self.assertRaisesRegex(
- Exception,
- r'(<AndroidDevice\|Mememe>, \'Something\')'):
+ with self.assertRaisesRegex(Exception,
+ r'(<AndroidDevice\|Mememe>, \'Something\')'):
raise Exception(ad, 'Something')
@mock.patch('mobly.controllers.android_device_lib.adb.AdbProxy',
diff --git a/tests/mobly/snippet/__init__.py b/tests/mobly/snippet/__init__.py
new file mode 100644
index 0000000..ac3f9e6
--- /dev/null
+++ b/tests/mobly/snippet/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/mobly/snippet/client_base_test.py b/tests/mobly/snippet/client_base_test.py
new file mode 100755
index 0000000..d9d99bd
--- /dev/null
+++ b/tests/mobly/snippet/client_base_test.py
@@ -0,0 +1,424 @@
+# Copyright 2022 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Unit tests for mobly.snippet.client_base."""
+
+import logging
+import random
+import string
+import unittest
+from unittest import mock
+
+from mobly.snippet import client_base
+from mobly.snippet import errors
+
+
+def _generate_fix_length_rpc_response(
+ response_length,
+ template='{"id": 0, "result": "%s", "error": null, "callback": null}'):
+ """Generates an RPC response string with specified length.
+
+ This function generates a random string and formats the template with the
+ generated random string to get the response string. This function formats
+ the template with printf style string formatting.
+
+ Args:
+ response_length: int, the length of the response string to generate.
+ template: str, the template used for generating the response string.
+
+ Returns:
+ The generated response string.
+
+ Raises:
+ ValueError: if the specified length is too small to generate a response.
+ """
+ # We need to -2 here because the string formatting will delete the substring
+ # '%s' in the template, of which the length is 2.
+ result_length = response_length - (len(template) - 2)
+ if result_length < 0:
+ raise ValueError(f'The response_length should be no smaller than '
+ f'template_length + 2. Got response_length '
+ f'{response_length}, template_length {len(template)}.')
+ chars = string.ascii_letters + string.digits
+ return template % ''.join(random.choice(chars) for _ in range(result_length))
+
+
+class FakeClient(client_base.ClientBase):
+ """Fake client class for unit tests."""
+
+ def __init__(self):
+ """Initializes the instance by mocking a device controller."""
+ mock_device = mock.Mock()
+ mock_device.log = logging
+ super().__init__(package='FakeClient', device=mock_device)
+
+ # Override abstract methods to enable initialization
+ def before_starting_server(self):
+ pass
+
+ def start_server(self):
+ pass
+
+ def make_connection(self):
+ pass
+
+ def restore_server_connection(self, port=None):
+ pass
+
+ def check_server_proc_running(self):
+ pass
+
+ def send_rpc_request(self, request):
+ pass
+
+ def handle_callback(self, callback_id, ret_value, rpc_func_name):
+ pass
+
+ def stop(self):
+ pass
+
+ def close_connection(self):
+ pass
+
+
+class ClientBaseTest(unittest.TestCase):
+ """Unit tests for mobly.snippet.client_base.ClientBase."""
+
+ def setUp(self):
+ super().setUp()
+ self.client = FakeClient()
+ self.client.host_port = 12345
+
+ @mock.patch.object(FakeClient, 'before_starting_server')
+ @mock.patch.object(FakeClient, 'start_server')
+ @mock.patch.object(FakeClient, '_make_connection')
+ def test_init_server_stage_order(self, mock_make_conn_func, mock_start_func,
+ mock_before_func):
+ """Test that initialization runs its stages in expected order."""
+ order_manager = mock.Mock()
+ order_manager.attach_mock(mock_before_func, 'mock_before_func')
+ order_manager.attach_mock(mock_start_func, 'mock_start_func')
+ order_manager.attach_mock(mock_make_conn_func, 'mock_make_conn_func')
+
+ self.client.initialize()
+
+ expected_call_order = [
+ mock.call.mock_before_func(),
+ mock.call.mock_start_func(),
+ mock.call.mock_make_conn_func(),
+ ]
+ self.assertListEqual(order_manager.mock_calls, expected_call_order)
+
+ @mock.patch.object(FakeClient, 'stop')
+ @mock.patch.object(FakeClient, 'before_starting_server')
+ def test_init_server_before_starting_server_fail(self, mock_before_func,
+ mock_stop_func):
+ """Test before_starting_server stage of initialization fails."""
+ mock_before_func.side_effect = Exception('ha')
+
+ with self.assertRaisesRegex(Exception, 'ha'):
+ self.client.initialize()
+ mock_stop_func.assert_not_called()
+
+ @mock.patch.object(FakeClient, 'stop')
+ @mock.patch.object(FakeClient, 'start_server')
+ def test_init_server_start_server_fail(self, mock_start_func, mock_stop_func):
+ """Test start_server stage of initialization fails."""
+ mock_start_func.side_effect = Exception('ha')
+
+ with self.assertRaisesRegex(Exception, 'ha'):
+ self.client.initialize()
+ mock_stop_func.assert_called()
+
+ @mock.patch.object(FakeClient, 'stop')
+ @mock.patch.object(FakeClient, '_make_connection')
+ def test_init_server_make_connection_fail(self, mock_make_conn_func,
+ mock_stop_func):
+ """Test _make_connection stage of initialization fails."""
+ mock_make_conn_func.side_effect = Exception('ha')
+
+ with self.assertRaisesRegex(Exception, 'ha'):
+ self.client.initialize()
+ mock_stop_func.assert_called()
+
+ @mock.patch.object(FakeClient, 'check_server_proc_running')
+ @mock.patch.object(FakeClient, '_gen_rpc_request')
+ @mock.patch.object(FakeClient, 'send_rpc_request')
+ @mock.patch.object(FakeClient, '_decode_response_string_and_validate_format')
+ @mock.patch.object(FakeClient, '_handle_rpc_response')
+ def test_rpc_stage_dependencies(self, mock_handle_resp, mock_decode_resp_str,
+ mock_send_request, mock_gen_request,
+ mock_precheck):
+ """Test the internal dependencies when sending an RPC.
+
+ When sending an RPC, it calls multiple functions in specific order, and
+ each function uses the output of the previously called function. This test
+ case checks above dependencies.
+
+ Args:
+ mock_handle_resp: the mock function of FakeClient._handle_rpc_response.
+ mock_decode_resp_str: the mock function of
+ FakeClient._decode_response_string_and_validate_format.
+ mock_send_request: the mock function of FakeClient.send_rpc_request.
+ mock_gen_request: the mock function of FakeClient._gen_rpc_request.
+ mock_precheck: the mock function of FakeClient.check_server_proc_running.
+ """
+ self.client.initialize()
+
+ expected_response_str = ('{"id": 0, "result": 123, "error": null, '
+ '"callback": null}')
+ expected_response_dict = {
+ 'id': 0,
+ 'result': 123,
+ 'error': None,
+ 'callback': None,
+ }
+ expected_request = ('{"id": 10, "method": "some_rpc", "params": [1, 2],'
+ '"kwargs": {"test_key": 3}')
+ expected_result = 123
+
+ mock_gen_request.return_value = expected_request
+ mock_send_request.return_value = expected_response_str
+ mock_decode_resp_str.return_value = expected_response_dict
+ mock_handle_resp.return_value = expected_result
+ rpc_result = self.client.some_rpc(1, 2, test_key=3)
+
+ mock_precheck.assert_called()
+ mock_gen_request.assert_called_with(0, 'some_rpc', 1, 2, test_key=3)
+ mock_send_request.assert_called_with(expected_request)
+ mock_decode_resp_str.assert_called_with(0, expected_response_str)
+ mock_handle_resp.assert_called_with('some_rpc', expected_response_dict)
+ self.assertEqual(rpc_result, expected_result)
+
+ @mock.patch.object(FakeClient, 'check_server_proc_running')
+ @mock.patch.object(FakeClient, '_gen_rpc_request')
+ @mock.patch.object(FakeClient, 'send_rpc_request')
+ @mock.patch.object(FakeClient, '_decode_response_string_and_validate_format')
+ @mock.patch.object(FakeClient, '_handle_rpc_response')
+ def test_rpc_precheck_fail(self, mock_handle_resp, mock_decode_resp_str,
+ mock_send_request, mock_gen_request,
+ mock_precheck):
+ """Test when RPC precheck fails it will skip sending the RPC."""
+ self.client.initialize()
+ mock_precheck.side_effect = Exception('server_died')
+
+ with self.assertRaisesRegex(Exception, 'server_died'):
+ self.client.some_rpc(1, 2)
+
+ mock_gen_request.assert_not_called()
+ mock_send_request.assert_not_called()
+ mock_handle_resp.assert_not_called()
+ mock_decode_resp_str.assert_not_called()
+
+ def test_gen_request(self):
+ """Test generating an RPC request.
+
+ Test that _gen_rpc_request returns a string represents a JSON dict
+ with all required fields.
+ """
+ request = self.client._gen_rpc_request(0, 'test_rpc', 1, 2, test_key=3)
+ expected_result = ('{"id": 0, "kwargs": {"test_key": 3}, '
+ '"method": "test_rpc", "params": [1, 2]}')
+ self.assertEqual(request, expected_result)
+
+ def test_gen_request_without_kwargs(self):
+ """Test no keyword arguments.
+
+ Test that _gen_rpc_request ignores the kwargs field when no
+ keyword arguments.
+ """
+ request = self.client._gen_rpc_request(0, 'test_rpc', 1, 2)
+ expected_result = '{"id": 0, "method": "test_rpc", "params": [1, 2]}'
+ self.assertEqual(request, expected_result)
+
+ def test_rpc_no_response(self):
+ """Test parsing an empty RPC response."""
+ with self.assertRaisesRegex(errors.ProtocolError,
+ errors.ProtocolError.NO_RESPONSE_FROM_SERVER):
+ self.client._decode_response_string_and_validate_format(0, '')
+
+ with self.assertRaisesRegex(errors.ProtocolError,
+ errors.ProtocolError.NO_RESPONSE_FROM_SERVER):
+ self.client._decode_response_string_and_validate_format(0, None)
+
+ def test_rpc_response_missing_fields(self):
+ """Test parsing an RPC response that misses some required fields."""
+ mock_resp_without_id = '{"result": 123, "error": null, "callback": null}'
+ with self.assertRaisesRegex(
+ errors.ProtocolError,
+ errors.ProtocolError.RESPONSE_MISSING_FIELD % 'id'):
+ self.client._decode_response_string_and_validate_format(
+ 10, mock_resp_without_id)
+
+ mock_resp_without_result = '{"id": 10, "error": null, "callback": null}'
+ with self.assertRaisesRegex(
+ errors.ProtocolError,
+ errors.ProtocolError.RESPONSE_MISSING_FIELD % 'result'):
+ self.client._decode_response_string_and_validate_format(
+ 10, mock_resp_without_result)
+
+ mock_resp_without_error = '{"id": 10, "result": 123, "callback": null}'
+ with self.assertRaisesRegex(
+ errors.ProtocolError,
+ errors.ProtocolError.RESPONSE_MISSING_FIELD % 'error'):
+ self.client._decode_response_string_and_validate_format(
+ 10, mock_resp_without_error)
+
+ mock_resp_without_callback = '{"id": 10, "result": 123, "error": null}'
+ with self.assertRaisesRegex(
+ errors.ProtocolError,
+ errors.ProtocolError.RESPONSE_MISSING_FIELD % 'callback'):
+ self.client._decode_response_string_and_validate_format(
+ 10, mock_resp_without_callback)
+
+ def test_rpc_response_error(self):
+ """Test parsing an RPC response with a non-empty error field."""
+ mock_resp_with_error = {
+ 'id': 10,
+ 'result': 123,
+ 'error': 'some_error',
+ 'callback': None,
+ }
+ with self.assertRaisesRegex(errors.ApiError, 'some_error'):
+ self.client._handle_rpc_response('some_rpc', mock_resp_with_error)
+
+ def test_rpc_response_callback(self):
+ """Test parsing response function handles the callback field well."""
+ # Call handle_callback function if the "callback" field is not None
+ mock_resp_with_callback = {
+ 'id': 10,
+ 'result': 123,
+ 'error': None,
+ 'callback': '1-0'
+ }
+ with mock.patch.object(self.client,
+ 'handle_callback') as mock_handle_callback:
+ expected_callback = mock.Mock()
+ mock_handle_callback.return_value = expected_callback
+
+ rpc_result = self.client._handle_rpc_response('some_rpc',
+ mock_resp_with_callback)
+ mock_handle_callback.assert_called_with('1-0', 123, 'some_rpc')
+ # Ensure the RPC function returns what handle_callback returned
+ self.assertIs(expected_callback, rpc_result)
+
+ # Do not call handle_callback function if the "callback" field is None
+ mock_resp_without_callback = {
+ 'id': 10,
+ 'result': 123,
+ 'error': None,
+ 'callback': None
+ }
+ with mock.patch.object(self.client,
+ 'handle_callback') as mock_handle_callback:
+ self.client._handle_rpc_response('some_rpc', mock_resp_without_callback)
+ mock_handle_callback.assert_not_called()
+
+ def test_rpc_response_id_mismatch(self):
+ """Test parsing an RPC response with a wrong id."""
+ right_id = 5
+ wrong_id = 20
+ resp = f'{{"id": {right_id}, "result": 1, "error": null, "callback": null}}'
+
+ with self.assertRaisesRegex(errors.ProtocolError,
+ errors.ProtocolError.MISMATCHED_API_ID):
+ self.client._decode_response_string_and_validate_format(wrong_id, resp)
+
+ @mock.patch.object(FakeClient, 'send_rpc_request')
+ def test_rpc_verbose_logging_with_long_string(self, mock_send_request):
+ """Test RPC response isn't truncated when verbose logging is on."""
+ mock_log = mock.Mock()
+ self.client.log = mock_log
+ self.client.set_snippet_client_verbose_logging(True)
+ self.client.initialize()
+
+ resp = _generate_fix_length_rpc_response(
+ client_base._MAX_RPC_RESP_LOGGING_LENGTH * 2)
+ mock_send_request.return_value = resp
+ self.client.some_rpc(1, 2)
+ mock_log.debug.assert_called_with('Snippet received: %s', resp)
+
+ @mock.patch.object(FakeClient, 'send_rpc_request')
+ def test_rpc_truncated_logging_short_response(self, mock_send_request):
+ """Test RPC response isn't truncated with small length."""
+ mock_log = mock.Mock()
+ self.client.log = mock_log
+ self.client.set_snippet_client_verbose_logging(False)
+ self.client.initialize()
+
+ resp = _generate_fix_length_rpc_response(
+ int(client_base._MAX_RPC_RESP_LOGGING_LENGTH // 2))
+ mock_send_request.return_value = resp
+ self.client.some_rpc(1, 2)
+ mock_log.debug.assert_called_with('Snippet received: %s', resp)
+
+ @mock.patch.object(FakeClient, 'send_rpc_request')
+ def test_rpc_truncated_logging_fit_size_response(self, mock_send_request):
+ """Test RPC response isn't truncated with length equal to the threshold."""
+ mock_log = mock.Mock()
+ self.client.log = mock_log
+ self.client.set_snippet_client_verbose_logging(False)
+ self.client.initialize()
+
+ resp = _generate_fix_length_rpc_response(
+ client_base._MAX_RPC_RESP_LOGGING_LENGTH)
+ mock_send_request.return_value = resp
+ self.client.some_rpc(1, 2)
+ mock_log.debug.assert_called_with('Snippet received: %s', resp)
+
+ @mock.patch.object(FakeClient, 'send_rpc_request')
+ def test_rpc_truncated_logging_long_response(self, mock_send_request):
+ """Test RPC response is truncated with length larger than the threshold."""
+ mock_log = mock.Mock()
+ self.client.log = mock_log
+ self.client.set_snippet_client_verbose_logging(False)
+ self.client.initialize()
+
+ max_len = client_base._MAX_RPC_RESP_LOGGING_LENGTH
+ resp = _generate_fix_length_rpc_response(max_len * 40)
+ mock_send_request.return_value = resp
+ self.client.some_rpc(1, 2)
+ mock_log.debug.assert_called_with(
+ 'Snippet received: %s... %d chars are truncated',
+ resp[:client_base._MAX_RPC_RESP_LOGGING_LENGTH],
+ len(resp) - max_len)
+
+ @mock.patch.object(FakeClient, 'send_rpc_request')
+ def test_rpc_call_increment_counter(self, mock_send_request):
+ """Test that with each RPC call the counter is incremented by 1."""
+ self.client.initialize()
+ resp = '{"id": %d, "result": 123, "error": null, "callback": null}'
+ mock_send_request.side_effect = (resp % (i,) for i in range(10))
+
+ for _ in range(0, 10):
+ self.client.some_rpc()
+
+ self.assertEqual(next(self.client._counter), 10)
+
+ @mock.patch.object(FakeClient, 'send_rpc_request')
+ def test_init_connection_reset_counter(self, mock_send_request):
+ """Test that _make_connection resets the counter to zero."""
+ self.client.initialize()
+ resp = '{"id": %d, "result": 123, "error": null, "callback": null}'
+ mock_send_request.side_effect = (resp % (i,) for i in range(10))
+
+ for _ in range(0, 10):
+ self.client.some_rpc()
+
+ self.assertEqual(next(self.client._counter), 10)
+ self.client._make_connection()
+ self.assertEqual(next(self.client._counter), 0)
+
+
+if __name__ == '__main__':
+ unittest.main()