aboutsummaryrefslogtreecommitdiff
path: root/absl
diff options
context:
space:
mode:
authorKarol M. Langner <langner@google.com>2020-09-30 10:16:10 -0700
committerCopybara-Service <copybara-worker@google.com>2020-09-30 10:16:31 -0700
commit9a0552c6743d387df6ab565f8ccb9878dd14ece4 (patch)
treec76a5e4fab205b0d04f3073c2836e8d0dcc93bec /absl
parentdab1732ed9cfe85f97e940a24a4c3a5d52a530f7 (diff)
downloadabsl-py-9a0552c6743d387df6ab565f8ccb9878dd14ece4.tar.gz
In flagsaver, set multiple flags together before their validators run.
This resolves an issue where multi-flag validators rely on specific flag combinations. PiperOrigin-RevId: 334625442 Change-Id: I7e6b625637a70356df57a4d9cbb01c203f14df4c
Diffstat (limited to 'absl')
-rw-r--r--absl/CHANGELOG.md3
-rw-r--r--absl/flags/_flagvalues.py25
-rw-r--r--absl/flags/tests/_flagvalues_test.py55
-rw-r--r--absl/testing/BUILD1
-rwxr-xr-xabsl/testing/flagsaver.py4
-rwxr-xr-xabsl/testing/tests/flagsaver_test.py116
6 files changed, 166 insertions, 38 deletions
diff --git a/absl/CHANGELOG.md b/absl/CHANGELOG.md
index 608e5a3..8602377 100644
--- a/absl/CHANGELOG.md
+++ b/absl/CHANGELOG.md
@@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com).
now suppressed and no longer reported in the xml_reporter.
* (logging) An exception is now raised instead of `logging.fatal` when logging
directories cannot be found.
+* (testing) Multiple flags are now set together before their validators run.
+ This resolves an issue where multi-flag validators rely on specific flag
+ combinations.
## 0.10.0 (2020-08-19)
diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py
index 72a747c..c6a209f 100644
--- a/absl/flags/_flagvalues.py
+++ b/absl/flags/_flagvalues.py
@@ -499,16 +499,25 @@ class FlagValues(object):
def __setattr__(self, name, value):
"""Sets the 'value' attribute of the flag --name."""
- fl = self._flags()
- if name in self.__dict__['__hiddenflags']:
- raise AttributeError(name)
- if name not in fl:
- return self._set_unknown_flag(name, value)
- fl[name].value = value
- self._assert_validators(fl[name].validators)
- fl[name].using_default_value = False
+ self._set_attributes(**{name: value})
return value
+ def _set_attributes(self, **attributes):
+ """Sets multiple flag values together, triggers validators afterwards."""
+ fl = self._flags()
+ known_flags = set()
+ for name, value in six.iteritems(attributes):
+ if name in self.__dict__['__hiddenflags']:
+ raise AttributeError(name)
+ if name in fl:
+ fl[name].value = value
+ known_flags.add(name)
+ else:
+ self._set_unknown_flag(name, value)
+ for name in known_flags:
+ self._assert_validators(fl[name].validators)
+ fl[name].using_default_value = False
+
def validate_all_flags(self):
"""Verifies whether all flags pass validation.
diff --git a/absl/flags/tests/_flagvalues_test.py b/absl/flags/tests/_flagvalues_test.py
index 658ec69..ed446f8 100644
--- a/absl/flags/tests/_flagvalues_test.py
+++ b/absl/flags/tests/_flagvalues_test.py
@@ -240,7 +240,7 @@ class FlagValuesTest(absltest.TestCase):
# Delete the changelist flag, its short name should still be registered.
del fv.changelist
module_or_id_changelist = testing_fn('changelist')
- self.assertEqual(module_or_id_changelist, None)
+ self.assertIsNone(module_or_id_changelist)
module_or_id_c = testing_fn('c')
self.assertEqual(module_or_id_c, current_module_or_id)
module_or_id_l = testing_fn('l')
@@ -333,21 +333,21 @@ class FlagValuesTest(absltest.TestCase):
def test_len(self):
fv = _flagvalues.FlagValues()
- self.assertEqual(0, len(fv))
+ self.assertEmpty(fv)
self.assertFalse(fv)
_defines.DEFINE_boolean('boolean', False, 'help', flag_values=fv)
- self.assertEqual(1, len(fv))
+ self.assertLen(fv, 1)
self.assertTrue(fv)
_defines.DEFINE_boolean(
'bool', False, 'help', short_name='b', flag_values=fv)
- self.assertEqual(3, len(fv))
+ self.assertLen(fv, 3)
self.assertTrue(fv)
def test_pickle(self):
fv = _flagvalues.FlagValues()
- with self.assertRaisesRegexp(TypeError, "can't pickle FlagValues"):
+ with self.assertRaisesRegex(TypeError, "can't pickle FlagValues"):
pickle.dumps(fv)
def test_copy(self):
@@ -355,8 +355,8 @@ class FlagValuesTest(absltest.TestCase):
_defines.DEFINE_integer('answer', 0, 'help', flag_values=fv)
fv(['', '--answer=1'])
- with self.assertRaisesRegexp(
- TypeError, 'FlagValues does not support shallow copies'):
+ with self.assertRaisesRegex(TypeError,
+ 'FlagValues does not support shallow copies'):
copy.copy(fv)
fv2 = copy.deepcopy(fv)
@@ -640,6 +640,7 @@ class FlagSubstrMatchingTests(parameterized.TestCase):
class SettingUnknownFlagTest(absltest.TestCase):
def setUp(self):
+ super(SettingUnknownFlagTest, self).setUp()
self.setter_called = 0
def set_undef(self, unused_name, unused_val):
@@ -679,9 +680,39 @@ class SettingUnknownFlagTest(absltest.TestCase):
new_flags.undefined_flag = 0
+class SetAttributesTest(absltest.TestCase):
+
+ def setUp(self):
+ super(SetAttributesTest, self).setUp()
+ self.new_flags = _flagvalues.FlagValues()
+ _defines.DEFINE_boolean(
+ 'defined_flag', None, '', flag_values=self.new_flags)
+ _defines.DEFINE_boolean(
+ 'another_defined_flag', None, '', flag_values=self.new_flags)
+ self.setter_called = 0
+
+ def set_undef(self, unused_name, unused_val):
+ self.setter_called += 1
+
+ def test_two_defined_flags(self):
+ self.new_flags._set_attributes(
+ defined_flag=False, another_defined_flag=False)
+ self.assertEqual(self.setter_called, 0)
+
+ def test_one_defined_one_undefined_flag(self):
+ with self.assertRaises(_exceptions.UnrecognizedFlagError):
+ self.new_flags._set_attributes(defined_flag=False, undefined_flag=0)
+
+ def test_register_unknown_flag_setter(self):
+ self.new_flags._register_unknown_flag_setter(self.set_undef)
+ self.new_flags._set_attributes(defined_flag=False, undefined_flag=0)
+ self.assertEqual(self.setter_called, 1)
+
+
class FlagsDashSyntaxTest(absltest.TestCase):
def setUp(self):
+ super(FlagsDashSyntaxTest, self).setUp()
self.fv = _flagvalues.FlagValues()
_defines.DEFINE_string(
'long_name', 'default', 'help', flag_values=self.fv, short_name='s')
@@ -754,7 +785,7 @@ class UnparseFlagsTest(absltest.TestCase):
fv.mark_as_parsed()
self.assertEqual('foo', fv.default_foo)
- self.assertEqual(None, fv.default_none)
+ self.assertIsNone(fv.default_none)
fv(['', '--default_foo=notFoo', '--default_none=notNone'])
self.assertEqual('notFoo', fv.default_foo)
@@ -762,7 +793,7 @@ class UnparseFlagsTest(absltest.TestCase):
fv.unparse_flags()
self.assertEqual('foo', fv['default_foo'].value)
- self.assertEqual(None, fv['default_none'].value)
+ self.assertIsNone(fv['default_none'].value)
fv(['', '--default_foo=alsoNotFoo', '--default_none=alsoNotNone'])
self.assertEqual('alsoNotFoo', fv.default_foo)
@@ -772,15 +803,15 @@ class UnparseFlagsTest(absltest.TestCase):
fv = _flagvalues.FlagValues()
_defines.DEFINE_multi_string('foo', None, 'help', flag_values=fv)
fv.mark_as_parsed()
- self.assertEqual(None, fv.foo)
+ self.assertIsNone(fv.foo)
fv(['', '--foo=aa'])
self.assertEqual(['aa'], fv.foo)
fv.unparse_flags()
- self.assertEqual(None, fv['foo'].value)
+ self.assertIsNone(fv['foo'].value)
fv(['', '--foo=bb', '--foo=cc'])
self.assertEqual(['bb', 'cc'], fv.foo)
fv.unparse_flags()
- self.assertEqual(None, fv['foo'].value)
+ self.assertIsNone(fv['foo'].value)
def test_multi_string_default_string(self):
fv = _flagvalues.FlagValues()
diff --git a/absl/testing/BUILD b/absl/testing/BUILD
index b101a1d..9cd28c9 100644
--- a/absl/testing/BUILD
+++ b/absl/testing/BUILD
@@ -50,7 +50,6 @@ py_library(
visibility = ["//visibility:public"],
deps = [
"//absl/flags",
- "@six_archive//:six",
],
)
diff --git a/absl/testing/flagsaver.py b/absl/testing/flagsaver.py
index 9a0e193..c33d56a 100755
--- a/absl/testing/flagsaver.py
+++ b/absl/testing/flagsaver.py
@@ -62,7 +62,6 @@ import functools
import inspect
from absl import flags
-import six
FLAGS = flags.FLAGS
@@ -156,8 +155,7 @@ class _FlagOverrider(object):
def __enter__(self):
self._saved_flag_values = save_flag_values(FLAGS)
try:
- for name, value in six.iteritems(self._overrides):
- setattr(FLAGS, name, value)
+ FLAGS._set_attributes(**self._overrides)
except:
# It may fail because of flag validators.
restore_flag_values(self._saved_flag_values, FLAGS)
diff --git a/absl/testing/tests/flagsaver_test.py b/absl/testing/tests/flagsaver_test.py
index 13fa1c3..ed428df 100755
--- a/absl/testing/tests/flagsaver_test.py
+++ b/absl/testing/tests/flagsaver_test.py
@@ -24,9 +24,21 @@ from absl.testing import flagsaver
flags.DEFINE_string('flagsaver_test_flag0', 'unchanged0', 'flag to test with')
flags.DEFINE_string('flagsaver_test_flag1', 'unchanged1', 'flag to test with')
+
flags.DEFINE_string('flagsaver_test_validated_flag', None, 'flag to test with')
flags.register_validator('flagsaver_test_validated_flag', lambda x: not x)
+flags.DEFINE_string('flagsaver_test_validated_flag1', None, 'flag to test with')
+flags.DEFINE_string('flagsaver_test_validated_flag2', None, 'flag to test with')
+
+
+@flags.multi_flags_validator(
+ ('flagsaver_test_validated_flag1', 'flagsaver_test_validated_flag2'))
+def validate_test_flags(flag_dict):
+ return (flag_dict['flagsaver_test_validated_flag1'] ==
+ flag_dict['flagsaver_test_validated_flag2'])
+
+
FLAGS = flags.FLAGS
@@ -41,19 +53,6 @@ class _TestError(Exception):
class FlagSaverTest(absltest.TestCase):
- def setUp(self):
- # Save the value of the instance of FLAGS local to this module.
- global FLAGS # pylint: disable=global-statement
- self.flags = FLAGS
- # pylint: disable=g-bad-name
- FLAGS = flags.FlagValues()
- FLAGS.append_flag_values(self.flags)
- FLAGS.mark_as_parsed()
-
- def tearDown(self):
- global FLAGS # pylint: disable=global-statement
- FLAGS = self.flags # pylint: disable=g-bad-name
-
def test_context_manager_without_parameters(self):
with flagsaver.flagsaver():
FLAGS.flagsaver_test_flag0 = 'new value'
@@ -66,6 +65,42 @@ class FlagSaverTest(absltest.TestCase):
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
+ def test_context_manager_with_cross_validated_overrides_set_together(self):
+ # When the flags are set in the same flagsaver call their validators will
+ # be triggered only once the setting is done.
+ with flagsaver.flagsaver(
+ flagsaver_test_validated_flag1='new_value',
+ flagsaver_test_validated_flag2='new_value'):
+ self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag1)
+ self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag2)
+
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ def test_context_manager_with_cross_validated_overrides_set_badly(self):
+
+ # Different values should violate the validator.
+ with self.assertRaisesRegex(flags.IllegalFlagValueError,
+ 'Flag validation failed'):
+ with flagsaver.flagsaver(
+ flagsaver_test_validated_flag1='new_value',
+ flagsaver_test_validated_flag2='other_value'):
+ pass
+
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ def test_context_manager_with_cross_validated_overrides_set_separately(self):
+
+ # Setting just one flag will trip the validator as well.
+ with self.assertRaisesRegex(flags.IllegalFlagValueError,
+ 'Flag validation failed'):
+ with flagsaver.flagsaver(flagsaver_test_validated_flag1='new_value'):
+ pass
+
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
def test_context_manager_with_exception(self):
with self.assertRaises(_TestError):
with flagsaver.flagsaver(flagsaver_test_flag0='new value'):
@@ -83,7 +118,7 @@ class FlagSaverTest(absltest.TestCase):
pass
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
- self.assertEqual(None, FLAGS.flagsaver_test_validated_flag)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag)
def test_decorator_without_call(self):
@@ -133,6 +168,59 @@ class FlagSaverTest(absltest.TestCase):
# But... notice that the flag is now unchanged0.
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+ def test_decorator_with_cross_validated_overrides_set_together(self):
+
+ # When the flags are set in the same flagsaver call their validators will
+ # be triggered only once the setting is done.
+ @flagsaver.flagsaver(
+ flagsaver_test_validated_flag1='new_value',
+ flagsaver_test_validated_flag2='new_value')
+ def mutate_flags_together():
+ return (FLAGS.flagsaver_test_validated_flag1,
+ FLAGS.flagsaver_test_validated_flag2)
+
+ self.assertEqual(('new_value', 'new_value'), mutate_flags_together())
+
+ # The flags have not changed outside the context of the function.
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ def test_decorator_with_cross_validated_overrides_set_badly(self):
+
+ # Different values should violate the validator.
+ @flagsaver.flagsaver(
+ flagsaver_test_validated_flag1='new_value',
+ flagsaver_test_validated_flag2='other_value')
+ def mutate_flags_together_badly():
+ return (FLAGS.flagsaver_test_validated_flag1,
+ FLAGS.flagsaver_test_validated_flag2)
+
+ with self.assertRaisesRegex(flags.IllegalFlagValueError,
+ 'Flag validation failed'):
+ mutate_flags_together_badly()
+
+ # The flags have not changed outside the context of the exception.
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ def test_decorator_with_cross_validated_overrides_set_separately(self):
+
+ # Setting the flags sequentially and not together will trip the validator,
+ # because it will be called at the end of each flagsaver call.
+ @flagsaver.flagsaver(flagsaver_test_validated_flag1='new_value')
+ @flagsaver.flagsaver(flagsaver_test_validated_flag2='new_value')
+ def mutate_flags_separately():
+ return (FLAGS.flagsaver_test_validated_flag1,
+ FLAGS.flagsaver_test_validated_flag2)
+
+ with self.assertRaisesRegex(flags.IllegalFlagValueError,
+ 'Flag validation failed'):
+ mutate_flags_separately()
+
+ # The flags have not changed outside the context of the exception.
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
def test_save_flag_value(self):
# First save the flag values.
saved_flag_values = flagsaver.save_flag_values()