aboutsummaryrefslogtreecommitdiff
path: root/absl/flags
diff options
context:
space:
mode:
authorKarol M. Langner <langner@google.com>2020-09-30 10:16:10 -0700
committerCopybara-Service <copybara-worker@google.com>2020-09-30 10:16:31 -0700
commit9a0552c6743d387df6ab565f8ccb9878dd14ece4 (patch)
treec76a5e4fab205b0d04f3073c2836e8d0dcc93bec /absl/flags
parentdab1732ed9cfe85f97e940a24a4c3a5d52a530f7 (diff)
downloadabsl-py-9a0552c6743d387df6ab565f8ccb9878dd14ece4.tar.gz
In flagsaver, set multiple flags together before their validators run.
This resolves an issue where multi-flag validators rely on specific flag combinations. PiperOrigin-RevId: 334625442 Change-Id: I7e6b625637a70356df57a4d9cbb01c203f14df4c
Diffstat (limited to 'absl/flags')
-rw-r--r--absl/flags/_flagvalues.py25
-rw-r--r--absl/flags/tests/_flagvalues_test.py55
2 files changed, 60 insertions, 20 deletions
diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py
index 72a747c..c6a209f 100644
--- a/absl/flags/_flagvalues.py
+++ b/absl/flags/_flagvalues.py
@@ -499,16 +499,25 @@ class FlagValues(object):
def __setattr__(self, name, value):
"""Sets the 'value' attribute of the flag --name."""
- fl = self._flags()
- if name in self.__dict__['__hiddenflags']:
- raise AttributeError(name)
- if name not in fl:
- return self._set_unknown_flag(name, value)
- fl[name].value = value
- self._assert_validators(fl[name].validators)
- fl[name].using_default_value = False
+ self._set_attributes(**{name: value})
return value
+ def _set_attributes(self, **attributes):
+ """Sets multiple flag values together, triggers validators afterwards."""
+ fl = self._flags()
+ known_flags = set()
+ for name, value in six.iteritems(attributes):
+ if name in self.__dict__['__hiddenflags']:
+ raise AttributeError(name)
+ if name in fl:
+ fl[name].value = value
+ known_flags.add(name)
+ else:
+ self._set_unknown_flag(name, value)
+ for name in known_flags:
+ self._assert_validators(fl[name].validators)
+ fl[name].using_default_value = False
+
def validate_all_flags(self):
"""Verifies whether all flags pass validation.
diff --git a/absl/flags/tests/_flagvalues_test.py b/absl/flags/tests/_flagvalues_test.py
index 658ec69..ed446f8 100644
--- a/absl/flags/tests/_flagvalues_test.py
+++ b/absl/flags/tests/_flagvalues_test.py
@@ -240,7 +240,7 @@ class FlagValuesTest(absltest.TestCase):
# Delete the changelist flag, its short name should still be registered.
del fv.changelist
module_or_id_changelist = testing_fn('changelist')
- self.assertEqual(module_or_id_changelist, None)
+ self.assertIsNone(module_or_id_changelist)
module_or_id_c = testing_fn('c')
self.assertEqual(module_or_id_c, current_module_or_id)
module_or_id_l = testing_fn('l')
@@ -333,21 +333,21 @@ class FlagValuesTest(absltest.TestCase):
def test_len(self):
fv = _flagvalues.FlagValues()
- self.assertEqual(0, len(fv))
+ self.assertEmpty(fv)
self.assertFalse(fv)
_defines.DEFINE_boolean('boolean', False, 'help', flag_values=fv)
- self.assertEqual(1, len(fv))
+ self.assertLen(fv, 1)
self.assertTrue(fv)
_defines.DEFINE_boolean(
'bool', False, 'help', short_name='b', flag_values=fv)
- self.assertEqual(3, len(fv))
+ self.assertLen(fv, 3)
self.assertTrue(fv)
def test_pickle(self):
fv = _flagvalues.FlagValues()
- with self.assertRaisesRegexp(TypeError, "can't pickle FlagValues"):
+ with self.assertRaisesRegex(TypeError, "can't pickle FlagValues"):
pickle.dumps(fv)
def test_copy(self):
@@ -355,8 +355,8 @@ class FlagValuesTest(absltest.TestCase):
_defines.DEFINE_integer('answer', 0, 'help', flag_values=fv)
fv(['', '--answer=1'])
- with self.assertRaisesRegexp(
- TypeError, 'FlagValues does not support shallow copies'):
+ with self.assertRaisesRegex(TypeError,
+ 'FlagValues does not support shallow copies'):
copy.copy(fv)
fv2 = copy.deepcopy(fv)
@@ -640,6 +640,7 @@ class FlagSubstrMatchingTests(parameterized.TestCase):
class SettingUnknownFlagTest(absltest.TestCase):
def setUp(self):
+ super(SettingUnknownFlagTest, self).setUp()
self.setter_called = 0
def set_undef(self, unused_name, unused_val):
@@ -679,9 +680,39 @@ class SettingUnknownFlagTest(absltest.TestCase):
new_flags.undefined_flag = 0
+class SetAttributesTest(absltest.TestCase):
+
+ def setUp(self):
+ super(SetAttributesTest, self).setUp()
+ self.new_flags = _flagvalues.FlagValues()
+ _defines.DEFINE_boolean(
+ 'defined_flag', None, '', flag_values=self.new_flags)
+ _defines.DEFINE_boolean(
+ 'another_defined_flag', None, '', flag_values=self.new_flags)
+ self.setter_called = 0
+
+ def set_undef(self, unused_name, unused_val):
+ self.setter_called += 1
+
+ def test_two_defined_flags(self):
+ self.new_flags._set_attributes(
+ defined_flag=False, another_defined_flag=False)
+ self.assertEqual(self.setter_called, 0)
+
+ def test_one_defined_one_undefined_flag(self):
+ with self.assertRaises(_exceptions.UnrecognizedFlagError):
+ self.new_flags._set_attributes(defined_flag=False, undefined_flag=0)
+
+ def test_register_unknown_flag_setter(self):
+ self.new_flags._register_unknown_flag_setter(self.set_undef)
+ self.new_flags._set_attributes(defined_flag=False, undefined_flag=0)
+ self.assertEqual(self.setter_called, 1)
+
+
class FlagsDashSyntaxTest(absltest.TestCase):
def setUp(self):
+ super(FlagsDashSyntaxTest, self).setUp()
self.fv = _flagvalues.FlagValues()
_defines.DEFINE_string(
'long_name', 'default', 'help', flag_values=self.fv, short_name='s')
@@ -754,7 +785,7 @@ class UnparseFlagsTest(absltest.TestCase):
fv.mark_as_parsed()
self.assertEqual('foo', fv.default_foo)
- self.assertEqual(None, fv.default_none)
+ self.assertIsNone(fv.default_none)
fv(['', '--default_foo=notFoo', '--default_none=notNone'])
self.assertEqual('notFoo', fv.default_foo)
@@ -762,7 +793,7 @@ class UnparseFlagsTest(absltest.TestCase):
fv.unparse_flags()
self.assertEqual('foo', fv['default_foo'].value)
- self.assertEqual(None, fv['default_none'].value)
+ self.assertIsNone(fv['default_none'].value)
fv(['', '--default_foo=alsoNotFoo', '--default_none=alsoNotNone'])
self.assertEqual('alsoNotFoo', fv.default_foo)
@@ -772,15 +803,15 @@ class UnparseFlagsTest(absltest.TestCase):
fv = _flagvalues.FlagValues()
_defines.DEFINE_multi_string('foo', None, 'help', flag_values=fv)
fv.mark_as_parsed()
- self.assertEqual(None, fv.foo)
+ self.assertIsNone(fv.foo)
fv(['', '--foo=aa'])
self.assertEqual(['aa'], fv.foo)
fv.unparse_flags()
- self.assertEqual(None, fv['foo'].value)
+ self.assertIsNone(fv['foo'].value)
fv(['', '--foo=bb', '--foo=cc'])
self.assertEqual(['bb', 'cc'], fv.foo)
fv.unparse_flags()
- self.assertEqual(None, fv['foo'].value)
+ self.assertIsNone(fv['foo'].value)
def test_multi_string_default_string(self):
fv = _flagvalues.FlagValues()