aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYilei Yang <yileiyang@google.com>2023-01-04 11:09:44 -0800
committerYilei Yang <yileiyang@google.com>2023-01-04 11:09:44 -0800
commit49ea060b60be905605a84cfa0be047144944444f (patch)
treeaa6e6953292a4149391e7795210195dfa00e02d5
parent1f1c14f1d14a5392a1930e93c81261f77827e0ca (diff)
parentfd32fea9dac2f3faa3516d4f9dca91625d886e90 (diff)
downloadabsl-py-49ea060b60be905605a84cfa0be047144944444f.tar.gz
Merge commit for internal changes.
-rw-r--r--CHANGELOG.md3
-rw-r--r--absl/flags/_flagvalues.py2
-rw-r--r--absl/testing/BUILD1
-rw-r--r--absl/testing/flagsaver.py190
-rw-r--r--absl/testing/tests/flagsaver_test.py681
5 files changed, 610 insertions, 267 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 56f832b..9ee4926 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,6 +7,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com).
## Unreleased
### Changed
+* (testing) Added `@flagsaver.as_parsed`: this allows saving/restoring flags
+ using string values as if parsed from the command line and will also reflect
+ other flag states after command line parsing, e.g. `.present` is set.
* (logging) If no log dir is specified `logging.find_log_dir()` now falls back
to `tempfile.gettempdir()` instead of `/tmp/`.
diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py
index 6661b78..fd0e631 100644
--- a/absl/flags/_flagvalues.py
+++ b/absl/flags/_flagvalues.py
@@ -792,8 +792,10 @@ class FlagValues:
continue
if flag is not None:
+ # LINT.IfChange
flag.parse(value)
flag.using_default_value = False
+ # LINT.ThenChange(../testing/flagsaver.py:flag_override_parsing)
else:
unparsed_names_and_args.append((name, arg))
diff --git a/absl/testing/BUILD b/absl/testing/BUILD
index d428792..3173c4b 100644
--- a/absl/testing/BUILD
+++ b/absl/testing/BUILD
@@ -212,6 +212,7 @@ py_test(
deps = [
":absltest",
":flagsaver",
+ ":parameterized",
"//absl/flags",
],
)
diff --git a/absl/testing/flagsaver.py b/absl/testing/flagsaver.py
index 774c698..e96c8c5 100644
--- a/absl/testing/flagsaver.py
+++ b/absl/testing/flagsaver.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
"""Decorator and context manager for saving and restoring flag values.
There are many ways to save and restore. Always use the most convenient method
@@ -49,6 +50,36 @@ Here are examples of each method. They all call ``do_stuff()`` while
finally:
flagsaver.restore_flag_values(saved_flag_values)
+ # Use the parsing version to emulate users providing the flags.
+ # Note that all flags must be provided as strings (unparsed).
+ @flagsaver.as_parsed(some_int_flag='123')
+ def some_func():
+ # Because the flag was parsed it is considered "present".
+ assert FLAGS.some_int_flag.present
+ do_stuff()
+
+ # flagsaver.as_parsed() can also be used as a context manager just like
+ # flagsaver.flagsaver()
+ with flagsaver.as_parsed(some_int_flag='123'):
+ do_stuff()
+
+ # The flagsaver.as_parsed() interface also supports FlagHolder objects.
+ @flagsaver.as_parsed((module.FOO_FLAG, 'foo'), (other_mod.BAR_FLAG, '23'))
+ def some_func():
+ do_stuff()
+
+ # Using as_parsed with a multi_X flag requires a sequence of strings.
+ @flagsaver.as_parsed(some_multi_int_flag=['123', '456'])
+ def some_func():
+ assert FLAGS.some_multi_int_flag.present
+ do_stuff()
+
+ # If a flag name includes non-identifier characters it can be specified like
+ # so:
+ @flagsaver.as_parsed(**{'i-like-dashes': 'true'})
+ def some_func():
+ do_stuff()
+
We save and restore a shallow copy of each Flag object's ``__dict__`` attribute.
This preserves all attributes of the flag, such as whether or not it was
overridden from its default value.
@@ -58,14 +89,16 @@ exception will be raised. However if you *add* a flag after saving flag values,
and then restore flag values, the added flag will be deleted with no errors.
"""
+import collections
import functools
import inspect
-from typing import overload, Any, Callable, Mapping, Tuple, TypeVar
+from typing import overload, Any, Callable, Mapping, Tuple, TypeVar, Type, Sequence, Union
from absl import flags
FLAGS = flags.FLAGS
+
# The type of pre/post wrapped functions.
_CallableT = TypeVar('_CallableT', bound=Callable)
@@ -83,8 +116,86 @@ def flagsaver(func: _CallableT) -> _CallableT:
def flagsaver(*args, **kwargs):
"""The main flagsaver interface. See module doc for usage."""
+ return _construct_overrider(_FlagOverrider, *args, **kwargs)
+
+
+@overload
+def as_parsed(*args: Tuple[flags.FlagHolder, Union[str, Sequence[str]]],
+ **kwargs: Union[str, Sequence[str]]) -> '_ParsingFlagOverrider':
+ ...
+
+
+@overload
+def as_parsed(func: _CallableT) -> _CallableT:
+ ...
+
+
+def as_parsed(*args, **kwargs):
+ """Overrides flags by parsing strings, saves flag state similar to flagsaver.
+
+ This function can be used as either a decorator or context manager similar to
+ flagsaver.flagsaver(). However, where flagsaver.flagsaver() directly sets the
+ flags to new values, this function will parse the provided arguments as if
+ they were provided on the command line. Among other things, this will cause
+ `FLAGS['flag_name'].parsed == True`.
+
+ A note on unparsed input: For many flag types, the unparsed version will be
+ a single string. However for multi_x (multi_string, multi_integer, multi_enum)
+ the unparsed version will be a Sequence of strings.
+
+ Args:
+ *args: Tuples of FlagHolders and their unparsed value.
+ **kwargs: The keyword args are flag names, and the values are unparsed
+ values.
+
+ Returns:
+ _ParsingFlagOverrider that serves as a context manager or decorator. Will
+ save previous flag state and parse new flags, then on cleanup it will
+ restore the previous flag state.
+ """
+ return _construct_overrider(_ParsingFlagOverrider, *args, **kwargs)
+
+
+# NOTE: the order of these overload declarations matters. The type checker will
+# pick the first match which could be incorrect.
+@overload
+def _construct_overrider(
+ flag_overrider_cls: Type['_ParsingFlagOverrider'],
+ *args: Tuple[flags.FlagHolder, Union[str, Sequence[str]]],
+ **kwargs: Union[str, Sequence[str]]) -> '_ParsingFlagOverrider':
+ ...
+
+
+@overload
+def _construct_overrider(flag_overrider_cls: Type['_FlagOverrider'],
+ *args: Tuple[flags.FlagHolder, Any],
+ **kwargs: Any) -> '_FlagOverrider':
+ ...
+
+
+@overload
+def _construct_overrider(flag_overrider_cls: Type['_FlagOverrider'],
+ func: _CallableT) -> _CallableT:
+ ...
+
+
+def _construct_overrider(flag_overrider_cls, *args, **kwargs):
+ """Handles the args/kwargs returning an instance of flag_overrider_cls.
+
+ If flag_overrider_cls is _FlagOverrider then values should be native python
+ types matching the python types. Otherwise if flag_overrider_cls is
+ _ParsingFlagOverrider the values should be strings or sequences of strings.
+
+ Args:
+ flag_overrider_cls: The class that will do the overriding.
+ *args: Tuples of FlagHolder and the new flag value.
+ **kwargs: Keword args mapping flag name to new flag value.
+
+ Returns:
+ A _FlagOverrider to be used as a decorator or context manager.
+ """
if not args:
- return _FlagOverrider(**kwargs)
+ return flag_overrider_cls(**kwargs)
# args can be [func] if used as `@flagsaver` instead of `@flagsaver(...)`
if len(args) == 1 and callable(args[0]):
if kwargs:
@@ -93,7 +204,7 @@ def flagsaver(*args, **kwargs):
func = args[0]
if inspect.isclass(func):
raise TypeError('@flagsaver.flagsaver cannot be applied to a class.')
- return _wrap(func, {})
+ return _wrap(flag_overrider_cls, func, {})
# args can be a list of (FlagHolder, value) pairs.
# In which case they augment any specified kwargs.
for arg in args:
@@ -105,7 +216,7 @@ def flagsaver(*args, **kwargs):
if holder.name in kwargs:
raise ValueError('Cannot set --%s multiple times' % holder.name)
kwargs[holder.name] = value
- return _FlagOverrider(**kwargs)
+ return flag_overrider_cls(**kwargs)
def save_flag_values(
@@ -144,13 +255,27 @@ def restore_flag_values(saved_flag_values: Mapping[str, Mapping[str, Any]],
flag_values[name].__dict__ = saved
-def _wrap(func: _CallableT, overrides: Mapping[str, Any]) -> _CallableT:
+@overload
+def _wrap(flag_overrider_cls: Type['_FlagOverrider'], func: _CallableT,
+ overrides: Mapping[str, Any]) -> _CallableT:
+ ...
+
+
+@overload
+def _wrap(flag_overrider_cls: Type['_ParsingFlagOverrider'], func: _CallableT,
+ overrides: Mapping[str, Union[str, Sequence[str]]]) -> _CallableT:
+ ...
+
+
+def _wrap(flag_overrider_cls, func, overrides):
"""Creates a wrapper function that saves/restores flag values.
Args:
+ flag_overrider_cls: The class that will be used as a context manager.
func: This will be called between saving flags and restoring flags.
overrides: Flag names mapped to their values. These flags will be set after
- saving the original flag state.
+ saving the original flag state. The type of the values depends on if
+ _FlagOverrider or _ParsingFlagOverrider was specified.
Returns:
A wrapped version of func.
@@ -159,7 +284,7 @@ def _wrap(func: _CallableT, overrides: Mapping[str, Any]) -> _CallableT:
@functools.wraps(func)
def _flagsaver_wrapper(*args, **kwargs):
"""Wrapper function that saves and restores flags."""
- with _FlagOverrider(**overrides):
+ with flag_overrider_cls(**overrides):
return func(*args, **kwargs)
return _flagsaver_wrapper
@@ -179,7 +304,7 @@ class _FlagOverrider(object):
def __call__(self, func: _CallableT) -> _CallableT:
if inspect.isclass(func):
raise TypeError('flagsaver cannot be applied to a class.')
- return _wrap(func, self._overrides)
+ return _wrap(self.__class__, func, self._overrides)
def __enter__(self):
self._saved_flag_values = save_flag_values(FLAGS)
@@ -194,6 +319,55 @@ class _FlagOverrider(object):
restore_flag_values(self._saved_flag_values, FLAGS)
+class _ParsingFlagOverrider(_FlagOverrider):
+ """Context manager for overriding flags.
+
+ Simulates command line parsing.
+
+ This is simlar to _FlagOverrider except that all **overrides should be
+ strings or sequences of strings, and when context is entered this class calls
+ .parse(value)
+
+ This results in the flags having .present set properly.
+ """
+
+ def __init__(self, **overrides: Union[str, Sequence[str]]):
+ for flag_name, new_value in overrides.items():
+ if isinstance(new_value, str):
+ continue
+ if (isinstance(new_value, collections.abc.Sequence) and
+ all(isinstance(single_value, str) for single_value in new_value)):
+ continue
+ raise TypeError(
+ f'flagsaver.as_parsed() cannot parse {flag_name}. Expected a single '
+ f'string or sequence of strings but {type(new_value)} was provided.')
+ super().__init__(**overrides)
+
+ def __enter__(self):
+ self._saved_flag_values = save_flag_values(FLAGS)
+ try:
+ for flag_name, unparsed_value in self._overrides.items():
+ # LINT.IfChange(flag_override_parsing)
+ FLAGS[flag_name].parse(unparsed_value)
+ FLAGS[flag_name].using_default_value = False
+ # LINT.ThenChange()
+
+ # Perform the validation on all modified flags. This is something that
+ # FLAGS._set_attributes() does for you in _FlagOverrider.
+ for flag_name in self._overrides:
+ FLAGS._assert_validators(FLAGS[flag_name].validators)
+
+ except KeyError as e:
+ # If a flag doesn't exist, an UnrecognizedFlagError is more specific.
+ restore_flag_values(self._saved_flag_values, FLAGS)
+ raise flags.UnrecognizedFlagError('Unknown command line flag.') from e
+
+ except:
+ # It may fail because of flag validators or general parsing issues.
+ restore_flag_values(self._saved_flag_values, FLAGS)
+ raise
+
+
def _copy_flag_dict(flag: flags.Flag) -> Mapping[str, Any]:
"""Returns a copy of the flag object's ``__dict__``.
diff --git a/absl/testing/tests/flagsaver_test.py b/absl/testing/tests/flagsaver_test.py
index e98cd06..b8f91a5 100644
--- a/absl/testing/tests/flagsaver_test.py
+++ b/absl/testing/tests/flagsaver_test.py
@@ -16,6 +16,7 @@
from absl import flags
from absl.testing import absltest
from absl.testing import flagsaver
+from absl.testing import parameterized
flags.DEFINE_string('flagsaver_test_flag0', 'unchanged0', 'flag to test with')
flags.DEFINE_string('flagsaver_test_flag1', 'unchanged1', 'flag to test with')
@@ -31,6 +32,9 @@ INT_FLAG = flags.DEFINE_integer(
STR_FLAG = flags.DEFINE_string(
'flagsaver_test_str_flag', default='str default', help='help')
+MULTI_INT_FLAG = flags.DEFINE_multi_integer('flagsaver_test_multi_int_flag',
+ None, 'flag to test with')
+
@flags.multi_flags_validator(
('flagsaver_test_validated_flag1', 'flagsaver_test_validated_flag2'))
@@ -51,194 +55,83 @@ class _TestError(Exception):
"""Exception class for use in these tests."""
-class FlagSaverTest(absltest.TestCase):
+class CommonUsageTest(absltest.TestCase):
+ """These test cases cover the most common usages of flagsaver."""
- def test_context_manager_without_parameters(self):
- with flagsaver.flagsaver():
- FLAGS.flagsaver_test_flag0 = 'new value'
- self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
-
- def test_context_manager_with_overrides(self):
- with flagsaver.flagsaver(flagsaver_test_flag0='new value'):
- self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
- FLAGS.flagsaver_test_flag1 = 'another value'
+ def test_as_parsed_context_manager(self):
+ # Precondition check, we expect all the flags to start as their default.
+ self.assertEqual('str default', STR_FLAG.value)
+ self.assertFalse(STR_FLAG.present)
+ self.assertEqual(1, INT_FLAG.value)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
- def test_context_manager_with_flagholders(self):
- with flagsaver.flagsaver((INT_FLAG, 3), (STR_FLAG, 'new value')):
- self.assertEqual('new value', STR_FLAG.value)
- self.assertEqual(3, INT_FLAG.value)
- FLAGS.flagsaver_test_flag1 = 'another value'
- self.assertEqual(INT_FLAG.value, INT_FLAG.default)
- self.assertEqual(STR_FLAG.value, STR_FLAG.default)
- self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
-
- def test_context_manager_with_overrides_and_flagholders(self):
- with flagsaver.flagsaver((INT_FLAG, 3), flagsaver_test_flag0='new value'):
- self.assertEqual(STR_FLAG.default, STR_FLAG.value)
- self.assertEqual(3, INT_FLAG.value)
- FLAGS.flagsaver_test_flag0 = 'new value'
- self.assertEqual(INT_FLAG.value, INT_FLAG.default)
- self.assertEqual(STR_FLAG.value, STR_FLAG.default)
- self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
-
- def test_context_manager_with_cross_validated_overrides_set_together(self):
- # When the flags are set in the same flagsaver call their validators will
- # be triggered only once the setting is done.
- with flagsaver.flagsaver(
- flagsaver_test_validated_flag1='new_value',
- flagsaver_test_validated_flag2='new_value'):
- self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag1)
- self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag2)
-
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
-
- def test_context_manager_with_cross_validated_overrides_set_badly(self):
-
- # Different values should violate the validator.
- with self.assertRaisesRegex(flags.IllegalFlagValueError,
- 'Flag validation failed'):
- with flagsaver.flagsaver(
- flagsaver_test_validated_flag1='new_value',
- flagsaver_test_validated_flag2='other_value'):
- pass
-
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
-
- def test_context_manager_with_cross_validated_overrides_set_separately(self):
-
- # Setting just one flag will trip the validator as well.
- with self.assertRaisesRegex(flags.IllegalFlagValueError,
- 'Flag validation failed'):
- with flagsaver.flagsaver(flagsaver_test_validated_flag1='new_value'):
- pass
-
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
-
- def test_context_manager_with_exception(self):
- with self.assertRaises(_TestError):
- with flagsaver.flagsaver(flagsaver_test_flag0='new value'):
- self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
- FLAGS.flagsaver_test_flag1 = 'another value'
- raise _TestError('oops')
- self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
-
- def test_context_manager_with_validation_exception(self):
- with self.assertRaises(flags.IllegalFlagValueError):
- with flagsaver.flagsaver(
- flagsaver_test_flag0='new value',
- flagsaver_test_validated_flag='new value'):
- pass
- self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag)
-
- def test_decorator_without_call(self):
-
- @flagsaver.flagsaver
- def mutate_flags(value):
- """Test function that mutates a flag."""
- # The undecorated method mutates --flagsaver_test_flag0 to the given value
- # and then returns the value of that flag. If the @flagsaver.flagsaver
- # decorator works as designed, then this mutation will be reverted after
- # this method returns.
- FLAGS.flagsaver_test_flag0 = value
- return FLAGS.flagsaver_test_flag0
-
- # mutate_flags returns the flag value before it gets restored by
- # the flagsaver decorator. So we check that flag value was
- # actually changed in the method's scope.
- self.assertEqual('new value', mutate_flags('new value'))
- # But... notice that the flag is now unchanged0.
- self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
-
- def test_decorator_without_parameters(self):
-
- @flagsaver.flagsaver()
- def mutate_flags(value):
- FLAGS.flagsaver_test_flag0 = value
- return FLAGS.flagsaver_test_flag0
-
- self.assertEqual('new value', mutate_flags('new value'))
- self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
-
- def test_decorator_with_overrides(self):
+ # Flagsaver will also save the state of flags that have been modified.
+ FLAGS.flagsaver_test_flag1 = 'outside flagsaver'
+
+ # Save all existing flag state, and set some flags as if they were parsed on
+ # the command line. Because of this, the new values must be provided as str,
+ # even if the flag type is something other than string.
+ with flagsaver.as_parsed(
+ (STR_FLAG, 'new string value'), # Override using flagholder object.
+ (INT_FLAG, '123'), # Override an int flag (NOTE: must specify as str).
+ flagsaver_test_flag0='new value', # Override using flag name.
+ ):
+ # All the flags have their overridden values.
+ self.assertEqual('new string value', STR_FLAG.value)
+ self.assertTrue(STR_FLAG.present)
+ self.assertEqual(123, INT_FLAG.value)
+ self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
+ # Even if we change other flags, they will reset on context exit.
+ FLAGS.flagsaver_test_flag1 = 'new value 1'
- @flagsaver.flagsaver(flagsaver_test_flag0='new value')
- def mutate_flags():
- """Test function expecting new value."""
- # If the @flagsaver.decorator decorator works as designed,
- # then the value of the flag should be changed in the scope of
- # the method but the change will be reverted after this method
- # returns.
- return FLAGS.flagsaver_test_flag0
-
- # mutate_flags returns the flag value before it gets restored by
- # the flagsaver decorator. So we check that flag value was
- # actually changed in the method's scope.
- self.assertEqual('new value', mutate_flags())
- # But... notice that the flag is now unchanged0.
+ # The flags have all reset to their pre-flagsaver values.
+ self.assertEqual('str default', STR_FLAG.value)
+ self.assertFalse(STR_FLAG.present)
+ self.assertEqual(1, INT_FLAG.value)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
-
- def test_decorator_with_cross_validated_overrides_set_together(self):
-
- # When the flags are set in the same flagsaver call their validators will
- # be triggered only once the setting is done.
- @flagsaver.flagsaver(
- flagsaver_test_validated_flag1='new_value',
- flagsaver_test_validated_flag2='new_value')
- def mutate_flags_together():
- return (FLAGS.flagsaver_test_validated_flag1,
- FLAGS.flagsaver_test_validated_flag2)
-
- self.assertEqual(('new_value', 'new_value'), mutate_flags_together())
-
- # The flags have not changed outside the context of the function.
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
-
- def test_decorator_with_cross_validated_overrides_set_badly(self):
-
- # Different values should violate the validator.
- @flagsaver.flagsaver(
- flagsaver_test_validated_flag1='new_value',
- flagsaver_test_validated_flag2='other_value')
- def mutate_flags_together_badly():
- return (FLAGS.flagsaver_test_validated_flag1,
- FLAGS.flagsaver_test_validated_flag2)
-
- with self.assertRaisesRegex(flags.IllegalFlagValueError,
- 'Flag validation failed'):
- mutate_flags_together_badly()
-
- # The flags have not changed outside the context of the exception.
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
-
- def test_decorator_with_cross_validated_overrides_set_separately(self):
-
- # Setting the flags sequentially and not together will trip the validator,
- # because it will be called at the end of each flagsaver call.
- @flagsaver.flagsaver(flagsaver_test_validated_flag1='new_value')
- @flagsaver.flagsaver(flagsaver_test_validated_flag2='new_value')
- def mutate_flags_separately():
- return (FLAGS.flagsaver_test_validated_flag1,
- FLAGS.flagsaver_test_validated_flag2)
-
- with self.assertRaisesRegex(flags.IllegalFlagValueError,
- 'Flag validation failed'):
- mutate_flags_separately()
-
- # The flags have not changed outside the context of the exception.
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
-
- def test_save_flag_value(self):
+ self.assertEqual('outside flagsaver', FLAGS.flagsaver_test_flag1)
+
+ def test_as_parsed_decorator(self):
+ # flagsaver.as_parsed can also be used as a decorator.
+ @flagsaver.as_parsed((INT_FLAG, '123'))
+ def do_something_with_flags():
+ self.assertEqual(123, INT_FLAG.value)
+ self.assertTrue(INT_FLAG.present)
+
+ do_something_with_flags()
+ self.assertEqual(1, INT_FLAG.value)
+ self.assertFalse(INT_FLAG.present)
+
+ def test_flagsaver_flagsaver(self):
+ # If you don't want the flags to go through parsing, you can instead use
+ # flagsaver.flagsaver(). With this method, you provide the native python
+ # value you'd like the flags to take on. Otherwise it functions similar to
+ # flagsaver.as_parsed().
+ @flagsaver.flagsaver((INT_FLAG, 345))
+ def do_something_with_flags():
+ self.assertEqual(345, INT_FLAG.value)
+ # Note that because this flag was never parsed, it will not register as
+ # .present unless you manually set that attribute.
+ self.assertFalse(INT_FLAG.present)
+ # If you do chose to modify things about the flag (such as .present) those
+ # changes will still be cleaned up when flagsaver.flagsaver() exits.
+ INT_FLAG.present = True
+
+ self.assertEqual(1, INT_FLAG.value)
+ # flagsaver.flagsaver() restored INT_FLAG.present to the state it was in
+ # before entering the context.
+ self.assertFalse(INT_FLAG.present)
+
+
+class SaveFlagValuesTest(absltest.TestCase):
+ """Test flagsaver.save_flag_values() and flagsaver.restore_flag_values().
+
+ In this test, we insure that *all* properties of flags get restored. In other
+ tests we only try changing the flag value.
+ """
+
+ def test_assign_value(self):
# First save the flag values.
saved_flag_values = flagsaver.save_flag_values()
@@ -250,7 +143,7 @@ class FlagSaverTest(absltest.TestCase):
flagsaver.restore_flag_values(saved_flag_values)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- def test_save_flag_default(self):
+ def test_set_default(self):
# First save the flag.
saved_flag_values = flagsaver.save_flag_values()
@@ -262,7 +155,7 @@ class FlagSaverTest(absltest.TestCase):
flagsaver.restore_flag_values(saved_flag_values)
self.assertEqual('unchanged0', FLAGS['flagsaver_test_flag0'].default)
- def test_restore_after_parse(self):
+ def test_parse(self):
# First save the flag.
saved_flag_values = flagsaver.save_flag_values()
@@ -278,9 +171,72 @@ class FlagSaverTest(absltest.TestCase):
self.assertEqual('unchanged0', FLAGS['flagsaver_test_flag0'].value)
self.assertEqual(0, FLAGS['flagsaver_test_flag0'].present)
- def test_decorator_with_exception(self):
+ def test_assign_validators(self):
+ # First save the flag.
+ saved_flag_values = flagsaver.save_flag_values()
+
+ # Sanity check that a validator already exists.
+ self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 1)
+ original_validators = list(FLAGS['flagsaver_test_flag0'].validators)
+
+ def no_space(value):
+ return ' ' not in value
+
+ # Add a new validator.
+ flags.register_validator('flagsaver_test_flag0', no_space)
+ self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 2)
+
+ # Now restore the flag to its original value.
+ flagsaver.restore_flag_values(saved_flag_values)
+ self.assertEqual(
+ original_validators, FLAGS['flagsaver_test_flag0'].validators
+ )
+
+
+@parameterized.named_parameters(
+ dict(
+ testcase_name='flagsaver.flagsaver',
+ flagsaver_method=flagsaver.flagsaver,
+ ),
+ dict(
+ testcase_name='flagsaver.as_parsed',
+ flagsaver_method=flagsaver.as_parsed,
+ ),
+)
+class NoOverridesTest(parameterized.TestCase):
+ """Test flagsaver.flagsaver and flagsaver.as_parsed without overrides."""
+
+ def test_context_manager_with_call(self, flagsaver_method):
+ with flagsaver_method():
+ FLAGS.flagsaver_test_flag0 = 'new value'
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+
+ def test_context_manager_with_exception(self, flagsaver_method):
+ with self.assertRaises(_TestError):
+ with flagsaver_method():
+ FLAGS.flagsaver_test_flag0 = 'new value'
+ # Simulate a failed test.
+ raise _TestError('something happened')
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+
+ def test_decorator_without_call(self, flagsaver_method):
+ @flagsaver_method
+ def mutate_flags():
+ FLAGS.flagsaver_test_flag0 = 'new value'
+
+ mutate_flags()
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+
+ def test_decorator_with_call(self, flagsaver_method):
+ @flagsaver_method()
+ def mutate_flags():
+ FLAGS.flagsaver_test_flag0 = 'new value'
- @flagsaver.flagsaver
+ mutate_flags()
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+
+ def test_decorator_with_exception(self, flagsaver_method):
+ @flagsaver_method()
def raise_exception():
FLAGS.flagsaver_test_flag0 = 'new value'
# Simulate a failed test.
@@ -290,62 +246,262 @@ class FlagSaverTest(absltest.TestCase):
self.assertRaises(_TestError, raise_exception)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- def test_validator_list_is_restored(self):
- self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 1)
- original_validators = list(FLAGS['flagsaver_test_flag0'].validators)
+@parameterized.named_parameters(
+ dict(
+ testcase_name='flagsaver.flagsaver',
+ flagsaver_method=flagsaver.flagsaver,
+ ),
+ dict(
+ testcase_name='flagsaver.as_parsed',
+ flagsaver_method=flagsaver.as_parsed,
+ ),
+)
+class TestStringFlagOverrides(parameterized.TestCase):
+ """Test flagsaver.flagsaver and flagsaver.as_parsed with string overrides.
+
+ Note that these tests can be parameterized because both .flagsaver and
+ .as_parsed expect a str input when overriding a string flag. For non-string
+ flags these two flagsaver methods have separate tests elsewhere in this file.
+
+ Each test is one class of overrides, executed twice. Once as a context
+ manager, and once as a decorator on a mutate_flags() method.
+ """
+
+ def test_keyword_overrides(self, flagsaver_method):
+ # Context manager:
+ with flagsaver_method(flagsaver_test_flag0='new value'):
+ self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- @flagsaver.flagsaver
- def modify_validators():
+ # Decorator:
+ @flagsaver_method(flagsaver_test_flag0='new value')
+ def mutate_flags():
+ self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
- def no_space(value):
- return ' ' not in value
+ mutate_flags()
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- flags.register_validator('flagsaver_test_flag0', no_space)
- self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 2)
+ def test_flagholder_overrides(self, flagsaver_method):
+ with flagsaver_method((STR_FLAG, 'new value')):
+ self.assertEqual('new value', STR_FLAG.value)
+ self.assertEqual('str default', STR_FLAG.value)
- modify_validators()
- self.assertEqual(original_validators,
- FLAGS['flagsaver_test_flag0'].validators)
+ @flagsaver_method((STR_FLAG, 'new value'))
+ def mutate_flags():
+ self.assertEqual('new value', STR_FLAG.value)
+ mutate_flags()
+ self.assertEqual('str default', STR_FLAG.value)
-class FlagSaverDecoratorUsageTest(absltest.TestCase):
+ def test_keyword_and_flagholder_overrides(self, flagsaver_method):
+ with flagsaver_method(
+ (STR_FLAG, 'another value'), flagsaver_test_flag0='new value'
+ ):
+ self.assertEqual('another value', STR_FLAG.value)
+ self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
+ self.assertEqual('str default', STR_FLAG.value)
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- @flagsaver.flagsaver
- def test_mutate1(self):
- # Even though other test cases change the flag, it should be
- # restored to 'unchanged0' if the flagsaver is working.
+ @flagsaver_method(
+ (STR_FLAG, 'another value'), flagsaver_test_flag0='new value'
+ )
+ def mutate_flags():
+ self.assertEqual('another value', STR_FLAG.value)
+ self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
+
+ mutate_flags()
+ self.assertEqual('str default', STR_FLAG.value)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- FLAGS.flagsaver_test_flag0 = 'changed0'
- @flagsaver.flagsaver
- def test_mutate2(self):
- # Even though other test cases change the flag, it should be
- # restored to 'unchanged0' if the flagsaver is working.
+ def test_cross_validated_overrides_set_together(self, flagsaver_method):
+ # When the flags are set in the same flagsaver call their validators will
+ # be triggered only once the setting is done.
+ with flagsaver_method(
+ flagsaver_test_validated_flag1='new_value',
+ flagsaver_test_validated_flag2='new_value',
+ ):
+ self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag1)
+ self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag2)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ @flagsaver_method(
+ flagsaver_test_validated_flag1='new_value',
+ flagsaver_test_validated_flag2='new_value',
+ )
+ def mutate_flags():
+ self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag1)
+ self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag2)
+
+ mutate_flags()
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ def test_cross_validated_overrides_set_badly(self, flagsaver_method):
+ # Different values should violate the validator.
+ with self.assertRaisesRegex(
+ flags.IllegalFlagValueError, 'Flag validation failed'
+ ):
+ with flagsaver_method(
+ flagsaver_test_validated_flag1='new_value',
+ flagsaver_test_validated_flag2='other_value',
+ ):
+ pass
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ @flagsaver_method(
+ flagsaver_test_validated_flag1='new_value',
+ flagsaver_test_validated_flag2='other_value',
+ )
+ def mutate_flags():
+ pass
+
+ self.assertRaisesRegex(
+ flags.IllegalFlagValueError, 'Flag validation failed', mutate_flags
+ )
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ def test_cross_validated_overrides_set_separately(self, flagsaver_method):
+ # Setting just one flag will trip the validator as well.
+ with self.assertRaisesRegex(
+ flags.IllegalFlagValueError, 'Flag validation failed'
+ ):
+ with flagsaver_method(flagsaver_test_validated_flag1='new_value'):
+ pass
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ @flagsaver_method(flagsaver_test_validated_flag1='new_value')
+ def mutate_flags():
+ pass
+
+ self.assertRaisesRegex(
+ flags.IllegalFlagValueError, 'Flag validation failed', mutate_flags
+ )
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ def test_validation_exception(self, flagsaver_method):
+ with self.assertRaises(flags.IllegalFlagValueError):
+ with flagsaver_method(
+ flagsaver_test_flag0='new value',
+ flagsaver_test_validated_flag='new value',
+ ):
+ pass
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- FLAGS.flagsaver_test_flag0 = 'changed0'
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag)
- @flagsaver.flagsaver
- def test_mutate3(self):
- # Even though other test cases change the flag, it should be
- # restored to 'unchanged0' if the flagsaver is working.
+ @flagsaver_method(
+ flagsaver_test_flag0='new value',
+ flagsaver_test_validated_flag='new value',
+ )
+ def mutate_flags():
+ pass
+
+ self.assertRaises(flags.IllegalFlagValueError, mutate_flags)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- FLAGS.flagsaver_test_flag0 = 'changed0'
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag)
- @flagsaver.flagsaver
- def test_mutate4(self):
- # Even though other test cases change the flag, it should be
- # restored to 'unchanged0' if the flagsaver is working.
+ def test_unknown_flag_raises_exception(self, flagsaver_method):
+ self.assertNotIn('this_flag_does_not_exist', FLAGS)
+
+ # Flagsaver raises an error when trying to override a non-existent flag.
+ with self.assertRaises(flags.UnrecognizedFlagError):
+ with flagsaver_method(
+ flagsaver_test_flag0='new value', this_flag_does_not_exist='new value'
+ ):
+ pass
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- FLAGS.flagsaver_test_flag0 = 'changed0'
+ @flagsaver_method(
+ flagsaver_test_flag0='new value', this_flag_does_not_exist='new value'
+ )
+ def mutate_flags():
+ pass
+
+ self.assertRaises(flags.UnrecognizedFlagError, mutate_flags)
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
-class FlagSaverSetUpTearDownUsageTest(absltest.TestCase):
+ # Make sure flagsaver didn't create the flag at any point.
+ self.assertNotIn('this_flag_does_not_exist', FLAGS)
+
+
+class AsParsedTest(absltest.TestCase):
+
+ def test_parse_context_manager_sets_present_and_using_default(self):
+ self.assertFalse(INT_FLAG.present)
+ self.assertFalse(STR_FLAG.present)
+ # Note that .using_default_value isn't available on the FlagHolder directly.
+ self.assertTrue(FLAGS[INT_FLAG.name].using_default_value)
+ self.assertTrue(FLAGS[STR_FLAG.name].using_default_value)
+
+ with flagsaver.as_parsed((INT_FLAG, '123'),
+ flagsaver_test_str_flag='new value'):
+ self.assertTrue(INT_FLAG.present)
+ self.assertTrue(STR_FLAG.present)
+ self.assertFalse(FLAGS[INT_FLAG.name].using_default_value)
+ self.assertFalse(FLAGS[STR_FLAG.name].using_default_value)
+
+ self.assertFalse(INT_FLAG.present)
+ self.assertFalse(STR_FLAG.present)
+ self.assertTrue(FLAGS[INT_FLAG.name].using_default_value)
+ self.assertTrue(FLAGS[STR_FLAG.name].using_default_value)
+
+ def test_parse_decorator_sets_present_and_using_default(self):
+ self.assertFalse(INT_FLAG.present)
+ self.assertFalse(STR_FLAG.present)
+ # Note that .using_default_value isn't available on the FlagHolder directly.
+ self.assertTrue(FLAGS[INT_FLAG.name].using_default_value)
+ self.assertTrue(FLAGS[STR_FLAG.name].using_default_value)
+
+ @flagsaver.as_parsed((INT_FLAG, '123'), flagsaver_test_str_flag='new value')
+ def some_func():
+ self.assertTrue(INT_FLAG.present)
+ self.assertTrue(STR_FLAG.present)
+ self.assertFalse(FLAGS[INT_FLAG.name].using_default_value)
+ self.assertFalse(FLAGS[STR_FLAG.name].using_default_value)
+
+ some_func()
+ self.assertFalse(INT_FLAG.present)
+ self.assertFalse(STR_FLAG.present)
+ self.assertTrue(FLAGS[INT_FLAG.name].using_default_value)
+ self.assertTrue(FLAGS[STR_FLAG.name].using_default_value)
+
+ def test_parse_decorator_with_multi_int_flag(self):
+ self.assertFalse(MULTI_INT_FLAG.present)
+ self.assertIsNone(MULTI_INT_FLAG.value)
+
+ @flagsaver.as_parsed((MULTI_INT_FLAG, ['123', '456']))
+ def assert_flags_updated():
+ self.assertTrue(MULTI_INT_FLAG.present)
+ self.assertCountEqual([123, 456], MULTI_INT_FLAG.value)
+
+ assert_flags_updated()
+ self.assertFalse(MULTI_INT_FLAG.present)
+ self.assertIsNone(MULTI_INT_FLAG.value)
+
+ def test_parse_raises_type_error(self):
+ with self.assertRaisesRegex(
+ TypeError,
+ r'flagsaver\.as_parsed\(\) cannot parse flagsaver_test_int_flag\. '
+ r'Expected a single string or sequence of strings but .*int.* was '
+ r'provided\.'):
+ manager = flagsaver.as_parsed(flagsaver_test_int_flag=123)
+ del manager
+
+
+class SetUpTearDownTest(absltest.TestCase):
+ """Example using a single flagsaver in setUp."""
def setUp(self):
+ super().setUp()
self.saved_flag_values = flagsaver.save_flag_values()
def tearDown(self):
+ super().tearDown()
flagsaver.restore_flag_values(self.saved_flag_values)
def test_mutate1(self):
@@ -360,28 +516,26 @@ class FlagSaverSetUpTearDownUsageTest(absltest.TestCase):
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
FLAGS.flagsaver_test_flag0 = 'changed0'
- def test_mutate3(self):
- # Even though other test cases change the flag, it should be
- # restored to 'unchanged0' if the flagsaver is working.
- self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- FLAGS.flagsaver_test_flag0 = 'changed0'
-
- def test_mutate4(self):
- # Even though other test cases change the flag, it should be
- # restored to 'unchanged0' if the flagsaver is working.
- self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- FLAGS.flagsaver_test_flag0 = 'changed0'
-
-class FlagSaverBadUsageTest(absltest.TestCase):
- """Tests that certain kinds of improper usages raise errors."""
-
- def test_flag_saver_on_class(self):
+@parameterized.named_parameters(
+ dict(
+ testcase_name='flagsaver.flagsaver',
+ flagsaver_method=flagsaver.flagsaver,
+ ),
+ dict(
+ testcase_name='flagsaver.as_parsed',
+ flagsaver_method=flagsaver.as_parsed,
+ ),
+)
+class BadUsageTest(parameterized.TestCase):
+ """Tests that improper usage (such as decorating a class) raise errors."""
+
+ def test_flag_saver_on_class(self, flagsaver_method):
with self.assertRaises(TypeError):
# WRONG. Don't do this.
# Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
- @flagsaver.flagsaver
+ @flagsaver_method
class FooTest(absltest.TestCase):
def test_tautology(self):
@@ -389,12 +543,12 @@ class FlagSaverBadUsageTest(absltest.TestCase):
del FooTest
- def test_flag_saver_call_on_class(self):
+ def test_flag_saver_call_on_class(self, flagsaver_method):
with self.assertRaises(TypeError):
# WRONG. Don't do this.
# Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
- @flagsaver.flagsaver()
+ @flagsaver_method()
class FooTest(absltest.TestCase):
def test_tautology(self):
@@ -402,12 +556,12 @@ class FlagSaverBadUsageTest(absltest.TestCase):
del FooTest
- def test_flag_saver_with_overrides_on_class(self):
+ def test_flag_saver_with_overrides_on_class(self, flagsaver_method):
with self.assertRaises(TypeError):
# WRONG. Don't do this.
# Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
- @flagsaver.flagsaver(foo='bar')
+ @flagsaver_method(foo='bar')
class FooTest(absltest.TestCase):
def test_tautology(self):
@@ -415,48 +569,57 @@ class FlagSaverBadUsageTest(absltest.TestCase):
del FooTest
- def test_multiple_positional_parameters(self):
+ def test_multiple_positional_parameters(self, flagsaver_method):
with self.assertRaises(ValueError):
func_a = lambda: None
func_b = lambda: None
- flagsaver.flagsaver(func_a, func_b)
+ flagsaver_method(func_a, func_b)
- def test_both_positional_and_keyword_parameters(self):
+ def test_both_positional_and_keyword_parameters(self, flagsaver_method):
with self.assertRaises(ValueError):
func_a = lambda: None
- flagsaver.flagsaver(func_a, flagsaver_test_flag0='new value')
+ flagsaver_method(func_a, flagsaver_test_flag0='new value')
- def test_duplicate_holder_parameters(self):
+ def test_duplicate_holder_parameters(self, flagsaver_method):
with self.assertRaises(ValueError):
- flagsaver.flagsaver((INT_FLAG, 45), (INT_FLAG, 45))
+ flagsaver_method((INT_FLAG, 45), (INT_FLAG, 45))
- def test_duplicate_holder_and_kw_parameter(self):
+ def test_duplicate_holder_and_kw_parameter(self, flagsaver_method):
with self.assertRaises(ValueError):
- flagsaver.flagsaver((INT_FLAG, 45), **{INT_FLAG.name: 45})
+ flagsaver_method((INT_FLAG, 45), **{INT_FLAG.name: 45})
- def test_both_positional_and_holder_parameters(self):
+ def test_both_positional_and_holder_parameters(self, flagsaver_method):
with self.assertRaises(ValueError):
func_a = lambda: None
- flagsaver.flagsaver(func_a, (INT_FLAG, 45))
+ flagsaver_method(func_a, (INT_FLAG, 45))
- def test_holder_parameters_wrong_shape(self):
+ def test_holder_parameters_wrong_shape(self, flagsaver_method):
with self.assertRaises(ValueError):
- flagsaver.flagsaver(INT_FLAG)
+ flagsaver_method(INT_FLAG)
- def test_holder_parameters_tuple_too_long(self):
+ def test_holder_parameters_tuple_too_long(self, flagsaver_method):
with self.assertRaises(ValueError):
# Even if it is a bool flag, it should be a tuple
- flagsaver.flagsaver((INT_FLAG, 4, 5))
+ flagsaver_method((INT_FLAG, 4, 5))
- def test_holder_parameters_tuple_wrong_type(self):
+ def test_holder_parameters_tuple_wrong_type(self, flagsaver_method):
with self.assertRaises(ValueError):
# Even if it is a bool flag, it should be a tuple
- flagsaver.flagsaver((4, INT_FLAG))
+ flagsaver_method((4, INT_FLAG))
- def test_both_wrong_positional_parameters(self):
+ def test_both_wrong_positional_parameters(self, flagsaver_method):
with self.assertRaises(ValueError):
func_a = lambda: None
- flagsaver.flagsaver(func_a, STR_FLAG, '45')
+ flagsaver_method(func_a, STR_FLAG, '45')
+
+ def test_context_manager_no_call(self, flagsaver_method):
+ # The exact exception that's raised appears to be system specific.
+ with self.assertRaises((AttributeError, TypeError)):
+ # Wrong. You must call the flagsaver method before using it as a CM.
+ with flagsaver_method:
+ # We don't expect to get here. A type error should happen when
+ # attempting to enter the context manager.
+ pass
if __name__ == '__main__':