diff options
author | Abseil Team <absl-team@google.com> | 2020-12-14 13:02:02 -0800 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2020-12-14 13:02:30 -0800 |
commit | 94670de3dcf271f95d000cd8cdad754014a4cb5d (patch) | |
tree | 05d8ab4ffa1a70fdd88ce16a60f0ffb15fb0a5f4 /absl/testing | |
parent | d61b0b6bda1902f645e5bbbc3f138c142767befa (diff) | |
download | absl-py-94670de3dcf271f95d000cd8cdad754014a4cb5d.tar.gz |
Support using flagholder in flagsaver.
`(HOLDER, value)` pairs can now be specified in positional arguments.
It is equivalent to specifying `**{HOLDER.name: value}`
We can mix and match holder and non-holder overrides. So usages like
`flagsaver((HOLDER1, value1), (HOLDER2, value2), flag_name=value)`
are legal as well.
PiperOrigin-RevId: 347450631
Change-Id: I45bdf7bd56ad1d65d62ff34536ef09e47fce7ae8
Diffstat (limited to 'absl/testing')
-rw-r--r-- | absl/testing/flagsaver.py | 23 | ||||
-rw-r--r-- | absl/testing/tests/flagsaver_test.py | 63 |
2 files changed, 77 insertions, 9 deletions
diff --git a/absl/testing/flagsaver.py b/absl/testing/flagsaver.py index c33d56a..7fe95fe 100644 --- a/absl/testing/flagsaver.py +++ b/absl/testing/flagsaver.py @@ -27,6 +27,11 @@ is temporarily set to 'foo'. def some_func(): do_stuff() + # Use a decorator which can optionally override flags with flagholders. + @flagsaver.flagsaver((module.FOO_FLAG, 'foo'), (other_mod.BAR_FLAG, 23)) + def some_func(): + do_stuff() + # Use a decorator which does not override flags itself. @flagsaver.flagsaver def some_func(): @@ -70,7 +75,8 @@ def flagsaver(*args, **kwargs): """The main flagsaver interface. See module doc for usage.""" if not args: return _FlagOverrider(**kwargs) - elif len(args) == 1: + # args can be [func] if used as `@flagsaver` instead of `@flagsaver(...)` + if len(args) == 1 and callable(args[0]): if kwargs: raise ValueError( "It's invalid to specify both positional and keyword parameters.") @@ -78,9 +84,18 @@ def flagsaver(*args, **kwargs): if inspect.isclass(func): raise TypeError('@flagsaver.flagsaver cannot be applied to a class.') return _wrap(func, {}) - else: - raise ValueError( - "It's invalid to specify more than one positional parameters.") + # args can be a list of (FlagHolder, value) pairs. + # In which case they augment any specified kwargs. + for arg in args: + if not isinstance(arg, tuple) or len(arg) != 2: + raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,)) + holder, value = arg + if not isinstance(holder, flags.FlagHolder): + raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,)) + if holder.name in kwargs: + raise ValueError('Cannot set --%s multiple times' % holder.name) + kwargs[holder.name] = value + return _FlagOverrider(**kwargs) def save_flag_values(flag_values=FLAGS): diff --git a/absl/testing/tests/flagsaver_test.py b/absl/testing/tests/flagsaver_test.py index ed428df..3439a32 100644 --- a/absl/testing/tests/flagsaver_test.py +++ b/absl/testing/tests/flagsaver_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for flagsaver.""" from __future__ import absolute_import @@ -31,6 +30,11 @@ flags.register_validator('flagsaver_test_validated_flag', lambda x: not x) flags.DEFINE_string('flagsaver_test_validated_flag1', None, 'flag to test with') flags.DEFINE_string('flagsaver_test_validated_flag2', None, 'flag to test with') +INT_FLAG = flags.DEFINE_integer( + 'flagsaver_test_int_flag', default=1, help='help') +STR_FLAG = flags.DEFINE_string( + 'flagsaver_test_str_flag', default='str default', help='help') + @flags.multi_flags_validator( ('flagsaver_test_validated_flag1', 'flagsaver_test_validated_flag2')) @@ -65,6 +69,24 @@ class FlagSaverTest(absltest.TestCase): self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1) + def test_context_manager_with_flagholders(self): + with flagsaver.flagsaver((INT_FLAG, 3), (STR_FLAG, 'new value')): + self.assertEqual('new value', STR_FLAG.value) + self.assertEqual(3, INT_FLAG.value) + FLAGS.flagsaver_test_flag1 = 'another value' + self.assertEqual(INT_FLAG.value, INT_FLAG.default) + self.assertEqual(STR_FLAG.value, STR_FLAG.default) + self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1) + + def test_context_manager_with_overrides_and_flagholders(self): + with flagsaver.flagsaver((INT_FLAG, 3), flagsaver_test_flag0='new value'): + self.assertEqual(STR_FLAG.default, STR_FLAG.value) + self.assertEqual(3, INT_FLAG.value) + FLAGS.flagsaver_test_flag0 = 'new value' + self.assertEqual(INT_FLAG.value, INT_FLAG.default) + self.assertEqual(STR_FLAG.value, STR_FLAG.default) + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + def test_context_manager_with_cross_validated_overrides_set_together(self): # When the flags are set in the same flagsaver call their validators will # be triggered only once the setting is done. @@ -135,8 +157,7 @@ class FlagSaverTest(absltest.TestCase): # mutate_flags returns the flag value before it gets restored by # the flagsaver decorator. So we check that flag value was # actually changed in the method's scope. - self.assertEqual('new value', - mutate_flags('new value')) + self.assertEqual('new value', mutate_flags('new value')) # But... notice that the flag is now unchanged0. self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) @@ -288,8 +309,8 @@ class FlagSaverTest(absltest.TestCase): self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 2) modify_validators() - self.assertEqual( - original_validators, FLAGS['flagsaver_test_flag0'].validators) + self.assertEqual(original_validators, + FLAGS['flagsaver_test_flag0'].validators) class FlagSaverDecoratorUsageTest(absltest.TestCase): @@ -409,6 +430,38 @@ class FlagSaverBadUsageTest(absltest.TestCase): func_a = lambda: None flagsaver.flagsaver(func_a, flagsaver_test_flag0='new value') + def test_duplicate_holder_parameters(self): + with self.assertRaises(ValueError): + flagsaver.flagsaver((INT_FLAG, 45), (INT_FLAG, 45)) + + def test_duplicate_holder_and_kw_parameter(self): + with self.assertRaises(ValueError): + flagsaver.flagsaver((INT_FLAG, 45), **{INT_FLAG.name: 45}) + + def test_both_positional_and_holder_parameters(self): + with self.assertRaises(ValueError): + func_a = lambda: None + flagsaver.flagsaver(func_a, (INT_FLAG, 45)) + + def test_holder_parameters_wrong_shape(self): + with self.assertRaises(ValueError): + flagsaver.flagsaver(INT_FLAG) + + def test_holder_parameters_tuple_too_long(self): + with self.assertRaises(ValueError): + # Even if it is a bool flag, it should be a tuple + flagsaver.flagsaver((INT_FLAG, 4, 5)) + + def test_holder_parameters_tuple_wrong_type(self): + with self.assertRaises(ValueError): + # Even if it is a bool flag, it should be a tuple + flagsaver.flagsaver((4, INT_FLAG)) + + def test_both_wrong_positional_parameters(self): + with self.assertRaises(ValueError): + func_a = lambda: None + flagsaver.flagsaver(func_a, STR_FLAG, '45') + if __name__ == '__main__': absltest.main() |