aboutsummaryrefslogtreecommitdiff
path: root/absl/testing/parameterized.py
diff options
context:
space:
mode:
Diffstat (limited to 'absl/testing/parameterized.py')
-rw-r--r--absl/testing/parameterized.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/absl/testing/parameterized.py b/absl/testing/parameterized.py
index c96d842..4d02c4e 100644
--- a/absl/testing/parameterized.py
+++ b/absl/testing/parameterized.py
@@ -209,6 +209,7 @@ from __future__ import division
from __future__ import print_function
import functools
+import inspect
import itertools
import re
import types
@@ -218,10 +219,6 @@ from absl._collections_abc import abc
from absl.testing import absltest
import six
-try:
- from absl.testing import _parameterized_async
-except (ImportError, SyntaxError):
- _parameterized_async = None
_ADDR_RE = re.compile(r'\<([a-zA-Z0-9_\-\.]+) object at 0x[a-fA-F0-9]+\>')
_NAMED = object()
@@ -265,6 +262,13 @@ def _format_parameter_list(testcase_params):
return _format_parameter_list((testcase_params,))
+def _async_wrapped(func):
+ @functools.wraps(func)
+ async def wrapper(*args, **kwargs):
+ return await func(*args, **kwargs)
+ return wrapper
+
+
class _ParameterizedTestIter(object):
"""Callable and iterable class for producing new test cases."""
@@ -369,9 +373,8 @@ class _ParameterizedTestIter(object):
bound_param_test.__name__, _format_parameter_list(testcase_params))
if test_method.__doc__:
bound_param_test.__doc__ += '\n%s' % (test_method.__doc__,)
- if (_parameterized_async and
- _parameterized_async.iscoroutinefunction(test_method)):
- return _parameterized_async.async_wrapped(bound_param_test)
+ if inspect.iscoroutinefunction(test_method):
+ return _async_wrapped(bound_param_test)
return bound_param_test
return (make_bound_param_test(c) for c in self.testcases)