aboutsummaryrefslogtreecommitdiff
path: root/absl/flags/tests/_validators_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'absl/flags/tests/_validators_test.py')
-rw-r--r--absl/flags/tests/_validators_test.py142
1 files changed, 98 insertions, 44 deletions
diff --git a/absl/flags/tests/_validators_test.py b/absl/flags/tests/_validators_test.py
index a5dec45..f724813 100644
--- a/absl/flags/tests/_validators_test.py
+++ b/absl/flags/tests/_validators_test.py
@@ -54,7 +54,7 @@ class SingleFlagValidatorTest(absltest.TestCase):
argv = ('./program',)
self.flag_values(argv)
- self.assertEqual(None, self.flag_values.test_flag)
+ self.assertIsNone(self.flag_values.test_flag)
self.flag_values.test_flag = 2
self.assertEqual(2, self.flag_values.test_flag)
self.assertEqual([None, 2], self.call_args)
@@ -110,11 +110,9 @@ class SingleFlagValidatorTest(absltest.TestCase):
argv = ('./program', '--test_flag=1')
self.flag_values(argv)
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.test_flag = 2
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flag --test_flag=2: Errors happen', str(e))
+ self.assertEqual('flag --test_flag=2: Errors happen', str(cm.exception))
self.assertEqual([1, 2], self.call_args)
def test_exception_raised_if_checker_raises_exception(self):
@@ -134,11 +132,9 @@ class SingleFlagValidatorTest(absltest.TestCase):
argv = ('./program', '--test_flag=1')
self.flag_values(argv)
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.test_flag = 2
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flag --test_flag=2: Specific message', str(e))
+ self.assertEqual('flag --test_flag=2: Specific message', str(cm.exception))
self.assertEqual([1, 2], self.call_args)
def test_error_message_when_checker_returns_false_on_start(self):
@@ -154,11 +150,9 @@ class SingleFlagValidatorTest(absltest.TestCase):
flag_values=self.flag_values)
argv = ('./program', '--test_flag=1')
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values(argv)
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flag --test_flag=1: Errors happen', str(e))
+ self.assertEqual('flag --test_flag=1: Errors happen', str(cm.exception))
self.assertEqual([1], self.call_args)
def test_error_message_when_checker_raises_exception_on_start(self):
@@ -175,11 +169,9 @@ class SingleFlagValidatorTest(absltest.TestCase):
flag_values=self.flag_values)
argv = ('./program', '--test_flag=1')
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values(argv)
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flag --test_flag=1: Specific message', str(e))
+ self.assertEqual('flag --test_flag=1: Specific message', str(cm.exception))
self.assertEqual([1], self.call_args)
def test_validators_checked_in_order(self):
@@ -222,7 +214,7 @@ class SingleFlagValidatorTest(absltest.TestCase):
argv = ('./program',)
self.flag_values(argv)
- self.assertEqual(None, self.flag_values.test_flag)
+ self.assertIsNone(self.flag_values.test_flag)
self.flag_values.test_flag = 2
self.assertEqual(2, self.flag_values.test_flag)
self.assertEqual([None, 2], self.call_args)
@@ -288,11 +280,9 @@ class MultiFlagsValidatorTest(absltest.TestCase):
argv = ('./program',)
self.flag_values(argv)
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.bar = 1
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flags foo=1, bar=1: Errors happen', str(e))
+ self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception))
self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
self.call_args)
@@ -313,11 +303,9 @@ class MultiFlagsValidatorTest(absltest.TestCase):
argv = ('./program',)
self.flag_values(argv)
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.bar = 1
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flags foo=1, bar=1: Specific message', str(e))
+ self.assertEqual('flags foo=1, bar=1: Specific message', str(cm.exception))
self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
self.call_args)
@@ -332,11 +320,9 @@ class MultiFlagsValidatorTest(absltest.TestCase):
argv = ('./program',)
self.flag_values(argv)
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.bar = 1
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flags foo=1, bar=1: Errors happen', str(e))
+ self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception))
self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
self.call_args)
@@ -373,8 +359,8 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
argv = ('./program',)
self.flag_values(argv)
- self.assertEqual(None, self.flag_values.flag_one)
- self.assertEqual(None, self.flag_values.flag_two)
+ self.assertIsNone(self.flag_values.flag_one)
+ self.assertIsNone(self.flag_values.flag_two)
def test_no_flags_present_required(self):
self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
@@ -452,8 +438,8 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
['multi_flag_one', 'multi_flag_two'], False)
argv = ('./program',)
self.flag_values(argv)
- self.assertEqual(None, self.flag_values.multi_flag_one)
- self.assertEqual(None, self.flag_values.multi_flag_two)
+ self.assertIsNone(self.flag_values.multi_flag_one)
+ self.assertIsNone(self.flag_values.multi_flag_two)
def test_no_multistring_flags_present_required(self):
self._mark_flags_as_mutually_exclusive(
@@ -558,8 +544,8 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase):
self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], False)
argv = ('./program', '--false_1', '--false_2')
expected = (
- 'flags false_1=True, false_2=True: At most one of (false_1, false_2) '
- 'must be True.')
+ 'flags false_1=True, false_2=True: At most one of (false_1, '
+ 'false_2) must be True.')
self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
expected, self.flag_values, argv)
@@ -610,11 +596,9 @@ class MarkFlagAsRequiredTest(absltest.TestCase):
self.assertEqual('value', self.flag_values.string_flag)
expected = ('flag --string_flag=None: Flag --string_flag must have a value '
'other than None.')
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.string_flag = None
- raise AssertionError('Failed to detect non-set required flag.')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual(expected, str(e))
+ self.assertEqual(expected, str(cm.exception))
def test_flag_default_not_none_warning(self):
_defines.DEFINE_string(
@@ -680,11 +664,81 @@ class MarkFlagsAsRequiredTest(absltest.TestCase):
expected = (
'flag --string_flag_1=None: Flag --string_flag_1 must have a value '
'other than None.')
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.string_flag_1 = None
- raise AssertionError('Failed to detect non-set required flag.')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual(expected, str(e))
+ self.assertEqual(expected, str(cm.exception))
+
+ def test_catch_multiple_flags_as_none_at_program_start(self):
+ _defines.DEFINE_float(
+ 'float_flag_1',
+ None,
+ 'string flag 1',
+ flag_values=self.flag_values)
+ _defines.DEFINE_float(
+ 'float_flag_2',
+ None,
+ 'string flag 2',
+ flag_values=self.flag_values)
+ _validators.mark_flags_as_required(
+ ['float_flag_1', 'float_flag_2'], flag_values=self.flag_values)
+ argv = ('./program', '')
+ expected = (
+ 'flag --float_flag_1=None: Flag --float_flag_1 must have a value '
+ 'other than None.\n'
+ 'flag --float_flag_2=None: Flag --float_flag_2 must have a value '
+ 'other than None.')
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values(argv)
+ self.assertEqual(expected, str(cm.exception))
+
+ def test_fail_fast_single_flag_and_skip_remaining_validators(self):
+ def raise_unexpected_error(x):
+ del x
+ raise _exceptions.ValidationError('Should not be raised.')
+ _defines.DEFINE_float(
+ 'flag_1', None, 'flag 1', flag_values=self.flag_values)
+ _defines.DEFINE_float(
+ 'flag_2', 4.2, 'flag 2', flag_values=self.flag_values)
+ _validators.mark_flag_as_required('flag_1', flag_values=self.flag_values)
+ _validators.register_validator(
+ 'flag_1', raise_unexpected_error, flag_values=self.flag_values)
+ _validators.register_multi_flags_validator(['flag_2', 'flag_1'],
+ raise_unexpected_error,
+ flag_values=self.flag_values)
+ argv = ('./program', '')
+ expected = (
+ 'flag --flag_1=None: Flag --flag_1 must have a value other than None.')
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values(argv)
+ self.assertEqual(expected, str(cm.exception))
+
+ def test_fail_fast_multi_flag_and_skip_remaining_validators(self):
+ def raise_expected_error(x):
+ del x
+ raise _exceptions.ValidationError('Expected error.')
+ def raise_unexpected_error(x):
+ del x
+ raise _exceptions.ValidationError('Got unexpected error.')
+ _defines.DEFINE_float(
+ 'flag_1', 5.1, 'flag 1', flag_values=self.flag_values)
+ _defines.DEFINE_float(
+ 'flag_2', 10.0, 'flag 2', flag_values=self.flag_values)
+ _validators.register_multi_flags_validator(['flag_1', 'flag_2'],
+ raise_expected_error,
+ flag_values=self.flag_values)
+ _validators.register_multi_flags_validator(['flag_2', 'flag_1'],
+ raise_unexpected_error,
+ flag_values=self.flag_values)
+ _validators.register_validator(
+ 'flag_1', raise_unexpected_error, flag_values=self.flag_values)
+ _validators.register_validator(
+ 'flag_2', raise_unexpected_error, flag_values=self.flag_values)
+ argv = ('./program', '')
+ expected = ('flags flag_1=5.1, flag_2=10.0: Expected error.')
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values(argv)
+ self.assertEqual(expected, str(cm.exception))
+
if __name__ == '__main__':
absltest.main()