diff options
Diffstat (limited to 'absl/flags/tests/_validators_test.py')
-rw-r--r-- | absl/flags/tests/_validators_test.py | 249 |
1 files changed, 237 insertions, 12 deletions
diff --git a/absl/flags/tests/_validators_test.py b/absl/flags/tests/_validators_test.py index f724813..9aa328e 100644 --- a/absl/flags/tests/_validators_test.py +++ b/absl/flags/tests/_validators_test.py @@ -18,10 +18,6 @@ This file tests that each flag validator called when it should be, and that failed validator will throw an exception, etc. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import warnings @@ -59,6 +55,45 @@ class SingleFlagValidatorTest(absltest.TestCase): self.assertEqual(2, self.flag_values.test_flag) self.assertEqual([None, 2], self.call_args) + def test_success_holder(self): + def checker(x): + self.call_args.append(x) + return True + + flag_holder = _defines.DEFINE_integer( + 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values) + _validators.register_validator( + flag_holder, + checker, + message='Errors happen', + flag_values=self.flag_values) + + argv = ('./program',) + self.flag_values(argv) + self.assertIsNone(self.flag_values.test_flag) + self.flag_values.test_flag = 2 + self.assertEqual(2, self.flag_values.test_flag) + self.assertEqual([None, 2], self.call_args) + + def test_success_holder_infer_flagvalues(self): + def checker(x): + self.call_args.append(x) + return True + + flag_holder = _defines.DEFINE_integer( + 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values) + _validators.register_validator( + flag_holder, + checker, + message='Errors happen') + + argv = ('./program',) + self.flag_values(argv) + self.assertIsNone(self.flag_values.test_flag) + self.flag_values.test_flag = 2 + self.assertEqual(2, self.flag_values.test_flag) + self.assertEqual([None, 2], self.call_args) + def test_default_value_not_used_success(self): def checker(x): self.call_args.append(x) @@ -222,6 +257,26 @@ class SingleFlagValidatorTest(absltest.TestCase): self.assertTrue(checker(3)) self.assertEqual([None, 2, 3], self.call_args) + def test_mismatching_flagvalues(self): + + def checker(x): + self.call_args.append(x) + return True + + flag_holder = _defines.DEFINE_integer( + 'test_flag', + None, + 'Usual integer flag', + flag_values=_flagvalues.FlagValues()) + expected = ( + 'flag_values must not be customized when operating on a FlagHolder') + with self.assertRaisesWithLiteralMatch(ValueError, expected): + _validators.register_validator( + flag_holder, + checker, + message='Errors happen', + flag_values=self.flag_values) + class MultiFlagsValidatorTest(absltest.TestCase): """Test flags multi-flag validators.""" @@ -230,9 +285,9 @@ class MultiFlagsValidatorTest(absltest.TestCase): super(MultiFlagsValidatorTest, self).setUp() self.flag_values = _flagvalues.FlagValues() self.call_args = [] - _defines.DEFINE_integer( + self.foo_holder = _defines.DEFINE_integer( 'foo', 1, 'Usual integer flag', flag_values=self.flag_values) - _defines.DEFINE_integer( + self.bar_holder = _defines.DEFINE_integer( 'bar', 2, 'Usual integer flag', flag_values=self.flag_values) def test_success(self): @@ -252,6 +307,55 @@ class MultiFlagsValidatorTest(absltest.TestCase): self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 2}], self.call_args) + def test_success_holder(self): + + def checker(flags_dict): + self.call_args.append(flags_dict) + return True + + _validators.register_multi_flags_validator( + [self.foo_holder, self.bar_holder], + checker, + flag_values=self.flag_values) + + argv = ('./program', '--bar=2') + self.flag_values(argv) + self.assertEqual(1, self.flag_values.foo) + self.assertEqual(2, self.flag_values.bar) + self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args) + self.flag_values.foo = 3 + self.assertEqual(3, self.flag_values.foo) + self.assertEqual([{ + 'foo': 1, + 'bar': 2 + }, { + 'foo': 3, + 'bar': 2 + }], self.call_args) + + def test_success_holder_infer_flagvalues(self): + def checker(flags_dict): + self.call_args.append(flags_dict) + return True + + _validators.register_multi_flags_validator( + [self.foo_holder, self.bar_holder], checker) + + argv = ('./program', '--bar=2') + self.flag_values(argv) + self.assertEqual(1, self.flag_values.foo) + self.assertEqual(2, self.flag_values.bar) + self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args) + self.flag_values.foo = 3 + self.assertEqual(3, self.flag_values.foo) + self.assertEqual([{ + 'foo': 1, + 'bar': 2 + }, { + 'foo': 3, + 'bar': 2 + }], self.call_args) + def test_validator_not_called_when_other_flag_is_changed(self): def checker(flags_dict): self.call_args.append(flags_dict) @@ -326,6 +430,30 @@ class MultiFlagsValidatorTest(absltest.TestCase): self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}], self.call_args) + def test_mismatching_flagvalues(self): + + def checker(flags_dict): + self.call_args.append(flags_dict) + values = flags_dict.values() + # Make sure all the flags have different values. + return len(set(values)) == len(values) + + other_holder = _defines.DEFINE_integer( + 'other_flag', + 3, + 'Other integer flag', + flag_values=_flagvalues.FlagValues()) + expected = ( + 'multiple FlagValues instances used in invocation. ' + 'FlagHolders must be registered to the same FlagValues instance as ' + 'do flag names, if provided.') + with self.assertRaisesWithLiteralMatch(ValueError, expected): + _validators.register_multi_flags_validator( + [self.foo_holder, self.bar_holder, other_holder], + checker, + message='Errors happen', + flag_values=self.flag_values) + class MarkFlagsAsMutualExclusiveTest(absltest.TestCase): @@ -333,9 +461,9 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase): super(MarkFlagsAsMutualExclusiveTest, self).setUp() self.flag_values = _flagvalues.FlagValues() - _defines.DEFINE_string( + self.flag_one_holder = _defines.DEFINE_string( 'flag_one', None, 'flag one', flag_values=self.flag_values) - _defines.DEFINE_string( + self.flag_two_holder = _defines.DEFINE_string( 'flag_two', None, 'flag two', flag_values=self.flag_values) _defines.DEFINE_string( 'flag_three', None, 'flag three', flag_values=self.flag_values) @@ -362,6 +490,24 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase): self.assertIsNone(self.flag_values.flag_one) self.assertIsNone(self.flag_values.flag_two) + def test_no_flags_present_holder(self): + self._mark_flags_as_mutually_exclusive( + [self.flag_one_holder, self.flag_two_holder], False) + argv = ('./program',) + + self.flag_values(argv) + self.assertIsNone(self.flag_values.flag_one) + self.assertIsNone(self.flag_values.flag_two) + + def test_no_flags_present_mixed(self): + self._mark_flags_as_mutually_exclusive([self.flag_one_holder, 'flag_two'], + False) + argv = ('./program',) + + self.flag_values(argv) + self.assertIsNone(self.flag_values.flag_one) + self.assertIsNone(self.flag_values.flag_two) + def test_no_flags_present_required(self): self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True) argv = ('./program',) @@ -498,6 +644,20 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase): self.assertIn('--flag_not_none has a non-None default value', str(caught_warnings[0].message)) + def test_multiple_flagvalues(self): + other_holder = _defines.DEFINE_boolean( + 'other_flagvalues', + False, + 'other ', + flag_values=_flagvalues.FlagValues()) + expected = ( + 'multiple FlagValues instances used in invocation. ' + 'FlagHolders must be registered to the same FlagValues instance as ' + 'do flag names, if provided.') + with self.assertRaisesWithLiteralMatch(ValueError, expected): + self._mark_flags_as_mutually_exclusive( + [self.flag_one_holder, other_holder], False) + class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase): @@ -505,13 +665,13 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase): super(MarkBoolFlagsAsMutualExclusiveTest, self).setUp() self.flag_values = _flagvalues.FlagValues() - _defines.DEFINE_boolean( + self.false_1_holder = _defines.DEFINE_boolean( 'false_1', False, 'default false 1', flag_values=self.flag_values) - _defines.DEFINE_boolean( + self.false_2_holder = _defines.DEFINE_boolean( 'false_2', False, 'default false 2', flag_values=self.flag_values) - _defines.DEFINE_boolean( + self.true_1_holder = _defines.DEFINE_boolean( 'true_1', True, 'default true 1', flag_values=self.flag_values) - _defines.DEFINE_integer( + self.non_bool_holder = _defines.DEFINE_integer( 'non_bool', None, 'non bool', flag_values=self.flag_values) def _mark_bool_flags_as_mutually_exclusive(self, flag_names, required): @@ -524,6 +684,20 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase): self.assertEqual(False, self.flag_values.false_1) self.assertEqual(False, self.flag_values.false_2) + def test_no_flags_present_holder(self): + self._mark_bool_flags_as_mutually_exclusive( + [self.false_1_holder, self.false_2_holder], False) + self.flag_values(('./program',)) + self.assertEqual(False, self.flag_values.false_1) + self.assertEqual(False, self.flag_values.false_2) + + def test_no_flags_present_mixed(self): + self._mark_bool_flags_as_mutually_exclusive( + [self.false_1_holder, 'false_2'], False) + self.flag_values(('./program',)) + self.assertEqual(False, self.flag_values.false_1) + self.assertEqual(False, self.flag_values.false_2) + def test_no_flags_present_required(self): self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], True) argv = ('./program',) @@ -558,6 +732,17 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase): self._mark_bool_flags_as_mutually_exclusive(['false_1', 'non_bool'], False) + def test_multiple_flagvalues(self): + other_bool_holder = _defines.DEFINE_boolean( + 'other_bool', False, 'other bool', flag_values=_flagvalues.FlagValues()) + expected = ( + 'multiple FlagValues instances used in invocation. ' + 'FlagHolders must be registered to the same FlagValues instance as ' + 'do flag names, if provided.') + with self.assertRaisesWithLiteralMatch(ValueError, expected): + self._mark_bool_flags_as_mutually_exclusive( + [self.false_1_holder, other_bool_holder], False) + class MarkFlagAsRequiredTest(absltest.TestCase): @@ -574,6 +759,22 @@ class MarkFlagAsRequiredTest(absltest.TestCase): self.flag_values(argv) self.assertEqual('value', self.flag_values.string_flag) + def test_success_holder(self): + holder = _defines.DEFINE_string( + 'string_flag', None, 'string flag', flag_values=self.flag_values) + _validators.mark_flag_as_required(holder, flag_values=self.flag_values) + argv = ('./program', '--string_flag=value') + self.flag_values(argv) + self.assertEqual('value', self.flag_values.string_flag) + + def test_success_holder_infer_flagvalues(self): + holder = _defines.DEFINE_string( + 'string_flag', None, 'string flag', flag_values=self.flag_values) + _validators.mark_flag_as_required(holder) + argv = ('./program', '--string_flag=value') + self.flag_values(argv) + self.assertEqual('value', self.flag_values.string_flag) + def test_catch_none_as_default(self): _defines.DEFINE_string( 'string_flag', None, 'string flag', flag_values=self.flag_values) @@ -612,6 +813,18 @@ class MarkFlagAsRequiredTest(absltest.TestCase): self.assertIn('--flag_not_none has a non-None default value', str(caught_warnings[0].message)) + def test_mismatching_flagvalues(self): + flag_holder = _defines.DEFINE_string( + 'string_flag', + 'value', + 'string flag', + flag_values=_flagvalues.FlagValues()) + expected = ( + 'flag_values must not be customized when operating on a FlagHolder') + with self.assertRaisesWithLiteralMatch(ValueError, expected): + _validators.mark_flag_as_required( + flag_holder, flag_values=self.flag_values) + class MarkFlagsAsRequiredTest(absltest.TestCase): @@ -631,6 +844,18 @@ class MarkFlagsAsRequiredTest(absltest.TestCase): self.assertEqual('value_1', self.flag_values.string_flag_1) self.assertEqual('value_2', self.flag_values.string_flag_2) + def test_success_holders(self): + flag_1_holder = _defines.DEFINE_string( + 'string_flag_1', None, 'string flag 1', flag_values=self.flag_values) + flag_2_holder = _defines.DEFINE_string( + 'string_flag_2', None, 'string flag 2', flag_values=self.flag_values) + _validators.mark_flags_as_required([flag_1_holder, flag_2_holder], + flag_values=self.flag_values) + argv = ('./program', '--string_flag_1=value_1', '--string_flag_2=value_2') + self.flag_values(argv) + self.assertEqual('value_1', self.flag_values.string_flag_1) + self.assertEqual('value_2', self.flag_values.string_flag_2) + def test_catch_none_as_default(self): _defines.DEFINE_string( 'string_flag_1', None, 'string flag 1', flag_values=self.flag_values) |