aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAbseil Team <absl-team@google.com>2022-09-20 11:41:53 -0700
committerCopybara-Service <copybara-worker@google.com>2022-09-20 11:42:41 -0700
commit06561a05a5b055adc243860a330bc1f20adb71e5 (patch)
tree0d1213cf27a4a5d0d558e131d0b70cb743508e89
parent0dd0210ed8919bca3338aebd1940824ada1a8e47 (diff)
downloadabsl-py-06561a05a5b055adc243860a330bc1f20adb71e5.tar.gz
Add a flags.set_default function.
PiperOrigin-RevId: 475611494 Change-Id: I9facdfcaece0b3c0bd2b0b38364ecd41cb06d79c
-rw-r--r--CHANGELOG.md7
-rw-r--r--absl/flags/__init__.py5
-rw-r--r--absl/flags/_defines.py17
-rw-r--r--absl/flags/_defines.pyi3
-rw-r--r--absl/flags/tests/flags_test.py65
5 files changed, 95 insertions, 2 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 025f8ed..17c5ce1 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -8,8 +8,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com).
### Added
-* (flags) The following functions now also accept `FlagHolder` instance(s)
- in addition to flag name(s) as their first positional argument:
+* (flags) Added a new `absl.flags.set_default` function that updates the flag
+ default for a provided `FlagHolder`. This parallels the
+ `absl.flags.FlagValues.set_default` interface which takes a flag name.
+* (flags) The following functions now also accept `FlagHolder` instance(s) in
+ addition to flag name(s) as their first positional argument:
- `flags.register_validator`
- `flags.validator`
- `flags.register_multi_flags_validator`
diff --git a/absl/flags/__init__.py b/absl/flags/__init__.py
index 45e64f3..6d8ba03 100644
--- a/absl/flags/__init__.py
+++ b/absl/flags/__init__.py
@@ -68,6 +68,8 @@ __all__ = (
'mark_flags_as_required',
'mark_flags_as_mutual_exclusive',
'mark_bool_flags_as_mutual_exclusive',
+ # Flag modifiers.
+ 'set_default',
# Key flag related functions.
'declare_key_flag',
'adopt_module_key_flags',
@@ -152,6 +154,9 @@ mark_flags_as_required = _validators.mark_flags_as_required
mark_flags_as_mutual_exclusive = _validators.mark_flags_as_mutual_exclusive
mark_bool_flags_as_mutual_exclusive = _validators.mark_bool_flags_as_mutual_exclusive
+# Flag modifiers.
+set_default = _defines.set_default
+
# Key flag related functions.
declare_key_flag = _defines.declare_key_flag
adopt_module_key_flags = _defines.adopt_module_key_flags
diff --git a/absl/flags/_defines.py b/absl/flags/_defines.py
index a4197c7..dce53ea 100644
--- a/absl/flags/_defines.py
+++ b/absl/flags/_defines.py
@@ -148,6 +148,23 @@ def DEFINE_flag( # pylint: disable=invalid-name
fv, flag, ensure_non_none_value=ensure_non_none_value)
+def set_default(flag_holder, value):
+ """Changes the default value of the provided flag object.
+
+ The flag's current value is also updated if the flag is currently using
+ the default value, i.e. not specified in the command line, and not set
+ by FLAGS.name = value.
+
+ Args:
+ flag_holder: FlagHolder, the flag to modify.
+ value: The new default value.
+
+ Raises:
+ IllegalFlagValueError: Raised when value is not valid.
+ """
+ flag_holder._flagvalues.set_default(flag_holder.name, value) # pylint: disable=protected-access
+
+
def _internal_declare_key_flags(flag_names,
flag_values=_flagvalues.FLAGS,
key_flag_values=None):
diff --git a/absl/flags/_defines.pyi b/absl/flags/_defines.pyi
index c75a4f5..9bc8067 100644
--- a/absl/flags/_defines.pyi
+++ b/absl/flags/_defines.pyi
@@ -650,6 +650,9 @@ def DEFINE_alias(
...
+def set_default(flag_holder: _flagvalues.FlagHolder[_T], value: _T) -> None:
+ ...
+
def declare_key_flag(flag_name: Union[Text, _flagvalues.FlagHolder],
flag_values: _flagvalues.FlagValues = ...) -> None:
diff --git a/absl/flags/tests/flags_test.py b/absl/flags/tests/flags_test.py
index d001161..77ed307 100644
--- a/absl/flags/tests/flags_test.py
+++ b/absl/flags/tests/flags_test.py
@@ -2483,6 +2483,71 @@ class NonGlobalFlagsTest(absltest.TestCase):
flag_values['flag_name'] = 'flag_value'
+class SetDefaultTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.flag_values = flags.FlagValues()
+
+ def test_success(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', 1, 'an int', flag_values=self.flag_values)
+
+ flags.set_default(int_holder, 2)
+ self.flag_values.mark_as_parsed()
+
+ self.assertEqual(int_holder.value, 2)
+
+ def test_update_after_parse(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', 1, 'an int', flag_values=self.flag_values)
+
+ self.flag_values.mark_as_parsed()
+ flags.set_default(int_holder, 2)
+
+ self.assertEqual(int_holder.value, 2)
+
+ def test_overridden_by_explicit_assignment(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', 1, 'an int', flag_values=self.flag_values)
+
+ self.flag_values.mark_as_parsed()
+ self.flag_values.an_int = 3
+ flags.set_default(int_holder, 2)
+
+ self.assertEqual(int_holder.value, 3)
+
+ def test_restores_back_to_none(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', None, 'an int', flag_values=self.flag_values)
+
+ self.flag_values.mark_as_parsed()
+ flags.set_default(int_holder, 3)
+ flags.set_default(int_holder, None)
+
+ self.assertIsNone(int_holder.value)
+
+ def test_failure_on_invalid_type(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', 1, 'an int', flag_values=self.flag_values)
+
+ self.flag_values.mark_as_parsed()
+
+ with self.assertRaises(flags.IllegalFlagValueError):
+ flags.set_default(int_holder, 'a')
+
+ def test_failure_on_type_protected_none_default(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', 1, 'an int', flag_values=self.flag_values)
+
+ self.flag_values.mark_as_parsed()
+
+ flags.set_default(int_holder, None) # NOTE: should be a type failure
+
+ with self.assertRaises(flags.IllegalFlagValueError):
+ _ = int_holder.value # Will also fail on later access.
+
+
class KeyFlagsTest(absltest.TestCase):
def setUp(self):