diff options
Diffstat (limited to 'absl/flags/tests/_validators_test.py')
-rw-r--r-- | absl/flags/tests/_validators_test.py | 142 |
1 files changed, 98 insertions, 44 deletions
diff --git a/absl/flags/tests/_validators_test.py b/absl/flags/tests/_validators_test.py index a5dec45..f724813 100644 --- a/absl/flags/tests/_validators_test.py +++ b/absl/flags/tests/_validators_test.py @@ -54,7 +54,7 @@ class SingleFlagValidatorTest(absltest.TestCase): argv = ('./program',) self.flag_values(argv) - self.assertEqual(None, self.flag_values.test_flag) + self.assertIsNone(self.flag_values.test_flag) self.flag_values.test_flag = 2 self.assertEqual(2, self.flag_values.test_flag) self.assertEqual([None, 2], self.call_args) @@ -110,11 +110,9 @@ class SingleFlagValidatorTest(absltest.TestCase): argv = ('./program', '--test_flag=1') self.flag_values(argv) - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.test_flag = 2 - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flag --test_flag=2: Errors happen', str(e)) + self.assertEqual('flag --test_flag=2: Errors happen', str(cm.exception)) self.assertEqual([1, 2], self.call_args) def test_exception_raised_if_checker_raises_exception(self): @@ -134,11 +132,9 @@ class SingleFlagValidatorTest(absltest.TestCase): argv = ('./program', '--test_flag=1') self.flag_values(argv) - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.test_flag = 2 - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flag --test_flag=2: Specific message', str(e)) + self.assertEqual('flag --test_flag=2: Specific message', str(cm.exception)) self.assertEqual([1, 2], self.call_args) def test_error_message_when_checker_returns_false_on_start(self): @@ -154,11 +150,9 @@ class SingleFlagValidatorTest(absltest.TestCase): flag_values=self.flag_values) argv = ('./program', '--test_flag=1') - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values(argv) - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flag --test_flag=1: Errors happen', str(e)) + self.assertEqual('flag --test_flag=1: Errors happen', str(cm.exception)) self.assertEqual([1], self.call_args) def test_error_message_when_checker_raises_exception_on_start(self): @@ -175,11 +169,9 @@ class SingleFlagValidatorTest(absltest.TestCase): flag_values=self.flag_values) argv = ('./program', '--test_flag=1') - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values(argv) - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flag --test_flag=1: Specific message', str(e)) + self.assertEqual('flag --test_flag=1: Specific message', str(cm.exception)) self.assertEqual([1], self.call_args) def test_validators_checked_in_order(self): @@ -222,7 +214,7 @@ class SingleFlagValidatorTest(absltest.TestCase): argv = ('./program',) self.flag_values(argv) - self.assertEqual(None, self.flag_values.test_flag) + self.assertIsNone(self.flag_values.test_flag) self.flag_values.test_flag = 2 self.assertEqual(2, self.flag_values.test_flag) self.assertEqual([None, 2], self.call_args) @@ -288,11 +280,9 @@ class MultiFlagsValidatorTest(absltest.TestCase): argv = ('./program',) self.flag_values(argv) - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.bar = 1 - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flags foo=1, bar=1: Errors happen', str(e)) + self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception)) self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}], self.call_args) @@ -313,11 +303,9 @@ class MultiFlagsValidatorTest(absltest.TestCase): argv = ('./program',) self.flag_values(argv) - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.bar = 1 - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flags foo=1, bar=1: Specific message', str(e)) + self.assertEqual('flags foo=1, bar=1: Specific message', str(cm.exception)) self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}], self.call_args) @@ -332,11 +320,9 @@ class MultiFlagsValidatorTest(absltest.TestCase): argv = ('./program',) self.flag_values(argv) - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.bar = 1 - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flags foo=1, bar=1: Errors happen', str(e)) + self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception)) self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}], self.call_args) @@ -373,8 +359,8 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase): argv = ('./program',) self.flag_values(argv) - self.assertEqual(None, self.flag_values.flag_one) - self.assertEqual(None, self.flag_values.flag_two) + self.assertIsNone(self.flag_values.flag_one) + self.assertIsNone(self.flag_values.flag_two) def test_no_flags_present_required(self): self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True) @@ -452,8 +438,8 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase): ['multi_flag_one', 'multi_flag_two'], False) argv = ('./program',) self.flag_values(argv) - self.assertEqual(None, self.flag_values.multi_flag_one) - self.assertEqual(None, self.flag_values.multi_flag_two) + self.assertIsNone(self.flag_values.multi_flag_one) + self.assertIsNone(self.flag_values.multi_flag_two) def test_no_multistring_flags_present_required(self): self._mark_flags_as_mutually_exclusive( @@ -558,8 +544,8 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase): self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], False) argv = ('./program', '--false_1', '--false_2') expected = ( - 'flags false_1=True, false_2=True: At most one of (false_1, false_2) ' - 'must be True.') + 'flags false_1=True, false_2=True: At most one of (false_1, ' + 'false_2) must be True.') self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError, expected, self.flag_values, argv) @@ -610,11 +596,9 @@ class MarkFlagAsRequiredTest(absltest.TestCase): self.assertEqual('value', self.flag_values.string_flag) expected = ('flag --string_flag=None: Flag --string_flag must have a value ' 'other than None.') - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.string_flag = None - raise AssertionError('Failed to detect non-set required flag.') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual(expected, str(e)) + self.assertEqual(expected, str(cm.exception)) def test_flag_default_not_none_warning(self): _defines.DEFINE_string( @@ -680,11 +664,81 @@ class MarkFlagsAsRequiredTest(absltest.TestCase): expected = ( 'flag --string_flag_1=None: Flag --string_flag_1 must have a value ' 'other than None.') - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.string_flag_1 = None - raise AssertionError('Failed to detect non-set required flag.') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual(expected, str(e)) + self.assertEqual(expected, str(cm.exception)) + + def test_catch_multiple_flags_as_none_at_program_start(self): + _defines.DEFINE_float( + 'float_flag_1', + None, + 'string flag 1', + flag_values=self.flag_values) + _defines.DEFINE_float( + 'float_flag_2', + None, + 'string flag 2', + flag_values=self.flag_values) + _validators.mark_flags_as_required( + ['float_flag_1', 'float_flag_2'], flag_values=self.flag_values) + argv = ('./program', '') + expected = ( + 'flag --float_flag_1=None: Flag --float_flag_1 must have a value ' + 'other than None.\n' + 'flag --float_flag_2=None: Flag --float_flag_2 must have a value ' + 'other than None.') + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values(argv) + self.assertEqual(expected, str(cm.exception)) + + def test_fail_fast_single_flag_and_skip_remaining_validators(self): + def raise_unexpected_error(x): + del x + raise _exceptions.ValidationError('Should not be raised.') + _defines.DEFINE_float( + 'flag_1', None, 'flag 1', flag_values=self.flag_values) + _defines.DEFINE_float( + 'flag_2', 4.2, 'flag 2', flag_values=self.flag_values) + _validators.mark_flag_as_required('flag_1', flag_values=self.flag_values) + _validators.register_validator( + 'flag_1', raise_unexpected_error, flag_values=self.flag_values) + _validators.register_multi_flags_validator(['flag_2', 'flag_1'], + raise_unexpected_error, + flag_values=self.flag_values) + argv = ('./program', '') + expected = ( + 'flag --flag_1=None: Flag --flag_1 must have a value other than None.') + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values(argv) + self.assertEqual(expected, str(cm.exception)) + + def test_fail_fast_multi_flag_and_skip_remaining_validators(self): + def raise_expected_error(x): + del x + raise _exceptions.ValidationError('Expected error.') + def raise_unexpected_error(x): + del x + raise _exceptions.ValidationError('Got unexpected error.') + _defines.DEFINE_float( + 'flag_1', 5.1, 'flag 1', flag_values=self.flag_values) + _defines.DEFINE_float( + 'flag_2', 10.0, 'flag 2', flag_values=self.flag_values) + _validators.register_multi_flags_validator(['flag_1', 'flag_2'], + raise_expected_error, + flag_values=self.flag_values) + _validators.register_multi_flags_validator(['flag_2', 'flag_1'], + raise_unexpected_error, + flag_values=self.flag_values) + _validators.register_validator( + 'flag_1', raise_unexpected_error, flag_values=self.flag_values) + _validators.register_validator( + 'flag_2', raise_unexpected_error, flag_values=self.flag_values) + argv = ('./program', '') + expected = ('flags flag_1=5.1, flag_2=10.0: Expected error.') + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values(argv) + self.assertEqual(expected, str(cm.exception)) + if __name__ == '__main__': absltest.main() |