aboutsummaryrefslogtreecommitdiff
path: root/absl
diff options
context:
space:
mode:
authorStephen Thorne <sthorne@google.com>2021-04-12 03:07:44 -0700
committerCopybara-Service <copybara-worker@google.com>2021-04-12 03:08:07 -0700
commitec7892ca08da797885db0982e79b9cceffa7684f (patch)
treef0690524abebaeefabe832e52dff2546e730fd6c /absl
parent1c62d88d4fb1ad903f298190dfb52f364d546ec2 (diff)
downloadabsl-py-ec7892ca08da797885db0982e79b9cceffa7684f.tar.gz
Add `@absltest.skipThisClass` to skip specific classes during testing.
This decorator marks a test in a way that it will be skipped, but none of its subclasses are. Suggested usage is for where you want to share functionality between tests, by having an 'abstract' base class: ``` @absltest.skipThisClass class _BaseTestCase(absltest.TestCase): def test_foo(self): self.assertEqual(self.object_under_test.method() class FooTest(_BaseTestCase): def setUp(self): self.object_under_test = Foo() class BarTest(_BaseTestCase): def setUp(self): self.object_under_test = Bar() ``` There are alternatives, but they have drawbacks: * Having `_BaseTestCase` subclass object, and `FooTest` multiple-inherit from both `absltest.TestCase` and `_BaseTestCase`. However, this ends up being problematic for type checking. * Repeating the same logic in `absltest.skipThisClass` within `setUpClass` for every class to skip. However, that is repetitive logic that is best put into a utility function. While `skipThisClass` is similar to `@unittest.skip`, it has an important distinction: regular `skip` will skip the decorated class and all subclasses; `skipThisClass` only skips the decorated class, allowing base classes to be correctly skipped while sub-classes are run as tests. PiperOrigin-RevId: 367965048 Change-Id: Ie050c43c7f2e5dbc5af731171259c27084d1ba12
Diffstat (limited to 'absl')
-rw-r--r--absl/CHANGELOG.md3
-rw-r--r--absl/testing/absltest.py75
-rw-r--r--absl/testing/tests/absltest_test.py196
3 files changed, 274 insertions, 0 deletions
diff --git a/absl/CHANGELOG.md b/absl/CHANGELOG.md
index 7a1de2f..b5cfd28 100644
--- a/absl/CHANGELOG.md
+++ b/absl/CHANGELOG.md
@@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com).
### Added
* (app) Type annotations for public `app` interfaces.
+* (testing) Added new decorator `@absltest.skipThisClass` to indicate a class
+ contains shared functionality to be used as a base class for other
+ TestCases, and therefore should be skipped.
### Changed
diff --git a/absl/testing/absltest.py b/absl/testing/absltest.py
index fbce512..3bf77d9 100644
--- a/absl/testing/absltest.py
+++ b/absl/testing/absltest.py
@@ -2142,6 +2142,81 @@ def _is_suspicious_attribute(testCaseClass, name):
return False
+def skipThisClass(reason):
+ # type: (Text) -> Callable[[_T], _T]
+ """Skip tests in the decorated TestCase, but not any of its subclasses.
+
+ This decorator indicates that this class should skip all its tests, but not
+ any of its subclasses. Useful for if you want to share testMethod or setUp
+ implementations between a number of concrete testcase classes.
+
+ Example usage, showing how you can share some common test methods between
+ subclasses. In this example, only 'BaseTest' will be marked as skipped, and
+ not RealTest or SecondRealTest:
+
+ @absltest.skipThisClass("Shared functionality")
+ class BaseTest(absltest.TestCase):
+ def test_simple_functionality(self):
+ self.assertEqual(self.system_under_test.method(), 1)
+
+ class RealTest(BaseTest):
+ def setUp(self):
+ super().setUp()
+ self.system_under_test = MakeSystem(argument)
+
+ def test_specific_behavior(self):
+ ...
+
+ class SecondRealTest(BaseTest):
+ def setUp(self):
+ super().setUp()
+ self.system_under_test = MakeSystem(other_arguments)
+
+ def test_other_behavior(self):
+ ...
+
+ Args:
+ reason: The reason we have a skip in place. For instance: 'shared test
+ methods' or 'shared assertion methods'.
+
+ Returns:
+ Decorator function that will cause a class to be skipped.
+ """
+ if isinstance(reason, type):
+ raise TypeError('Got {!r}, expected reason as string'.format(reason))
+
+ def _skip_class(test_case_class):
+ if not issubclass(test_case_class, unittest.TestCase):
+ raise TypeError(
+ 'Decorating {!r}, expected TestCase subclass'.format(test_case_class))
+
+ # Only shadow the setUpClass method if it is directly defined. If it is
+ # in the parent class we invoke it via a super() call instead of holding
+ # a reference to it.
+ shadowed_setupclass = test_case_class.__dict__.get('setUpClass', None)
+
+ @classmethod
+ def replacement_setupclass(cls, *args, **kwargs):
+ # Skip this class if it is the one that was decorated with @skipThisClass
+ if cls is test_case_class:
+ raise SkipTest(reason)
+ if shadowed_setupclass:
+ # Pass along `cls` so the MRO chain doesn't break.
+ # The original method is a `classmethod` descriptor, which can't
+ # be directly called, but `__func__` has the underlying function.
+ return shadowed_setupclass.__func__(cls, *args, **kwargs)
+ else:
+ # Because there's no setUpClass() defined directly on test_case_class,
+ # we call super() ourselves to continue execution of the inheritance
+ # chain.
+ return super(test_case_class, cls).setUpClass(*args, **kwargs)
+
+ test_case_class.setUpClass = replacement_setupclass
+ return test_case_class
+
+ return _skip_class
+
+
class TestLoader(unittest.TestLoader):
"""A test loader which supports common test features.
diff --git a/absl/testing/tests/absltest_test.py b/absl/testing/tests/absltest_test.py
index f7abd39..0ac3009 100644
--- a/absl/testing/tests/absltest_test.py
+++ b/absl/testing/tests/absltest_test.py
@@ -28,6 +28,7 @@ import string
import subprocess
import sys
import tempfile
+import unittest
from absl.testing import _bazelize_command
from absl.testing import absltest
@@ -2191,6 +2192,201 @@ class TempFileTest(absltest.TestCase, HelperMixin):
self.run_tempfile_helper('OFF', expected)
+class SkipClassTest(absltest.TestCase):
+
+ def test_incorrect_decorator_call(self):
+ with self.assertRaises(TypeError):
+
+ @absltest.skipThisClass # pylint: disable=unused-variable
+ class Test(absltest.TestCase):
+ pass
+
+ def test_incorrect_decorator_subclass(self):
+ with self.assertRaises(TypeError):
+
+ @absltest.skipThisClass('reason')
+ def test_method(): # pylint: disable=unused-variable
+ pass
+
+ def test_correct_decorator_class(self):
+
+ @absltest.skipThisClass('reason')
+ class Test(absltest.TestCase):
+ pass
+
+ with self.assertRaises(absltest.SkipTest):
+ Test.setUpClass()
+
+ def test_correct_decorator_subclass(self):
+
+ @absltest.skipThisClass('reason')
+ class Test(absltest.TestCase):
+ pass
+
+ class Subclass(Test):
+ pass
+
+ with self.subTest('Base class should be skipped'):
+ with self.assertRaises(absltest.SkipTest):
+ Test.setUpClass()
+
+ with self.subTest('Subclass should not be skipped'):
+ Subclass.setUpClass() # should not raise.
+
+ def test_setup(self):
+
+ @absltest.skipThisClass('reason')
+ class Test(absltest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ super(Test, cls).setUpClass()
+ cls.foo = 1
+
+ class Subclass(Test):
+ pass
+
+ Subclass.setUpClass()
+ self.assertEqual(Subclass.foo, 1)
+
+ def test_setup_chain(self):
+
+ @absltest.skipThisClass('reason')
+ class BaseTest(absltest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ super(BaseTest, cls).setUpClass()
+ cls.foo = 1
+
+ @absltest.skipThisClass('reason')
+ class SecondBaseTest(BaseTest):
+
+ @classmethod
+ def setUpClass(cls):
+ super(SecondBaseTest, cls).setUpClass()
+ cls.bar = 2
+
+ class Subclass(SecondBaseTest):
+ pass
+
+ Subclass.setUpClass()
+ self.assertEqual(Subclass.foo, 1)
+ self.assertEqual(Subclass.bar, 2)
+
+ def test_setup_args(self):
+
+ @absltest.skipThisClass('reason')
+ class Test(absltest.TestCase):
+
+ @classmethod
+ def setUpClass(cls, foo, bar=None):
+ super(Test, cls).setUpClass()
+ cls.foo = foo
+ cls.bar = bar
+
+ class Subclass(Test):
+
+ @classmethod
+ def setUpClass(cls):
+ super(Subclass, cls).setUpClass('foo', bar='baz')
+
+ Subclass.setUpClass()
+ self.assertEqual(Subclass.foo, 'foo')
+ self.assertEqual(Subclass.bar, 'baz')
+
+ def test_setup_multiple_inheritance(self):
+
+ # Test that skipping this class doesn't break the MRO chain and stop
+ # RequiredBase.setUpClass from running.
+ @absltest.skipThisClass('reason')
+ class Left(absltest.TestCase):
+ pass
+
+ class RequiredBase(absltest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ super(RequiredBase, cls).setUpClass()
+ cls.foo = 'foo'
+
+ class Right(RequiredBase):
+
+ @classmethod
+ def setUpClass(cls):
+ super(Right, cls).setUpClass()
+
+ # Test will fail unless Left.setUpClass() follows mro properly
+ # Right.setUpClass()
+ class Subclass(Left, Right):
+
+ @classmethod
+ def setUpClass(cls):
+ super(Subclass, cls).setUpClass()
+
+ class Test(Subclass):
+ pass
+
+ Test.setUpClass()
+ self.assertEqual(Test.foo, 'foo')
+
+ def test_skip_class(self):
+
+ @absltest.skipThisClass('reason')
+ class BaseTest(absltest.TestCase):
+
+ def test_foo(self):
+ _ = 1 / 0
+
+ class Test(BaseTest):
+
+ def test_foo(self):
+ self.assertEqual(1, 1)
+
+ with self.subTest('base class'):
+ ts = unittest.makeSuite(BaseTest)
+ self.assertEqual(1, ts.countTestCases())
+
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertTrue(res.wasSuccessful())
+ self.assertLen(res.skipped, 1)
+ self.assertEqual(0, res.testsRun)
+ self.assertEmpty(res.failures)
+ self.assertEmpty(res.errors)
+
+ with self.subTest('real test'):
+ ts = unittest.makeSuite(Test)
+ self.assertEqual(1, ts.countTestCases())
+
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertTrue(res.wasSuccessful())
+ self.assertEqual(1, res.testsRun)
+ self.assertEmpty(res.skipped)
+ self.assertEmpty(res.failures)
+ self.assertEmpty(res.errors)
+
+ def test_skip_class_unittest(self):
+
+ @absltest.skipThisClass('reason')
+ class Test(unittest.TestCase): # note: unittest not absltest
+
+ def test_foo(self):
+ _ = 1 / 0
+
+ ts = unittest.makeSuite(Test)
+ self.assertEqual(1, ts.countTestCases())
+
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertTrue(res.wasSuccessful())
+ self.assertLen(res.skipped, 1)
+ self.assertEqual(0, res.testsRun)
+ self.assertEmpty(res.failures)
+ self.assertEmpty(res.errors)
+
+
def _listdir_recursive(path):
for dirname, _, filenames in os.walk(path):
yield dirname