diff options
Diffstat (limited to 'absl/testing/parameterized.py')
-rw-r--r-- | absl/testing/parameterized.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/absl/testing/parameterized.py b/absl/testing/parameterized.py index c96d842..4d02c4e 100644 --- a/absl/testing/parameterized.py +++ b/absl/testing/parameterized.py @@ -209,6 +209,7 @@ from __future__ import division from __future__ import print_function import functools +import inspect import itertools import re import types @@ -218,10 +219,6 @@ from absl._collections_abc import abc from absl.testing import absltest import six -try: - from absl.testing import _parameterized_async -except (ImportError, SyntaxError): - _parameterized_async = None _ADDR_RE = re.compile(r'\<([a-zA-Z0-9_\-\.]+) object at 0x[a-fA-F0-9]+\>') _NAMED = object() @@ -265,6 +262,13 @@ def _format_parameter_list(testcase_params): return _format_parameter_list((testcase_params,)) +def _async_wrapped(func): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + return await func(*args, **kwargs) + return wrapper + + class _ParameterizedTestIter(object): """Callable and iterable class for producing new test cases.""" @@ -369,9 +373,8 @@ class _ParameterizedTestIter(object): bound_param_test.__name__, _format_parameter_list(testcase_params)) if test_method.__doc__: bound_param_test.__doc__ += '\n%s' % (test_method.__doc__,) - if (_parameterized_async and - _parameterized_async.iscoroutinefunction(test_method)): - return _parameterized_async.async_wrapped(bound_param_test) + if inspect.iscoroutinefunction(test_method): + return _async_wrapped(bound_param_test) return bound_param_test return (make_bound_param_test(c) for c in self.testcases) |