diff options
author | Ang Li <angli@google.com> | 2021-04-23 13:30:41 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-04-23 13:30:41 -0700 |
commit | 532b04a4ae1969b65cbec933c687423b43602c2d (patch) | |
tree | 8c10647a41813a1d3e5c62436453736fc3602067 | |
parent | f3335acaa0758050a8bbef17ff17eaa0bfde97e9 (diff) | |
download | mobly-532b04a4ae1969b65cbec933c687423b43602c2d.tar.gz |
Support native repeat and retry in Mobly. (#734)
* Provides decorators for users to mark test cases for repeat and retry.
* Adds new attributes to the TestResultRecord for associating retry test records.
* Refactors existing code to support the repeat/retry behavior.
-rw-r--r-- | mobly/base_test.py | 134 | ||||
-rw-r--r-- | mobly/records.py | 14 | ||||
-rw-r--r-- | mobly/runtime_test_info.py | 5 | ||||
-rwxr-xr-x | tests/mobly/base_test_test.py | 162 | ||||
-rwxr-xr-x | tests/mobly/controllers/android_device_lib/services/logcat_test.py | 4 | ||||
-rwxr-xr-x | tests/mobly/controllers/android_device_test.py | 7 | ||||
-rwxr-xr-x | tests/mobly/records_test.py | 12 |
7 files changed, 325 insertions, 13 deletions
diff --git a/mobly/base_test.py b/mobly/base_test.py index f1a28bf..ebcc2b8 100644 --- a/mobly/base_test.py +++ b/mobly/base_test.py @@ -43,11 +43,86 @@ STAGE_NAME_TEARDOWN_TEST = 'teardown_test' STAGE_NAME_TEARDOWN_CLASS = 'teardown_class' STAGE_NAME_CLEAN_UP = 'clean_up' +# Attribute names +ATTR_REPEAT_CNT = '_repeat_count' +ATTR_MAX_RETRY_CNT = '_max_count' + class Error(Exception): """Raised for exceptions that occurred in BaseTestClass.""" +def repeat(count): + """Decorator for repeating a test case multiple times. + + The BaseTestClass will execute the test cases annotated with this decorator + the specified number of time. + + This decorator only stores the information needed for the repeat. It does not + execute the repeat. + + Args: + count: int, the total number of times to execute the decorated test case. + + Returns: + The wrapped test function. + + Raises: + ValueError, if the user input is invalid. + """ + if count <= 1: + raise ValueError( + f'The `count` for `repeat` must be larger than 1, got "{count}".') + + def _outer_decorator(func): + setattr(func, ATTR_REPEAT_CNT, count) + + @functools.wraps(func) + def _wrapper(*args): + func(*args) + + return _wrapper + + return _outer_decorator + + +def retry(max_count): + """Decorator for retrying a test case until it passes. + + The BaseTestClass will keep executing the test cases annotated with this + decorator until the test passes, or the maxinum number of iterations have + been met. + + This decorator only stores the information needed for the retry. It does not + execute the retry. + + Args: + max_count: int, the maximum number of times to execute the decorated test + case. + + Returns: + The wrapped test function. + + Raises: + ValueError, if the user input is invalid. + """ + if max_count <= 1: + raise ValueError( + f'The `max_count` for `retry` must be larger than 1, got "{max_count}".' + ) + + def _outer_decorator(func): + setattr(func, ATTR_MAX_RETRY_CNT, max_count) + + @functools.wraps(func) + def _wrapper(*args): + func(*args) + + return _wrapper + + return _outer_decorator + + class BaseTestClass: """Base class for all test classes to inherit from. @@ -559,7 +634,38 @@ class BaseTestClass: content['timestamp'] = utils.get_current_epoch_time() self.summary_writer.dump(content, records.TestSummaryEntryType.USER_DATA) - def exec_one_test(self, test_name, test_method): + def _exec_one_test_with_retry(self, test_name, test_method, max_count): + """Executes one test and retry the test if needed. + + Repeatedly execute a test case until it passes or the maximum count of + iteration has been reached. + + Args: + test_name: string, Name of the test. + test_method: function, The test method to execute. + max_count: int, the maximum number of iterations to execute the test for. + """ + + def should_retry(record): + return record.result in [ + records.TestResultEnums.TEST_RESULT_FAIL, + records.TestResultEnums.TEST_RESULT_ERROR + ] + + previous_record = self.exec_one_test(test_name, test_method) + + if not should_retry(previous_record): + return previous_record + + for i in range(max_count - 1): + retry_name = f'{test_name}_retry_{i+1}' + new_record = records.TestResultRecord(retry_name, self.TAG) + new_record.retry_parent = previous_record.signature + previous_record = self.exec_one_test(retry_name, test_method, new_record) + if not should_retry(previous_record): + break + + def exec_one_test(self, test_name, test_method, record=None): """Executes one test and update test results. Executes setup_test, the test method, and teardown_test; then creates a @@ -569,8 +675,18 @@ class BaseTestClass: Args: test_name: string, Name of the test. test_method: function, The test method to execute. + record: records.TestResultRecord, optional arg for injecting a record + object to use for this test execution. If not set, a new one is created + created. This is meant for passing infomation between consecutive test + case execution for retry purposes. Do NOT abuse this for "magical" + features. + + Returns: + TestResultRecord, the test result record object of the test execution. + This object is strictly for read-only purposes. Modifying this record + will not change what is reported in the test run's summary yaml file. """ - tr_record = records.TestResultRecord(test_name, self.TAG) + tr_record = record or records.TestResultRecord(test_name, self.TAG) tr_record.uid = getattr(test_method, 'uid', None) tr_record.test_begin() self.current_test_info = runtime_test_info.RuntimeTestInfo( @@ -650,6 +766,7 @@ class BaseTestClass: self.summary_writer.dump(tr_record.to_dict(), records.TestSummaryEntryType.RECORD) self.current_test_info = None + return tr_record def _assert_function_name_in_stack(self, expected_func_name): """Asserts that the current stack contains the given function name.""" @@ -767,7 +884,12 @@ class BaseTestClass: test_method = self._generated_test_table[test_name] else: raise Error('%s does not have test method %s.' % (self.TAG, test_name)) - test_methods.append((test_name, test_method)) + repeat_count = getattr(test_method, ATTR_REPEAT_CNT, 0) + if repeat_count: + for i in range(repeat_count): + test_methods.append((f'{test_name}_{i}', test_method)) + else: + test_methods.append((test_name, test_method)) return test_methods def _skip_remaining_tests(self, exception): @@ -831,7 +953,11 @@ class BaseTestClass: return setup_class_result # Run tests in order. for test_name, test_method in tests: - self.exec_one_test(test_name, test_method) + max_count = getattr(test_method, ATTR_MAX_RETRY_CNT, 0) + if max_count: + self._exec_one_test_with_retry(test_name, test_method, max_count) + else: + self.exec_one_test(test_name, test_method) return self.results except signals.TestAbortClass as e: e.details = 'Test class aborted due to: %s' % e.details diff --git a/mobly/records.py b/mobly/records.py index 8dcc5be..9047692 100644 --- a/mobly/records.py +++ b/mobly/records.py @@ -182,6 +182,8 @@ class TestResultEnums: RECORD_EXTRA_ERRORS = 'Extra Errors' RECORD_DETAILS = 'Details' RECORD_STACKTRACE = 'Stacktrace' + RECORD_SIGNATURE = 'Signature' + RECORD_RETRY_PARENT = 'Retry Parent' RECORD_POSITION = 'Position' TEST_RESULT_PASS = 'PASS' TEST_RESULT_FAIL = 'FAIL' @@ -311,7 +313,12 @@ class TestResultRecord: test_name: string, the name of the test. begin_time: Epoch timestamp of when the test started. end_time: Epoch timestamp of when the test ended. - uid: Unique identifier of a test. + uid: User-defined unique identifier of the test. + signature: string, unique identifier of a test record, the value is + generated by Mobly. + retry_parent: string, only set for retry iterations. This is the signature + of the previous iteration of this retry. Parsers can use this field to + construct the chain of execution for each retried test. termination_signal: ExceptionRecord, the main exception of the test. extra_errors: OrderedDict, all exceptions occurred during the entire test lifecycle. The order of occurrence is preserved. @@ -324,6 +331,8 @@ class TestResultRecord: self.begin_time = None self.end_time = None self.uid = None + self.signature = None + self.retry_parent = None self.termination_signal = None self.extra_errors = collections.OrderedDict() self.result = None @@ -360,6 +369,7 @@ class TestResultRecord: Sets the begin_time of this record. """ self.begin_time = utils.get_current_epoch_time() + self.signature = '%s-%s' % (self.test_name, self.begin_time) def _test_end(self, result, e): """Marks the end of the test logic. @@ -480,6 +490,8 @@ class TestResultRecord: d[TestResultEnums.RECORD_END_TIME] = self.end_time d[TestResultEnums.RECORD_RESULT] = self.result d[TestResultEnums.RECORD_UID] = self.uid + d[TestResultEnums.RECORD_SIGNATURE] = self.signature + d[TestResultEnums.RECORD_RETRY_PARENT] = self.retry_parent d[TestResultEnums.RECORD_EXTRAS] = self.extras d[TestResultEnums.RECORD_DETAILS] = self.details d[TestResultEnums.RECORD_EXTRA_ERRORS] = { diff --git a/mobly/runtime_test_info.py b/mobly/runtime_test_info.py index 99a5c72..ed691bc 100644 --- a/mobly/runtime_test_info.py +++ b/mobly/runtime_test_info.py @@ -39,9 +39,8 @@ class RuntimeTestInfo: def __init__(self, test_name, log_path, record): self._name = test_name self._record = record - self._signature = '%s-%s' % (test_name, record.begin_time) self._output_dir_path = utils.abs_path( - os.path.join(log_path, self._signature)) + os.path.join(log_path, self._record.signature)) @property def name(self): @@ -49,7 +48,7 @@ class RuntimeTestInfo: @property def signature(self): - return self._signature + return self.record.signature @property def record(self): diff --git a/tests/mobly/base_test_test.py b/tests/mobly/base_test_test.py index 4fcc853..a720752 100755 --- a/tests/mobly/base_test_test.py +++ b/tests/mobly/base_test_test.py @@ -2237,6 +2237,134 @@ class BaseTestTest(unittest.TestCase): 'mock_controller: Some failure') self.assertEqual(record.details, expected_msg) + def test_repeat_invalid_count(self): + + with self.assertRaisesRegex( + ValueError, 'The `count` for `repeat` must be larger than 1, got "1".'): + + class MockBaseTest(base_test.BaseTestClass): + + @base_test.repeat(count=1) + def test_something(self): + pass + + def test_repeat(self): + repeat_count = 3 + + class MockBaseTest(base_test.BaseTestClass): + + @base_test.repeat(count=repeat_count) + def test_something(self): + pass + + bt_cls = MockBaseTest(self.mock_test_cls_configs) + bt_cls.run() + self.assertEqual(repeat_count, len(bt_cls.results.passed)) + for i, record in enumerate(bt_cls.results.passed): + self.assertEqual(record.test_name, f'test_something_{i}') + + def test_repeat_with_failures(self): + repeat_count = 3 + mock_action = mock.MagicMock() + mock_action.side_effect = [None, Exception('Something failed'), None] + + class MockBaseTest(base_test.BaseTestClass): + + @base_test.repeat(count=repeat_count) + def test_something(self): + mock_action() + + bt_cls = MockBaseTest(self.mock_test_cls_configs) + bt_cls.run() + self.assertEqual(repeat_count, len(bt_cls.results.executed)) + self.assertEqual(1, len(bt_cls.results.error)) + self.assertEqual(2, len(bt_cls.results.passed)) + iter_2 = bt_cls.results.error[0] + iter_1, iter_3 = bt_cls.results.passed + self.assertEqual(iter_2.test_name, 'test_something_1') + self.assertEqual(iter_1.test_name, 'test_something_0') + self.assertEqual(iter_3.test_name, 'test_something_2') + + def test_retry_invalid_count(self): + + with self.assertRaisesRegex( + ValueError, + 'The `max_count` for `retry` must be larger than 1, got "1".'): + + class MockBaseTest(base_test.BaseTestClass): + + @base_test.retry(max_count=1) + def test_something(self): + pass + + def test_retry_first_pass(self): + max_count = 3 + mock_action = mock.MagicMock() + + class MockBaseTest(base_test.BaseTestClass): + + @base_test.retry(max_count=max_count) + def test_something(self): + mock_action() + + bt_cls = MockBaseTest(self.mock_test_cls_configs) + bt_cls.run() + self.assertEqual(1, len(bt_cls.results.executed)) + self.assertEqual(1, len(bt_cls.results.passed)) + pass_record = bt_cls.results.passed[0] + self.assertEqual(pass_record.test_name, f'test_something') + self.assertEqual(0, len(bt_cls.results.error)) + + def test_retry_last_pass(self): + max_count = 3 + mock_action = mock.MagicMock() + mock_action.side_effect = [Exception('Fail 1'), Exception('Fail 2'), None] + + class MockBaseTest(base_test.BaseTestClass): + + @base_test.retry(max_count=max_count) + def test_something(self): + mock_action() + + bt_cls = MockBaseTest(self.mock_test_cls_configs) + bt_cls.run() + self.assertEqual(3, len(bt_cls.results.executed)) + self.assertEqual(1, len(bt_cls.results.passed)) + pass_record = bt_cls.results.passed[0] + self.assertEqual(pass_record.test_name, f'test_something_retry_2') + self.assertEqual(2, len(bt_cls.results.error)) + error_record_1, error_record_2 = bt_cls.results.error + self.assertEqual(error_record_1.test_name, 'test_something') + self.assertEqual(error_record_2.test_name, 'test_something_retry_1') + self.assertEqual(error_record_1.signature, error_record_2.retry_parent) + self.assertEqual(error_record_2.signature, pass_record.retry_parent) + + def test_retry_all_fail(self): + max_count = 3 + mock_action = mock.MagicMock() + mock_action.side_effect = [ + Exception('Fail 1'), + Exception('Fail 2'), + Exception('Fail 3') + ] + + class MockBaseTest(base_test.BaseTestClass): + + @base_test.retry(max_count=max_count) + def test_something(self): + mock_action() + + bt_cls = MockBaseTest(self.mock_test_cls_configs) + bt_cls.run() + self.assertEqual(3, len(bt_cls.results.executed)) + self.assertEqual(3, len(bt_cls.results.error)) + error_record_1, error_record_2, error_record_3 = bt_cls.results.error + self.assertEqual(error_record_1.test_name, 'test_something') + self.assertEqual(error_record_2.test_name, 'test_something_retry_1') + self.assertEqual(error_record_3.test_name, 'test_something_retry_2') + self.assertEqual(error_record_1.signature, error_record_2.retry_parent) + self.assertEqual(error_record_2.signature, error_record_3.retry_parent) + def test_uid(self): class MockBaseTest(base_test.BaseTestClass): @@ -2271,6 +2399,40 @@ class BaseTestTest(unittest.TestCase): def not_a_test(self): pass + def test_repeat_with_uid(self): + repeat_count = 3 + + class MockBaseTest(base_test.BaseTestClass): + + @base_test.repeat(count=repeat_count) + @records.uid('some-uid') + def test_something(self): + pass + + bt_cls = MockBaseTest(self.mock_test_cls_configs) + bt_cls.run() + self.assertEqual(repeat_count, len(bt_cls.results.passed)) + for i, record in enumerate(bt_cls.results.passed): + self.assertEqual(record.test_name, f'test_something_{i}') + self.assertEqual(record.uid, 'some-uid') + + def test_uid_with_repeat(self): + repeat_count = 3 + + class MockBaseTest(base_test.BaseTestClass): + + @records.uid('some-uid') + @base_test.repeat(count=repeat_count) + def test_something(self): + pass + + bt_cls = MockBaseTest(self.mock_test_cls_configs) + bt_cls.run() + self.assertEqual(repeat_count, len(bt_cls.results.passed)) + for i, record in enumerate(bt_cls.results.passed): + self.assertEqual(record.test_name, f'test_something_{i}') + self.assertEqual(record.uid, 'some-uid') + def test_log_stage_always_logs_end_statement(self): instance = base_test.BaseTestClass(self.mock_test_cls_configs) instance.current_test_info = mock.Mock() diff --git a/tests/mobly/controllers/android_device_lib/services/logcat_test.py b/tests/mobly/controllers/android_device_lib/services/logcat_test.py index ba18698..795bf0b 100755 --- a/tests/mobly/controllers/android_device_lib/services/logcat_test.py +++ b/tests/mobly/controllers/android_device_lib/services/logcat_test.py @@ -20,6 +20,7 @@ import shutil import tempfile import unittest +from mobly import records from mobly import utils from mobly import runtime_test_info from mobly.controllers import android_device @@ -239,8 +240,9 @@ class LogcatTest(unittest.TestCase): with open(logcat_service.adb_logcat_file_path, 'a') as f: f.write(logcat_file_content) test_output_dir = os.path.join(self.tmp_dir, test_name) - mock_record = mock.MagicMock() + mock_record = records.TestResultRecord(test_name) mock_record.begin_time = test_begin_time + mock_record.signature = f'{test_name}-{test_begin_time}' test_run_info = runtime_test_info.RuntimeTestInfo(test_name, test_output_dir, mock_record) diff --git a/tests/mobly/controllers/android_device_test.py b/tests/mobly/controllers/android_device_test.py index d3b566f..c64c826 100755 --- a/tests/mobly/controllers/android_device_test.py +++ b/tests/mobly/controllers/android_device_test.py @@ -579,8 +579,11 @@ class AndroidDeviceTest(unittest.TestCase): mock_serial = 1 ad = android_device.AndroidDevice(serial=mock_serial) get_log_file_timestamp_mock.return_value = '07-22-2019_17-53-34-450' - mock_record = mock.MagicMock(begin_time='1234567') - mock_test_info = runtime_test_info.RuntimeTestInfo('test_xyz', '/tmp/blah/', + mock_record = mock.MagicMock(test_name='test_xyz', + begin_time='1234567', + signature='test_xyz-1234567') + mock_test_info = runtime_test_info.RuntimeTestInfo(mock_record.test_name, + '/tmp/blah/', mock_record) filename = ad.generate_filename('MagicLog', time_identifier=mock_test_info) self.assertEqual(filename, 'MagicLog,1,fakemodel,test_xyz-1234567') diff --git a/tests/mobly/records_test.py b/tests/mobly/records_test.py index a3c0689..765b7bb 100755 --- a/tests/mobly/records_test.py +++ b/tests/mobly/records_test.py @@ -73,7 +73,10 @@ class RecordsTest(unittest.TestCase): d[records.TestResultEnums.RECORD_EXTRAS] = extras d[records.TestResultEnums.RECORD_BEGIN_TIME] = record.begin_time d[records.TestResultEnums.RECORD_END_TIME] = record.end_time + d[records.TestResultEnums. + RECORD_SIGNATURE] = f'{self.tn}-{record.begin_time}' d[records.TestResultEnums.RECORD_UID] = None + d[records.TestResultEnums.RECORD_RETRY_PARENT] = None d[records.TestResultEnums.RECORD_CLASS] = None d[records.TestResultEnums.RECORD_EXTRA_ERRORS] = {} d[records.TestResultEnums.RECORD_STACKTRACE] = stacktrace @@ -89,8 +92,6 @@ class RecordsTest(unittest.TestCase): self.assertTrue(str(record), 'str of the record should not be empty.') self.assertTrue(repr(record), "the record's repr shouldn't be empty.") - """ Begin of Tests """ - def test_result_record_pass_none(self): record = records.TestResultRecord(self.tn) record.test_begin() @@ -375,6 +376,13 @@ class RecordsTest(unittest.TestCase): self.assertEqual(content[records.TestResultEnums.RECORD_EXTRAS], unicode_extras) + @mock.patch('mobly.utils.get_current_epoch_time') + def test_signature(self, mock_time_src): + mock_time_src.return_value = 12345 + record = records.TestResultRecord(self.tn) + record.test_begin() + self.assertEqual(record.signature, 'test_name-12345') + def test_summary_user_data(self): user_data1 = {'a': 1} user_data2 = {'b': 1} |