diff options
author | Abseil Team <absl-team@google.com> | 2022-09-20 11:41:53 -0700 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2022-09-20 11:42:41 -0700 |
commit | 06561a05a5b055adc243860a330bc1f20adb71e5 (patch) | |
tree | 0d1213cf27a4a5d0d558e131d0b70cb743508e89 | |
parent | 0dd0210ed8919bca3338aebd1940824ada1a8e47 (diff) | |
download | absl-py-06561a05a5b055adc243860a330bc1f20adb71e5.tar.gz |
Add a flags.set_default function.
PiperOrigin-RevId: 475611494
Change-Id: I9facdfcaece0b3c0bd2b0b38364ecd41cb06d79c
-rw-r--r-- | CHANGELOG.md | 7 | ||||
-rw-r--r-- | absl/flags/__init__.py | 5 | ||||
-rw-r--r-- | absl/flags/_defines.py | 17 | ||||
-rw-r--r-- | absl/flags/_defines.pyi | 3 | ||||
-rw-r--r-- | absl/flags/tests/flags_test.py | 65 |
5 files changed, 95 insertions, 2 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 025f8ed..17c5ce1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,8 +8,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com). ### Added -* (flags) The following functions now also accept `FlagHolder` instance(s) - in addition to flag name(s) as their first positional argument: +* (flags) Added a new `absl.flags.set_default` function that updates the flag + default for a provided `FlagHolder`. This parallels the + `absl.flags.FlagValues.set_default` interface which takes a flag name. +* (flags) The following functions now also accept `FlagHolder` instance(s) in + addition to flag name(s) as their first positional argument: - `flags.register_validator` - `flags.validator` - `flags.register_multi_flags_validator` diff --git a/absl/flags/__init__.py b/absl/flags/__init__.py index 45e64f3..6d8ba03 100644 --- a/absl/flags/__init__.py +++ b/absl/flags/__init__.py @@ -68,6 +68,8 @@ __all__ = ( 'mark_flags_as_required', 'mark_flags_as_mutual_exclusive', 'mark_bool_flags_as_mutual_exclusive', + # Flag modifiers. + 'set_default', # Key flag related functions. 'declare_key_flag', 'adopt_module_key_flags', @@ -152,6 +154,9 @@ mark_flags_as_required = _validators.mark_flags_as_required mark_flags_as_mutual_exclusive = _validators.mark_flags_as_mutual_exclusive mark_bool_flags_as_mutual_exclusive = _validators.mark_bool_flags_as_mutual_exclusive +# Flag modifiers. +set_default = _defines.set_default + # Key flag related functions. declare_key_flag = _defines.declare_key_flag adopt_module_key_flags = _defines.adopt_module_key_flags diff --git a/absl/flags/_defines.py b/absl/flags/_defines.py index a4197c7..dce53ea 100644 --- a/absl/flags/_defines.py +++ b/absl/flags/_defines.py @@ -148,6 +148,23 @@ def DEFINE_flag( # pylint: disable=invalid-name fv, flag, ensure_non_none_value=ensure_non_none_value) +def set_default(flag_holder, value): + """Changes the default value of the provided flag object. + + The flag's current value is also updated if the flag is currently using + the default value, i.e. not specified in the command line, and not set + by FLAGS.name = value. + + Args: + flag_holder: FlagHolder, the flag to modify. + value: The new default value. + + Raises: + IllegalFlagValueError: Raised when value is not valid. + """ + flag_holder._flagvalues.set_default(flag_holder.name, value) # pylint: disable=protected-access + + def _internal_declare_key_flags(flag_names, flag_values=_flagvalues.FLAGS, key_flag_values=None): diff --git a/absl/flags/_defines.pyi b/absl/flags/_defines.pyi index c75a4f5..9bc8067 100644 --- a/absl/flags/_defines.pyi +++ b/absl/flags/_defines.pyi @@ -650,6 +650,9 @@ def DEFINE_alias( ... +def set_default(flag_holder: _flagvalues.FlagHolder[_T], value: _T) -> None: + ... + def declare_key_flag(flag_name: Union[Text, _flagvalues.FlagHolder], flag_values: _flagvalues.FlagValues = ...) -> None: diff --git a/absl/flags/tests/flags_test.py b/absl/flags/tests/flags_test.py index d001161..77ed307 100644 --- a/absl/flags/tests/flags_test.py +++ b/absl/flags/tests/flags_test.py @@ -2483,6 +2483,71 @@ class NonGlobalFlagsTest(absltest.TestCase): flag_values['flag_name'] = 'flag_value' +class SetDefaultTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.flag_values = flags.FlagValues() + + def test_success(self): + int_holder = flags.DEFINE_integer( + 'an_int', 1, 'an int', flag_values=self.flag_values) + + flags.set_default(int_holder, 2) + self.flag_values.mark_as_parsed() + + self.assertEqual(int_holder.value, 2) + + def test_update_after_parse(self): + int_holder = flags.DEFINE_integer( + 'an_int', 1, 'an int', flag_values=self.flag_values) + + self.flag_values.mark_as_parsed() + flags.set_default(int_holder, 2) + + self.assertEqual(int_holder.value, 2) + + def test_overridden_by_explicit_assignment(self): + int_holder = flags.DEFINE_integer( + 'an_int', 1, 'an int', flag_values=self.flag_values) + + self.flag_values.mark_as_parsed() + self.flag_values.an_int = 3 + flags.set_default(int_holder, 2) + + self.assertEqual(int_holder.value, 3) + + def test_restores_back_to_none(self): + int_holder = flags.DEFINE_integer( + 'an_int', None, 'an int', flag_values=self.flag_values) + + self.flag_values.mark_as_parsed() + flags.set_default(int_holder, 3) + flags.set_default(int_holder, None) + + self.assertIsNone(int_holder.value) + + def test_failure_on_invalid_type(self): + int_holder = flags.DEFINE_integer( + 'an_int', 1, 'an int', flag_values=self.flag_values) + + self.flag_values.mark_as_parsed() + + with self.assertRaises(flags.IllegalFlagValueError): + flags.set_default(int_holder, 'a') + + def test_failure_on_type_protected_none_default(self): + int_holder = flags.DEFINE_integer( + 'an_int', 1, 'an int', flag_values=self.flag_values) + + self.flag_values.mark_as_parsed() + + flags.set_default(int_holder, None) # NOTE: should be a type failure + + with self.assertRaises(flags.IllegalFlagValueError): + _ = int_holder.value # Will also fail on later access. + + class KeyFlagsTest(absltest.TestCase): def setUp(self): |