aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAbseil Team <absl-team@google.com>2022-09-13 21:15:56 -0700
committerCopybara-Service <copybara-worker@google.com>2022-09-13 21:16:26 -0700
commit0dd0210ed8919bca3338aebd1940824ada1a8e47 (patch)
tree3eac88ac81320eff80866cfd1f7df5a7ea85d581
parentb825efe5474a3c493b919752ecf630ce70b31b69 (diff)
downloadabsl-py-0dd0210ed8919bca3338aebd1940824ada1a8e47.tar.gz
Support FlagHolders in flag module-level functions.
PiperOrigin-RevId: 474195462 Change-Id: I2b57d8ea6b8eb66d1f777f026fb19e5401983789
-rw-r--r--CHANGELOG.md25
-rw-r--r--absl/flags/_defines.py14
-rw-r--r--absl/flags/_defines.pyi2
-rw-r--r--absl/flags/_flagvalues.py32
-rw-r--r--absl/flags/_validators.py51
-rw-r--r--absl/flags/tests/_validators_test.py245
-rw-r--r--absl/flags/tests/flags_test.py34
7 files changed, 379 insertions, 24 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 06a8ee0..025f8ed 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -6,6 +6,20 @@ The format is based on [Keep a Changelog](https://keepachangelog.com).
## Unreleased
+### Added
+
+* (flags) The following functions now also accept `FlagHolder` instance(s)
+ in addition to flag name(s) as their first positional argument:
+ - `flags.register_validator`
+ - `flags.validator`
+ - `flags.register_multi_flags_validator`
+ - `flags.multi_flags_validator`
+ - `flags.mark_flag_as_required`
+ - `flags.mark_flags_as_required`
+ - `flags.mark_flags_as_mutual_exclusive`
+ - `flags.mark_bool_flags_as_mutual_exclusive`
+ - `flags.declare_key_flag`
+
### Changed
* (testing) Assertions `assertRaisesWithPredicateMatch` and
@@ -13,6 +27,17 @@ The format is based on [Keep a Changelog](https://keepachangelog.com).
further analysis when used as a context manager.
* (testing) TextAndXMLTestRunner now produces time duration values with
millisecond precision in XML test result output.
+* (flags) Keyword access to `flag_name` arguments in the following functions
+ is deprecated. This parameter will be renamed in a future 2.0.0 release.
+ - `flags.register_validator`
+ - `flags.validator`
+ - `flags.register_multi_flags_validator`
+ - `flags.multi_flags_validator`
+ - `flags.mark_flag_as_required`
+ - `flags.mark_flags_as_required`
+ - `flags.mark_flags_as_mutual_exclusive`
+ - `flags.mark_bool_flags_as_mutual_exclusive`
+ - `flags.declare_key_flag`
## 1.2.0 (2022-07-18)
diff --git a/absl/flags/_defines.py b/absl/flags/_defines.py
index 12335e5..a4197c7 100644
--- a/absl/flags/_defines.py
+++ b/absl/flags/_defines.py
@@ -157,8 +157,7 @@ def _internal_declare_key_flags(flag_names,
adopt_module_key_flags instead.
Args:
- flag_names: [str], a list of strings that are names of already-registered
- Flag objects.
+ flag_names: [str], a list of names of already-registered Flag objects.
flag_values: :class:`FlagValues`, the FlagValues instance with which the
flags listed in flag_names have registered (the value of the flag_values
argument from the ``DEFINE_*`` calls that defined those flags). This
@@ -176,8 +175,7 @@ def _internal_declare_key_flags(flag_names,
module = _helpers.get_calling_module()
for flag_name in flag_names:
- flag = flag_values[flag_name]
- key_flag_values.register_key_flag_for_module(module, flag)
+ key_flag_values.register_key_flag_for_module(module, flag_values[flag_name])
def declare_key_flag(flag_name, flag_values=_flagvalues.FLAGS):
@@ -194,9 +192,10 @@ def declare_key_flag(flag_name, flag_values=_flagvalues.FLAGS):
flags.declare_key_flag('flag_1')
Args:
- flag_name: str, the name of an already declared flag. (Redeclaring flags as
- key, including flags implicitly key because they were declared in this
- module, is a no-op.)
+ flag_name: str | :class:`FlagHolder`, the name or holder of an already
+ declared flag. (Redeclaring flags as key, including flags implicitly key
+ because they were declared in this module, is a no-op.)
+ Positional-only parameter.
flag_values: :class:`FlagValues`, the FlagValues instance in which the
flag will be declared as a key flag. This should almost never need to be
overridden.
@@ -204,6 +203,7 @@ def declare_key_flag(flag_name, flag_values=_flagvalues.FLAGS):
Raises:
ValueError: Raised if flag_name not defined as a Python flag.
"""
+ flag_name, flag_values = _flagvalues.resolve_flag_ref(flag_name, flag_values)
if flag_name in _helpers.SPECIAL_FLAGS:
# Take care of the special flags, e.g., --flagfile, --undefok.
# These flags are defined in SPECIAL_FLAGS, and are treated
diff --git a/absl/flags/_defines.pyi b/absl/flags/_defines.pyi
index 0fbe921..c75a4f5 100644
--- a/absl/flags/_defines.pyi
+++ b/absl/flags/_defines.pyi
@@ -651,7 +651,7 @@ def DEFINE_alias(
-def declare_key_flag(flag_name: Text,
+def declare_key_flag(flag_name: Union[Text, _flagvalues.FlagHolder],
flag_values: _flagvalues.FlagValues = ...) -> None:
...
diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py
index 48df1a0..937dc6c 100644
--- a/absl/flags/_flagvalues.py
+++ b/absl/flags/_flagvalues.py
@@ -1384,3 +1384,35 @@ class FlagHolder(Generic[_T]):
def present(self):
"""Returns True if the flag was parsed from command-line flags."""
return bool(self._flagvalues[self._name].present)
+
+
+def resolve_flag_ref(flag_ref, flag_values):
+ """Helper to validate and resolve a flag reference argument."""
+ if isinstance(flag_ref, FlagHolder):
+ new_flag_values = flag_ref._flagvalues # pylint: disable=protected-access
+ if flag_values != FLAGS and flag_values != new_flag_values:
+ raise ValueError(
+ 'flag_values must not be customized when operating on a FlagHolder')
+ return flag_ref.name, new_flag_values
+ return flag_ref, flag_values
+
+
+def resolve_flag_refs(flag_refs, flag_values):
+ """Helper to validate and resolve flag reference list arguments."""
+ fv = None
+ names = []
+ for ref in flag_refs:
+ if isinstance(ref, FlagHolder):
+ newfv = ref._flagvalues # pylint: disable=protected-access
+ name = ref.name
+ else:
+ newfv = flag_values
+ name = ref
+ if fv and fv != newfv:
+ raise ValueError(
+ 'multiple FlagValues instances used in invocation. '
+ 'FlagHolders must be registered to the same FlagValues instance as '
+ 'do flag names, if provided.')
+ fv = newfv
+ names.append(name)
+ return names, fv
diff --git a/absl/flags/_validators.py b/absl/flags/_validators.py
index c4e1139..2161284 100644
--- a/absl/flags/_validators.py
+++ b/absl/flags/_validators.py
@@ -51,7 +51,8 @@ def register_validator(flag_name,
change of the corresponding flag's value.
Args:
- flag_name: str, name of the flag to be checked.
+ flag_name: str | FlagHolder, name or holder of the flag to be checked.
+ Positional-only parameter.
checker: callable, a function to validate the flag.
* input - A single positional argument: The value of the corresponding
@@ -70,7 +71,10 @@ def register_validator(flag_name,
Raises:
AttributeError: Raised when flag_name is not registered as a valid flag
name.
+ ValueError: Raised when flag_values is non-default and does not match the
+ FlagValues of the provided FlagHolder instance.
"""
+ flag_name, flag_values = _flagvalues.resolve_flag_ref(flag_name, flag_values)
v = _validators_classes.SingleFlagValidator(flag_name, checker, message)
_add_validator(flag_values, v)
@@ -88,7 +92,8 @@ def validator(flag_name, message='Flag validation failed',
See :func:`register_validator` for the specification of checker function.
Args:
- flag_name: str, name of the flag to be checked.
+ flag_name: str | FlagHolder, name or holder of the flag to be checked.
+ Positional-only parameter.
message: str, error text to be shown to the user if checker returns False.
If checker raises flags.ValidationError, message from the raised
error will be shown.
@@ -119,7 +124,8 @@ def register_multi_flags_validator(flag_names,
change of the corresponding flag's value.
Args:
- flag_names: [str], a list of the flag names to be checked.
+ flag_names: [str | FlagHolder], a list of the flag names or holders to be
+ checked. Positional-only parameter.
multi_flags_checker: callable, a function to validate the flag.
* input - dict, with keys() being flag_names, and value for each key
@@ -136,7 +142,13 @@ def register_multi_flags_validator(flag_names,
Raises:
AttributeError: Raised when a flag is not registered as a valid flag name.
+ ValueError: Raised when multiple FlagValues are used in the same
+ invocation. This can occur when FlagHolders have different `_flagvalues`
+ or when str-type flag_names entries are present and the `flag_values`
+ argument does not match that of provided FlagHolder(s).
"""
+ flag_names, flag_values = _flagvalues.resolve_flag_refs(
+ flag_names, flag_values)
v = _validators_classes.MultiFlagsValidator(
flag_names, multi_flags_checker, message)
_add_validator(flag_values, v)
@@ -157,7 +169,8 @@ def multi_flags_validator(flag_names,
function.
Args:
- flag_names: [str], a list of the flag names to be checked.
+ flag_names: [str | FlagHolder], a list of the flag names or holders to be
+ checked. Positional-only parameter.
message: str, error text to be shown to the user if checker returns False.
If checker raises flags.ValidationError, message from the raised
error will be shown.
@@ -196,13 +209,17 @@ def mark_flag_as_required(flag_name, flag_values=_flagvalues.FLAGS):
app.run()
Args:
- flag_name: str, name of the flag
+ flag_name: str | FlagHolder, name or holder of the flag.
+ Positional-only parameter.
flag_values: flags.FlagValues, optional :class:`~absl.flags.FlagValues`
instance where the flag is defined.
Raises:
AttributeError: Raised when flag_name is not registered as a valid flag
name.
+ ValueError: Raised when flag_values is non-default and does not match the
+ FlagValues of the provided FlagHolder instance.
"""
+ flag_name, flag_values = _flagvalues.resolve_flag_ref(flag_name, flag_values)
if flag_values[flag_name].default is not None:
warnings.warn(
'Flag --%s has a non-None default value; therefore, '
@@ -227,7 +244,7 @@ def mark_flags_as_required(flag_names, flag_values=_flagvalues.FLAGS):
app.run()
Args:
- flag_names: Sequence[str], names of the flags.
+ flag_names: Sequence[str | FlagHolder], names or holders of the flags.
flag_values: flags.FlagValues, optional FlagValues instance where the flags
are defined.
Raises:
@@ -248,13 +265,22 @@ def mark_flags_as_mutual_exclusive(flag_names, required=False,
includes multi flags with a default value of ``[]`` instead of None.
Args:
- flag_names: [str], names of the flags.
+ flag_names: [str | FlagHolder], names or holders of flags.
+ Positional-only parameter.
required: bool. If true, exactly one of the flags must have a value other
than None. Otherwise, at most one of the flags can have a value other
than None, and it is valid for all of the flags to be None.
flag_values: flags.FlagValues, optional FlagValues instance where the flags
are defined.
+
+ Raises:
+ ValueError: Raised when multiple FlagValues are used in the same
+ invocation. This can occur when FlagHolders have different `_flagvalues`
+ or when str-type flag_names entries are present and the `flag_values`
+ argument does not match that of provided FlagHolder(s).
"""
+ flag_names, flag_values = _flagvalues.resolve_flag_refs(
+ flag_names, flag_values)
for flag_name in flag_names:
if flag_values[flag_name].default is not None:
warnings.warn(
@@ -280,12 +306,21 @@ def mark_bool_flags_as_mutual_exclusive(flag_names, required=False,
"""Ensures that only one flag among flag_names is True.
Args:
- flag_names: [str], names of the flags.
+ flag_names: [str | FlagHolder], names or holders of flags.
+ Positional-only parameter.
required: bool. If true, exactly one flag must be True. Otherwise, at most
one flag can be True, and it is valid for all flags to be False.
flag_values: flags.FlagValues, optional FlagValues instance where the flags
are defined.
+
+ Raises:
+ ValueError: Raised when multiple FlagValues are used in the same
+ invocation. This can occur when FlagHolders have different `_flagvalues`
+ or when str-type flag_names entries are present and the `flag_values`
+ argument does not match that of provided FlagHolder(s).
"""
+ flag_names, flag_values = _flagvalues.resolve_flag_refs(
+ flag_names, flag_values)
for flag_name in flag_names:
if not flag_values[flag_name].boolean:
raise _exceptions.ValidationError(
diff --git a/absl/flags/tests/_validators_test.py b/absl/flags/tests/_validators_test.py
index 1cccf53..9aa328e 100644
--- a/absl/flags/tests/_validators_test.py
+++ b/absl/flags/tests/_validators_test.py
@@ -55,6 +55,45 @@ class SingleFlagValidatorTest(absltest.TestCase):
self.assertEqual(2, self.flag_values.test_flag)
self.assertEqual([None, 2], self.call_args)
+ def test_success_holder(self):
+ def checker(x):
+ self.call_args.append(x)
+ return True
+
+ flag_holder = _defines.DEFINE_integer(
+ 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
+ _validators.register_validator(
+ flag_holder,
+ checker,
+ message='Errors happen',
+ flag_values=self.flag_values)
+
+ argv = ('./program',)
+ self.flag_values(argv)
+ self.assertIsNone(self.flag_values.test_flag)
+ self.flag_values.test_flag = 2
+ self.assertEqual(2, self.flag_values.test_flag)
+ self.assertEqual([None, 2], self.call_args)
+
+ def test_success_holder_infer_flagvalues(self):
+ def checker(x):
+ self.call_args.append(x)
+ return True
+
+ flag_holder = _defines.DEFINE_integer(
+ 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
+ _validators.register_validator(
+ flag_holder,
+ checker,
+ message='Errors happen')
+
+ argv = ('./program',)
+ self.flag_values(argv)
+ self.assertIsNone(self.flag_values.test_flag)
+ self.flag_values.test_flag = 2
+ self.assertEqual(2, self.flag_values.test_flag)
+ self.assertEqual([None, 2], self.call_args)
+
def test_default_value_not_used_success(self):
def checker(x):
self.call_args.append(x)
@@ -218,6 +257,26 @@ class SingleFlagValidatorTest(absltest.TestCase):
self.assertTrue(checker(3))
self.assertEqual([None, 2, 3], self.call_args)
+ def test_mismatching_flagvalues(self):
+
+ def checker(x):
+ self.call_args.append(x)
+ return True
+
+ flag_holder = _defines.DEFINE_integer(
+ 'test_flag',
+ None,
+ 'Usual integer flag',
+ flag_values=_flagvalues.FlagValues())
+ expected = (
+ 'flag_values must not be customized when operating on a FlagHolder')
+ with self.assertRaisesWithLiteralMatch(ValueError, expected):
+ _validators.register_validator(
+ flag_holder,
+ checker,
+ message='Errors happen',
+ flag_values=self.flag_values)
+
class MultiFlagsValidatorTest(absltest.TestCase):
"""Test flags multi-flag validators."""
@@ -226,9 +285,9 @@ class MultiFlagsValidatorTest(absltest.TestCase):
super(MultiFlagsValidatorTest, self).setUp()
self.flag_values = _flagvalues.FlagValues()
self.call_args = []
- _defines.DEFINE_integer(
+ self.foo_holder = _defines.DEFINE_integer(
'foo', 1, 'Usual integer flag', flag_values=self.flag_values)
- _defines.DEFINE_integer(
+ self.bar_holder = _defines.DEFINE_integer(
'bar', 2, 'Usual integer flag', flag_values=self.flag_values)
def test_success(self):
@@ -248,6 +307,55 @@ class MultiFlagsValidatorTest(absltest.TestCase):
self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 2}],
self.call_args)
+ def test_success_holder(self):
+
+ def checker(flags_dict):
+ self.call_args.append(flags_dict)
+ return True
+
+ _validators.register_multi_flags_validator(
+ [self.foo_holder, self.bar_holder],
+ checker,
+ flag_values=self.flag_values)
+
+ argv = ('./program', '--bar=2')
+ self.flag_values(argv)
+ self.assertEqual(1, self.flag_values.foo)
+ self.assertEqual(2, self.flag_values.bar)
+ self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
+ self.flag_values.foo = 3
+ self.assertEqual(3, self.flag_values.foo)
+ self.assertEqual([{
+ 'foo': 1,
+ 'bar': 2
+ }, {
+ 'foo': 3,
+ 'bar': 2
+ }], self.call_args)
+
+ def test_success_holder_infer_flagvalues(self):
+ def checker(flags_dict):
+ self.call_args.append(flags_dict)
+ return True
+
+ _validators.register_multi_flags_validator(
+ [self.foo_holder, self.bar_holder], checker)
+
+ argv = ('./program', '--bar=2')
+ self.flag_values(argv)
+ self.assertEqual(1, self.flag_values.foo)
+ self.assertEqual(2, self.flag_values.bar)
+ self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
+ self.flag_values.foo = 3
+ self.assertEqual(3, self.flag_values.foo)
+ self.assertEqual([{
+ 'foo': 1,
+ 'bar': 2
+ }, {
+ 'foo': 3,
+ 'bar': 2
+ }], self.call_args)
+
def test_validator_not_called_when_other_flag_is_changed(self):
def checker(flags_dict):
self.call_args.append(flags_dict)
@@ -322,6 +430,30 @@ class MultiFlagsValidatorTest(absltest.TestCase):
self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
self.call_args)
+ def test_mismatching_flagvalues(self):
+
+ def checker(flags_dict):
+ self.call_args.append(flags_dict)
+ values = flags_dict.values()
+ # Make sure all the flags have different values.
+ return len(set(values)) == len(values)
+
+ other_holder = _defines.DEFINE_integer(
+ 'other_flag',
+ 3,
+ 'Other integer flag',
+ flag_values=_flagvalues.FlagValues())
+ expected = (
+ 'multiple FlagValues instances used in invocation. '
+ 'FlagHolders must be registered to the same FlagValues instance as '
+ 'do flag names, if provided.')
+ with self.assertRaisesWithLiteralMatch(ValueError, expected):
+ _validators.register_multi_flags_validator(
+ [self.foo_holder, self.bar_holder, other_holder],
+ checker,
+ message='Errors happen',
+ flag_values=self.flag_values)
+
class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
@@ -329,9 +461,9 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
super(MarkFlagsAsMutualExclusiveTest, self).setUp()
self.flag_values = _flagvalues.FlagValues()
- _defines.DEFINE_string(
+ self.flag_one_holder = _defines.DEFINE_string(
'flag_one', None, 'flag one', flag_values=self.flag_values)
- _defines.DEFINE_string(
+ self.flag_two_holder = _defines.DEFINE_string(
'flag_two', None, 'flag two', flag_values=self.flag_values)
_defines.DEFINE_string(
'flag_three', None, 'flag three', flag_values=self.flag_values)
@@ -358,6 +490,24 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
self.assertIsNone(self.flag_values.flag_one)
self.assertIsNone(self.flag_values.flag_two)
+ def test_no_flags_present_holder(self):
+ self._mark_flags_as_mutually_exclusive(
+ [self.flag_one_holder, self.flag_two_holder], False)
+ argv = ('./program',)
+
+ self.flag_values(argv)
+ self.assertIsNone(self.flag_values.flag_one)
+ self.assertIsNone(self.flag_values.flag_two)
+
+ def test_no_flags_present_mixed(self):
+ self._mark_flags_as_mutually_exclusive([self.flag_one_holder, 'flag_two'],
+ False)
+ argv = ('./program',)
+
+ self.flag_values(argv)
+ self.assertIsNone(self.flag_values.flag_one)
+ self.assertIsNone(self.flag_values.flag_two)
+
def test_no_flags_present_required(self):
self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
argv = ('./program',)
@@ -494,6 +644,20 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
self.assertIn('--flag_not_none has a non-None default value',
str(caught_warnings[0].message))
+ def test_multiple_flagvalues(self):
+ other_holder = _defines.DEFINE_boolean(
+ 'other_flagvalues',
+ False,
+ 'other ',
+ flag_values=_flagvalues.FlagValues())
+ expected = (
+ 'multiple FlagValues instances used in invocation. '
+ 'FlagHolders must be registered to the same FlagValues instance as '
+ 'do flag names, if provided.')
+ with self.assertRaisesWithLiteralMatch(ValueError, expected):
+ self._mark_flags_as_mutually_exclusive(
+ [self.flag_one_holder, other_holder], False)
+
class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase):
@@ -501,13 +665,13 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase):
super(MarkBoolFlagsAsMutualExclusiveTest, self).setUp()
self.flag_values = _flagvalues.FlagValues()
- _defines.DEFINE_boolean(
+ self.false_1_holder = _defines.DEFINE_boolean(
'false_1', False, 'default false 1', flag_values=self.flag_values)
- _defines.DEFINE_boolean(
+ self.false_2_holder = _defines.DEFINE_boolean(
'false_2', False, 'default false 2', flag_values=self.flag_values)
- _defines.DEFINE_boolean(
+ self.true_1_holder = _defines.DEFINE_boolean(
'true_1', True, 'default true 1', flag_values=self.flag_values)
- _defines.DEFINE_integer(
+ self.non_bool_holder = _defines.DEFINE_integer(
'non_bool', None, 'non bool', flag_values=self.flag_values)
def _mark_bool_flags_as_mutually_exclusive(self, flag_names, required):
@@ -520,6 +684,20 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase):
self.assertEqual(False, self.flag_values.false_1)
self.assertEqual(False, self.flag_values.false_2)
+ def test_no_flags_present_holder(self):
+ self._mark_bool_flags_as_mutually_exclusive(
+ [self.false_1_holder, self.false_2_holder], False)
+ self.flag_values(('./program',))
+ self.assertEqual(False, self.flag_values.false_1)
+ self.assertEqual(False, self.flag_values.false_2)
+
+ def test_no_flags_present_mixed(self):
+ self._mark_bool_flags_as_mutually_exclusive(
+ [self.false_1_holder, 'false_2'], False)
+ self.flag_values(('./program',))
+ self.assertEqual(False, self.flag_values.false_1)
+ self.assertEqual(False, self.flag_values.false_2)
+
def test_no_flags_present_required(self):
self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], True)
argv = ('./program',)
@@ -554,6 +732,17 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase):
self._mark_bool_flags_as_mutually_exclusive(['false_1', 'non_bool'],
False)
+ def test_multiple_flagvalues(self):
+ other_bool_holder = _defines.DEFINE_boolean(
+ 'other_bool', False, 'other bool', flag_values=_flagvalues.FlagValues())
+ expected = (
+ 'multiple FlagValues instances used in invocation. '
+ 'FlagHolders must be registered to the same FlagValues instance as '
+ 'do flag names, if provided.')
+ with self.assertRaisesWithLiteralMatch(ValueError, expected):
+ self._mark_bool_flags_as_mutually_exclusive(
+ [self.false_1_holder, other_bool_holder], False)
+
class MarkFlagAsRequiredTest(absltest.TestCase):
@@ -570,6 +759,22 @@ class MarkFlagAsRequiredTest(absltest.TestCase):
self.flag_values(argv)
self.assertEqual('value', self.flag_values.string_flag)
+ def test_success_holder(self):
+ holder = _defines.DEFINE_string(
+ 'string_flag', None, 'string flag', flag_values=self.flag_values)
+ _validators.mark_flag_as_required(holder, flag_values=self.flag_values)
+ argv = ('./program', '--string_flag=value')
+ self.flag_values(argv)
+ self.assertEqual('value', self.flag_values.string_flag)
+
+ def test_success_holder_infer_flagvalues(self):
+ holder = _defines.DEFINE_string(
+ 'string_flag', None, 'string flag', flag_values=self.flag_values)
+ _validators.mark_flag_as_required(holder)
+ argv = ('./program', '--string_flag=value')
+ self.flag_values(argv)
+ self.assertEqual('value', self.flag_values.string_flag)
+
def test_catch_none_as_default(self):
_defines.DEFINE_string(
'string_flag', None, 'string flag', flag_values=self.flag_values)
@@ -608,6 +813,18 @@ class MarkFlagAsRequiredTest(absltest.TestCase):
self.assertIn('--flag_not_none has a non-None default value',
str(caught_warnings[0].message))
+ def test_mismatching_flagvalues(self):
+ flag_holder = _defines.DEFINE_string(
+ 'string_flag',
+ 'value',
+ 'string flag',
+ flag_values=_flagvalues.FlagValues())
+ expected = (
+ 'flag_values must not be customized when operating on a FlagHolder')
+ with self.assertRaisesWithLiteralMatch(ValueError, expected):
+ _validators.mark_flag_as_required(
+ flag_holder, flag_values=self.flag_values)
+
class MarkFlagsAsRequiredTest(absltest.TestCase):
@@ -627,6 +844,18 @@ class MarkFlagsAsRequiredTest(absltest.TestCase):
self.assertEqual('value_1', self.flag_values.string_flag_1)
self.assertEqual('value_2', self.flag_values.string_flag_2)
+ def test_success_holders(self):
+ flag_1_holder = _defines.DEFINE_string(
+ 'string_flag_1', None, 'string flag 1', flag_values=self.flag_values)
+ flag_2_holder = _defines.DEFINE_string(
+ 'string_flag_2', None, 'string flag 2', flag_values=self.flag_values)
+ _validators.mark_flags_as_required([flag_1_holder, flag_2_holder],
+ flag_values=self.flag_values)
+ argv = ('./program', '--string_flag_1=value_1', '--string_flag_2=value_2')
+ self.flag_values(argv)
+ self.assertEqual('value_1', self.flag_values.string_flag_1)
+ self.assertEqual('value_2', self.flag_values.string_flag_2)
+
def test_catch_none_as_default(self):
_defines.DEFINE_string(
'string_flag_1', None, 'string flag 1', flag_values=self.flag_values)
diff --git a/absl/flags/tests/flags_test.py b/absl/flags/tests/flags_test.py
index 8a42bc9..d001161 100644
--- a/absl/flags/tests/flags_test.py
+++ b/absl/flags/tests/flags_test.py
@@ -2646,6 +2646,40 @@ class KeyFlagsTest(absltest.TestCase):
self._get_names_of_key_flags(main_module, fv),
names_of_flags_defined_by_bar + ['flagfile', 'undefok'])
+ def test_key_flags_with_flagholders(self):
+ main_module = sys.argv[0]
+
+ self.assertListEqual(
+ self._get_names_of_key_flags(main_module, self.flag_values), [])
+ self.assertListEqual(
+ self._get_names_of_defined_flags(main_module, self.flag_values), [])
+
+ int_holder = flags.DEFINE_integer(
+ 'main_module_int_fg',
+ 1,
+ 'Integer flag in the main module.',
+ flag_values=self.flag_values)
+
+ flags.declare_key_flag(int_holder, self.flag_values)
+
+ self.assertCountEqual(
+ self.flag_values.get_flags_for_module(main_module),
+ self.flag_values.get_key_flags_for_module(main_module))
+
+ bool_holder = flags.DEFINE_boolean(
+ 'main_module_bool_fg',
+ False,
+ 'Boolean flag in the main module.',
+ flag_values=self.flag_values)
+
+ flags.declare_key_flag(bool_holder) # omitted flag_values
+
+ self.assertCountEqual(
+ self.flag_values.get_flags_for_module(main_module),
+ self.flag_values.get_key_flags_for_module(main_module))
+
+ self.assertLen(self.flag_values.get_flags_for_module(main_module), 2)
+
def test_main_module_help_with_key_flags(self):
# Similar to test_main_module_help, but this time we make sure to
# declare some key flags.