aboutsummaryrefslogtreecommitdiff
path: root/absl/testing
diff options
context:
space:
mode:
authorAbseil Team <absl-team@google.com>2021-02-09 11:53:52 -0800
committerCopybara-Service <copybara-worker@google.com>2021-02-09 11:54:14 -0800
commit2444faee6d820c024c26380f60aa5e12fdf4b3a4 (patch)
tree38fded8ffb8a2c2fd93571be7f155f4f32ec84cd /absl/testing
parenta8d57c3317d375bedae2281e40e8d57bb944f393 (diff)
downloadabsl-py-2444faee6d820c024c26380f60aa5e12fdf4b3a4.tar.gz
More general parameterized.product, supporting products of dicts as well as products of single parameters.
PiperOrigin-RevId: 356556227 Change-Id: Ie476d2e971e75e6f5cce88d359e9512228003af7
Diffstat (limited to 'absl/testing')
-rw-r--r--absl/testing/parameterized.py66
-rw-r--r--absl/testing/tests/parameterized_test.py71
2 files changed, 127 insertions, 10 deletions
diff --git a/absl/testing/parameterized.py b/absl/testing/parameterized.py
index 894cf6c..570dbb6 100644
--- a/absl/testing/parameterized.py
+++ b/absl/testing/parameterized.py
@@ -170,6 +170,21 @@ test combinations:
def testModuloResult(self, num, modulo, expected):
self.assertEqual(expected, num % modulo)
+This results in 6 test cases being created - one for each combination of the
+parameters. It is also possible to supply sequences of keyword argument dicts
+as elements of the cartesian product:
+
+ @parameterized.product(
+ (dict(num=5, modulo=3, expected=2),
+ dict(num=7, modulo=4, expected=3)),
+ dtype=(int, float)
+ )
+ def testModuloResult(self, num, modulo, expected):
+ self.assertEqual(expected, dtype(num) % modulo)
+
+This results in 4 test cases being created - for each of the two sets of test
+data (supplied as kwarg dicts) and for each of the two data types (supplied as
+a named parameter). Multiple keyword argument dicts may be supplied if required.
Async Support
===============================
@@ -464,15 +479,19 @@ def named_parameters(*testcases):
return _parameter_decorator(_NAMED, testcases)
-def product(**testgrid):
+def product(*kwargs_seqs, **testgrid):
"""A decorator for running tests over cartesian product of parameters values.
See the module docstring for a usage example. The test will be run for every
possible combination of the parameters.
Args:
- **testgrid: A mapping of parameter names and their possible values.
- Possible values should given as either a list or a tuple.
+ *kwargs_seqs: Each positional parameter is a sequence of keyword arg dicts;
+ every test case generated will include exactly one kwargs dict from each
+ positional parameter; these will then be merged to form an overall list
+ of arguments for the test case.
+ **testgrid: A mapping of parameter names and their possible values. Possible
+ values should given as either a list or a tuple.
Raises:
NoTestsError: Raised when the decorator generates no tests.
@@ -482,24 +501,53 @@ def product(**testgrid):
"""
for name, values in testgrid.items():
- assert isinstance(values, list) or isinstance(values, tuple), (
+ assert isinstance(values, (list, tuple)), (
'Values of {} must be given as list or tuple, found {}'.format(
name, type(values)))
+ prior_arg_names = set()
+ for kwargs_seq in kwargs_seqs:
+ assert ((isinstance(kwargs_seq, (list, tuple))) and
+ all(isinstance(kwargs, dict) for kwargs in kwargs_seq)), (
+ 'Positional parameters must be a sequence of keyword arg'
+ 'dicts, found {}'
+ .format(kwargs_seq))
+ if kwargs_seq:
+ arg_names = set(kwargs_seq[0])
+ assert all(set(kwargs) == arg_names for kwargs in kwargs_seq), (
+ 'Keyword argument dicts within a single parameter must all have the '
+ 'same keys, found {}'.format(kwargs_seq))
+ assert not (arg_names & prior_arg_names), (
+ 'Keyword argument dict sequences must all have distinct argument '
+ 'names, found duplicate(s) {}'
+ .format(sorted(arg_names & prior_arg_names)))
+ prior_arg_names |= arg_names
+
+ assert not (prior_arg_names & set(testgrid)), (
+ 'Arguments supplied in kwargs dicts in positional parameters must not '
+ 'overlap with arguments supplied as named parameters; found duplicate '
+ 'argument(s) {}'.format(sorted(prior_arg_names & set(testgrid))))
+
+ # Convert testgrid into a sequence of sequences of kwargs dicts and combine
+ # with the positional parameters.
+ # So foo=[1,2], bar=[3,4] --> [[{foo: 1}, {foo: 2}], [{bar: 3, bar: 4}]]
+ testgrid = (tuple({k: v} for v in vs) for k, vs in testgrid.items())
+ testgrid = tuple(kwargs_seqs) + tuple(testgrid)
+
# Create all possible combinations of parameters as a cartesian product
# of parameter values.
testcases = [
- dict(zip(testgrid.keys(), product))
- for product in itertools.product(*testgrid.values())
+ dict(itertools.chain.from_iterable(case.items()
+ for case in cases))
+ for cases in itertools.product(*testgrid)
]
-
return _parameter_decorator(_ARGUMENT_REPR, testcases)
class TestGeneratorMetaclass(type):
"""Metaclass for adding tests generated by parameterized decorators."""
- def __new__(mcs, class_name, bases, dct):
+ def __new__(cls, class_name, bases, dct):
# NOTE: _test_params_repr is private to parameterized.TestCase and it's
# metaclass; do not use it outside of those classes.
test_params_reprs = dct.setdefault('_test_params_reprs', {})
@@ -544,7 +592,7 @@ class TestGeneratorMetaclass(type):
# That's why it should only inherit it if it does not exist.
test_params_reprs.setdefault(test_method, test_method_id)
- return type.__new__(mcs, class_name, bases, dct)
+ return type.__new__(cls, class_name, bases, dct)
def _update_class_dict_for_param_test_case(
diff --git a/absl/testing/tests/parameterized_test.py b/absl/testing/tests/parameterized_test.py
index b099cd3..fb618ea 100644
--- a/absl/testing/tests/parameterized_test.py
+++ b/absl/testing/tests/parameterized_test.py
@@ -494,7 +494,7 @@ class ParameterizedTestsTest(absltest.TestCase):
],
short_descs)
- def test_successful_product_test(self):
+ def test_successful_product_test_testgrid(self):
class GoodProductTestCase(parameterized.TestCase):
@@ -513,6 +513,75 @@ class ParameterizedTestsTest(absltest.TestCase):
self.assertEqual(res.testsRun, 6)
self.assertTrue(res.wasSuccessful())
+ def test_successful_product_test_kwarg_seqs(self):
+
+ class GoodProductTestCase(parameterized.TestCase):
+
+ @parameterized.product((dict(num=0), dict(num=20), dict(num=0)),
+ (dict(modulo=2), dict(modulo=4)),
+ (dict(expected=0),))
+ def testModuloResult(self, num, modulo, expected):
+ self.assertEqual(expected, num % modulo)
+
+ ts = unittest.makeSuite(GoodProductTestCase)
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(ts.countTestCases(), 6)
+ self.assertEqual(res.testsRun, 6)
+ self.assertTrue(res.wasSuccessful())
+
+ def test_successful_product_test_kwarg_seq_and_testgrid(self):
+
+ class GoodProductTestCase(parameterized.TestCase):
+
+ @parameterized.product((dict(
+ num=5, modulo=3, expected=2), dict(num=7, modulo=4, expected=3)),
+ dtype=(int, float))
+ def testModuloResult(self, num, dtype, modulo, expected):
+ self.assertEqual(expected, dtype(num) % modulo)
+
+ ts = unittest.makeSuite(GoodProductTestCase)
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(ts.countTestCases(), 4)
+ self.assertEqual(res.testsRun, 4)
+ self.assertTrue(res.wasSuccessful())
+
+ def test_inconsistent_arg_names_in_kwargs_seq(self):
+ with self.assertRaisesRegex(AssertionError, 'must all have the same keys'):
+
+ class BadProductParams(parameterized.TestCase): # pylint: disable=unused-variable
+
+ @parameterized.product((dict(num=5, modulo=3), dict(num=7, modula=2)),
+ dtype=(int, float))
+ def test_something(self):
+ pass # not called because argnames are not the same
+
+ def test_duplicate_arg_names_in_kwargs_seqs(self):
+ with self.assertRaisesRegex(AssertionError, 'must all have distinct'):
+
+ class BadProductParams(parameterized.TestCase): # pylint: disable=unused-variable
+
+ @parameterized.product((dict(num=5, modulo=3), dict(num=7, modulo=4)),
+ (dict(foo='bar', num=5), dict(foo='baz', num=7)),
+ dtype=(int, float))
+ def test_something(self):
+ pass # not called because `num` is specified twice
+
+ def test_duplicate_arg_names_in_kwargs_seq_and_testgrid(self):
+ with self.assertRaisesRegex(AssertionError, 'duplicate argument'):
+
+ class BadProductParams(parameterized.TestCase): # pylint: disable=unused-variable
+
+ @parameterized.product(
+ (dict(num=5, modulo=3), dict(num=7, modulo=4)),
+ (dict(foo='bar'), dict(foo='baz')),
+ dtype=(int, float),
+ foo=('a', 'b'),
+ )
+ def test_something(self):
+ pass # not called because `foo` is specified twice
+
def test_product_recorded_failures(self):
class MixedProductTestCase(parameterized.TestCase):