diff options
Diffstat (limited to 'absl/flags/tests/flags_test.py')
-rw-r--r-- | absl/flags/tests/flags_test.py | 99 |
1 files changed, 99 insertions, 0 deletions
diff --git a/absl/flags/tests/flags_test.py b/absl/flags/tests/flags_test.py index 8a42bc9..77ed307 100644 --- a/absl/flags/tests/flags_test.py +++ b/absl/flags/tests/flags_test.py @@ -2483,6 +2483,71 @@ class NonGlobalFlagsTest(absltest.TestCase): flag_values['flag_name'] = 'flag_value' +class SetDefaultTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.flag_values = flags.FlagValues() + + def test_success(self): + int_holder = flags.DEFINE_integer( + 'an_int', 1, 'an int', flag_values=self.flag_values) + + flags.set_default(int_holder, 2) + self.flag_values.mark_as_parsed() + + self.assertEqual(int_holder.value, 2) + + def test_update_after_parse(self): + int_holder = flags.DEFINE_integer( + 'an_int', 1, 'an int', flag_values=self.flag_values) + + self.flag_values.mark_as_parsed() + flags.set_default(int_holder, 2) + + self.assertEqual(int_holder.value, 2) + + def test_overridden_by_explicit_assignment(self): + int_holder = flags.DEFINE_integer( + 'an_int', 1, 'an int', flag_values=self.flag_values) + + self.flag_values.mark_as_parsed() + self.flag_values.an_int = 3 + flags.set_default(int_holder, 2) + + self.assertEqual(int_holder.value, 3) + + def test_restores_back_to_none(self): + int_holder = flags.DEFINE_integer( + 'an_int', None, 'an int', flag_values=self.flag_values) + + self.flag_values.mark_as_parsed() + flags.set_default(int_holder, 3) + flags.set_default(int_holder, None) + + self.assertIsNone(int_holder.value) + + def test_failure_on_invalid_type(self): + int_holder = flags.DEFINE_integer( + 'an_int', 1, 'an int', flag_values=self.flag_values) + + self.flag_values.mark_as_parsed() + + with self.assertRaises(flags.IllegalFlagValueError): + flags.set_default(int_holder, 'a') + + def test_failure_on_type_protected_none_default(self): + int_holder = flags.DEFINE_integer( + 'an_int', 1, 'an int', flag_values=self.flag_values) + + self.flag_values.mark_as_parsed() + + flags.set_default(int_holder, None) # NOTE: should be a type failure + + with self.assertRaises(flags.IllegalFlagValueError): + _ = int_holder.value # Will also fail on later access. + + class KeyFlagsTest(absltest.TestCase): def setUp(self): @@ -2646,6 +2711,40 @@ class KeyFlagsTest(absltest.TestCase): self._get_names_of_key_flags(main_module, fv), names_of_flags_defined_by_bar + ['flagfile', 'undefok']) + def test_key_flags_with_flagholders(self): + main_module = sys.argv[0] + + self.assertListEqual( + self._get_names_of_key_flags(main_module, self.flag_values), []) + self.assertListEqual( + self._get_names_of_defined_flags(main_module, self.flag_values), []) + + int_holder = flags.DEFINE_integer( + 'main_module_int_fg', + 1, + 'Integer flag in the main module.', + flag_values=self.flag_values) + + flags.declare_key_flag(int_holder, self.flag_values) + + self.assertCountEqual( + self.flag_values.get_flags_for_module(main_module), + self.flag_values.get_key_flags_for_module(main_module)) + + bool_holder = flags.DEFINE_boolean( + 'main_module_bool_fg', + False, + 'Boolean flag in the main module.', + flag_values=self.flag_values) + + flags.declare_key_flag(bool_holder) # omitted flag_values + + self.assertCountEqual( + self.flag_values.get_flags_for_module(main_module), + self.flag_values.get_key_flags_for_module(main_module)) + + self.assertLen(self.flag_values.get_flags_for_module(main_module), 2) + def test_main_module_help_with_key_flags(self): # Similar to test_main_module_help, but this time we make sure to # declare some key flags. |