aboutsummaryrefslogtreecommitdiff
path: root/absl/testing
diff options
context:
space:
mode:
Diffstat (limited to 'absl/testing')
-rwxr-xr-xabsl/testing/parameterized.py115
-rwxr-xr-xabsl/testing/tests/parameterized_test.py134
-rwxr-xr-xabsl/testing/tests/xml_reporter_test.py5
3 files changed, 112 insertions, 142 deletions
diff --git a/absl/testing/parameterized.py b/absl/testing/parameterized.py
index 9824ed3..7d39fc7 100755
--- a/absl/testing/parameterized.py
+++ b/absl/testing/parameterized.py
@@ -45,9 +45,7 @@ or dictionaries (with named parameters):
self.assertEqual(result, op1 + op2)
If a parameterized test fails, the error message will show the
-original test name (which is modified internally) and the arguments
-for the specific invocation, which are part of the string returned by
-the shortDescription() method on test cases.
+original test name and the parameters for that test.
The id method of the test, used internally by the unittest framework, is also
modified to show the arguments (but note that the name reported by `id()`
@@ -179,7 +177,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
import functools
import re
import types
@@ -277,7 +274,6 @@ class _ParameterizedTestIter(object):
def __iter__(self):
test_method = self._test_method
naming_type = self._naming_type
- extra_ids = collections.defaultdict(int)
def make_bound_param_test(testcase_params):
@functools.wraps(test_method)
@@ -331,13 +327,9 @@ class _ParameterizedTestIter(object):
# _ARGUMENT_REPR tests using an indexed suffix.
# To keep test names descriptive, only the original method name is used.
# To make sure test names are unique, we add a unique descriptive suffix
- # __x_extra_id__ for every test.
- extra_id = '(%s)' % (_format_parameter_list(testcase_params),)
- extra_ids[extra_id] += 1
- while extra_ids[extra_id] > 1:
- extra_id = '%s (%d)' % (extra_id, extra_ids[extra_id])
- extra_ids[extra_id] += 1
- bound_param_test.__x_extra_id__ = extra_id
+ # __x_params_repr__ for every test.
+ params_repr = '(%s)' % (_format_parameter_list(testcase_params),)
+ bound_param_test.__x_params_repr__ = params_repr
else:
raise RuntimeError('%s is not a valid naming type.' % (naming_type,))
@@ -354,18 +346,20 @@ class _ParameterizedTestIter(object):
def _modify_class(class_object, testcases, naming_type):
- assert not getattr(class_object, '_test_method_ids', None), (
+ assert not getattr(class_object, '_test_params_reprs', None), (
'Cannot add parameters to %s. Either it already has parameterized '
'methods, or its super class is also a parameterized class.' % (
class_object,))
- class_object._test_method_ids = test_method_ids = {}
+ # NOTE: _test_params_repr is private to parameterized.TestCase and it's
+ # metaclass; do not use it outside of those classes.
+ class_object._test_params_reprs = test_params_reprs = {}
for name, obj in six.iteritems(class_object.__dict__.copy()):
if (name.startswith(unittest.TestLoader.testMethodPrefix)
and isinstance(obj, types.FunctionType)):
delattr(class_object, name)
methods = {}
_update_class_dict_for_param_test_case(
- class_object.__name__, methods, test_method_ids, name,
+ class_object.__name__, methods, test_params_reprs, name,
_ParameterizedTestIter(obj, testcases, naming_type, name))
for meth_name, meth in six.iteritems(methods):
setattr(class_object, meth_name, meth)
@@ -455,69 +449,81 @@ def named_parameters(*testcases):
class TestGeneratorMetaclass(type):
- """Metaclass for test cases with test generators.
-
- A test generator is an iterable in a testcase that produces callables. These
- callables must be single-argument methods. These methods are injected into
- the class namespace and the original iterable is removed. If the name of the
- iterable conforms to the test pattern, the injected methods will be picked
- up as tests by the unittest framework.
-
- In general, it is supposed to be used in conjuction with the
- parameters decorator.
- """
+ """Metaclass for adding tests generated by parameterized decorators."""
def __new__(mcs, class_name, bases, dct):
- test_method_ids = dct.setdefault('_test_method_ids', {})
+ # NOTE: _test_params_repr is private to parameterized.TestCase and it's
+ # metaclass; do not use it outside of those classes.
+ test_params_reprs = dct.setdefault('_test_params_reprs', {})
for name, obj in six.iteritems(dct.copy()):
if (name.startswith(unittest.TestLoader.testMethodPrefix) and
_non_string_or_bytes_iterable(obj)):
+ # NOTE: `obj` might not be a _ParameterizedTestIter in two cases:
+ # 1. a class-level iterable named test* that isn't a test, such as
+ # a list of something. Such attributes get deleted from the class.
+ #
+ # 2. If a decorator is applied to the parameterized test, e.g.
+ # @morestuff
+ # @parameterized.parameters(...)
+ # def test_foo(...): ...
+ #
+ # This is OK so long as the underlying parameterized function state
+ # is forwarded (e.g. using functool.wraps() and **without**
+ # accessing explicitly accessing the internal attributes.
if isinstance(obj, _ParameterizedTestIter):
# Update the original test method name so it's more accurate.
# The mismatch might happen when another decorator is used inside
# the parameterized decrators, and the inner decorator doesn't
# preserve its __name__.
- # `obj` might be a generator, not _ParameterizedTestIter.
obj._original_name = name
iterator = iter(obj)
dct.pop(name)
_update_class_dict_for_param_test_case(
- class_name, dct, test_method_ids, name, iterator)
+ class_name, dct, test_params_reprs, name, iterator)
# If the base class is a subclass of parameterized.TestCase, inherit its
- # _test_method_ids too.
+ # _test_params_reprs too.
for base in bases:
- # Check if the base has _test_method_ids first, then check if it's a
+ # Check if the base has _test_params_reprs first, then check if it's a
# subclass of parameterized.TestCase. Otherwise when this is called for
# the parameterized.TestCase definition itself, this raises because
# itself is not defined yet. This works as long as absltest.TestCase does
- # not define _test_method_ids.
- if getattr(base, '_test_method_ids', None) and issubclass(base, TestCase):
- for test_method, test_method_id in six.iteritems(base._test_method_ids):
+ # not define _test_params_reprs.
+ base_test_params_reprs = getattr(base, '_test_params_reprs', None)
+ if base_test_params_reprs and issubclass(base, TestCase):
+ for test_method, test_method_id in base_test_params_reprs.items():
# test_method may both exists in base and this class.
# This class's method overrides base class's.
# That's why it should only inherit it if it does not exist.
- test_method_ids.setdefault(test_method, test_method_id)
+ test_params_reprs.setdefault(test_method, test_method_id)
return type.__new__(mcs, class_name, bases, dct)
def _update_class_dict_for_param_test_case(
- test_class_name, dct, test_method_ids, name, iterator):
+ test_class_name, dct, test_params_reprs, name, iterator):
"""Adds individual test cases to a dictionary.
Args:
test_class_name: The name of the class tests are added to.
dct: The target dictionary.
- test_method_ids: The dictionary for mapping names to test IDs.
+ test_params_reprs: The dictionary for mapping names to test IDs.
name: The original name of the test case.
iterator: The iterator generating the individual test cases.
Raises:
DuplicateTestNameError: Raised when a test name occurs multiple times.
+ RuntimeError: If non-parameterized functions are generated.
"""
for idx, func in enumerate(iterator):
assert callable(func), 'Test generators must yield callables, got %r' % (
func,)
+ if not (getattr(func, '__x_use_name__', None) or
+ getattr(func, '__x_params_repr__', None)):
+ raise RuntimeError(
+ '{}.{} generated a test function without using the parameterized '
+ 'decorators. Only tests generated using the decorators are '
+ 'supported.'.format(test_class_name, name))
+
if getattr(func, '__x_use_name__', False):
original_name = func.__name__
new_name = original_name
@@ -529,24 +535,22 @@ def _update_class_dict_for_param_test_case(
raise DuplicateTestNameError(test_class_name, new_name, original_name)
dct[new_name] = func
- test_method_id = original_name + getattr(func, '__x_extra_id__', '')
- assert test_method_id not in test_method_ids.values(), (
- 'Id of parameterized test case "%s" not unique' % (test_method_id,))
- test_method_ids[new_name] = test_method_id
+ test_params_reprs[new_name] = getattr(func, '__x_params_repr__', '')
class TestCase(six.with_metaclass(TestGeneratorMetaclass, absltest.TestCase)):
"""Base class for test cases using the parameters decorator."""
- def shortDescription(self):
- base = super(TestCase, self).shortDescription().split('\n')
- # Replace the test id line with ours that has the params info
- base[0] = str(self)
- return '\n'.join(base)
+ # visibility: private; do not call outside this class.
+ def _get_params_repr(self):
+ return self._test_params_reprs.get(self._testMethodName, '')
def __str__(self):
- return '%s (%s)' % (
- self._test_method_ids.get(self._testMethodName, self._testMethodName),
+ params_repr = self._get_params_repr()
+ if params_repr:
+ params_repr = ' ' + params_repr
+ return '{}{} ({})'.format(
+ self._testMethodName, params_repr,
unittest.util.strclass(self.__class__))
def id(self):
@@ -558,11 +562,16 @@ class TestCase(six.with_metaclass(TestGeneratorMetaclass, absltest.TestCase)):
Returns:
The test id.
"""
- return '%s.%s' % (
- unittest.util.strclass(self.__class__),
- # When a test method is NOT decorated, it doesn't exist in
- # _test_method_ids. Use the _testMethodName directly.
- self._test_method_ids.get(self._testMethodName, self._testMethodName))
+ base = super(TestCase, self).id()
+ params_repr = self._get_params_repr()
+ if params_repr:
+ # We include the params in the id so that, when reported in the
+ # test.xml file, the value is more informative than just "test_foo0".
+ # Use a space to separate them so that it's copy/paste friendly and
+ # easy to identify the actual test id.
+ return '{} {}'.format(base, params_repr)
+ else:
+ return base
# This function is kept CamelCase because it's used as a class's base class.
diff --git a/absl/testing/tests/parameterized_test.py b/absl/testing/tests/parameterized_test.py
index 85f8e9b..5f2dcc0 100755
--- a/absl/testing/tests/parameterized_test.py
+++ b/absl/testing/tests/parameterized_test.py
@@ -336,23 +336,10 @@ class ParameterizedTestsTest(absltest.TestCase):
class UniqueDescriptiveNamesTest(parameterized.TestCase):
- class JustBeingMean(object):
-
- def __repr__(self):
- return '13) (2'
-
@parameterized.parameters(13, 13)
def test_normal(self, number):
del number
- @parameterized.parameters(13, 13, JustBeingMean())
- def test_double_conflict(self, number):
- del number
-
- @parameterized.parameters(13, JustBeingMean(), 13, 13)
- def test_triple_conflict(self, number):
- del number
-
class MultiGeneratorsTestCase(parameterized.TestCase):
@parameterized.parameters((i for i in (1, 2, 3)), (i for i in (3, 2, 1)))
@@ -442,11 +429,13 @@ class ParameterizedTestsTest(absltest.TestCase):
def test_short_description(self):
ts = unittest.makeSuite(self.GoodAdditionParams)
- short_desc = list(ts)[0].shortDescription().split('\n')
- self.assertEqual(
- 'test_addition(1, 2, 3)', short_desc[1])
- self.assertTrue(
- short_desc[0].startswith('test_addition(1, 2, 3)'))
+ short_desc = list(ts)[0].shortDescription()
+
+ location = unittest.util.strclass(self.GoodAdditionParams).replace(
+ '__main__.', '')
+ expected = ('{}.test_addition0 (1, 2, 3)\n'.format(location) +
+ 'test_addition(1, 2, 3)')
+ self.assertEqual(expected, short_desc)
def test_short_description_addresses_removed(self):
ts = unittest.makeSuite(self.ArgumentsWithAddresses)
@@ -461,14 +450,22 @@ class ParameterizedTestsTest(absltest.TestCase):
ts = unittest.makeSuite(self.ArgumentsWithAddresses)
self.assertEqual(
(unittest.util.strclass(self.ArgumentsWithAddresses) +
- '.test_something(<object>)'),
+ '.test_something0 (<object>)'),
list(ts)[0].id())
ts = unittest.makeSuite(self.GoodAdditionParams)
self.assertEqual(
(unittest.util.strclass(self.GoodAdditionParams) +
- '.test_addition(1, 2, 3)'),
+ '.test_addition0 (1, 2, 3)'),
list(ts)[0].id())
+ def test_str(self):
+ ts = unittest.makeSuite(self.GoodAdditionParams)
+ test = list(ts)[0]
+
+ expected = 'test_addition0 (1, 2, 3) ({})'.format(
+ unittest.util.strclass(self.GoodAdditionParams))
+ self.assertEqual(expected, str(test))
+
def test_dict_parameters(self):
ts = unittest.makeSuite(self.DictionaryArguments)
res = unittest.TestResult()
@@ -481,24 +478,21 @@ class ParameterizedTestsTest(absltest.TestCase):
self.assertEqual(4, ts.countTestCases())
short_descs = [x.shortDescription() for x in list(ts)]
full_class_name = unittest.util.strclass(self.NoParameterizedTests)
+ full_class_name = full_class_name.replace('__main__.', '')
self.assertSameElements(
[
- 'testGenerator (%s)' % (full_class_name,),
- 'test_generator (%s)' % (full_class_name,),
- 'testNormal (%s)' % (full_class_name,),
- 'test_normal (%s)' % (full_class_name,),
+ '{}.testGenerator'.format(full_class_name),
+ '{}.test_generator'.format(full_class_name),
+ '{}.testNormal'.format(full_class_name),
+ '{}.test_normal'.format(full_class_name),
],
short_descs)
- def test_generator_tests(self):
- with self.assertRaises(AssertionError):
-
- # This fails because the generated test methods share the same test id.
- class GeneratorTests(parameterized.TestCase):
+ def test_generator_tests_disallowed(self):
+ with self.assertRaisesRegex(RuntimeError, 'generated.*without'):
+ class GeneratorTests(parameterized.TestCase): # pylint: disable=unused-variable
test_generator_method = (lambda self: None for _ in range(10))
- del GeneratorTests
-
def test_named_parameters_run(self):
ts = unittest.makeSuite(self.NamedTests)
self.assertEqual(9, ts.countTestCases())
@@ -576,47 +570,20 @@ class ParameterizedTestsTest(absltest.TestCase):
def test_named_parameters_short_description(self):
ts = sorted(unittest.makeSuite(self.NamedTests),
key=lambda t: t.id())
- short_desc = ts[0].shortDescription().split('\n')
- self.assertEqual(
- 'test_dict_single_interesting(case=0)', short_desc[1])
- self.assertTrue(
- short_desc[0].startswith('test_dict_single_interesting'))
-
- short_desc = ts[1].shortDescription().split('\n')
- self.assertEqual(
- 'test_dict_something_boring(case=1)', short_desc[1])
- self.assertTrue(
- short_desc[0].startswith('test_dict_something_boring'))
-
- short_desc = ts[2].shortDescription().split('\n')
- self.assertEqual(
- 'test_dict_something_interesting(case=0)', short_desc[1])
- self.assertTrue(
- short_desc[0].startswith('test_dict_something_interesting'))
-
- short_desc = ts[3].shortDescription().split('\n')
- self.assertEqual(
- 'test_mixed_something_boring(1)', short_desc[1])
- self.assertTrue(
- short_desc[0].startswith('test_mixed_something_boring'))
-
- short_desc = ts[4].shortDescription().split('\n')
- self.assertEqual(
- 'test_mixed_something_interesting(case=0)', short_desc[1])
- self.assertTrue(
- short_desc[0].startswith('test_mixed_something_interesting'))
-
- short_desc = ts[6].shortDescription().split('\n')
- self.assertEqual(
- 'test_something_boring(1)', short_desc[1])
- self.assertTrue(
- short_desc[0].startswith('test_something_boring'))
-
- short_desc = ts[7].shortDescription().split('\n')
- self.assertEqual(
- 'test_something_interesting(0)', short_desc[1])
- self.assertTrue(
- short_desc[0].startswith('test_something_interesting'))
+ actual = {t._testMethodName: t.shortDescription() for t in ts}
+ expected = {
+ 'test_dict_single_interesting': 'case=0',
+ 'test_dict_something_boring': 'case=1',
+ 'test_dict_something_interesting': 'case=0',
+ 'test_mixed_something_boring': '1',
+ 'test_mixed_something_interesting': 'case=0',
+ 'test_something_boring': '1',
+ 'test_something_interesting': '0',
+ }
+ for test_name, param_repr in expected.items():
+ short_desc = actual[test_name].split('\n')
+ self.assertIn(test_name, short_desc[0])
+ self.assertEqual('{}({})'.format(test_name, param_repr), short_desc[1])
def test_load_tuple_named_test(self):
loader = unittest.TestLoader()
@@ -861,19 +828,12 @@ class ParameterizedTestsTest(absltest.TestCase):
res = RecordSuccessTestsResult()
ts.run(res)
self.assertTrue(res.wasSuccessful())
- self.assertEqual(9, res.testsRun)
+ self.assertEqual(2, res.testsRun)
test_ids = [test.id() for test in res.successful_tests]
full_class_name = unittest.util.strclass(self.UniqueDescriptiveNamesTest)
expected_test_ids = [
- full_class_name + '.test_normal(13)',
- full_class_name + '.test_normal(13) (2)',
- full_class_name + '.test_double_conflict(13)',
- full_class_name + '.test_double_conflict(13) (2)',
- full_class_name + '.test_double_conflict(13) (2) (2)',
- full_class_name + '.test_triple_conflict(13)',
- full_class_name + '.test_triple_conflict(13) (2)',
- full_class_name + '.test_triple_conflict(13) (2) (2)',
- full_class_name + '.test_triple_conflict(13) (3)',
+ full_class_name + '.test_normal0 (13)',
+ full_class_name + '.test_normal1 (13)',
]
self.assertTrue(test_ids)
self.assertItemsEqual(expected_test_ids, test_ids)
@@ -892,13 +852,13 @@ class ParameterizedTestsTest(absltest.TestCase):
self.assertEqual(8, res.testsRun)
self.assertTrue(res.wasSuccessful(), msg=str(res.failures))
- def test_subclass_inherits_superclass_test_method_ids(self):
+ def test_subclass_inherits_superclass_test_params_reprs(self):
self.assertEqual(
- {'test_name0': "test_name('foo')", 'test_name1': "test_name('bar')"},
- self.SuperclassTestCase._test_method_ids)
+ {'test_name0': "('foo')", 'test_name1': "('bar')"},
+ self.SuperclassTestCase._test_params_reprs)
self.assertEqual(
- {'test_name0': "test_name('foo')", 'test_name1': "test_name('bar')"},
- self.SubclassTestCase._test_method_ids)
+ {'test_name0': "('foo')", 'test_name1': "('bar')"},
+ self.SubclassTestCase._test_params_reprs)
def _decorate_with_side_effects(func, self):
diff --git a/absl/testing/tests/xml_reporter_test.py b/absl/testing/tests/xml_reporter_test.py
index f4a6c2d..a7c3f9d 100755
--- a/absl/testing/tests/xml_reporter_test.py
+++ b/absl/testing/tests/xml_reporter_test.py
@@ -135,7 +135,8 @@ class TextAndXMLTestResultTest(absltest.TestCase):
'foo', 0, timer)
def _assert_match(self, regex, output):
- self.assertRegex(output, regex)
+ fail_msg = 'Expected regex:\n{}\nTo match:\n{}'.format(regex, output)
+ self.assertRegex(output, regex, fail_msg)
def _assert_valid_xml(self, xml_output):
try:
@@ -786,7 +787,7 @@ class TextAndXMLTestResultTest(absltest.TestCase):
'errors': 0,
'run_time': run_time,
'start_time': re.escape(self._iso_timestamp(start_time),),
- 'test_name': re.escape('test_prefix(&apos;a&#x20;(b.c)&apos;)'),
+ 'test_name': re.escape('test_prefix0&#x20;(&apos;a&#x20;(b.c)&apos;)'),
'classname': classname,
'status': 'run',
'result': 'completed',