aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAng Li <angli@google.com>2021-04-23 13:30:41 -0700
committerGitHub <noreply@github.com>2021-04-23 13:30:41 -0700
commit532b04a4ae1969b65cbec933c687423b43602c2d (patch)
tree8c10647a41813a1d3e5c62436453736fc3602067
parentf3335acaa0758050a8bbef17ff17eaa0bfde97e9 (diff)
downloadmobly-532b04a4ae1969b65cbec933c687423b43602c2d.tar.gz
Support native repeat and retry in Mobly. (#734)
* Provides decorators for users to mark test cases for repeat and retry. * Adds new attributes to the TestResultRecord for associating retry test records. * Refactors existing code to support the repeat/retry behavior.
-rw-r--r--mobly/base_test.py134
-rw-r--r--mobly/records.py14
-rw-r--r--mobly/runtime_test_info.py5
-rwxr-xr-xtests/mobly/base_test_test.py162
-rwxr-xr-xtests/mobly/controllers/android_device_lib/services/logcat_test.py4
-rwxr-xr-xtests/mobly/controllers/android_device_test.py7
-rwxr-xr-xtests/mobly/records_test.py12
7 files changed, 325 insertions, 13 deletions
diff --git a/mobly/base_test.py b/mobly/base_test.py
index f1a28bf..ebcc2b8 100644
--- a/mobly/base_test.py
+++ b/mobly/base_test.py
@@ -43,11 +43,86 @@ STAGE_NAME_TEARDOWN_TEST = 'teardown_test'
STAGE_NAME_TEARDOWN_CLASS = 'teardown_class'
STAGE_NAME_CLEAN_UP = 'clean_up'
+# Attribute names
+ATTR_REPEAT_CNT = '_repeat_count'
+ATTR_MAX_RETRY_CNT = '_max_count'
+
class Error(Exception):
"""Raised for exceptions that occurred in BaseTestClass."""
+def repeat(count):
+ """Decorator for repeating a test case multiple times.
+
+ The BaseTestClass will execute the test cases annotated with this decorator
+ the specified number of time.
+
+ This decorator only stores the information needed for the repeat. It does not
+ execute the repeat.
+
+ Args:
+ count: int, the total number of times to execute the decorated test case.
+
+ Returns:
+ The wrapped test function.
+
+ Raises:
+ ValueError, if the user input is invalid.
+ """
+ if count <= 1:
+ raise ValueError(
+ f'The `count` for `repeat` must be larger than 1, got "{count}".')
+
+ def _outer_decorator(func):
+ setattr(func, ATTR_REPEAT_CNT, count)
+
+ @functools.wraps(func)
+ def _wrapper(*args):
+ func(*args)
+
+ return _wrapper
+
+ return _outer_decorator
+
+
+def retry(max_count):
+ """Decorator for retrying a test case until it passes.
+
+ The BaseTestClass will keep executing the test cases annotated with this
+ decorator until the test passes, or the maxinum number of iterations have
+ been met.
+
+ This decorator only stores the information needed for the retry. It does not
+ execute the retry.
+
+ Args:
+ max_count: int, the maximum number of times to execute the decorated test
+ case.
+
+ Returns:
+ The wrapped test function.
+
+ Raises:
+ ValueError, if the user input is invalid.
+ """
+ if max_count <= 1:
+ raise ValueError(
+ f'The `max_count` for `retry` must be larger than 1, got "{max_count}".'
+ )
+
+ def _outer_decorator(func):
+ setattr(func, ATTR_MAX_RETRY_CNT, max_count)
+
+ @functools.wraps(func)
+ def _wrapper(*args):
+ func(*args)
+
+ return _wrapper
+
+ return _outer_decorator
+
+
class BaseTestClass:
"""Base class for all test classes to inherit from.
@@ -559,7 +634,38 @@ class BaseTestClass:
content['timestamp'] = utils.get_current_epoch_time()
self.summary_writer.dump(content, records.TestSummaryEntryType.USER_DATA)
- def exec_one_test(self, test_name, test_method):
+ def _exec_one_test_with_retry(self, test_name, test_method, max_count):
+ """Executes one test and retry the test if needed.
+
+ Repeatedly execute a test case until it passes or the maximum count of
+ iteration has been reached.
+
+ Args:
+ test_name: string, Name of the test.
+ test_method: function, The test method to execute.
+ max_count: int, the maximum number of iterations to execute the test for.
+ """
+
+ def should_retry(record):
+ return record.result in [
+ records.TestResultEnums.TEST_RESULT_FAIL,
+ records.TestResultEnums.TEST_RESULT_ERROR
+ ]
+
+ previous_record = self.exec_one_test(test_name, test_method)
+
+ if not should_retry(previous_record):
+ return previous_record
+
+ for i in range(max_count - 1):
+ retry_name = f'{test_name}_retry_{i+1}'
+ new_record = records.TestResultRecord(retry_name, self.TAG)
+ new_record.retry_parent = previous_record.signature
+ previous_record = self.exec_one_test(retry_name, test_method, new_record)
+ if not should_retry(previous_record):
+ break
+
+ def exec_one_test(self, test_name, test_method, record=None):
"""Executes one test and update test results.
Executes setup_test, the test method, and teardown_test; then creates a
@@ -569,8 +675,18 @@ class BaseTestClass:
Args:
test_name: string, Name of the test.
test_method: function, The test method to execute.
+ record: records.TestResultRecord, optional arg for injecting a record
+ object to use for this test execution. If not set, a new one is created
+ created. This is meant for passing infomation between consecutive test
+ case execution for retry purposes. Do NOT abuse this for "magical"
+ features.
+
+ Returns:
+ TestResultRecord, the test result record object of the test execution.
+ This object is strictly for read-only purposes. Modifying this record
+ will not change what is reported in the test run's summary yaml file.
"""
- tr_record = records.TestResultRecord(test_name, self.TAG)
+ tr_record = record or records.TestResultRecord(test_name, self.TAG)
tr_record.uid = getattr(test_method, 'uid', None)
tr_record.test_begin()
self.current_test_info = runtime_test_info.RuntimeTestInfo(
@@ -650,6 +766,7 @@ class BaseTestClass:
self.summary_writer.dump(tr_record.to_dict(),
records.TestSummaryEntryType.RECORD)
self.current_test_info = None
+ return tr_record
def _assert_function_name_in_stack(self, expected_func_name):
"""Asserts that the current stack contains the given function name."""
@@ -767,7 +884,12 @@ class BaseTestClass:
test_method = self._generated_test_table[test_name]
else:
raise Error('%s does not have test method %s.' % (self.TAG, test_name))
- test_methods.append((test_name, test_method))
+ repeat_count = getattr(test_method, ATTR_REPEAT_CNT, 0)
+ if repeat_count:
+ for i in range(repeat_count):
+ test_methods.append((f'{test_name}_{i}', test_method))
+ else:
+ test_methods.append((test_name, test_method))
return test_methods
def _skip_remaining_tests(self, exception):
@@ -831,7 +953,11 @@ class BaseTestClass:
return setup_class_result
# Run tests in order.
for test_name, test_method in tests:
- self.exec_one_test(test_name, test_method)
+ max_count = getattr(test_method, ATTR_MAX_RETRY_CNT, 0)
+ if max_count:
+ self._exec_one_test_with_retry(test_name, test_method, max_count)
+ else:
+ self.exec_one_test(test_name, test_method)
return self.results
except signals.TestAbortClass as e:
e.details = 'Test class aborted due to: %s' % e.details
diff --git a/mobly/records.py b/mobly/records.py
index 8dcc5be..9047692 100644
--- a/mobly/records.py
+++ b/mobly/records.py
@@ -182,6 +182,8 @@ class TestResultEnums:
RECORD_EXTRA_ERRORS = 'Extra Errors'
RECORD_DETAILS = 'Details'
RECORD_STACKTRACE = 'Stacktrace'
+ RECORD_SIGNATURE = 'Signature'
+ RECORD_RETRY_PARENT = 'Retry Parent'
RECORD_POSITION = 'Position'
TEST_RESULT_PASS = 'PASS'
TEST_RESULT_FAIL = 'FAIL'
@@ -311,7 +313,12 @@ class TestResultRecord:
test_name: string, the name of the test.
begin_time: Epoch timestamp of when the test started.
end_time: Epoch timestamp of when the test ended.
- uid: Unique identifier of a test.
+ uid: User-defined unique identifier of the test.
+ signature: string, unique identifier of a test record, the value is
+ generated by Mobly.
+ retry_parent: string, only set for retry iterations. This is the signature
+ of the previous iteration of this retry. Parsers can use this field to
+ construct the chain of execution for each retried test.
termination_signal: ExceptionRecord, the main exception of the test.
extra_errors: OrderedDict, all exceptions occurred during the entire
test lifecycle. The order of occurrence is preserved.
@@ -324,6 +331,8 @@ class TestResultRecord:
self.begin_time = None
self.end_time = None
self.uid = None
+ self.signature = None
+ self.retry_parent = None
self.termination_signal = None
self.extra_errors = collections.OrderedDict()
self.result = None
@@ -360,6 +369,7 @@ class TestResultRecord:
Sets the begin_time of this record.
"""
self.begin_time = utils.get_current_epoch_time()
+ self.signature = '%s-%s' % (self.test_name, self.begin_time)
def _test_end(self, result, e):
"""Marks the end of the test logic.
@@ -480,6 +490,8 @@ class TestResultRecord:
d[TestResultEnums.RECORD_END_TIME] = self.end_time
d[TestResultEnums.RECORD_RESULT] = self.result
d[TestResultEnums.RECORD_UID] = self.uid
+ d[TestResultEnums.RECORD_SIGNATURE] = self.signature
+ d[TestResultEnums.RECORD_RETRY_PARENT] = self.retry_parent
d[TestResultEnums.RECORD_EXTRAS] = self.extras
d[TestResultEnums.RECORD_DETAILS] = self.details
d[TestResultEnums.RECORD_EXTRA_ERRORS] = {
diff --git a/mobly/runtime_test_info.py b/mobly/runtime_test_info.py
index 99a5c72..ed691bc 100644
--- a/mobly/runtime_test_info.py
+++ b/mobly/runtime_test_info.py
@@ -39,9 +39,8 @@ class RuntimeTestInfo:
def __init__(self, test_name, log_path, record):
self._name = test_name
self._record = record
- self._signature = '%s-%s' % (test_name, record.begin_time)
self._output_dir_path = utils.abs_path(
- os.path.join(log_path, self._signature))
+ os.path.join(log_path, self._record.signature))
@property
def name(self):
@@ -49,7 +48,7 @@ class RuntimeTestInfo:
@property
def signature(self):
- return self._signature
+ return self.record.signature
@property
def record(self):
diff --git a/tests/mobly/base_test_test.py b/tests/mobly/base_test_test.py
index 4fcc853..a720752 100755
--- a/tests/mobly/base_test_test.py
+++ b/tests/mobly/base_test_test.py
@@ -2237,6 +2237,134 @@ class BaseTestTest(unittest.TestCase):
'mock_controller: Some failure')
self.assertEqual(record.details, expected_msg)
+ def test_repeat_invalid_count(self):
+
+ with self.assertRaisesRegex(
+ ValueError, 'The `count` for `repeat` must be larger than 1, got "1".'):
+
+ class MockBaseTest(base_test.BaseTestClass):
+
+ @base_test.repeat(count=1)
+ def test_something(self):
+ pass
+
+ def test_repeat(self):
+ repeat_count = 3
+
+ class MockBaseTest(base_test.BaseTestClass):
+
+ @base_test.repeat(count=repeat_count)
+ def test_something(self):
+ pass
+
+ bt_cls = MockBaseTest(self.mock_test_cls_configs)
+ bt_cls.run()
+ self.assertEqual(repeat_count, len(bt_cls.results.passed))
+ for i, record in enumerate(bt_cls.results.passed):
+ self.assertEqual(record.test_name, f'test_something_{i}')
+
+ def test_repeat_with_failures(self):
+ repeat_count = 3
+ mock_action = mock.MagicMock()
+ mock_action.side_effect = [None, Exception('Something failed'), None]
+
+ class MockBaseTest(base_test.BaseTestClass):
+
+ @base_test.repeat(count=repeat_count)
+ def test_something(self):
+ mock_action()
+
+ bt_cls = MockBaseTest(self.mock_test_cls_configs)
+ bt_cls.run()
+ self.assertEqual(repeat_count, len(bt_cls.results.executed))
+ self.assertEqual(1, len(bt_cls.results.error))
+ self.assertEqual(2, len(bt_cls.results.passed))
+ iter_2 = bt_cls.results.error[0]
+ iter_1, iter_3 = bt_cls.results.passed
+ self.assertEqual(iter_2.test_name, 'test_something_1')
+ self.assertEqual(iter_1.test_name, 'test_something_0')
+ self.assertEqual(iter_3.test_name, 'test_something_2')
+
+ def test_retry_invalid_count(self):
+
+ with self.assertRaisesRegex(
+ ValueError,
+ 'The `max_count` for `retry` must be larger than 1, got "1".'):
+
+ class MockBaseTest(base_test.BaseTestClass):
+
+ @base_test.retry(max_count=1)
+ def test_something(self):
+ pass
+
+ def test_retry_first_pass(self):
+ max_count = 3
+ mock_action = mock.MagicMock()
+
+ class MockBaseTest(base_test.BaseTestClass):
+
+ @base_test.retry(max_count=max_count)
+ def test_something(self):
+ mock_action()
+
+ bt_cls = MockBaseTest(self.mock_test_cls_configs)
+ bt_cls.run()
+ self.assertEqual(1, len(bt_cls.results.executed))
+ self.assertEqual(1, len(bt_cls.results.passed))
+ pass_record = bt_cls.results.passed[0]
+ self.assertEqual(pass_record.test_name, f'test_something')
+ self.assertEqual(0, len(bt_cls.results.error))
+
+ def test_retry_last_pass(self):
+ max_count = 3
+ mock_action = mock.MagicMock()
+ mock_action.side_effect = [Exception('Fail 1'), Exception('Fail 2'), None]
+
+ class MockBaseTest(base_test.BaseTestClass):
+
+ @base_test.retry(max_count=max_count)
+ def test_something(self):
+ mock_action()
+
+ bt_cls = MockBaseTest(self.mock_test_cls_configs)
+ bt_cls.run()
+ self.assertEqual(3, len(bt_cls.results.executed))
+ self.assertEqual(1, len(bt_cls.results.passed))
+ pass_record = bt_cls.results.passed[0]
+ self.assertEqual(pass_record.test_name, f'test_something_retry_2')
+ self.assertEqual(2, len(bt_cls.results.error))
+ error_record_1, error_record_2 = bt_cls.results.error
+ self.assertEqual(error_record_1.test_name, 'test_something')
+ self.assertEqual(error_record_2.test_name, 'test_something_retry_1')
+ self.assertEqual(error_record_1.signature, error_record_2.retry_parent)
+ self.assertEqual(error_record_2.signature, pass_record.retry_parent)
+
+ def test_retry_all_fail(self):
+ max_count = 3
+ mock_action = mock.MagicMock()
+ mock_action.side_effect = [
+ Exception('Fail 1'),
+ Exception('Fail 2'),
+ Exception('Fail 3')
+ ]
+
+ class MockBaseTest(base_test.BaseTestClass):
+
+ @base_test.retry(max_count=max_count)
+ def test_something(self):
+ mock_action()
+
+ bt_cls = MockBaseTest(self.mock_test_cls_configs)
+ bt_cls.run()
+ self.assertEqual(3, len(bt_cls.results.executed))
+ self.assertEqual(3, len(bt_cls.results.error))
+ error_record_1, error_record_2, error_record_3 = bt_cls.results.error
+ self.assertEqual(error_record_1.test_name, 'test_something')
+ self.assertEqual(error_record_2.test_name, 'test_something_retry_1')
+ self.assertEqual(error_record_3.test_name, 'test_something_retry_2')
+ self.assertEqual(error_record_1.signature, error_record_2.retry_parent)
+ self.assertEqual(error_record_2.signature, error_record_3.retry_parent)
+
def test_uid(self):
class MockBaseTest(base_test.BaseTestClass):
@@ -2271,6 +2399,40 @@ class BaseTestTest(unittest.TestCase):
def not_a_test(self):
pass
+ def test_repeat_with_uid(self):
+ repeat_count = 3
+
+ class MockBaseTest(base_test.BaseTestClass):
+
+ @base_test.repeat(count=repeat_count)
+ @records.uid('some-uid')
+ def test_something(self):
+ pass
+
+ bt_cls = MockBaseTest(self.mock_test_cls_configs)
+ bt_cls.run()
+ self.assertEqual(repeat_count, len(bt_cls.results.passed))
+ for i, record in enumerate(bt_cls.results.passed):
+ self.assertEqual(record.test_name, f'test_something_{i}')
+ self.assertEqual(record.uid, 'some-uid')
+
+ def test_uid_with_repeat(self):
+ repeat_count = 3
+
+ class MockBaseTest(base_test.BaseTestClass):
+
+ @records.uid('some-uid')
+ @base_test.repeat(count=repeat_count)
+ def test_something(self):
+ pass
+
+ bt_cls = MockBaseTest(self.mock_test_cls_configs)
+ bt_cls.run()
+ self.assertEqual(repeat_count, len(bt_cls.results.passed))
+ for i, record in enumerate(bt_cls.results.passed):
+ self.assertEqual(record.test_name, f'test_something_{i}')
+ self.assertEqual(record.uid, 'some-uid')
+
def test_log_stage_always_logs_end_statement(self):
instance = base_test.BaseTestClass(self.mock_test_cls_configs)
instance.current_test_info = mock.Mock()
diff --git a/tests/mobly/controllers/android_device_lib/services/logcat_test.py b/tests/mobly/controllers/android_device_lib/services/logcat_test.py
index ba18698..795bf0b 100755
--- a/tests/mobly/controllers/android_device_lib/services/logcat_test.py
+++ b/tests/mobly/controllers/android_device_lib/services/logcat_test.py
@@ -20,6 +20,7 @@ import shutil
import tempfile
import unittest
+from mobly import records
from mobly import utils
from mobly import runtime_test_info
from mobly.controllers import android_device
@@ -239,8 +240,9 @@ class LogcatTest(unittest.TestCase):
with open(logcat_service.adb_logcat_file_path, 'a') as f:
f.write(logcat_file_content)
test_output_dir = os.path.join(self.tmp_dir, test_name)
- mock_record = mock.MagicMock()
+ mock_record = records.TestResultRecord(test_name)
mock_record.begin_time = test_begin_time
+ mock_record.signature = f'{test_name}-{test_begin_time}'
test_run_info = runtime_test_info.RuntimeTestInfo(test_name,
test_output_dir,
mock_record)
diff --git a/tests/mobly/controllers/android_device_test.py b/tests/mobly/controllers/android_device_test.py
index d3b566f..c64c826 100755
--- a/tests/mobly/controllers/android_device_test.py
+++ b/tests/mobly/controllers/android_device_test.py
@@ -579,8 +579,11 @@ class AndroidDeviceTest(unittest.TestCase):
mock_serial = 1
ad = android_device.AndroidDevice(serial=mock_serial)
get_log_file_timestamp_mock.return_value = '07-22-2019_17-53-34-450'
- mock_record = mock.MagicMock(begin_time='1234567')
- mock_test_info = runtime_test_info.RuntimeTestInfo('test_xyz', '/tmp/blah/',
+ mock_record = mock.MagicMock(test_name='test_xyz',
+ begin_time='1234567',
+ signature='test_xyz-1234567')
+ mock_test_info = runtime_test_info.RuntimeTestInfo(mock_record.test_name,
+ '/tmp/blah/',
mock_record)
filename = ad.generate_filename('MagicLog', time_identifier=mock_test_info)
self.assertEqual(filename, 'MagicLog,1,fakemodel,test_xyz-1234567')
diff --git a/tests/mobly/records_test.py b/tests/mobly/records_test.py
index a3c0689..765b7bb 100755
--- a/tests/mobly/records_test.py
+++ b/tests/mobly/records_test.py
@@ -73,7 +73,10 @@ class RecordsTest(unittest.TestCase):
d[records.TestResultEnums.RECORD_EXTRAS] = extras
d[records.TestResultEnums.RECORD_BEGIN_TIME] = record.begin_time
d[records.TestResultEnums.RECORD_END_TIME] = record.end_time
+ d[records.TestResultEnums.
+ RECORD_SIGNATURE] = f'{self.tn}-{record.begin_time}'
d[records.TestResultEnums.RECORD_UID] = None
+ d[records.TestResultEnums.RECORD_RETRY_PARENT] = None
d[records.TestResultEnums.RECORD_CLASS] = None
d[records.TestResultEnums.RECORD_EXTRA_ERRORS] = {}
d[records.TestResultEnums.RECORD_STACKTRACE] = stacktrace
@@ -89,8 +92,6 @@ class RecordsTest(unittest.TestCase):
self.assertTrue(str(record), 'str of the record should not be empty.')
self.assertTrue(repr(record), "the record's repr shouldn't be empty.")
- """ Begin of Tests """
-
def test_result_record_pass_none(self):
record = records.TestResultRecord(self.tn)
record.test_begin()
@@ -375,6 +376,13 @@ class RecordsTest(unittest.TestCase):
self.assertEqual(content[records.TestResultEnums.RECORD_EXTRAS],
unicode_extras)
+ @mock.patch('mobly.utils.get_current_epoch_time')
+ def test_signature(self, mock_time_src):
+ mock_time_src.return_value = 12345
+ record = records.TestResultRecord(self.tn)
+ record.test_begin()
+ self.assertEqual(record.signature, 'test_name-12345')
+
def test_summary_user_data(self):
user_data1 = {'a': 1}
user_data2 = {'b': 1}