aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYifan Hong <elsk@google.com>2023-12-08 20:05:37 +0000
committerAutomerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>2023-12-08 20:05:37 +0000
commitedf79a69401345d1bad25c758b7ffd6598a3b424 (patch)
treecb823973d78e8ea66a74b7e4942c18c9b7d076f4
parentbd7cdd4d3776c9d08bbfea96d711958b24a9fe9c (diff)
parentadbe1e13891188cf53841ca670d21f5479119973 (diff)
downloadabsl-py-edf79a69401345d1bad25c758b7ffd6598a3b424.tar.gz
Merge tag 'upstream/v1.4.0' into main am: adbe1e1389
Original change: https://android-review.googlesource.com/c/platform/external/python/absl-py/+/2865107 Change-Id: Ie152f87adbe2e84c2525851dbacf32d63133d379 Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
-rw-r--r--CHANGELOG.md18
-rw-r--r--METADATA4
-rw-r--r--absl/flags/__init__.pyi3
-rw-r--r--absl/flags/_defines.py10
-rw-r--r--absl/flags/_flag.pyi3
-rw-r--r--absl/flags/_flagvalues.py6
-rw-r--r--absl/flags/tests/flags_test.py11
-rw-r--r--absl/logging/__init__.py7
-rw-r--r--absl/logging/tests/logging_test.py6
-rw-r--r--absl/testing/BUILD1
-rw-r--r--absl/testing/absltest.py7
-rw-r--r--absl/testing/flagsaver.py232
-rw-r--r--absl/testing/tests/flagsaver_test.py681
-rw-r--r--setup.py2
14 files changed, 699 insertions, 292 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index ae82a55..c8006e9 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -8,6 +8,24 @@ The format is based on [Keep a Changelog](https://keepachangelog.com).
Nothing notable unreleased.
+## 1.4.0 (2023-01-11)
+
+### New
+
+* (testing) Added `@flagsaver.as_parsed`: this allows saving/restoring flags
+ using string values as if parsed from the command line and will also reflect
+ other flag states after command line parsing, e.g. `.present` is set.
+
+### Changed
+
+* (logging) If no log dir is specified `logging.find_log_dir()` now falls back
+ to `tempfile.gettempdir()` instead of `/tmp/`.
+
+### Fixed
+
+* (flags) Additional kwargs (e.g. `short_name=`) to `DEFINE_multi_enum_class`
+ are now correctly passed to the underlying `Flag` object.
+
## 1.3.0 (2022-10-11)
### Added
diff --git a/METADATA b/METADATA
index 78080bb..1dc4ea0 100644
--- a/METADATA
+++ b/METADATA
@@ -16,8 +16,8 @@ third_party {
version: "v1.3.0"
license_type: NOTICE
last_upgrade_date {
- year: 2022
+ year: 2024
month: 11
- day: 2
+ day: 8
}
}
diff --git a/absl/flags/__init__.pyi b/absl/flags/__init__.pyi
index 4eee59e..7bf6842 100644
--- a/absl/flags/__init__.pyi
+++ b/absl/flags/__init__.pyi
@@ -52,6 +52,9 @@ mark_flags_as_required = _validators.mark_flags_as_required
mark_flags_as_mutual_exclusive = _validators.mark_flags_as_mutual_exclusive
mark_bool_flags_as_mutual_exclusive = _validators.mark_bool_flags_as_mutual_exclusive
+# Flag modifiers.
+set_default = _defines.set_default
+
# Key flag related functions.
declare_key_flag = _defines.declare_key_flag
adopt_module_key_flags = _defines.adopt_module_key_flags
diff --git a/absl/flags/_defines.py b/absl/flags/_defines.py
index dce53ea..61354e9 100644
--- a/absl/flags/_defines.py
+++ b/absl/flags/_defines.py
@@ -859,11 +859,17 @@ def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin
"""
return DEFINE_flag(
_flag.MultiEnumClassFlag(
- name, default, help, enum_class, case_sensitive=case_sensitive),
+ name,
+ default,
+ help,
+ enum_class,
+ case_sensitive=case_sensitive,
+ **args,
+ ),
flag_values,
module_name,
required=required,
- **args)
+ )
def DEFINE_alias( # pylint: disable=invalid-name
diff --git a/absl/flags/_flag.pyi b/absl/flags/_flag.pyi
index 9b4a3d3..3506644 100644
--- a/absl/flags/_flag.pyi
+++ b/absl/flags/_flag.pyi
@@ -20,7 +20,7 @@ import functools
from absl.flags import _argument_parser
import enum
-from typing import Text, TypeVar, Generic, Iterable, Type, List, Optional, Any, Union, Sequence
+from typing import Callable, Text, TypeVar, Generic, Iterable, Type, List, Optional, Any, Union, Sequence
_T = TypeVar('_T')
_ET = TypeVar('_ET', bound=enum.Enum)
@@ -44,6 +44,7 @@ class Flag(Generic[_T]):
using_default_value = ... # type: bool
allow_overwrite = ... # type: bool
allow_using_method_names = ... # type: bool
+ validators = ... # type: List[Callable[[Any], bool]]
def __init__(self,
parser: _argument_parser.ArgumentParser[_T],
diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py
index 937dc6c..fd0e631 100644
--- a/absl/flags/_flagvalues.py
+++ b/absl/flags/_flagvalues.py
@@ -411,7 +411,9 @@ class FlagValues:
"""Registers a new flag variable."""
fl = self._flags()
if not isinstance(flag, _flag.Flag):
- raise _exceptions.IllegalFlagValueError(flag)
+ raise _exceptions.IllegalFlagValueError(
+ f'Expect Flag instances, found type {type(flag)}. '
+ "Maybe you didn't mean to use FlagValue.__setitem__?")
if not isinstance(name, str):
raise _exceptions.Error('Flag name must be a string')
if not name:
@@ -790,8 +792,10 @@ class FlagValues:
continue
if flag is not None:
+ # LINT.IfChange
flag.parse(value)
flag.using_default_value = False
+ # LINT.ThenChange(../testing/flagsaver.py:flag_override_parsing)
else:
unparsed_names_and_args.append((name, arg))
diff --git a/absl/flags/tests/flags_test.py b/absl/flags/tests/flags_test.py
index 77ed307..7cacbc8 100644
--- a/absl/flags/tests/flags_test.py
+++ b/absl/flags/tests/flags_test.py
@@ -1591,6 +1591,17 @@ class MultiEnumFlagsTest(absltest.TestCase):
class MultiEnumClassFlagsTest(absltest.TestCase):
+ def test_short_name(self):
+ fv = flags.FlagValues()
+ flags.DEFINE_multi_enum_class(
+ 'fruit',
+ None,
+ Fruit,
+ 'Enum option that can occur multiple times',
+ flag_values=fv,
+ short_name='me')
+ self.assertEqual(fv['fruit'].short_name, 'me')
+
def test_define_results_in_registered_flag_with_none(self):
fv = flags.FlagValues()
enum_defaults = None
diff --git a/absl/logging/__init__.py b/absl/logging/__init__.py
index c0ba4b0..f4e7967 100644
--- a/absl/logging/__init__.py
+++ b/absl/logging/__init__.py
@@ -86,7 +86,9 @@ import os
import socket
import struct
import sys
+import tempfile
import threading
+import tempfile
import time
import timeit
import traceback
@@ -707,6 +709,9 @@ def find_log_dir(log_dir=None):
OSError: raised in Python 2 when it cannot find a log directory.
"""
# Get a list of possible log dirs (will try to use them in order).
+ # NOTE: Google's internal implementation has a special handling for Google
+ # machines, which uses a list of directories. Hence the following uses `dirs`
+ # instead of a single directory.
if log_dir:
# log_dir was explicitly specified as an arg, so use it and it alone.
dirs = [log_dir]
@@ -715,7 +720,7 @@ def find_log_dir(log_dir=None):
# behavior of the same flag in logging.cc).
dirs = [FLAGS['log_dir'].value]
else:
- dirs = ['/tmp/', './']
+ dirs = [tempfile.gettempdir()]
# Find the first usable log dir.
for d in dirs:
diff --git a/absl/logging/tests/logging_test.py b/absl/logging/tests/logging_test.py
index e5c4fcc..1c337f9 100644
--- a/absl/logging/tests/logging_test.py
+++ b/absl/logging/tests/logging_test.py
@@ -706,7 +706,7 @@ class LoggingTest(absltest.TestCase):
os.path.isdir.return_value = True
os.access.return_value = True
log_dir = logging.find_log_dir()
- self.assertEqual('/tmp/', log_dir)
+ self.assertEqual(tempfile.gettempdir(), log_dir)
@flagsaver.flagsaver(log_dir='')
def test_find_log_dir_with_tmp(self):
@@ -714,10 +714,10 @@ class LoggingTest(absltest.TestCase):
mock.patch.object(os.path, 'exists'), \
mock.patch.object(os.path, 'isdir'):
os.path.exists.return_value = False
- os.path.isdir.side_effect = lambda path: path == '/tmp/'
+ os.path.isdir.side_effect = lambda path: path == tempfile.gettempdir()
os.access.return_value = True
log_dir = logging.find_log_dir()
- self.assertEqual('/tmp/', log_dir)
+ self.assertEqual(tempfile.gettempdir(), log_dir)
def test_find_log_dir_with_nothing(self):
with mock.patch.object(os.path, 'exists'), \
diff --git a/absl/testing/BUILD b/absl/testing/BUILD
index d428792..3173c4b 100644
--- a/absl/testing/BUILD
+++ b/absl/testing/BUILD
@@ -212,6 +212,7 @@ py_test(
deps = [
":absltest",
":flagsaver",
+ ":parameterized",
"//absl/flags",
],
)
diff --git a/absl/testing/absltest.py b/absl/testing/absltest.py
index 9071f8f..1bbcee7 100644
--- a/absl/testing/absltest.py
+++ b/absl/testing/absltest.py
@@ -533,7 +533,10 @@ class _TempFile(object):
# currently `Any` to avoid [bad-return-type] errors in the open_* methods.
@contextlib.contextmanager
def _open(
- self, mode: str, encoding: str = 'utf8', errors: str = 'strict'
+ self,
+ mode: str,
+ encoding: Optional[str] = 'utf8',
+ errors: Optional[str] = 'strict',
) -> Iterator[Any]:
with io.open(
self.full_path, mode=mode, encoding=encoding, errors=errors) as fp:
@@ -638,7 +641,7 @@ class TestCase(unittest.TestCase):
self.assertTrue(os.path.exists(expected_paths[1]))
self.assertEqual('foo', out_log.read_text())
- See also: :meth:`create_tempdir` for creating temporary files.
+ See also: :meth:`create_tempfile` for creating temporary files.
Args:
name: Optional name of the directory. If not given, a unique
diff --git a/absl/testing/flagsaver.py b/absl/testing/flagsaver.py
index 37926d7..e96c8c5 100644
--- a/absl/testing/flagsaver.py
+++ b/absl/testing/flagsaver.py
@@ -50,6 +50,36 @@ Here are examples of each method. They all call ``do_stuff()`` while
finally:
flagsaver.restore_flag_values(saved_flag_values)
+ # Use the parsing version to emulate users providing the flags.
+ # Note that all flags must be provided as strings (unparsed).
+ @flagsaver.as_parsed(some_int_flag='123')
+ def some_func():
+ # Because the flag was parsed it is considered "present".
+ assert FLAGS.some_int_flag.present
+ do_stuff()
+
+ # flagsaver.as_parsed() can also be used as a context manager just like
+ # flagsaver.flagsaver()
+ with flagsaver.as_parsed(some_int_flag='123'):
+ do_stuff()
+
+ # The flagsaver.as_parsed() interface also supports FlagHolder objects.
+ @flagsaver.as_parsed((module.FOO_FLAG, 'foo'), (other_mod.BAR_FLAG, '23'))
+ def some_func():
+ do_stuff()
+
+ # Using as_parsed with a multi_X flag requires a sequence of strings.
+ @flagsaver.as_parsed(some_multi_int_flag=['123', '456'])
+ def some_func():
+ assert FLAGS.some_multi_int_flag.present
+ do_stuff()
+
+ # If a flag name includes non-identifier characters it can be specified like
+ # so:
+ @flagsaver.as_parsed(**{'i-like-dashes': 'true'})
+ def some_func():
+ do_stuff()
+
We save and restore a shallow copy of each Flag object's ``__dict__`` attribute.
This preserves all attributes of the flag, such as whether or not it was
overridden from its default value.
@@ -59,18 +89,113 @@ exception will be raised. However if you *add* a flag after saving flag values,
and then restore flag values, the added flag will be deleted with no errors.
"""
+import collections
import functools
import inspect
+from typing import overload, Any, Callable, Mapping, Tuple, TypeVar, Type, Sequence, Union
from absl import flags
FLAGS = flags.FLAGS
+# The type of pre/post wrapped functions.
+_CallableT = TypeVar('_CallableT', bound=Callable)
+
+
+@overload
+def flagsaver(*args: Tuple[flags.FlagHolder, Any],
+ **kwargs: Any) -> '_FlagOverrider':
+ ...
+
+
+@overload
+def flagsaver(func: _CallableT) -> _CallableT:
+ ...
+
+
def flagsaver(*args, **kwargs):
"""The main flagsaver interface. See module doc for usage."""
+ return _construct_overrider(_FlagOverrider, *args, **kwargs)
+
+
+@overload
+def as_parsed(*args: Tuple[flags.FlagHolder, Union[str, Sequence[str]]],
+ **kwargs: Union[str, Sequence[str]]) -> '_ParsingFlagOverrider':
+ ...
+
+
+@overload
+def as_parsed(func: _CallableT) -> _CallableT:
+ ...
+
+
+def as_parsed(*args, **kwargs):
+ """Overrides flags by parsing strings, saves flag state similar to flagsaver.
+
+ This function can be used as either a decorator or context manager similar to
+ flagsaver.flagsaver(). However, where flagsaver.flagsaver() directly sets the
+ flags to new values, this function will parse the provided arguments as if
+ they were provided on the command line. Among other things, this will cause
+ `FLAGS['flag_name'].parsed == True`.
+
+ A note on unparsed input: For many flag types, the unparsed version will be
+ a single string. However for multi_x (multi_string, multi_integer, multi_enum)
+ the unparsed version will be a Sequence of strings.
+
+ Args:
+ *args: Tuples of FlagHolders and their unparsed value.
+ **kwargs: The keyword args are flag names, and the values are unparsed
+ values.
+
+ Returns:
+ _ParsingFlagOverrider that serves as a context manager or decorator. Will
+ save previous flag state and parse new flags, then on cleanup it will
+ restore the previous flag state.
+ """
+ return _construct_overrider(_ParsingFlagOverrider, *args, **kwargs)
+
+
+# NOTE: the order of these overload declarations matters. The type checker will
+# pick the first match which could be incorrect.
+@overload
+def _construct_overrider(
+ flag_overrider_cls: Type['_ParsingFlagOverrider'],
+ *args: Tuple[flags.FlagHolder, Union[str, Sequence[str]]],
+ **kwargs: Union[str, Sequence[str]]) -> '_ParsingFlagOverrider':
+ ...
+
+
+@overload
+def _construct_overrider(flag_overrider_cls: Type['_FlagOverrider'],
+ *args: Tuple[flags.FlagHolder, Any],
+ **kwargs: Any) -> '_FlagOverrider':
+ ...
+
+
+@overload
+def _construct_overrider(flag_overrider_cls: Type['_FlagOverrider'],
+ func: _CallableT) -> _CallableT:
+ ...
+
+
+def _construct_overrider(flag_overrider_cls, *args, **kwargs):
+ """Handles the args/kwargs returning an instance of flag_overrider_cls.
+
+ If flag_overrider_cls is _FlagOverrider then values should be native python
+ types matching the python types. Otherwise if flag_overrider_cls is
+ _ParsingFlagOverrider the values should be strings or sequences of strings.
+
+ Args:
+ flag_overrider_cls: The class that will do the overriding.
+ *args: Tuples of FlagHolder and the new flag value.
+ **kwargs: Keword args mapping flag name to new flag value.
+
+ Returns:
+ A _FlagOverrider to be used as a decorator or context manager.
+ """
if not args:
- return _FlagOverrider(**kwargs)
+ return flag_overrider_cls(**kwargs)
# args can be [func] if used as `@flagsaver` instead of `@flagsaver(...)`
if len(args) == 1 and callable(args[0]):
if kwargs:
@@ -79,7 +204,7 @@ def flagsaver(*args, **kwargs):
func = args[0]
if inspect.isclass(func):
raise TypeError('@flagsaver.flagsaver cannot be applied to a class.')
- return _wrap(func, {})
+ return _wrap(flag_overrider_cls, func, {})
# args can be a list of (FlagHolder, value) pairs.
# In which case they augment any specified kwargs.
for arg in args:
@@ -91,15 +216,17 @@ def flagsaver(*args, **kwargs):
if holder.name in kwargs:
raise ValueError('Cannot set --%s multiple times' % holder.name)
kwargs[holder.name] = value
- return _FlagOverrider(**kwargs)
+ return flag_overrider_cls(**kwargs)
-def save_flag_values(flag_values=FLAGS):
+def save_flag_values(
+ flag_values: flags.FlagValues = FLAGS) -> Mapping[str, Mapping[str, Any]]:
"""Returns copy of flag values as a dict.
Args:
- flag_values: FlagValues, the FlagValues instance with which the flag will
- be saved. This should almost never need to be overridden.
+ flag_values: FlagValues, the FlagValues instance with which the flag will be
+ saved. This should almost never need to be overridden.
+
Returns:
Dictionary mapping keys to values. Keys are flag names, values are
corresponding ``__dict__`` members. E.g. ``{'key': value_dict, ...}``.
@@ -107,13 +234,14 @@ def save_flag_values(flag_values=FLAGS):
return {name: _copy_flag_dict(flag_values[name]) for name in flag_values}
-def restore_flag_values(saved_flag_values, flag_values=FLAGS):
+def restore_flag_values(saved_flag_values: Mapping[str, Mapping[str, Any]],
+ flag_values: flags.FlagValues = FLAGS):
"""Restores flag values based on the dictionary of flag values.
Args:
saved_flag_values: {'flag_name': value_dict, ...}
- flag_values: FlagValues, the FlagValues instance from which the flag will
- be restored. This should almost never need to be overridden.
+ flag_values: FlagValues, the FlagValues instance from which the flag will be
+ restored. This should almost never need to be overridden.
"""
new_flag_names = list(flag_values)
for name in new_flag_names:
@@ -127,23 +255,38 @@ def restore_flag_values(saved_flag_values, flag_values=FLAGS):
flag_values[name].__dict__ = saved
-def _wrap(func, overrides):
+@overload
+def _wrap(flag_overrider_cls: Type['_FlagOverrider'], func: _CallableT,
+ overrides: Mapping[str, Any]) -> _CallableT:
+ ...
+
+
+@overload
+def _wrap(flag_overrider_cls: Type['_ParsingFlagOverrider'], func: _CallableT,
+ overrides: Mapping[str, Union[str, Sequence[str]]]) -> _CallableT:
+ ...
+
+
+def _wrap(flag_overrider_cls, func, overrides):
"""Creates a wrapper function that saves/restores flag values.
Args:
- func: function object - This will be called between saving flags and
- restoring flags.
- overrides: {str: object} - Flag names mapped to their values. These flags
- will be set after saving the original flag state.
+ flag_overrider_cls: The class that will be used as a context manager.
+ func: This will be called between saving flags and restoring flags.
+ overrides: Flag names mapped to their values. These flags will be set after
+ saving the original flag state. The type of the values depends on if
+ _FlagOverrider or _ParsingFlagOverrider was specified.
Returns:
- return value from func()
+ A wrapped version of func.
"""
+
@functools.wraps(func)
def _flagsaver_wrapper(*args, **kwargs):
"""Wrapper function that saves and restores flags."""
- with _FlagOverrider(**overrides):
+ with flag_overrider_cls(**overrides):
return func(*args, **kwargs)
+
return _flagsaver_wrapper
@@ -154,14 +297,14 @@ class _FlagOverrider(object):
completes.
"""
- def __init__(self, **overrides):
+ def __init__(self, **overrides: Any):
self._overrides = overrides
self._saved_flag_values = None
- def __call__(self, func):
+ def __call__(self, func: _CallableT) -> _CallableT:
if inspect.isclass(func):
raise TypeError('flagsaver cannot be applied to a class.')
- return _wrap(func, self._overrides)
+ return _wrap(self.__class__, func, self._overrides)
def __enter__(self):
self._saved_flag_values = save_flag_values(FLAGS)
@@ -176,7 +319,56 @@ class _FlagOverrider(object):
restore_flag_values(self._saved_flag_values, FLAGS)
-def _copy_flag_dict(flag):
+class _ParsingFlagOverrider(_FlagOverrider):
+ """Context manager for overriding flags.
+
+ Simulates command line parsing.
+
+ This is simlar to _FlagOverrider except that all **overrides should be
+ strings or sequences of strings, and when context is entered this class calls
+ .parse(value)
+
+ This results in the flags having .present set properly.
+ """
+
+ def __init__(self, **overrides: Union[str, Sequence[str]]):
+ for flag_name, new_value in overrides.items():
+ if isinstance(new_value, str):
+ continue
+ if (isinstance(new_value, collections.abc.Sequence) and
+ all(isinstance(single_value, str) for single_value in new_value)):
+ continue
+ raise TypeError(
+ f'flagsaver.as_parsed() cannot parse {flag_name}. Expected a single '
+ f'string or sequence of strings but {type(new_value)} was provided.')
+ super().__init__(**overrides)
+
+ def __enter__(self):
+ self._saved_flag_values = save_flag_values(FLAGS)
+ try:
+ for flag_name, unparsed_value in self._overrides.items():
+ # LINT.IfChange(flag_override_parsing)
+ FLAGS[flag_name].parse(unparsed_value)
+ FLAGS[flag_name].using_default_value = False
+ # LINT.ThenChange()
+
+ # Perform the validation on all modified flags. This is something that
+ # FLAGS._set_attributes() does for you in _FlagOverrider.
+ for flag_name in self._overrides:
+ FLAGS._assert_validators(FLAGS[flag_name].validators)
+
+ except KeyError as e:
+ # If a flag doesn't exist, an UnrecognizedFlagError is more specific.
+ restore_flag_values(self._saved_flag_values, FLAGS)
+ raise flags.UnrecognizedFlagError('Unknown command line flag.') from e
+
+ except:
+ # It may fail because of flag validators or general parsing issues.
+ restore_flag_values(self._saved_flag_values, FLAGS)
+ raise
+
+
+def _copy_flag_dict(flag: flags.Flag) -> Mapping[str, Any]:
"""Returns a copy of the flag object's ``__dict__``.
It's mostly a shallow copy of the ``__dict__``, except it also does a shallow
diff --git a/absl/testing/tests/flagsaver_test.py b/absl/testing/tests/flagsaver_test.py
index e98cd06..b8f91a5 100644
--- a/absl/testing/tests/flagsaver_test.py
+++ b/absl/testing/tests/flagsaver_test.py
@@ -16,6 +16,7 @@
from absl import flags
from absl.testing import absltest
from absl.testing import flagsaver
+from absl.testing import parameterized
flags.DEFINE_string('flagsaver_test_flag0', 'unchanged0', 'flag to test with')
flags.DEFINE_string('flagsaver_test_flag1', 'unchanged1', 'flag to test with')
@@ -31,6 +32,9 @@ INT_FLAG = flags.DEFINE_integer(
STR_FLAG = flags.DEFINE_string(
'flagsaver_test_str_flag', default='str default', help='help')
+MULTI_INT_FLAG = flags.DEFINE_multi_integer('flagsaver_test_multi_int_flag',
+ None, 'flag to test with')
+
@flags.multi_flags_validator(
('flagsaver_test_validated_flag1', 'flagsaver_test_validated_flag2'))
@@ -51,194 +55,83 @@ class _TestError(Exception):
"""Exception class for use in these tests."""
-class FlagSaverTest(absltest.TestCase):
+class CommonUsageTest(absltest.TestCase):
+ """These test cases cover the most common usages of flagsaver."""
- def test_context_manager_without_parameters(self):
- with flagsaver.flagsaver():
- FLAGS.flagsaver_test_flag0 = 'new value'
- self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
-
- def test_context_manager_with_overrides(self):
- with flagsaver.flagsaver(flagsaver_test_flag0='new value'):
- self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
- FLAGS.flagsaver_test_flag1 = 'another value'
+ def test_as_parsed_context_manager(self):
+ # Precondition check, we expect all the flags to start as their default.
+ self.assertEqual('str default', STR_FLAG.value)
+ self.assertFalse(STR_FLAG.present)
+ self.assertEqual(1, INT_FLAG.value)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
- def test_context_manager_with_flagholders(self):
- with flagsaver.flagsaver((INT_FLAG, 3), (STR_FLAG, 'new value')):
- self.assertEqual('new value', STR_FLAG.value)
- self.assertEqual(3, INT_FLAG.value)
- FLAGS.flagsaver_test_flag1 = 'another value'
- self.assertEqual(INT_FLAG.value, INT_FLAG.default)
- self.assertEqual(STR_FLAG.value, STR_FLAG.default)
- self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
-
- def test_context_manager_with_overrides_and_flagholders(self):
- with flagsaver.flagsaver((INT_FLAG, 3), flagsaver_test_flag0='new value'):
- self.assertEqual(STR_FLAG.default, STR_FLAG.value)
- self.assertEqual(3, INT_FLAG.value)
- FLAGS.flagsaver_test_flag0 = 'new value'
- self.assertEqual(INT_FLAG.value, INT_FLAG.default)
- self.assertEqual(STR_FLAG.value, STR_FLAG.default)
- self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
-
- def test_context_manager_with_cross_validated_overrides_set_together(self):
- # When the flags are set in the same flagsaver call their validators will
- # be triggered only once the setting is done.
- with flagsaver.flagsaver(
- flagsaver_test_validated_flag1='new_value',
- flagsaver_test_validated_flag2='new_value'):
- self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag1)
- self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag2)
-
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
-
- def test_context_manager_with_cross_validated_overrides_set_badly(self):
-
- # Different values should violate the validator.
- with self.assertRaisesRegex(flags.IllegalFlagValueError,
- 'Flag validation failed'):
- with flagsaver.flagsaver(
- flagsaver_test_validated_flag1='new_value',
- flagsaver_test_validated_flag2='other_value'):
- pass
-
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
-
- def test_context_manager_with_cross_validated_overrides_set_separately(self):
-
- # Setting just one flag will trip the validator as well.
- with self.assertRaisesRegex(flags.IllegalFlagValueError,
- 'Flag validation failed'):
- with flagsaver.flagsaver(flagsaver_test_validated_flag1='new_value'):
- pass
-
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
-
- def test_context_manager_with_exception(self):
- with self.assertRaises(_TestError):
- with flagsaver.flagsaver(flagsaver_test_flag0='new value'):
- self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
- FLAGS.flagsaver_test_flag1 = 'another value'
- raise _TestError('oops')
- self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
-
- def test_context_manager_with_validation_exception(self):
- with self.assertRaises(flags.IllegalFlagValueError):
- with flagsaver.flagsaver(
- flagsaver_test_flag0='new value',
- flagsaver_test_validated_flag='new value'):
- pass
- self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag)
-
- def test_decorator_without_call(self):
-
- @flagsaver.flagsaver
- def mutate_flags(value):
- """Test function that mutates a flag."""
- # The undecorated method mutates --flagsaver_test_flag0 to the given value
- # and then returns the value of that flag. If the @flagsaver.flagsaver
- # decorator works as designed, then this mutation will be reverted after
- # this method returns.
- FLAGS.flagsaver_test_flag0 = value
- return FLAGS.flagsaver_test_flag0
-
- # mutate_flags returns the flag value before it gets restored by
- # the flagsaver decorator. So we check that flag value was
- # actually changed in the method's scope.
- self.assertEqual('new value', mutate_flags('new value'))
- # But... notice that the flag is now unchanged0.
- self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
-
- def test_decorator_without_parameters(self):
-
- @flagsaver.flagsaver()
- def mutate_flags(value):
- FLAGS.flagsaver_test_flag0 = value
- return FLAGS.flagsaver_test_flag0
-
- self.assertEqual('new value', mutate_flags('new value'))
- self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
-
- def test_decorator_with_overrides(self):
+ # Flagsaver will also save the state of flags that have been modified.
+ FLAGS.flagsaver_test_flag1 = 'outside flagsaver'
+
+ # Save all existing flag state, and set some flags as if they were parsed on
+ # the command line. Because of this, the new values must be provided as str,
+ # even if the flag type is something other than string.
+ with flagsaver.as_parsed(
+ (STR_FLAG, 'new string value'), # Override using flagholder object.
+ (INT_FLAG, '123'), # Override an int flag (NOTE: must specify as str).
+ flagsaver_test_flag0='new value', # Override using flag name.
+ ):
+ # All the flags have their overridden values.
+ self.assertEqual('new string value', STR_FLAG.value)
+ self.assertTrue(STR_FLAG.present)
+ self.assertEqual(123, INT_FLAG.value)
+ self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
+ # Even if we change other flags, they will reset on context exit.
+ FLAGS.flagsaver_test_flag1 = 'new value 1'
- @flagsaver.flagsaver(flagsaver_test_flag0='new value')
- def mutate_flags():
- """Test function expecting new value."""
- # If the @flagsaver.decorator decorator works as designed,
- # then the value of the flag should be changed in the scope of
- # the method but the change will be reverted after this method
- # returns.
- return FLAGS.flagsaver_test_flag0
-
- # mutate_flags returns the flag value before it gets restored by
- # the flagsaver decorator. So we check that flag value was
- # actually changed in the method's scope.
- self.assertEqual('new value', mutate_flags())
- # But... notice that the flag is now unchanged0.
+ # The flags have all reset to their pre-flagsaver values.
+ self.assertEqual('str default', STR_FLAG.value)
+ self.assertFalse(STR_FLAG.present)
+ self.assertEqual(1, INT_FLAG.value)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
-
- def test_decorator_with_cross_validated_overrides_set_together(self):
-
- # When the flags are set in the same flagsaver call their validators will
- # be triggered only once the setting is done.
- @flagsaver.flagsaver(
- flagsaver_test_validated_flag1='new_value',
- flagsaver_test_validated_flag2='new_value')
- def mutate_flags_together():
- return (FLAGS.flagsaver_test_validated_flag1,
- FLAGS.flagsaver_test_validated_flag2)
-
- self.assertEqual(('new_value', 'new_value'), mutate_flags_together())
-
- # The flags have not changed outside the context of the function.
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
-
- def test_decorator_with_cross_validated_overrides_set_badly(self):
-
- # Different values should violate the validator.
- @flagsaver.flagsaver(
- flagsaver_test_validated_flag1='new_value',
- flagsaver_test_validated_flag2='other_value')
- def mutate_flags_together_badly():
- return (FLAGS.flagsaver_test_validated_flag1,
- FLAGS.flagsaver_test_validated_flag2)
-
- with self.assertRaisesRegex(flags.IllegalFlagValueError,
- 'Flag validation failed'):
- mutate_flags_together_badly()
-
- # The flags have not changed outside the context of the exception.
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
-
- def test_decorator_with_cross_validated_overrides_set_separately(self):
-
- # Setting the flags sequentially and not together will trip the validator,
- # because it will be called at the end of each flagsaver call.
- @flagsaver.flagsaver(flagsaver_test_validated_flag1='new_value')
- @flagsaver.flagsaver(flagsaver_test_validated_flag2='new_value')
- def mutate_flags_separately():
- return (FLAGS.flagsaver_test_validated_flag1,
- FLAGS.flagsaver_test_validated_flag2)
-
- with self.assertRaisesRegex(flags.IllegalFlagValueError,
- 'Flag validation failed'):
- mutate_flags_separately()
-
- # The flags have not changed outside the context of the exception.
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
- self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
-
- def test_save_flag_value(self):
+ self.assertEqual('outside flagsaver', FLAGS.flagsaver_test_flag1)
+
+ def test_as_parsed_decorator(self):
+ # flagsaver.as_parsed can also be used as a decorator.
+ @flagsaver.as_parsed((INT_FLAG, '123'))
+ def do_something_with_flags():
+ self.assertEqual(123, INT_FLAG.value)
+ self.assertTrue(INT_FLAG.present)
+
+ do_something_with_flags()
+ self.assertEqual(1, INT_FLAG.value)
+ self.assertFalse(INT_FLAG.present)
+
+ def test_flagsaver_flagsaver(self):
+ # If you don't want the flags to go through parsing, you can instead use
+ # flagsaver.flagsaver(). With this method, you provide the native python
+ # value you'd like the flags to take on. Otherwise it functions similar to
+ # flagsaver.as_parsed().
+ @flagsaver.flagsaver((INT_FLAG, 345))
+ def do_something_with_flags():
+ self.assertEqual(345, INT_FLAG.value)
+ # Note that because this flag was never parsed, it will not register as
+ # .present unless you manually set that attribute.
+ self.assertFalse(INT_FLAG.present)
+ # If you do chose to modify things about the flag (such as .present) those
+ # changes will still be cleaned up when flagsaver.flagsaver() exits.
+ INT_FLAG.present = True
+
+ self.assertEqual(1, INT_FLAG.value)
+ # flagsaver.flagsaver() restored INT_FLAG.present to the state it was in
+ # before entering the context.
+ self.assertFalse(INT_FLAG.present)
+
+
+class SaveFlagValuesTest(absltest.TestCase):
+ """Test flagsaver.save_flag_values() and flagsaver.restore_flag_values().
+
+ In this test, we insure that *all* properties of flags get restored. In other
+ tests we only try changing the flag value.
+ """
+
+ def test_assign_value(self):
# First save the flag values.
saved_flag_values = flagsaver.save_flag_values()
@@ -250,7 +143,7 @@ class FlagSaverTest(absltest.TestCase):
flagsaver.restore_flag_values(saved_flag_values)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- def test_save_flag_default(self):
+ def test_set_default(self):
# First save the flag.
saved_flag_values = flagsaver.save_flag_values()
@@ -262,7 +155,7 @@ class FlagSaverTest(absltest.TestCase):
flagsaver.restore_flag_values(saved_flag_values)
self.assertEqual('unchanged0', FLAGS['flagsaver_test_flag0'].default)
- def test_restore_after_parse(self):
+ def test_parse(self):
# First save the flag.
saved_flag_values = flagsaver.save_flag_values()
@@ -278,9 +171,72 @@ class FlagSaverTest(absltest.TestCase):
self.assertEqual('unchanged0', FLAGS['flagsaver_test_flag0'].value)
self.assertEqual(0, FLAGS['flagsaver_test_flag0'].present)
- def test_decorator_with_exception(self):
+ def test_assign_validators(self):
+ # First save the flag.
+ saved_flag_values = flagsaver.save_flag_values()
+
+ # Sanity check that a validator already exists.
+ self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 1)
+ original_validators = list(FLAGS['flagsaver_test_flag0'].validators)
+
+ def no_space(value):
+ return ' ' not in value
+
+ # Add a new validator.
+ flags.register_validator('flagsaver_test_flag0', no_space)
+ self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 2)
+
+ # Now restore the flag to its original value.
+ flagsaver.restore_flag_values(saved_flag_values)
+ self.assertEqual(
+ original_validators, FLAGS['flagsaver_test_flag0'].validators
+ )
+
+
+@parameterized.named_parameters(
+ dict(
+ testcase_name='flagsaver.flagsaver',
+ flagsaver_method=flagsaver.flagsaver,
+ ),
+ dict(
+ testcase_name='flagsaver.as_parsed',
+ flagsaver_method=flagsaver.as_parsed,
+ ),
+)
+class NoOverridesTest(parameterized.TestCase):
+ """Test flagsaver.flagsaver and flagsaver.as_parsed without overrides."""
+
+ def test_context_manager_with_call(self, flagsaver_method):
+ with flagsaver_method():
+ FLAGS.flagsaver_test_flag0 = 'new value'
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+
+ def test_context_manager_with_exception(self, flagsaver_method):
+ with self.assertRaises(_TestError):
+ with flagsaver_method():
+ FLAGS.flagsaver_test_flag0 = 'new value'
+ # Simulate a failed test.
+ raise _TestError('something happened')
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+
+ def test_decorator_without_call(self, flagsaver_method):
+ @flagsaver_method
+ def mutate_flags():
+ FLAGS.flagsaver_test_flag0 = 'new value'
+
+ mutate_flags()
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+
+ def test_decorator_with_call(self, flagsaver_method):
+ @flagsaver_method()
+ def mutate_flags():
+ FLAGS.flagsaver_test_flag0 = 'new value'
- @flagsaver.flagsaver
+ mutate_flags()
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+
+ def test_decorator_with_exception(self, flagsaver_method):
+ @flagsaver_method()
def raise_exception():
FLAGS.flagsaver_test_flag0 = 'new value'
# Simulate a failed test.
@@ -290,62 +246,262 @@ class FlagSaverTest(absltest.TestCase):
self.assertRaises(_TestError, raise_exception)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- def test_validator_list_is_restored(self):
- self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 1)
- original_validators = list(FLAGS['flagsaver_test_flag0'].validators)
+@parameterized.named_parameters(
+ dict(
+ testcase_name='flagsaver.flagsaver',
+ flagsaver_method=flagsaver.flagsaver,
+ ),
+ dict(
+ testcase_name='flagsaver.as_parsed',
+ flagsaver_method=flagsaver.as_parsed,
+ ),
+)
+class TestStringFlagOverrides(parameterized.TestCase):
+ """Test flagsaver.flagsaver and flagsaver.as_parsed with string overrides.
+
+ Note that these tests can be parameterized because both .flagsaver and
+ .as_parsed expect a str input when overriding a string flag. For non-string
+ flags these two flagsaver methods have separate tests elsewhere in this file.
+
+ Each test is one class of overrides, executed twice. Once as a context
+ manager, and once as a decorator on a mutate_flags() method.
+ """
+
+ def test_keyword_overrides(self, flagsaver_method):
+ # Context manager:
+ with flagsaver_method(flagsaver_test_flag0='new value'):
+ self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- @flagsaver.flagsaver
- def modify_validators():
+ # Decorator:
+ @flagsaver_method(flagsaver_test_flag0='new value')
+ def mutate_flags():
+ self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
- def no_space(value):
- return ' ' not in value
+ mutate_flags()
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- flags.register_validator('flagsaver_test_flag0', no_space)
- self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 2)
+ def test_flagholder_overrides(self, flagsaver_method):
+ with flagsaver_method((STR_FLAG, 'new value')):
+ self.assertEqual('new value', STR_FLAG.value)
+ self.assertEqual('str default', STR_FLAG.value)
- modify_validators()
- self.assertEqual(original_validators,
- FLAGS['flagsaver_test_flag0'].validators)
+ @flagsaver_method((STR_FLAG, 'new value'))
+ def mutate_flags():
+ self.assertEqual('new value', STR_FLAG.value)
+ mutate_flags()
+ self.assertEqual('str default', STR_FLAG.value)
-class FlagSaverDecoratorUsageTest(absltest.TestCase):
+ def test_keyword_and_flagholder_overrides(self, flagsaver_method):
+ with flagsaver_method(
+ (STR_FLAG, 'another value'), flagsaver_test_flag0='new value'
+ ):
+ self.assertEqual('another value', STR_FLAG.value)
+ self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
+ self.assertEqual('str default', STR_FLAG.value)
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- @flagsaver.flagsaver
- def test_mutate1(self):
- # Even though other test cases change the flag, it should be
- # restored to 'unchanged0' if the flagsaver is working.
+ @flagsaver_method(
+ (STR_FLAG, 'another value'), flagsaver_test_flag0='new value'
+ )
+ def mutate_flags():
+ self.assertEqual('another value', STR_FLAG.value)
+ self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
+
+ mutate_flags()
+ self.assertEqual('str default', STR_FLAG.value)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- FLAGS.flagsaver_test_flag0 = 'changed0'
- @flagsaver.flagsaver
- def test_mutate2(self):
- # Even though other test cases change the flag, it should be
- # restored to 'unchanged0' if the flagsaver is working.
+ def test_cross_validated_overrides_set_together(self, flagsaver_method):
+ # When the flags are set in the same flagsaver call their validators will
+ # be triggered only once the setting is done.
+ with flagsaver_method(
+ flagsaver_test_validated_flag1='new_value',
+ flagsaver_test_validated_flag2='new_value',
+ ):
+ self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag1)
+ self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag2)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ @flagsaver_method(
+ flagsaver_test_validated_flag1='new_value',
+ flagsaver_test_validated_flag2='new_value',
+ )
+ def mutate_flags():
+ self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag1)
+ self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag2)
+
+ mutate_flags()
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ def test_cross_validated_overrides_set_badly(self, flagsaver_method):
+ # Different values should violate the validator.
+ with self.assertRaisesRegex(
+ flags.IllegalFlagValueError, 'Flag validation failed'
+ ):
+ with flagsaver_method(
+ flagsaver_test_validated_flag1='new_value',
+ flagsaver_test_validated_flag2='other_value',
+ ):
+ pass
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ @flagsaver_method(
+ flagsaver_test_validated_flag1='new_value',
+ flagsaver_test_validated_flag2='other_value',
+ )
+ def mutate_flags():
+ pass
+
+ self.assertRaisesRegex(
+ flags.IllegalFlagValueError, 'Flag validation failed', mutate_flags
+ )
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ def test_cross_validated_overrides_set_separately(self, flagsaver_method):
+ # Setting just one flag will trip the validator as well.
+ with self.assertRaisesRegex(
+ flags.IllegalFlagValueError, 'Flag validation failed'
+ ):
+ with flagsaver_method(flagsaver_test_validated_flag1='new_value'):
+ pass
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ @flagsaver_method(flagsaver_test_validated_flag1='new_value')
+ def mutate_flags():
+ pass
+
+ self.assertRaisesRegex(
+ flags.IllegalFlagValueError, 'Flag validation failed', mutate_flags
+ )
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ def test_validation_exception(self, flagsaver_method):
+ with self.assertRaises(flags.IllegalFlagValueError):
+ with flagsaver_method(
+ flagsaver_test_flag0='new value',
+ flagsaver_test_validated_flag='new value',
+ ):
+ pass
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- FLAGS.flagsaver_test_flag0 = 'changed0'
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag)
- @flagsaver.flagsaver
- def test_mutate3(self):
- # Even though other test cases change the flag, it should be
- # restored to 'unchanged0' if the flagsaver is working.
+ @flagsaver_method(
+ flagsaver_test_flag0='new value',
+ flagsaver_test_validated_flag='new value',
+ )
+ def mutate_flags():
+ pass
+
+ self.assertRaises(flags.IllegalFlagValueError, mutate_flags)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- FLAGS.flagsaver_test_flag0 = 'changed0'
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag)
- @flagsaver.flagsaver
- def test_mutate4(self):
- # Even though other test cases change the flag, it should be
- # restored to 'unchanged0' if the flagsaver is working.
+ def test_unknown_flag_raises_exception(self, flagsaver_method):
+ self.assertNotIn('this_flag_does_not_exist', FLAGS)
+
+ # Flagsaver raises an error when trying to override a non-existent flag.
+ with self.assertRaises(flags.UnrecognizedFlagError):
+ with flagsaver_method(
+ flagsaver_test_flag0='new value', this_flag_does_not_exist='new value'
+ ):
+ pass
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- FLAGS.flagsaver_test_flag0 = 'changed0'
+ @flagsaver_method(
+ flagsaver_test_flag0='new value', this_flag_does_not_exist='new value'
+ )
+ def mutate_flags():
+ pass
+
+ self.assertRaises(flags.UnrecognizedFlagError, mutate_flags)
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
-class FlagSaverSetUpTearDownUsageTest(absltest.TestCase):
+ # Make sure flagsaver didn't create the flag at any point.
+ self.assertNotIn('this_flag_does_not_exist', FLAGS)
+
+
+class AsParsedTest(absltest.TestCase):
+
+ def test_parse_context_manager_sets_present_and_using_default(self):
+ self.assertFalse(INT_FLAG.present)
+ self.assertFalse(STR_FLAG.present)
+ # Note that .using_default_value isn't available on the FlagHolder directly.
+ self.assertTrue(FLAGS[INT_FLAG.name].using_default_value)
+ self.assertTrue(FLAGS[STR_FLAG.name].using_default_value)
+
+ with flagsaver.as_parsed((INT_FLAG, '123'),
+ flagsaver_test_str_flag='new value'):
+ self.assertTrue(INT_FLAG.present)
+ self.assertTrue(STR_FLAG.present)
+ self.assertFalse(FLAGS[INT_FLAG.name].using_default_value)
+ self.assertFalse(FLAGS[STR_FLAG.name].using_default_value)
+
+ self.assertFalse(INT_FLAG.present)
+ self.assertFalse(STR_FLAG.present)
+ self.assertTrue(FLAGS[INT_FLAG.name].using_default_value)
+ self.assertTrue(FLAGS[STR_FLAG.name].using_default_value)
+
+ def test_parse_decorator_sets_present_and_using_default(self):
+ self.assertFalse(INT_FLAG.present)
+ self.assertFalse(STR_FLAG.present)
+ # Note that .using_default_value isn't available on the FlagHolder directly.
+ self.assertTrue(FLAGS[INT_FLAG.name].using_default_value)
+ self.assertTrue(FLAGS[STR_FLAG.name].using_default_value)
+
+ @flagsaver.as_parsed((INT_FLAG, '123'), flagsaver_test_str_flag='new value')
+ def some_func():
+ self.assertTrue(INT_FLAG.present)
+ self.assertTrue(STR_FLAG.present)
+ self.assertFalse(FLAGS[INT_FLAG.name].using_default_value)
+ self.assertFalse(FLAGS[STR_FLAG.name].using_default_value)
+
+ some_func()
+ self.assertFalse(INT_FLAG.present)
+ self.assertFalse(STR_FLAG.present)
+ self.assertTrue(FLAGS[INT_FLAG.name].using_default_value)
+ self.assertTrue(FLAGS[STR_FLAG.name].using_default_value)
+
+ def test_parse_decorator_with_multi_int_flag(self):
+ self.assertFalse(MULTI_INT_FLAG.present)
+ self.assertIsNone(MULTI_INT_FLAG.value)
+
+ @flagsaver.as_parsed((MULTI_INT_FLAG, ['123', '456']))
+ def assert_flags_updated():
+ self.assertTrue(MULTI_INT_FLAG.present)
+ self.assertCountEqual([123, 456], MULTI_INT_FLAG.value)
+
+ assert_flags_updated()
+ self.assertFalse(MULTI_INT_FLAG.present)
+ self.assertIsNone(MULTI_INT_FLAG.value)
+
+ def test_parse_raises_type_error(self):
+ with self.assertRaisesRegex(
+ TypeError,
+ r'flagsaver\.as_parsed\(\) cannot parse flagsaver_test_int_flag\. '
+ r'Expected a single string or sequence of strings but .*int.* was '
+ r'provided\.'):
+ manager = flagsaver.as_parsed(flagsaver_test_int_flag=123)
+ del manager
+
+
+class SetUpTearDownTest(absltest.TestCase):
+ """Example using a single flagsaver in setUp."""
def setUp(self):
+ super().setUp()
self.saved_flag_values = flagsaver.save_flag_values()
def tearDown(self):
+ super().tearDown()
flagsaver.restore_flag_values(self.saved_flag_values)
def test_mutate1(self):
@@ -360,28 +516,26 @@ class FlagSaverSetUpTearDownUsageTest(absltest.TestCase):
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
FLAGS.flagsaver_test_flag0 = 'changed0'
- def test_mutate3(self):
- # Even though other test cases change the flag, it should be
- # restored to 'unchanged0' if the flagsaver is working.
- self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- FLAGS.flagsaver_test_flag0 = 'changed0'
-
- def test_mutate4(self):
- # Even though other test cases change the flag, it should be
- # restored to 'unchanged0' if the flagsaver is working.
- self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
- FLAGS.flagsaver_test_flag0 = 'changed0'
-
-class FlagSaverBadUsageTest(absltest.TestCase):
- """Tests that certain kinds of improper usages raise errors."""
-
- def test_flag_saver_on_class(self):
+@parameterized.named_parameters(
+ dict(
+ testcase_name='flagsaver.flagsaver',
+ flagsaver_method=flagsaver.flagsaver,
+ ),
+ dict(
+ testcase_name='flagsaver.as_parsed',
+ flagsaver_method=flagsaver.as_parsed,
+ ),
+)
+class BadUsageTest(parameterized.TestCase):
+ """Tests that improper usage (such as decorating a class) raise errors."""
+
+ def test_flag_saver_on_class(self, flagsaver_method):
with self.assertRaises(TypeError):
# WRONG. Don't do this.
# Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
- @flagsaver.flagsaver
+ @flagsaver_method
class FooTest(absltest.TestCase):
def test_tautology(self):
@@ -389,12 +543,12 @@ class FlagSaverBadUsageTest(absltest.TestCase):
del FooTest
- def test_flag_saver_call_on_class(self):
+ def test_flag_saver_call_on_class(self, flagsaver_method):
with self.assertRaises(TypeError):
# WRONG. Don't do this.
# Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
- @flagsaver.flagsaver()
+ @flagsaver_method()
class FooTest(absltest.TestCase):
def test_tautology(self):
@@ -402,12 +556,12 @@ class FlagSaverBadUsageTest(absltest.TestCase):
del FooTest
- def test_flag_saver_with_overrides_on_class(self):
+ def test_flag_saver_with_overrides_on_class(self, flagsaver_method):
with self.assertRaises(TypeError):
# WRONG. Don't do this.
# Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
- @flagsaver.flagsaver(foo='bar')
+ @flagsaver_method(foo='bar')
class FooTest(absltest.TestCase):
def test_tautology(self):
@@ -415,48 +569,57 @@ class FlagSaverBadUsageTest(absltest.TestCase):
del FooTest
- def test_multiple_positional_parameters(self):
+ def test_multiple_positional_parameters(self, flagsaver_method):
with self.assertRaises(ValueError):
func_a = lambda: None
func_b = lambda: None
- flagsaver.flagsaver(func_a, func_b)
+ flagsaver_method(func_a, func_b)
- def test_both_positional_and_keyword_parameters(self):
+ def test_both_positional_and_keyword_parameters(self, flagsaver_method):
with self.assertRaises(ValueError):
func_a = lambda: None
- flagsaver.flagsaver(func_a, flagsaver_test_flag0='new value')
+ flagsaver_method(func_a, flagsaver_test_flag0='new value')
- def test_duplicate_holder_parameters(self):
+ def test_duplicate_holder_parameters(self, flagsaver_method):
with self.assertRaises(ValueError):
- flagsaver.flagsaver((INT_FLAG, 45), (INT_FLAG, 45))
+ flagsaver_method((INT_FLAG, 45), (INT_FLAG, 45))
- def test_duplicate_holder_and_kw_parameter(self):
+ def test_duplicate_holder_and_kw_parameter(self, flagsaver_method):
with self.assertRaises(ValueError):
- flagsaver.flagsaver((INT_FLAG, 45), **{INT_FLAG.name: 45})
+ flagsaver_method((INT_FLAG, 45), **{INT_FLAG.name: 45})
- def test_both_positional_and_holder_parameters(self):
+ def test_both_positional_and_holder_parameters(self, flagsaver_method):
with self.assertRaises(ValueError):
func_a = lambda: None
- flagsaver.flagsaver(func_a, (INT_FLAG, 45))
+ flagsaver_method(func_a, (INT_FLAG, 45))
- def test_holder_parameters_wrong_shape(self):
+ def test_holder_parameters_wrong_shape(self, flagsaver_method):
with self.assertRaises(ValueError):
- flagsaver.flagsaver(INT_FLAG)
+ flagsaver_method(INT_FLAG)
- def test_holder_parameters_tuple_too_long(self):
+ def test_holder_parameters_tuple_too_long(self, flagsaver_method):
with self.assertRaises(ValueError):
# Even if it is a bool flag, it should be a tuple
- flagsaver.flagsaver((INT_FLAG, 4, 5))
+ flagsaver_method((INT_FLAG, 4, 5))
- def test_holder_parameters_tuple_wrong_type(self):
+ def test_holder_parameters_tuple_wrong_type(self, flagsaver_method):
with self.assertRaises(ValueError):
# Even if it is a bool flag, it should be a tuple
- flagsaver.flagsaver((4, INT_FLAG))
+ flagsaver_method((4, INT_FLAG))
- def test_both_wrong_positional_parameters(self):
+ def test_both_wrong_positional_parameters(self, flagsaver_method):
with self.assertRaises(ValueError):
func_a = lambda: None
- flagsaver.flagsaver(func_a, STR_FLAG, '45')
+ flagsaver_method(func_a, STR_FLAG, '45')
+
+ def test_context_manager_no_call(self, flagsaver_method):
+ # The exact exception that's raised appears to be system specific.
+ with self.assertRaises((AttributeError, TypeError)):
+ # Wrong. You must call the flagsaver method before using it as a CM.
+ with flagsaver_method:
+ # We don't expect to get here. A type error should happen when
+ # attempting to enter the context manager.
+ pass
if __name__ == '__main__':
diff --git a/setup.py b/setup.py
index f947fd7..1a119f5 100644
--- a/setup.py
+++ b/setup.py
@@ -43,7 +43,7 @@ with open(_README_PATH, 'rb') as fp:
setuptools.setup(
name='absl-py',
- version='1.3.0',
+ version='1.4.0',
description=(
'Abseil Python Common Libraries, '
'see https://github.com/abseil/abseil-py.'),