diff options
author | Abseil Team <absl-team@google.com> | 2021-02-09 11:53:52 -0800 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2021-02-09 11:54:14 -0800 |
commit | 2444faee6d820c024c26380f60aa5e12fdf4b3a4 (patch) | |
tree | 38fded8ffb8a2c2fd93571be7f155f4f32ec84cd /absl/testing | |
parent | a8d57c3317d375bedae2281e40e8d57bb944f393 (diff) | |
download | absl-py-2444faee6d820c024c26380f60aa5e12fdf4b3a4.tar.gz |
More general parameterized.product, supporting products of dicts as well as products of single parameters.
PiperOrigin-RevId: 356556227
Change-Id: Ie476d2e971e75e6f5cce88d359e9512228003af7
Diffstat (limited to 'absl/testing')
-rw-r--r-- | absl/testing/parameterized.py | 66 | ||||
-rw-r--r-- | absl/testing/tests/parameterized_test.py | 71 |
2 files changed, 127 insertions, 10 deletions
diff --git a/absl/testing/parameterized.py b/absl/testing/parameterized.py index 894cf6c..570dbb6 100644 --- a/absl/testing/parameterized.py +++ b/absl/testing/parameterized.py @@ -170,6 +170,21 @@ test combinations: def testModuloResult(self, num, modulo, expected): self.assertEqual(expected, num % modulo) +This results in 6 test cases being created - one for each combination of the +parameters. It is also possible to supply sequences of keyword argument dicts +as elements of the cartesian product: + + @parameterized.product( + (dict(num=5, modulo=3, expected=2), + dict(num=7, modulo=4, expected=3)), + dtype=(int, float) + ) + def testModuloResult(self, num, modulo, expected): + self.assertEqual(expected, dtype(num) % modulo) + +This results in 4 test cases being created - for each of the two sets of test +data (supplied as kwarg dicts) and for each of the two data types (supplied as +a named parameter). Multiple keyword argument dicts may be supplied if required. Async Support =============================== @@ -464,15 +479,19 @@ def named_parameters(*testcases): return _parameter_decorator(_NAMED, testcases) -def product(**testgrid): +def product(*kwargs_seqs, **testgrid): """A decorator for running tests over cartesian product of parameters values. See the module docstring for a usage example. The test will be run for every possible combination of the parameters. Args: - **testgrid: A mapping of parameter names and their possible values. - Possible values should given as either a list or a tuple. + *kwargs_seqs: Each positional parameter is a sequence of keyword arg dicts; + every test case generated will include exactly one kwargs dict from each + positional parameter; these will then be merged to form an overall list + of arguments for the test case. + **testgrid: A mapping of parameter names and their possible values. Possible + values should given as either a list or a tuple. Raises: NoTestsError: Raised when the decorator generates no tests. @@ -482,24 +501,53 @@ def product(**testgrid): """ for name, values in testgrid.items(): - assert isinstance(values, list) or isinstance(values, tuple), ( + assert isinstance(values, (list, tuple)), ( 'Values of {} must be given as list or tuple, found {}'.format( name, type(values))) + prior_arg_names = set() + for kwargs_seq in kwargs_seqs: + assert ((isinstance(kwargs_seq, (list, tuple))) and + all(isinstance(kwargs, dict) for kwargs in kwargs_seq)), ( + 'Positional parameters must be a sequence of keyword arg' + 'dicts, found {}' + .format(kwargs_seq)) + if kwargs_seq: + arg_names = set(kwargs_seq[0]) + assert all(set(kwargs) == arg_names for kwargs in kwargs_seq), ( + 'Keyword argument dicts within a single parameter must all have the ' + 'same keys, found {}'.format(kwargs_seq)) + assert not (arg_names & prior_arg_names), ( + 'Keyword argument dict sequences must all have distinct argument ' + 'names, found duplicate(s) {}' + .format(sorted(arg_names & prior_arg_names))) + prior_arg_names |= arg_names + + assert not (prior_arg_names & set(testgrid)), ( + 'Arguments supplied in kwargs dicts in positional parameters must not ' + 'overlap with arguments supplied as named parameters; found duplicate ' + 'argument(s) {}'.format(sorted(prior_arg_names & set(testgrid)))) + + # Convert testgrid into a sequence of sequences of kwargs dicts and combine + # with the positional parameters. + # So foo=[1,2], bar=[3,4] --> [[{foo: 1}, {foo: 2}], [{bar: 3, bar: 4}]] + testgrid = (tuple({k: v} for v in vs) for k, vs in testgrid.items()) + testgrid = tuple(kwargs_seqs) + tuple(testgrid) + # Create all possible combinations of parameters as a cartesian product # of parameter values. testcases = [ - dict(zip(testgrid.keys(), product)) - for product in itertools.product(*testgrid.values()) + dict(itertools.chain.from_iterable(case.items() + for case in cases)) + for cases in itertools.product(*testgrid) ] - return _parameter_decorator(_ARGUMENT_REPR, testcases) class TestGeneratorMetaclass(type): """Metaclass for adding tests generated by parameterized decorators.""" - def __new__(mcs, class_name, bases, dct): + def __new__(cls, class_name, bases, dct): # NOTE: _test_params_repr is private to parameterized.TestCase and it's # metaclass; do not use it outside of those classes. test_params_reprs = dct.setdefault('_test_params_reprs', {}) @@ -544,7 +592,7 @@ class TestGeneratorMetaclass(type): # That's why it should only inherit it if it does not exist. test_params_reprs.setdefault(test_method, test_method_id) - return type.__new__(mcs, class_name, bases, dct) + return type.__new__(cls, class_name, bases, dct) def _update_class_dict_for_param_test_case( diff --git a/absl/testing/tests/parameterized_test.py b/absl/testing/tests/parameterized_test.py index b099cd3..fb618ea 100644 --- a/absl/testing/tests/parameterized_test.py +++ b/absl/testing/tests/parameterized_test.py @@ -494,7 +494,7 @@ class ParameterizedTestsTest(absltest.TestCase): ], short_descs) - def test_successful_product_test(self): + def test_successful_product_test_testgrid(self): class GoodProductTestCase(parameterized.TestCase): @@ -513,6 +513,75 @@ class ParameterizedTestsTest(absltest.TestCase): self.assertEqual(res.testsRun, 6) self.assertTrue(res.wasSuccessful()) + def test_successful_product_test_kwarg_seqs(self): + + class GoodProductTestCase(parameterized.TestCase): + + @parameterized.product((dict(num=0), dict(num=20), dict(num=0)), + (dict(modulo=2), dict(modulo=4)), + (dict(expected=0),)) + def testModuloResult(self, num, modulo, expected): + self.assertEqual(expected, num % modulo) + + ts = unittest.makeSuite(GoodProductTestCase) + res = unittest.TestResult() + ts.run(res) + self.assertEqual(ts.countTestCases(), 6) + self.assertEqual(res.testsRun, 6) + self.assertTrue(res.wasSuccessful()) + + def test_successful_product_test_kwarg_seq_and_testgrid(self): + + class GoodProductTestCase(parameterized.TestCase): + + @parameterized.product((dict( + num=5, modulo=3, expected=2), dict(num=7, modulo=4, expected=3)), + dtype=(int, float)) + def testModuloResult(self, num, dtype, modulo, expected): + self.assertEqual(expected, dtype(num) % modulo) + + ts = unittest.makeSuite(GoodProductTestCase) + res = unittest.TestResult() + ts.run(res) + self.assertEqual(ts.countTestCases(), 4) + self.assertEqual(res.testsRun, 4) + self.assertTrue(res.wasSuccessful()) + + def test_inconsistent_arg_names_in_kwargs_seq(self): + with self.assertRaisesRegex(AssertionError, 'must all have the same keys'): + + class BadProductParams(parameterized.TestCase): # pylint: disable=unused-variable + + @parameterized.product((dict(num=5, modulo=3), dict(num=7, modula=2)), + dtype=(int, float)) + def test_something(self): + pass # not called because argnames are not the same + + def test_duplicate_arg_names_in_kwargs_seqs(self): + with self.assertRaisesRegex(AssertionError, 'must all have distinct'): + + class BadProductParams(parameterized.TestCase): # pylint: disable=unused-variable + + @parameterized.product((dict(num=5, modulo=3), dict(num=7, modulo=4)), + (dict(foo='bar', num=5), dict(foo='baz', num=7)), + dtype=(int, float)) + def test_something(self): + pass # not called because `num` is specified twice + + def test_duplicate_arg_names_in_kwargs_seq_and_testgrid(self): + with self.assertRaisesRegex(AssertionError, 'duplicate argument'): + + class BadProductParams(parameterized.TestCase): # pylint: disable=unused-variable + + @parameterized.product( + (dict(num=5, modulo=3), dict(num=7, modulo=4)), + (dict(foo='bar'), dict(foo='baz')), + dtype=(int, float), + foo=('a', 'b'), + ) + def test_something(self): + pass # not called because `foo` is specified twice + def test_product_recorded_failures(self): class MixedProductTestCase(parameterized.TestCase): |