aboutsummaryrefslogtreecommitdiff
path: root/absl/flags/tests/flags_test.py
diff options
context:
space:
mode:
authorAbseil Team <absl-team@google.com>2022-09-20 11:41:53 -0700
committerCopybara-Service <copybara-worker@google.com>2022-09-20 11:42:41 -0700
commit06561a05a5b055adc243860a330bc1f20adb71e5 (patch)
tree0d1213cf27a4a5d0d558e131d0b70cb743508e89 /absl/flags/tests/flags_test.py
parent0dd0210ed8919bca3338aebd1940824ada1a8e47 (diff)
downloadabsl-py-06561a05a5b055adc243860a330bc1f20adb71e5.tar.gz
Add a flags.set_default function.
PiperOrigin-RevId: 475611494 Change-Id: I9facdfcaece0b3c0bd2b0b38364ecd41cb06d79c
Diffstat (limited to 'absl/flags/tests/flags_test.py')
-rw-r--r--absl/flags/tests/flags_test.py65
1 files changed, 65 insertions, 0 deletions
diff --git a/absl/flags/tests/flags_test.py b/absl/flags/tests/flags_test.py
index d001161..77ed307 100644
--- a/absl/flags/tests/flags_test.py
+++ b/absl/flags/tests/flags_test.py
@@ -2483,6 +2483,71 @@ class NonGlobalFlagsTest(absltest.TestCase):
flag_values['flag_name'] = 'flag_value'
+class SetDefaultTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.flag_values = flags.FlagValues()
+
+ def test_success(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', 1, 'an int', flag_values=self.flag_values)
+
+ flags.set_default(int_holder, 2)
+ self.flag_values.mark_as_parsed()
+
+ self.assertEqual(int_holder.value, 2)
+
+ def test_update_after_parse(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', 1, 'an int', flag_values=self.flag_values)
+
+ self.flag_values.mark_as_parsed()
+ flags.set_default(int_holder, 2)
+
+ self.assertEqual(int_holder.value, 2)
+
+ def test_overridden_by_explicit_assignment(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', 1, 'an int', flag_values=self.flag_values)
+
+ self.flag_values.mark_as_parsed()
+ self.flag_values.an_int = 3
+ flags.set_default(int_holder, 2)
+
+ self.assertEqual(int_holder.value, 3)
+
+ def test_restores_back_to_none(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', None, 'an int', flag_values=self.flag_values)
+
+ self.flag_values.mark_as_parsed()
+ flags.set_default(int_holder, 3)
+ flags.set_default(int_holder, None)
+
+ self.assertIsNone(int_holder.value)
+
+ def test_failure_on_invalid_type(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', 1, 'an int', flag_values=self.flag_values)
+
+ self.flag_values.mark_as_parsed()
+
+ with self.assertRaises(flags.IllegalFlagValueError):
+ flags.set_default(int_holder, 'a')
+
+ def test_failure_on_type_protected_none_default(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', 1, 'an int', flag_values=self.flag_values)
+
+ self.flag_values.mark_as_parsed()
+
+ flags.set_default(int_holder, None) # NOTE: should be a type failure
+
+ with self.assertRaises(flags.IllegalFlagValueError):
+ _ = int_holder.value # Will also fail on later access.
+
+
class KeyFlagsTest(absltest.TestCase):
def setUp(self):