diff options
author | Abseil Team <absl-team@google.com> | 2022-09-13 21:15:56 -0700 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2022-09-13 21:16:26 -0700 |
commit | 0dd0210ed8919bca3338aebd1940824ada1a8e47 (patch) | |
tree | 3eac88ac81320eff80866cfd1f7df5a7ea85d581 | |
parent | b825efe5474a3c493b919752ecf630ce70b31b69 (diff) | |
download | absl-py-0dd0210ed8919bca3338aebd1940824ada1a8e47.tar.gz |
Support FlagHolders in flag module-level functions.
PiperOrigin-RevId: 474195462
Change-Id: I2b57d8ea6b8eb66d1f777f026fb19e5401983789
-rw-r--r-- | CHANGELOG.md | 25 | ||||
-rw-r--r-- | absl/flags/_defines.py | 14 | ||||
-rw-r--r-- | absl/flags/_defines.pyi | 2 | ||||
-rw-r--r-- | absl/flags/_flagvalues.py | 32 | ||||
-rw-r--r-- | absl/flags/_validators.py | 51 | ||||
-rw-r--r-- | absl/flags/tests/_validators_test.py | 245 | ||||
-rw-r--r-- | absl/flags/tests/flags_test.py | 34 |
7 files changed, 379 insertions, 24 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 06a8ee0..025f8ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,20 @@ The format is based on [Keep a Changelog](https://keepachangelog.com). ## Unreleased +### Added + +* (flags) The following functions now also accept `FlagHolder` instance(s) + in addition to flag name(s) as their first positional argument: + - `flags.register_validator` + - `flags.validator` + - `flags.register_multi_flags_validator` + - `flags.multi_flags_validator` + - `flags.mark_flag_as_required` + - `flags.mark_flags_as_required` + - `flags.mark_flags_as_mutual_exclusive` + - `flags.mark_bool_flags_as_mutual_exclusive` + - `flags.declare_key_flag` + ### Changed * (testing) Assertions `assertRaisesWithPredicateMatch` and @@ -13,6 +27,17 @@ The format is based on [Keep a Changelog](https://keepachangelog.com). further analysis when used as a context manager. * (testing) TextAndXMLTestRunner now produces time duration values with millisecond precision in XML test result output. +* (flags) Keyword access to `flag_name` arguments in the following functions + is deprecated. This parameter will be renamed in a future 2.0.0 release. + - `flags.register_validator` + - `flags.validator` + - `flags.register_multi_flags_validator` + - `flags.multi_flags_validator` + - `flags.mark_flag_as_required` + - `flags.mark_flags_as_required` + - `flags.mark_flags_as_mutual_exclusive` + - `flags.mark_bool_flags_as_mutual_exclusive` + - `flags.declare_key_flag` ## 1.2.0 (2022-07-18) diff --git a/absl/flags/_defines.py b/absl/flags/_defines.py index 12335e5..a4197c7 100644 --- a/absl/flags/_defines.py +++ b/absl/flags/_defines.py @@ -157,8 +157,7 @@ def _internal_declare_key_flags(flag_names, adopt_module_key_flags instead. Args: - flag_names: [str], a list of strings that are names of already-registered - Flag objects. + flag_names: [str], a list of names of already-registered Flag objects. flag_values: :class:`FlagValues`, the FlagValues instance with which the flags listed in flag_names have registered (the value of the flag_values argument from the ``DEFINE_*`` calls that defined those flags). This @@ -176,8 +175,7 @@ def _internal_declare_key_flags(flag_names, module = _helpers.get_calling_module() for flag_name in flag_names: - flag = flag_values[flag_name] - key_flag_values.register_key_flag_for_module(module, flag) + key_flag_values.register_key_flag_for_module(module, flag_values[flag_name]) def declare_key_flag(flag_name, flag_values=_flagvalues.FLAGS): @@ -194,9 +192,10 @@ def declare_key_flag(flag_name, flag_values=_flagvalues.FLAGS): flags.declare_key_flag('flag_1') Args: - flag_name: str, the name of an already declared flag. (Redeclaring flags as - key, including flags implicitly key because they were declared in this - module, is a no-op.) + flag_name: str | :class:`FlagHolder`, the name or holder of an already + declared flag. (Redeclaring flags as key, including flags implicitly key + because they were declared in this module, is a no-op.) + Positional-only parameter. flag_values: :class:`FlagValues`, the FlagValues instance in which the flag will be declared as a key flag. This should almost never need to be overridden. @@ -204,6 +203,7 @@ def declare_key_flag(flag_name, flag_values=_flagvalues.FLAGS): Raises: ValueError: Raised if flag_name not defined as a Python flag. """ + flag_name, flag_values = _flagvalues.resolve_flag_ref(flag_name, flag_values) if flag_name in _helpers.SPECIAL_FLAGS: # Take care of the special flags, e.g., --flagfile, --undefok. # These flags are defined in SPECIAL_FLAGS, and are treated diff --git a/absl/flags/_defines.pyi b/absl/flags/_defines.pyi index 0fbe921..c75a4f5 100644 --- a/absl/flags/_defines.pyi +++ b/absl/flags/_defines.pyi @@ -651,7 +651,7 @@ def DEFINE_alias( -def declare_key_flag(flag_name: Text, +def declare_key_flag(flag_name: Union[Text, _flagvalues.FlagHolder], flag_values: _flagvalues.FlagValues = ...) -> None: ... diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py index 48df1a0..937dc6c 100644 --- a/absl/flags/_flagvalues.py +++ b/absl/flags/_flagvalues.py @@ -1384,3 +1384,35 @@ class FlagHolder(Generic[_T]): def present(self): """Returns True if the flag was parsed from command-line flags.""" return bool(self._flagvalues[self._name].present) + + +def resolve_flag_ref(flag_ref, flag_values): + """Helper to validate and resolve a flag reference argument.""" + if isinstance(flag_ref, FlagHolder): + new_flag_values = flag_ref._flagvalues # pylint: disable=protected-access + if flag_values != FLAGS and flag_values != new_flag_values: + raise ValueError( + 'flag_values must not be customized when operating on a FlagHolder') + return flag_ref.name, new_flag_values + return flag_ref, flag_values + + +def resolve_flag_refs(flag_refs, flag_values): + """Helper to validate and resolve flag reference list arguments.""" + fv = None + names = [] + for ref in flag_refs: + if isinstance(ref, FlagHolder): + newfv = ref._flagvalues # pylint: disable=protected-access + name = ref.name + else: + newfv = flag_values + name = ref + if fv and fv != newfv: + raise ValueError( + 'multiple FlagValues instances used in invocation. ' + 'FlagHolders must be registered to the same FlagValues instance as ' + 'do flag names, if provided.') + fv = newfv + names.append(name) + return names, fv diff --git a/absl/flags/_validators.py b/absl/flags/_validators.py index c4e1139..2161284 100644 --- a/absl/flags/_validators.py +++ b/absl/flags/_validators.py @@ -51,7 +51,8 @@ def register_validator(flag_name, change of the corresponding flag's value. Args: - flag_name: str, name of the flag to be checked. + flag_name: str | FlagHolder, name or holder of the flag to be checked. + Positional-only parameter. checker: callable, a function to validate the flag. * input - A single positional argument: The value of the corresponding @@ -70,7 +71,10 @@ def register_validator(flag_name, Raises: AttributeError: Raised when flag_name is not registered as a valid flag name. + ValueError: Raised when flag_values is non-default and does not match the + FlagValues of the provided FlagHolder instance. """ + flag_name, flag_values = _flagvalues.resolve_flag_ref(flag_name, flag_values) v = _validators_classes.SingleFlagValidator(flag_name, checker, message) _add_validator(flag_values, v) @@ -88,7 +92,8 @@ def validator(flag_name, message='Flag validation failed', See :func:`register_validator` for the specification of checker function. Args: - flag_name: str, name of the flag to be checked. + flag_name: str | FlagHolder, name or holder of the flag to be checked. + Positional-only parameter. message: str, error text to be shown to the user if checker returns False. If checker raises flags.ValidationError, message from the raised error will be shown. @@ -119,7 +124,8 @@ def register_multi_flags_validator(flag_names, change of the corresponding flag's value. Args: - flag_names: [str], a list of the flag names to be checked. + flag_names: [str | FlagHolder], a list of the flag names or holders to be + checked. Positional-only parameter. multi_flags_checker: callable, a function to validate the flag. * input - dict, with keys() being flag_names, and value for each key @@ -136,7 +142,13 @@ def register_multi_flags_validator(flag_names, Raises: AttributeError: Raised when a flag is not registered as a valid flag name. + ValueError: Raised when multiple FlagValues are used in the same + invocation. This can occur when FlagHolders have different `_flagvalues` + or when str-type flag_names entries are present and the `flag_values` + argument does not match that of provided FlagHolder(s). """ + flag_names, flag_values = _flagvalues.resolve_flag_refs( + flag_names, flag_values) v = _validators_classes.MultiFlagsValidator( flag_names, multi_flags_checker, message) _add_validator(flag_values, v) @@ -157,7 +169,8 @@ def multi_flags_validator(flag_names, function. Args: - flag_names: [str], a list of the flag names to be checked. + flag_names: [str | FlagHolder], a list of the flag names or holders to be + checked. Positional-only parameter. message: str, error text to be shown to the user if checker returns False. If checker raises flags.ValidationError, message from the raised error will be shown. @@ -196,13 +209,17 @@ def mark_flag_as_required(flag_name, flag_values=_flagvalues.FLAGS): app.run() Args: - flag_name: str, name of the flag + flag_name: str | FlagHolder, name or holder of the flag. + Positional-only parameter. flag_values: flags.FlagValues, optional :class:`~absl.flags.FlagValues` instance where the flag is defined. Raises: AttributeError: Raised when flag_name is not registered as a valid flag name. + ValueError: Raised when flag_values is non-default and does not match the + FlagValues of the provided FlagHolder instance. """ + flag_name, flag_values = _flagvalues.resolve_flag_ref(flag_name, flag_values) if flag_values[flag_name].default is not None: warnings.warn( 'Flag --%s has a non-None default value; therefore, ' @@ -227,7 +244,7 @@ def mark_flags_as_required(flag_names, flag_values=_flagvalues.FLAGS): app.run() Args: - flag_names: Sequence[str], names of the flags. + flag_names: Sequence[str | FlagHolder], names or holders of the flags. flag_values: flags.FlagValues, optional FlagValues instance where the flags are defined. Raises: @@ -248,13 +265,22 @@ def mark_flags_as_mutual_exclusive(flag_names, required=False, includes multi flags with a default value of ``[]`` instead of None. Args: - flag_names: [str], names of the flags. + flag_names: [str | FlagHolder], names or holders of flags. + Positional-only parameter. required: bool. If true, exactly one of the flags must have a value other than None. Otherwise, at most one of the flags can have a value other than None, and it is valid for all of the flags to be None. flag_values: flags.FlagValues, optional FlagValues instance where the flags are defined. + + Raises: + ValueError: Raised when multiple FlagValues are used in the same + invocation. This can occur when FlagHolders have different `_flagvalues` + or when str-type flag_names entries are present and the `flag_values` + argument does not match that of provided FlagHolder(s). """ + flag_names, flag_values = _flagvalues.resolve_flag_refs( + flag_names, flag_values) for flag_name in flag_names: if flag_values[flag_name].default is not None: warnings.warn( @@ -280,12 +306,21 @@ def mark_bool_flags_as_mutual_exclusive(flag_names, required=False, """Ensures that only one flag among flag_names is True. Args: - flag_names: [str], names of the flags. + flag_names: [str | FlagHolder], names or holders of flags. + Positional-only parameter. required: bool. If true, exactly one flag must be True. Otherwise, at most one flag can be True, and it is valid for all flags to be False. flag_values: flags.FlagValues, optional FlagValues instance where the flags are defined. + + Raises: + ValueError: Raised when multiple FlagValues are used in the same + invocation. This can occur when FlagHolders have different `_flagvalues` + or when str-type flag_names entries are present and the `flag_values` + argument does not match that of provided FlagHolder(s). """ + flag_names, flag_values = _flagvalues.resolve_flag_refs( + flag_names, flag_values) for flag_name in flag_names: if not flag_values[flag_name].boolean: raise _exceptions.ValidationError( diff --git a/absl/flags/tests/_validators_test.py b/absl/flags/tests/_validators_test.py index 1cccf53..9aa328e 100644 --- a/absl/flags/tests/_validators_test.py +++ b/absl/flags/tests/_validators_test.py @@ -55,6 +55,45 @@ class SingleFlagValidatorTest(absltest.TestCase): self.assertEqual(2, self.flag_values.test_flag) self.assertEqual([None, 2], self.call_args) + def test_success_holder(self): + def checker(x): + self.call_args.append(x) + return True + + flag_holder = _defines.DEFINE_integer( + 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values) + _validators.register_validator( + flag_holder, + checker, + message='Errors happen', + flag_values=self.flag_values) + + argv = ('./program',) + self.flag_values(argv) + self.assertIsNone(self.flag_values.test_flag) + self.flag_values.test_flag = 2 + self.assertEqual(2, self.flag_values.test_flag) + self.assertEqual([None, 2], self.call_args) + + def test_success_holder_infer_flagvalues(self): + def checker(x): + self.call_args.append(x) + return True + + flag_holder = _defines.DEFINE_integer( + 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values) + _validators.register_validator( + flag_holder, + checker, + message='Errors happen') + + argv = ('./program',) + self.flag_values(argv) + self.assertIsNone(self.flag_values.test_flag) + self.flag_values.test_flag = 2 + self.assertEqual(2, self.flag_values.test_flag) + self.assertEqual([None, 2], self.call_args) + def test_default_value_not_used_success(self): def checker(x): self.call_args.append(x) @@ -218,6 +257,26 @@ class SingleFlagValidatorTest(absltest.TestCase): self.assertTrue(checker(3)) self.assertEqual([None, 2, 3], self.call_args) + def test_mismatching_flagvalues(self): + + def checker(x): + self.call_args.append(x) + return True + + flag_holder = _defines.DEFINE_integer( + 'test_flag', + None, + 'Usual integer flag', + flag_values=_flagvalues.FlagValues()) + expected = ( + 'flag_values must not be customized when operating on a FlagHolder') + with self.assertRaisesWithLiteralMatch(ValueError, expected): + _validators.register_validator( + flag_holder, + checker, + message='Errors happen', + flag_values=self.flag_values) + class MultiFlagsValidatorTest(absltest.TestCase): """Test flags multi-flag validators.""" @@ -226,9 +285,9 @@ class MultiFlagsValidatorTest(absltest.TestCase): super(MultiFlagsValidatorTest, self).setUp() self.flag_values = _flagvalues.FlagValues() self.call_args = [] - _defines.DEFINE_integer( + self.foo_holder = _defines.DEFINE_integer( 'foo', 1, 'Usual integer flag', flag_values=self.flag_values) - _defines.DEFINE_integer( + self.bar_holder = _defines.DEFINE_integer( 'bar', 2, 'Usual integer flag', flag_values=self.flag_values) def test_success(self): @@ -248,6 +307,55 @@ class MultiFlagsValidatorTest(absltest.TestCase): self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 2}], self.call_args) + def test_success_holder(self): + + def checker(flags_dict): + self.call_args.append(flags_dict) + return True + + _validators.register_multi_flags_validator( + [self.foo_holder, self.bar_holder], + checker, + flag_values=self.flag_values) + + argv = ('./program', '--bar=2') + self.flag_values(argv) + self.assertEqual(1, self.flag_values.foo) + self.assertEqual(2, self.flag_values.bar) + self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args) + self.flag_values.foo = 3 + self.assertEqual(3, self.flag_values.foo) + self.assertEqual([{ + 'foo': 1, + 'bar': 2 + }, { + 'foo': 3, + 'bar': 2 + }], self.call_args) + + def test_success_holder_infer_flagvalues(self): + def checker(flags_dict): + self.call_args.append(flags_dict) + return True + + _validators.register_multi_flags_validator( + [self.foo_holder, self.bar_holder], checker) + + argv = ('./program', '--bar=2') + self.flag_values(argv) + self.assertEqual(1, self.flag_values.foo) + self.assertEqual(2, self.flag_values.bar) + self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args) + self.flag_values.foo = 3 + self.assertEqual(3, self.flag_values.foo) + self.assertEqual([{ + 'foo': 1, + 'bar': 2 + }, { + 'foo': 3, + 'bar': 2 + }], self.call_args) + def test_validator_not_called_when_other_flag_is_changed(self): def checker(flags_dict): self.call_args.append(flags_dict) @@ -322,6 +430,30 @@ class MultiFlagsValidatorTest(absltest.TestCase): self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}], self.call_args) + def test_mismatching_flagvalues(self): + + def checker(flags_dict): + self.call_args.append(flags_dict) + values = flags_dict.values() + # Make sure all the flags have different values. + return len(set(values)) == len(values) + + other_holder = _defines.DEFINE_integer( + 'other_flag', + 3, + 'Other integer flag', + flag_values=_flagvalues.FlagValues()) + expected = ( + 'multiple FlagValues instances used in invocation. ' + 'FlagHolders must be registered to the same FlagValues instance as ' + 'do flag names, if provided.') + with self.assertRaisesWithLiteralMatch(ValueError, expected): + _validators.register_multi_flags_validator( + [self.foo_holder, self.bar_holder, other_holder], + checker, + message='Errors happen', + flag_values=self.flag_values) + class MarkFlagsAsMutualExclusiveTest(absltest.TestCase): @@ -329,9 +461,9 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase): super(MarkFlagsAsMutualExclusiveTest, self).setUp() self.flag_values = _flagvalues.FlagValues() - _defines.DEFINE_string( + self.flag_one_holder = _defines.DEFINE_string( 'flag_one', None, 'flag one', flag_values=self.flag_values) - _defines.DEFINE_string( + self.flag_two_holder = _defines.DEFINE_string( 'flag_two', None, 'flag two', flag_values=self.flag_values) _defines.DEFINE_string( 'flag_three', None, 'flag three', flag_values=self.flag_values) @@ -358,6 +490,24 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase): self.assertIsNone(self.flag_values.flag_one) self.assertIsNone(self.flag_values.flag_two) + def test_no_flags_present_holder(self): + self._mark_flags_as_mutually_exclusive( + [self.flag_one_holder, self.flag_two_holder], False) + argv = ('./program',) + + self.flag_values(argv) + self.assertIsNone(self.flag_values.flag_one) + self.assertIsNone(self.flag_values.flag_two) + + def test_no_flags_present_mixed(self): + self._mark_flags_as_mutually_exclusive([self.flag_one_holder, 'flag_two'], + False) + argv = ('./program',) + + self.flag_values(argv) + self.assertIsNone(self.flag_values.flag_one) + self.assertIsNone(self.flag_values.flag_two) + def test_no_flags_present_required(self): self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True) argv = ('./program',) @@ -494,6 +644,20 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase): self.assertIn('--flag_not_none has a non-None default value', str(caught_warnings[0].message)) + def test_multiple_flagvalues(self): + other_holder = _defines.DEFINE_boolean( + 'other_flagvalues', + False, + 'other ', + flag_values=_flagvalues.FlagValues()) + expected = ( + 'multiple FlagValues instances used in invocation. ' + 'FlagHolders must be registered to the same FlagValues instance as ' + 'do flag names, if provided.') + with self.assertRaisesWithLiteralMatch(ValueError, expected): + self._mark_flags_as_mutually_exclusive( + [self.flag_one_holder, other_holder], False) + class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase): @@ -501,13 +665,13 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase): super(MarkBoolFlagsAsMutualExclusiveTest, self).setUp() self.flag_values = _flagvalues.FlagValues() - _defines.DEFINE_boolean( + self.false_1_holder = _defines.DEFINE_boolean( 'false_1', False, 'default false 1', flag_values=self.flag_values) - _defines.DEFINE_boolean( + self.false_2_holder = _defines.DEFINE_boolean( 'false_2', False, 'default false 2', flag_values=self.flag_values) - _defines.DEFINE_boolean( + self.true_1_holder = _defines.DEFINE_boolean( 'true_1', True, 'default true 1', flag_values=self.flag_values) - _defines.DEFINE_integer( + self.non_bool_holder = _defines.DEFINE_integer( 'non_bool', None, 'non bool', flag_values=self.flag_values) def _mark_bool_flags_as_mutually_exclusive(self, flag_names, required): @@ -520,6 +684,20 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase): self.assertEqual(False, self.flag_values.false_1) self.assertEqual(False, self.flag_values.false_2) + def test_no_flags_present_holder(self): + self._mark_bool_flags_as_mutually_exclusive( + [self.false_1_holder, self.false_2_holder], False) + self.flag_values(('./program',)) + self.assertEqual(False, self.flag_values.false_1) + self.assertEqual(False, self.flag_values.false_2) + + def test_no_flags_present_mixed(self): + self._mark_bool_flags_as_mutually_exclusive( + [self.false_1_holder, 'false_2'], False) + self.flag_values(('./program',)) + self.assertEqual(False, self.flag_values.false_1) + self.assertEqual(False, self.flag_values.false_2) + def test_no_flags_present_required(self): self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], True) argv = ('./program',) @@ -554,6 +732,17 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase): self._mark_bool_flags_as_mutually_exclusive(['false_1', 'non_bool'], False) + def test_multiple_flagvalues(self): + other_bool_holder = _defines.DEFINE_boolean( + 'other_bool', False, 'other bool', flag_values=_flagvalues.FlagValues()) + expected = ( + 'multiple FlagValues instances used in invocation. ' + 'FlagHolders must be registered to the same FlagValues instance as ' + 'do flag names, if provided.') + with self.assertRaisesWithLiteralMatch(ValueError, expected): + self._mark_bool_flags_as_mutually_exclusive( + [self.false_1_holder, other_bool_holder], False) + class MarkFlagAsRequiredTest(absltest.TestCase): @@ -570,6 +759,22 @@ class MarkFlagAsRequiredTest(absltest.TestCase): self.flag_values(argv) self.assertEqual('value', self.flag_values.string_flag) + def test_success_holder(self): + holder = _defines.DEFINE_string( + 'string_flag', None, 'string flag', flag_values=self.flag_values) + _validators.mark_flag_as_required(holder, flag_values=self.flag_values) + argv = ('./program', '--string_flag=value') + self.flag_values(argv) + self.assertEqual('value', self.flag_values.string_flag) + + def test_success_holder_infer_flagvalues(self): + holder = _defines.DEFINE_string( + 'string_flag', None, 'string flag', flag_values=self.flag_values) + _validators.mark_flag_as_required(holder) + argv = ('./program', '--string_flag=value') + self.flag_values(argv) + self.assertEqual('value', self.flag_values.string_flag) + def test_catch_none_as_default(self): _defines.DEFINE_string( 'string_flag', None, 'string flag', flag_values=self.flag_values) @@ -608,6 +813,18 @@ class MarkFlagAsRequiredTest(absltest.TestCase): self.assertIn('--flag_not_none has a non-None default value', str(caught_warnings[0].message)) + def test_mismatching_flagvalues(self): + flag_holder = _defines.DEFINE_string( + 'string_flag', + 'value', + 'string flag', + flag_values=_flagvalues.FlagValues()) + expected = ( + 'flag_values must not be customized when operating on a FlagHolder') + with self.assertRaisesWithLiteralMatch(ValueError, expected): + _validators.mark_flag_as_required( + flag_holder, flag_values=self.flag_values) + class MarkFlagsAsRequiredTest(absltest.TestCase): @@ -627,6 +844,18 @@ class MarkFlagsAsRequiredTest(absltest.TestCase): self.assertEqual('value_1', self.flag_values.string_flag_1) self.assertEqual('value_2', self.flag_values.string_flag_2) + def test_success_holders(self): + flag_1_holder = _defines.DEFINE_string( + 'string_flag_1', None, 'string flag 1', flag_values=self.flag_values) + flag_2_holder = _defines.DEFINE_string( + 'string_flag_2', None, 'string flag 2', flag_values=self.flag_values) + _validators.mark_flags_as_required([flag_1_holder, flag_2_holder], + flag_values=self.flag_values) + argv = ('./program', '--string_flag_1=value_1', '--string_flag_2=value_2') + self.flag_values(argv) + self.assertEqual('value_1', self.flag_values.string_flag_1) + self.assertEqual('value_2', self.flag_values.string_flag_2) + def test_catch_none_as_default(self): _defines.DEFINE_string( 'string_flag_1', None, 'string flag 1', flag_values=self.flag_values) diff --git a/absl/flags/tests/flags_test.py b/absl/flags/tests/flags_test.py index 8a42bc9..d001161 100644 --- a/absl/flags/tests/flags_test.py +++ b/absl/flags/tests/flags_test.py @@ -2646,6 +2646,40 @@ class KeyFlagsTest(absltest.TestCase): self._get_names_of_key_flags(main_module, fv), names_of_flags_defined_by_bar + ['flagfile', 'undefok']) + def test_key_flags_with_flagholders(self): + main_module = sys.argv[0] + + self.assertListEqual( + self._get_names_of_key_flags(main_module, self.flag_values), []) + self.assertListEqual( + self._get_names_of_defined_flags(main_module, self.flag_values), []) + + int_holder = flags.DEFINE_integer( + 'main_module_int_fg', + 1, + 'Integer flag in the main module.', + flag_values=self.flag_values) + + flags.declare_key_flag(int_holder, self.flag_values) + + self.assertCountEqual( + self.flag_values.get_flags_for_module(main_module), + self.flag_values.get_key_flags_for_module(main_module)) + + bool_holder = flags.DEFINE_boolean( + 'main_module_bool_fg', + False, + 'Boolean flag in the main module.', + flag_values=self.flag_values) + + flags.declare_key_flag(bool_holder) # omitted flag_values + + self.assertCountEqual( + self.flag_values.get_flags_for_module(main_module), + self.flag_values.get_key_flags_for_module(main_module)) + + self.assertLen(self.flag_values.get_flags_for_module(main_module), 2) + def test_main_module_help_with_key_flags(self): # Similar to test_main_module_help, but this time we make sure to # declare some key flags. |