aboutsummaryrefslogtreecommitdiff
path: root/absl/flags/tests/flags_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'absl/flags/tests/flags_test.py')
-rw-r--r--absl/flags/tests/flags_test.py99
1 files changed, 99 insertions, 0 deletions
diff --git a/absl/flags/tests/flags_test.py b/absl/flags/tests/flags_test.py
index 8a42bc9..77ed307 100644
--- a/absl/flags/tests/flags_test.py
+++ b/absl/flags/tests/flags_test.py
@@ -2483,6 +2483,71 @@ class NonGlobalFlagsTest(absltest.TestCase):
flag_values['flag_name'] = 'flag_value'
+class SetDefaultTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.flag_values = flags.FlagValues()
+
+ def test_success(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', 1, 'an int', flag_values=self.flag_values)
+
+ flags.set_default(int_holder, 2)
+ self.flag_values.mark_as_parsed()
+
+ self.assertEqual(int_holder.value, 2)
+
+ def test_update_after_parse(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', 1, 'an int', flag_values=self.flag_values)
+
+ self.flag_values.mark_as_parsed()
+ flags.set_default(int_holder, 2)
+
+ self.assertEqual(int_holder.value, 2)
+
+ def test_overridden_by_explicit_assignment(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', 1, 'an int', flag_values=self.flag_values)
+
+ self.flag_values.mark_as_parsed()
+ self.flag_values.an_int = 3
+ flags.set_default(int_holder, 2)
+
+ self.assertEqual(int_holder.value, 3)
+
+ def test_restores_back_to_none(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', None, 'an int', flag_values=self.flag_values)
+
+ self.flag_values.mark_as_parsed()
+ flags.set_default(int_holder, 3)
+ flags.set_default(int_holder, None)
+
+ self.assertIsNone(int_holder.value)
+
+ def test_failure_on_invalid_type(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', 1, 'an int', flag_values=self.flag_values)
+
+ self.flag_values.mark_as_parsed()
+
+ with self.assertRaises(flags.IllegalFlagValueError):
+ flags.set_default(int_holder, 'a')
+
+ def test_failure_on_type_protected_none_default(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', 1, 'an int', flag_values=self.flag_values)
+
+ self.flag_values.mark_as_parsed()
+
+ flags.set_default(int_holder, None) # NOTE: should be a type failure
+
+ with self.assertRaises(flags.IllegalFlagValueError):
+ _ = int_holder.value # Will also fail on later access.
+
+
class KeyFlagsTest(absltest.TestCase):
def setUp(self):
@@ -2646,6 +2711,40 @@ class KeyFlagsTest(absltest.TestCase):
self._get_names_of_key_flags(main_module, fv),
names_of_flags_defined_by_bar + ['flagfile', 'undefok'])
+ def test_key_flags_with_flagholders(self):
+ main_module = sys.argv[0]
+
+ self.assertListEqual(
+ self._get_names_of_key_flags(main_module, self.flag_values), [])
+ self.assertListEqual(
+ self._get_names_of_defined_flags(main_module, self.flag_values), [])
+
+ int_holder = flags.DEFINE_integer(
+ 'main_module_int_fg',
+ 1,
+ 'Integer flag in the main module.',
+ flag_values=self.flag_values)
+
+ flags.declare_key_flag(int_holder, self.flag_values)
+
+ self.assertCountEqual(
+ self.flag_values.get_flags_for_module(main_module),
+ self.flag_values.get_key_flags_for_module(main_module))
+
+ bool_holder = flags.DEFINE_boolean(
+ 'main_module_bool_fg',
+ False,
+ 'Boolean flag in the main module.',
+ flag_values=self.flag_values)
+
+ flags.declare_key_flag(bool_holder) # omitted flag_values
+
+ self.assertCountEqual(
+ self.flag_values.get_flags_for_module(main_module),
+ self.flag_values.get_key_flags_for_module(main_module))
+
+ self.assertLen(self.flag_values.get_flags_for_module(main_module), 2)
+
def test_main_module_help_with_key_flags(self):
# Similar to test_main_module_help, but this time we make sure to
# declare some key flags.