diff options
author | Karol M. Langner <langner@google.com> | 2020-09-30 10:16:10 -0700 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2020-09-30 10:16:31 -0700 |
commit | 9a0552c6743d387df6ab565f8ccb9878dd14ece4 (patch) | |
tree | c76a5e4fab205b0d04f3073c2836e8d0dcc93bec /absl | |
parent | dab1732ed9cfe85f97e940a24a4c3a5d52a530f7 (diff) | |
download | absl-py-9a0552c6743d387df6ab565f8ccb9878dd14ece4.tar.gz |
In flagsaver, set multiple flags together before their validators run.
This resolves an issue where multi-flag validators rely on specific flag combinations.
PiperOrigin-RevId: 334625442
Change-Id: I7e6b625637a70356df57a4d9cbb01c203f14df4c
Diffstat (limited to 'absl')
-rw-r--r-- | absl/CHANGELOG.md | 3 | ||||
-rw-r--r-- | absl/flags/_flagvalues.py | 25 | ||||
-rw-r--r-- | absl/flags/tests/_flagvalues_test.py | 55 | ||||
-rw-r--r-- | absl/testing/BUILD | 1 | ||||
-rwxr-xr-x | absl/testing/flagsaver.py | 4 | ||||
-rwxr-xr-x | absl/testing/tests/flagsaver_test.py | 116 |
6 files changed, 166 insertions, 38 deletions
diff --git a/absl/CHANGELOG.md b/absl/CHANGELOG.md index 608e5a3..8602377 100644 --- a/absl/CHANGELOG.md +++ b/absl/CHANGELOG.md @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com). now suppressed and no longer reported in the xml_reporter. * (logging) An exception is now raised instead of `logging.fatal` when logging directories cannot be found. +* (testing) Multiple flags are now set together before their validators run. + This resolves an issue where multi-flag validators rely on specific flag + combinations. ## 0.10.0 (2020-08-19) diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py index 72a747c..c6a209f 100644 --- a/absl/flags/_flagvalues.py +++ b/absl/flags/_flagvalues.py @@ -499,16 +499,25 @@ class FlagValues(object): def __setattr__(self, name, value): """Sets the 'value' attribute of the flag --name.""" - fl = self._flags() - if name in self.__dict__['__hiddenflags']: - raise AttributeError(name) - if name not in fl: - return self._set_unknown_flag(name, value) - fl[name].value = value - self._assert_validators(fl[name].validators) - fl[name].using_default_value = False + self._set_attributes(**{name: value}) return value + def _set_attributes(self, **attributes): + """Sets multiple flag values together, triggers validators afterwards.""" + fl = self._flags() + known_flags = set() + for name, value in six.iteritems(attributes): + if name in self.__dict__['__hiddenflags']: + raise AttributeError(name) + if name in fl: + fl[name].value = value + known_flags.add(name) + else: + self._set_unknown_flag(name, value) + for name in known_flags: + self._assert_validators(fl[name].validators) + fl[name].using_default_value = False + def validate_all_flags(self): """Verifies whether all flags pass validation. diff --git a/absl/flags/tests/_flagvalues_test.py b/absl/flags/tests/_flagvalues_test.py index 658ec69..ed446f8 100644 --- a/absl/flags/tests/_flagvalues_test.py +++ b/absl/flags/tests/_flagvalues_test.py @@ -240,7 +240,7 @@ class FlagValuesTest(absltest.TestCase): # Delete the changelist flag, its short name should still be registered. del fv.changelist module_or_id_changelist = testing_fn('changelist') - self.assertEqual(module_or_id_changelist, None) + self.assertIsNone(module_or_id_changelist) module_or_id_c = testing_fn('c') self.assertEqual(module_or_id_c, current_module_or_id) module_or_id_l = testing_fn('l') @@ -333,21 +333,21 @@ class FlagValuesTest(absltest.TestCase): def test_len(self): fv = _flagvalues.FlagValues() - self.assertEqual(0, len(fv)) + self.assertEmpty(fv) self.assertFalse(fv) _defines.DEFINE_boolean('boolean', False, 'help', flag_values=fv) - self.assertEqual(1, len(fv)) + self.assertLen(fv, 1) self.assertTrue(fv) _defines.DEFINE_boolean( 'bool', False, 'help', short_name='b', flag_values=fv) - self.assertEqual(3, len(fv)) + self.assertLen(fv, 3) self.assertTrue(fv) def test_pickle(self): fv = _flagvalues.FlagValues() - with self.assertRaisesRegexp(TypeError, "can't pickle FlagValues"): + with self.assertRaisesRegex(TypeError, "can't pickle FlagValues"): pickle.dumps(fv) def test_copy(self): @@ -355,8 +355,8 @@ class FlagValuesTest(absltest.TestCase): _defines.DEFINE_integer('answer', 0, 'help', flag_values=fv) fv(['', '--answer=1']) - with self.assertRaisesRegexp( - TypeError, 'FlagValues does not support shallow copies'): + with self.assertRaisesRegex(TypeError, + 'FlagValues does not support shallow copies'): copy.copy(fv) fv2 = copy.deepcopy(fv) @@ -640,6 +640,7 @@ class FlagSubstrMatchingTests(parameterized.TestCase): class SettingUnknownFlagTest(absltest.TestCase): def setUp(self): + super(SettingUnknownFlagTest, self).setUp() self.setter_called = 0 def set_undef(self, unused_name, unused_val): @@ -679,9 +680,39 @@ class SettingUnknownFlagTest(absltest.TestCase): new_flags.undefined_flag = 0 +class SetAttributesTest(absltest.TestCase): + + def setUp(self): + super(SetAttributesTest, self).setUp() + self.new_flags = _flagvalues.FlagValues() + _defines.DEFINE_boolean( + 'defined_flag', None, '', flag_values=self.new_flags) + _defines.DEFINE_boolean( + 'another_defined_flag', None, '', flag_values=self.new_flags) + self.setter_called = 0 + + def set_undef(self, unused_name, unused_val): + self.setter_called += 1 + + def test_two_defined_flags(self): + self.new_flags._set_attributes( + defined_flag=False, another_defined_flag=False) + self.assertEqual(self.setter_called, 0) + + def test_one_defined_one_undefined_flag(self): + with self.assertRaises(_exceptions.UnrecognizedFlagError): + self.new_flags._set_attributes(defined_flag=False, undefined_flag=0) + + def test_register_unknown_flag_setter(self): + self.new_flags._register_unknown_flag_setter(self.set_undef) + self.new_flags._set_attributes(defined_flag=False, undefined_flag=0) + self.assertEqual(self.setter_called, 1) + + class FlagsDashSyntaxTest(absltest.TestCase): def setUp(self): + super(FlagsDashSyntaxTest, self).setUp() self.fv = _flagvalues.FlagValues() _defines.DEFINE_string( 'long_name', 'default', 'help', flag_values=self.fv, short_name='s') @@ -754,7 +785,7 @@ class UnparseFlagsTest(absltest.TestCase): fv.mark_as_parsed() self.assertEqual('foo', fv.default_foo) - self.assertEqual(None, fv.default_none) + self.assertIsNone(fv.default_none) fv(['', '--default_foo=notFoo', '--default_none=notNone']) self.assertEqual('notFoo', fv.default_foo) @@ -762,7 +793,7 @@ class UnparseFlagsTest(absltest.TestCase): fv.unparse_flags() self.assertEqual('foo', fv['default_foo'].value) - self.assertEqual(None, fv['default_none'].value) + self.assertIsNone(fv['default_none'].value) fv(['', '--default_foo=alsoNotFoo', '--default_none=alsoNotNone']) self.assertEqual('alsoNotFoo', fv.default_foo) @@ -772,15 +803,15 @@ class UnparseFlagsTest(absltest.TestCase): fv = _flagvalues.FlagValues() _defines.DEFINE_multi_string('foo', None, 'help', flag_values=fv) fv.mark_as_parsed() - self.assertEqual(None, fv.foo) + self.assertIsNone(fv.foo) fv(['', '--foo=aa']) self.assertEqual(['aa'], fv.foo) fv.unparse_flags() - self.assertEqual(None, fv['foo'].value) + self.assertIsNone(fv['foo'].value) fv(['', '--foo=bb', '--foo=cc']) self.assertEqual(['bb', 'cc'], fv.foo) fv.unparse_flags() - self.assertEqual(None, fv['foo'].value) + self.assertIsNone(fv['foo'].value) def test_multi_string_default_string(self): fv = _flagvalues.FlagValues() diff --git a/absl/testing/BUILD b/absl/testing/BUILD index b101a1d..9cd28c9 100644 --- a/absl/testing/BUILD +++ b/absl/testing/BUILD @@ -50,7 +50,6 @@ py_library( visibility = ["//visibility:public"], deps = [ "//absl/flags", - "@six_archive//:six", ], ) diff --git a/absl/testing/flagsaver.py b/absl/testing/flagsaver.py index 9a0e193..c33d56a 100755 --- a/absl/testing/flagsaver.py +++ b/absl/testing/flagsaver.py @@ -62,7 +62,6 @@ import functools import inspect from absl import flags -import six FLAGS = flags.FLAGS @@ -156,8 +155,7 @@ class _FlagOverrider(object): def __enter__(self): self._saved_flag_values = save_flag_values(FLAGS) try: - for name, value in six.iteritems(self._overrides): - setattr(FLAGS, name, value) + FLAGS._set_attributes(**self._overrides) except: # It may fail because of flag validators. restore_flag_values(self._saved_flag_values, FLAGS) diff --git a/absl/testing/tests/flagsaver_test.py b/absl/testing/tests/flagsaver_test.py index 13fa1c3..ed428df 100755 --- a/absl/testing/tests/flagsaver_test.py +++ b/absl/testing/tests/flagsaver_test.py @@ -24,9 +24,21 @@ from absl.testing import flagsaver flags.DEFINE_string('flagsaver_test_flag0', 'unchanged0', 'flag to test with') flags.DEFINE_string('flagsaver_test_flag1', 'unchanged1', 'flag to test with') + flags.DEFINE_string('flagsaver_test_validated_flag', None, 'flag to test with') flags.register_validator('flagsaver_test_validated_flag', lambda x: not x) +flags.DEFINE_string('flagsaver_test_validated_flag1', None, 'flag to test with') +flags.DEFINE_string('flagsaver_test_validated_flag2', None, 'flag to test with') + + +@flags.multi_flags_validator( + ('flagsaver_test_validated_flag1', 'flagsaver_test_validated_flag2')) +def validate_test_flags(flag_dict): + return (flag_dict['flagsaver_test_validated_flag1'] == + flag_dict['flagsaver_test_validated_flag2']) + + FLAGS = flags.FLAGS @@ -41,19 +53,6 @@ class _TestError(Exception): class FlagSaverTest(absltest.TestCase): - def setUp(self): - # Save the value of the instance of FLAGS local to this module. - global FLAGS # pylint: disable=global-statement - self.flags = FLAGS - # pylint: disable=g-bad-name - FLAGS = flags.FlagValues() - FLAGS.append_flag_values(self.flags) - FLAGS.mark_as_parsed() - - def tearDown(self): - global FLAGS # pylint: disable=global-statement - FLAGS = self.flags # pylint: disable=g-bad-name - def test_context_manager_without_parameters(self): with flagsaver.flagsaver(): FLAGS.flagsaver_test_flag0 = 'new value' @@ -66,6 +65,42 @@ class FlagSaverTest(absltest.TestCase): self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1) + def test_context_manager_with_cross_validated_overrides_set_together(self): + # When the flags are set in the same flagsaver call their validators will + # be triggered only once the setting is done. + with flagsaver.flagsaver( + flagsaver_test_validated_flag1='new_value', + flagsaver_test_validated_flag2='new_value'): + self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag1) + self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag2) + + self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) + + def test_context_manager_with_cross_validated_overrides_set_badly(self): + + # Different values should violate the validator. + with self.assertRaisesRegex(flags.IllegalFlagValueError, + 'Flag validation failed'): + with flagsaver.flagsaver( + flagsaver_test_validated_flag1='new_value', + flagsaver_test_validated_flag2='other_value'): + pass + + self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) + + def test_context_manager_with_cross_validated_overrides_set_separately(self): + + # Setting just one flag will trip the validator as well. + with self.assertRaisesRegex(flags.IllegalFlagValueError, + 'Flag validation failed'): + with flagsaver.flagsaver(flagsaver_test_validated_flag1='new_value'): + pass + + self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) + def test_context_manager_with_exception(self): with self.assertRaises(_TestError): with flagsaver.flagsaver(flagsaver_test_flag0='new value'): @@ -83,7 +118,7 @@ class FlagSaverTest(absltest.TestCase): pass self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1) - self.assertEqual(None, FLAGS.flagsaver_test_validated_flag) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag) def test_decorator_without_call(self): @@ -133,6 +168,59 @@ class FlagSaverTest(absltest.TestCase): # But... notice that the flag is now unchanged0. self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + def test_decorator_with_cross_validated_overrides_set_together(self): + + # When the flags are set in the same flagsaver call their validators will + # be triggered only once the setting is done. + @flagsaver.flagsaver( + flagsaver_test_validated_flag1='new_value', + flagsaver_test_validated_flag2='new_value') + def mutate_flags_together(): + return (FLAGS.flagsaver_test_validated_flag1, + FLAGS.flagsaver_test_validated_flag2) + + self.assertEqual(('new_value', 'new_value'), mutate_flags_together()) + + # The flags have not changed outside the context of the function. + self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) + + def test_decorator_with_cross_validated_overrides_set_badly(self): + + # Different values should violate the validator. + @flagsaver.flagsaver( + flagsaver_test_validated_flag1='new_value', + flagsaver_test_validated_flag2='other_value') + def mutate_flags_together_badly(): + return (FLAGS.flagsaver_test_validated_flag1, + FLAGS.flagsaver_test_validated_flag2) + + with self.assertRaisesRegex(flags.IllegalFlagValueError, + 'Flag validation failed'): + mutate_flags_together_badly() + + # The flags have not changed outside the context of the exception. + self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) + + def test_decorator_with_cross_validated_overrides_set_separately(self): + + # Setting the flags sequentially and not together will trip the validator, + # because it will be called at the end of each flagsaver call. + @flagsaver.flagsaver(flagsaver_test_validated_flag1='new_value') + @flagsaver.flagsaver(flagsaver_test_validated_flag2='new_value') + def mutate_flags_separately(): + return (FLAGS.flagsaver_test_validated_flag1, + FLAGS.flagsaver_test_validated_flag2) + + with self.assertRaisesRegex(flags.IllegalFlagValueError, + 'Flag validation failed'): + mutate_flags_separately() + + # The flags have not changed outside the context of the exception. + self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) + def test_save_flag_value(self): # First save the flag values. saved_flag_values = flagsaver.save_flag_values() |