aboutsummaryrefslogtreecommitdiff
path: root/absl/flags/tests/_validators_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'absl/flags/tests/_validators_test.py')
-rw-r--r--absl/flags/tests/_validators_test.py249
1 files changed, 237 insertions, 12 deletions
diff --git a/absl/flags/tests/_validators_test.py b/absl/flags/tests/_validators_test.py
index f724813..9aa328e 100644
--- a/absl/flags/tests/_validators_test.py
+++ b/absl/flags/tests/_validators_test.py
@@ -18,10 +18,6 @@ This file tests that each flag validator called when it should be, and that
failed validator will throw an exception, etc.
"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import warnings
@@ -59,6 +55,45 @@ class SingleFlagValidatorTest(absltest.TestCase):
self.assertEqual(2, self.flag_values.test_flag)
self.assertEqual([None, 2], self.call_args)
+ def test_success_holder(self):
+ def checker(x):
+ self.call_args.append(x)
+ return True
+
+ flag_holder = _defines.DEFINE_integer(
+ 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
+ _validators.register_validator(
+ flag_holder,
+ checker,
+ message='Errors happen',
+ flag_values=self.flag_values)
+
+ argv = ('./program',)
+ self.flag_values(argv)
+ self.assertIsNone(self.flag_values.test_flag)
+ self.flag_values.test_flag = 2
+ self.assertEqual(2, self.flag_values.test_flag)
+ self.assertEqual([None, 2], self.call_args)
+
+ def test_success_holder_infer_flagvalues(self):
+ def checker(x):
+ self.call_args.append(x)
+ return True
+
+ flag_holder = _defines.DEFINE_integer(
+ 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
+ _validators.register_validator(
+ flag_holder,
+ checker,
+ message='Errors happen')
+
+ argv = ('./program',)
+ self.flag_values(argv)
+ self.assertIsNone(self.flag_values.test_flag)
+ self.flag_values.test_flag = 2
+ self.assertEqual(2, self.flag_values.test_flag)
+ self.assertEqual([None, 2], self.call_args)
+
def test_default_value_not_used_success(self):
def checker(x):
self.call_args.append(x)
@@ -222,6 +257,26 @@ class SingleFlagValidatorTest(absltest.TestCase):
self.assertTrue(checker(3))
self.assertEqual([None, 2, 3], self.call_args)
+ def test_mismatching_flagvalues(self):
+
+ def checker(x):
+ self.call_args.append(x)
+ return True
+
+ flag_holder = _defines.DEFINE_integer(
+ 'test_flag',
+ None,
+ 'Usual integer flag',
+ flag_values=_flagvalues.FlagValues())
+ expected = (
+ 'flag_values must not be customized when operating on a FlagHolder')
+ with self.assertRaisesWithLiteralMatch(ValueError, expected):
+ _validators.register_validator(
+ flag_holder,
+ checker,
+ message='Errors happen',
+ flag_values=self.flag_values)
+
class MultiFlagsValidatorTest(absltest.TestCase):
"""Test flags multi-flag validators."""
@@ -230,9 +285,9 @@ class MultiFlagsValidatorTest(absltest.TestCase):
super(MultiFlagsValidatorTest, self).setUp()
self.flag_values = _flagvalues.FlagValues()
self.call_args = []
- _defines.DEFINE_integer(
+ self.foo_holder = _defines.DEFINE_integer(
'foo', 1, 'Usual integer flag', flag_values=self.flag_values)
- _defines.DEFINE_integer(
+ self.bar_holder = _defines.DEFINE_integer(
'bar', 2, 'Usual integer flag', flag_values=self.flag_values)
def test_success(self):
@@ -252,6 +307,55 @@ class MultiFlagsValidatorTest(absltest.TestCase):
self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 2}],
self.call_args)
+ def test_success_holder(self):
+
+ def checker(flags_dict):
+ self.call_args.append(flags_dict)
+ return True
+
+ _validators.register_multi_flags_validator(
+ [self.foo_holder, self.bar_holder],
+ checker,
+ flag_values=self.flag_values)
+
+ argv = ('./program', '--bar=2')
+ self.flag_values(argv)
+ self.assertEqual(1, self.flag_values.foo)
+ self.assertEqual(2, self.flag_values.bar)
+ self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
+ self.flag_values.foo = 3
+ self.assertEqual(3, self.flag_values.foo)
+ self.assertEqual([{
+ 'foo': 1,
+ 'bar': 2
+ }, {
+ 'foo': 3,
+ 'bar': 2
+ }], self.call_args)
+
+ def test_success_holder_infer_flagvalues(self):
+ def checker(flags_dict):
+ self.call_args.append(flags_dict)
+ return True
+
+ _validators.register_multi_flags_validator(
+ [self.foo_holder, self.bar_holder], checker)
+
+ argv = ('./program', '--bar=2')
+ self.flag_values(argv)
+ self.assertEqual(1, self.flag_values.foo)
+ self.assertEqual(2, self.flag_values.bar)
+ self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
+ self.flag_values.foo = 3
+ self.assertEqual(3, self.flag_values.foo)
+ self.assertEqual([{
+ 'foo': 1,
+ 'bar': 2
+ }, {
+ 'foo': 3,
+ 'bar': 2
+ }], self.call_args)
+
def test_validator_not_called_when_other_flag_is_changed(self):
def checker(flags_dict):
self.call_args.append(flags_dict)
@@ -326,6 +430,30 @@ class MultiFlagsValidatorTest(absltest.TestCase):
self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
self.call_args)
+ def test_mismatching_flagvalues(self):
+
+ def checker(flags_dict):
+ self.call_args.append(flags_dict)
+ values = flags_dict.values()
+ # Make sure all the flags have different values.
+ return len(set(values)) == len(values)
+
+ other_holder = _defines.DEFINE_integer(
+ 'other_flag',
+ 3,
+ 'Other integer flag',
+ flag_values=_flagvalues.FlagValues())
+ expected = (
+ 'multiple FlagValues instances used in invocation. '
+ 'FlagHolders must be registered to the same FlagValues instance as '
+ 'do flag names, if provided.')
+ with self.assertRaisesWithLiteralMatch(ValueError, expected):
+ _validators.register_multi_flags_validator(
+ [self.foo_holder, self.bar_holder, other_holder],
+ checker,
+ message='Errors happen',
+ flag_values=self.flag_values)
+
class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
@@ -333,9 +461,9 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
super(MarkFlagsAsMutualExclusiveTest, self).setUp()
self.flag_values = _flagvalues.FlagValues()
- _defines.DEFINE_string(
+ self.flag_one_holder = _defines.DEFINE_string(
'flag_one', None, 'flag one', flag_values=self.flag_values)
- _defines.DEFINE_string(
+ self.flag_two_holder = _defines.DEFINE_string(
'flag_two', None, 'flag two', flag_values=self.flag_values)
_defines.DEFINE_string(
'flag_three', None, 'flag three', flag_values=self.flag_values)
@@ -362,6 +490,24 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
self.assertIsNone(self.flag_values.flag_one)
self.assertIsNone(self.flag_values.flag_two)
+ def test_no_flags_present_holder(self):
+ self._mark_flags_as_mutually_exclusive(
+ [self.flag_one_holder, self.flag_two_holder], False)
+ argv = ('./program',)
+
+ self.flag_values(argv)
+ self.assertIsNone(self.flag_values.flag_one)
+ self.assertIsNone(self.flag_values.flag_two)
+
+ def test_no_flags_present_mixed(self):
+ self._mark_flags_as_mutually_exclusive([self.flag_one_holder, 'flag_two'],
+ False)
+ argv = ('./program',)
+
+ self.flag_values(argv)
+ self.assertIsNone(self.flag_values.flag_one)
+ self.assertIsNone(self.flag_values.flag_two)
+
def test_no_flags_present_required(self):
self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
argv = ('./program',)
@@ -498,6 +644,20 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
self.assertIn('--flag_not_none has a non-None default value',
str(caught_warnings[0].message))
+ def test_multiple_flagvalues(self):
+ other_holder = _defines.DEFINE_boolean(
+ 'other_flagvalues',
+ False,
+ 'other ',
+ flag_values=_flagvalues.FlagValues())
+ expected = (
+ 'multiple FlagValues instances used in invocation. '
+ 'FlagHolders must be registered to the same FlagValues instance as '
+ 'do flag names, if provided.')
+ with self.assertRaisesWithLiteralMatch(ValueError, expected):
+ self._mark_flags_as_mutually_exclusive(
+ [self.flag_one_holder, other_holder], False)
+
class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase):
@@ -505,13 +665,13 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase):
super(MarkBoolFlagsAsMutualExclusiveTest, self).setUp()
self.flag_values = _flagvalues.FlagValues()
- _defines.DEFINE_boolean(
+ self.false_1_holder = _defines.DEFINE_boolean(
'false_1', False, 'default false 1', flag_values=self.flag_values)
- _defines.DEFINE_boolean(
+ self.false_2_holder = _defines.DEFINE_boolean(
'false_2', False, 'default false 2', flag_values=self.flag_values)
- _defines.DEFINE_boolean(
+ self.true_1_holder = _defines.DEFINE_boolean(
'true_1', True, 'default true 1', flag_values=self.flag_values)
- _defines.DEFINE_integer(
+ self.non_bool_holder = _defines.DEFINE_integer(
'non_bool', None, 'non bool', flag_values=self.flag_values)
def _mark_bool_flags_as_mutually_exclusive(self, flag_names, required):
@@ -524,6 +684,20 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase):
self.assertEqual(False, self.flag_values.false_1)
self.assertEqual(False, self.flag_values.false_2)
+ def test_no_flags_present_holder(self):
+ self._mark_bool_flags_as_mutually_exclusive(
+ [self.false_1_holder, self.false_2_holder], False)
+ self.flag_values(('./program',))
+ self.assertEqual(False, self.flag_values.false_1)
+ self.assertEqual(False, self.flag_values.false_2)
+
+ def test_no_flags_present_mixed(self):
+ self._mark_bool_flags_as_mutually_exclusive(
+ [self.false_1_holder, 'false_2'], False)
+ self.flag_values(('./program',))
+ self.assertEqual(False, self.flag_values.false_1)
+ self.assertEqual(False, self.flag_values.false_2)
+
def test_no_flags_present_required(self):
self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], True)
argv = ('./program',)
@@ -558,6 +732,17 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase):
self._mark_bool_flags_as_mutually_exclusive(['false_1', 'non_bool'],
False)
+ def test_multiple_flagvalues(self):
+ other_bool_holder = _defines.DEFINE_boolean(
+ 'other_bool', False, 'other bool', flag_values=_flagvalues.FlagValues())
+ expected = (
+ 'multiple FlagValues instances used in invocation. '
+ 'FlagHolders must be registered to the same FlagValues instance as '
+ 'do flag names, if provided.')
+ with self.assertRaisesWithLiteralMatch(ValueError, expected):
+ self._mark_bool_flags_as_mutually_exclusive(
+ [self.false_1_holder, other_bool_holder], False)
+
class MarkFlagAsRequiredTest(absltest.TestCase):
@@ -574,6 +759,22 @@ class MarkFlagAsRequiredTest(absltest.TestCase):
self.flag_values(argv)
self.assertEqual('value', self.flag_values.string_flag)
+ def test_success_holder(self):
+ holder = _defines.DEFINE_string(
+ 'string_flag', None, 'string flag', flag_values=self.flag_values)
+ _validators.mark_flag_as_required(holder, flag_values=self.flag_values)
+ argv = ('./program', '--string_flag=value')
+ self.flag_values(argv)
+ self.assertEqual('value', self.flag_values.string_flag)
+
+ def test_success_holder_infer_flagvalues(self):
+ holder = _defines.DEFINE_string(
+ 'string_flag', None, 'string flag', flag_values=self.flag_values)
+ _validators.mark_flag_as_required(holder)
+ argv = ('./program', '--string_flag=value')
+ self.flag_values(argv)
+ self.assertEqual('value', self.flag_values.string_flag)
+
def test_catch_none_as_default(self):
_defines.DEFINE_string(
'string_flag', None, 'string flag', flag_values=self.flag_values)
@@ -612,6 +813,18 @@ class MarkFlagAsRequiredTest(absltest.TestCase):
self.assertIn('--flag_not_none has a non-None default value',
str(caught_warnings[0].message))
+ def test_mismatching_flagvalues(self):
+ flag_holder = _defines.DEFINE_string(
+ 'string_flag',
+ 'value',
+ 'string flag',
+ flag_values=_flagvalues.FlagValues())
+ expected = (
+ 'flag_values must not be customized when operating on a FlagHolder')
+ with self.assertRaisesWithLiteralMatch(ValueError, expected):
+ _validators.mark_flag_as_required(
+ flag_holder, flag_values=self.flag_values)
+
class MarkFlagsAsRequiredTest(absltest.TestCase):
@@ -631,6 +844,18 @@ class MarkFlagsAsRequiredTest(absltest.TestCase):
self.assertEqual('value_1', self.flag_values.string_flag_1)
self.assertEqual('value_2', self.flag_values.string_flag_2)
+ def test_success_holders(self):
+ flag_1_holder = _defines.DEFINE_string(
+ 'string_flag_1', None, 'string flag 1', flag_values=self.flag_values)
+ flag_2_holder = _defines.DEFINE_string(
+ 'string_flag_2', None, 'string flag 2', flag_values=self.flag_values)
+ _validators.mark_flags_as_required([flag_1_holder, flag_2_holder],
+ flag_values=self.flag_values)
+ argv = ('./program', '--string_flag_1=value_1', '--string_flag_2=value_2')
+ self.flag_values(argv)
+ self.assertEqual('value_1', self.flag_values.string_flag_1)
+ self.assertEqual('value_2', self.flag_values.string_flag_2)
+
def test_catch_none_as_default(self):
_defines.DEFINE_string(
'string_flag_1', None, 'string flag 1', flag_values=self.flag_values)