diff options
author | Karol M. Langner <langner@google.com> | 2020-09-30 10:16:10 -0700 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2020-09-30 10:16:31 -0700 |
commit | 9a0552c6743d387df6ab565f8ccb9878dd14ece4 (patch) | |
tree | c76a5e4fab205b0d04f3073c2836e8d0dcc93bec /absl/flags | |
parent | dab1732ed9cfe85f97e940a24a4c3a5d52a530f7 (diff) | |
download | absl-py-9a0552c6743d387df6ab565f8ccb9878dd14ece4.tar.gz |
In flagsaver, set multiple flags together before their validators run.
This resolves an issue where multi-flag validators rely on specific flag combinations.
PiperOrigin-RevId: 334625442
Change-Id: I7e6b625637a70356df57a4d9cbb01c203f14df4c
Diffstat (limited to 'absl/flags')
-rw-r--r-- | absl/flags/_flagvalues.py | 25 | ||||
-rw-r--r-- | absl/flags/tests/_flagvalues_test.py | 55 |
2 files changed, 60 insertions, 20 deletions
diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py index 72a747c..c6a209f 100644 --- a/absl/flags/_flagvalues.py +++ b/absl/flags/_flagvalues.py @@ -499,16 +499,25 @@ class FlagValues(object): def __setattr__(self, name, value): """Sets the 'value' attribute of the flag --name.""" - fl = self._flags() - if name in self.__dict__['__hiddenflags']: - raise AttributeError(name) - if name not in fl: - return self._set_unknown_flag(name, value) - fl[name].value = value - self._assert_validators(fl[name].validators) - fl[name].using_default_value = False + self._set_attributes(**{name: value}) return value + def _set_attributes(self, **attributes): + """Sets multiple flag values together, triggers validators afterwards.""" + fl = self._flags() + known_flags = set() + for name, value in six.iteritems(attributes): + if name in self.__dict__['__hiddenflags']: + raise AttributeError(name) + if name in fl: + fl[name].value = value + known_flags.add(name) + else: + self._set_unknown_flag(name, value) + for name in known_flags: + self._assert_validators(fl[name].validators) + fl[name].using_default_value = False + def validate_all_flags(self): """Verifies whether all flags pass validation. diff --git a/absl/flags/tests/_flagvalues_test.py b/absl/flags/tests/_flagvalues_test.py index 658ec69..ed446f8 100644 --- a/absl/flags/tests/_flagvalues_test.py +++ b/absl/flags/tests/_flagvalues_test.py @@ -240,7 +240,7 @@ class FlagValuesTest(absltest.TestCase): # Delete the changelist flag, its short name should still be registered. del fv.changelist module_or_id_changelist = testing_fn('changelist') - self.assertEqual(module_or_id_changelist, None) + self.assertIsNone(module_or_id_changelist) module_or_id_c = testing_fn('c') self.assertEqual(module_or_id_c, current_module_or_id) module_or_id_l = testing_fn('l') @@ -333,21 +333,21 @@ class FlagValuesTest(absltest.TestCase): def test_len(self): fv = _flagvalues.FlagValues() - self.assertEqual(0, len(fv)) + self.assertEmpty(fv) self.assertFalse(fv) _defines.DEFINE_boolean('boolean', False, 'help', flag_values=fv) - self.assertEqual(1, len(fv)) + self.assertLen(fv, 1) self.assertTrue(fv) _defines.DEFINE_boolean( 'bool', False, 'help', short_name='b', flag_values=fv) - self.assertEqual(3, len(fv)) + self.assertLen(fv, 3) self.assertTrue(fv) def test_pickle(self): fv = _flagvalues.FlagValues() - with self.assertRaisesRegexp(TypeError, "can't pickle FlagValues"): + with self.assertRaisesRegex(TypeError, "can't pickle FlagValues"): pickle.dumps(fv) def test_copy(self): @@ -355,8 +355,8 @@ class FlagValuesTest(absltest.TestCase): _defines.DEFINE_integer('answer', 0, 'help', flag_values=fv) fv(['', '--answer=1']) - with self.assertRaisesRegexp( - TypeError, 'FlagValues does not support shallow copies'): + with self.assertRaisesRegex(TypeError, + 'FlagValues does not support shallow copies'): copy.copy(fv) fv2 = copy.deepcopy(fv) @@ -640,6 +640,7 @@ class FlagSubstrMatchingTests(parameterized.TestCase): class SettingUnknownFlagTest(absltest.TestCase): def setUp(self): + super(SettingUnknownFlagTest, self).setUp() self.setter_called = 0 def set_undef(self, unused_name, unused_val): @@ -679,9 +680,39 @@ class SettingUnknownFlagTest(absltest.TestCase): new_flags.undefined_flag = 0 +class SetAttributesTest(absltest.TestCase): + + def setUp(self): + super(SetAttributesTest, self).setUp() + self.new_flags = _flagvalues.FlagValues() + _defines.DEFINE_boolean( + 'defined_flag', None, '', flag_values=self.new_flags) + _defines.DEFINE_boolean( + 'another_defined_flag', None, '', flag_values=self.new_flags) + self.setter_called = 0 + + def set_undef(self, unused_name, unused_val): + self.setter_called += 1 + + def test_two_defined_flags(self): + self.new_flags._set_attributes( + defined_flag=False, another_defined_flag=False) + self.assertEqual(self.setter_called, 0) + + def test_one_defined_one_undefined_flag(self): + with self.assertRaises(_exceptions.UnrecognizedFlagError): + self.new_flags._set_attributes(defined_flag=False, undefined_flag=0) + + def test_register_unknown_flag_setter(self): + self.new_flags._register_unknown_flag_setter(self.set_undef) + self.new_flags._set_attributes(defined_flag=False, undefined_flag=0) + self.assertEqual(self.setter_called, 1) + + class FlagsDashSyntaxTest(absltest.TestCase): def setUp(self): + super(FlagsDashSyntaxTest, self).setUp() self.fv = _flagvalues.FlagValues() _defines.DEFINE_string( 'long_name', 'default', 'help', flag_values=self.fv, short_name='s') @@ -754,7 +785,7 @@ class UnparseFlagsTest(absltest.TestCase): fv.mark_as_parsed() self.assertEqual('foo', fv.default_foo) - self.assertEqual(None, fv.default_none) + self.assertIsNone(fv.default_none) fv(['', '--default_foo=notFoo', '--default_none=notNone']) self.assertEqual('notFoo', fv.default_foo) @@ -762,7 +793,7 @@ class UnparseFlagsTest(absltest.TestCase): fv.unparse_flags() self.assertEqual('foo', fv['default_foo'].value) - self.assertEqual(None, fv['default_none'].value) + self.assertIsNone(fv['default_none'].value) fv(['', '--default_foo=alsoNotFoo', '--default_none=alsoNotNone']) self.assertEqual('alsoNotFoo', fv.default_foo) @@ -772,15 +803,15 @@ class UnparseFlagsTest(absltest.TestCase): fv = _flagvalues.FlagValues() _defines.DEFINE_multi_string('foo', None, 'help', flag_values=fv) fv.mark_as_parsed() - self.assertEqual(None, fv.foo) + self.assertIsNone(fv.foo) fv(['', '--foo=aa']) self.assertEqual(['aa'], fv.foo) fv.unparse_flags() - self.assertEqual(None, fv['foo'].value) + self.assertIsNone(fv['foo'].value) fv(['', '--foo=bb', '--foo=cc']) self.assertEqual(['bb', 'cc'], fv.foo) fv.unparse_flags() - self.assertEqual(None, fv['foo'].value) + self.assertIsNone(fv['foo'].value) def test_multi_string_default_string(self): fv = _flagvalues.FlagValues() |