diff options
author | Yilei "Dolee" Yang <yileiyang@google.com> | 2021-06-14 11:55:03 -1000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-06-14 11:55:03 -1000 |
commit | 4d8ea27b446c943a10a3e9ec60c9a7a2980508ba (patch) | |
tree | ac44eb1e2afa7ec7fc47c0673e7185fb2239cf64 /absl/flags | |
parent | 9954557f9df0b346a57ff82688438c55202d2188 (diff) | |
parent | dd66afb1011d54814d7b527a28592fb09fa8566e (diff) | |
download | absl-py-4d8ea27b446c943a10a3e9ec60c9a7a2980508ba.tar.gz |
Merge pull request #168 from yilei/push_up_to_379349640upstream/pypi-v0.13.0
Push up to 379349640
Diffstat (limited to 'absl/flags')
-rw-r--r-- | absl/flags/BUILD | 13 | ||||
-rw-r--r-- | absl/flags/_argument_parser.pyi | 3 | ||||
-rw-r--r-- | absl/flags/_flag.pyi | 2 | ||||
-rw-r--r-- | absl/flags/_flagvalues.py | 17 | ||||
-rw-r--r-- | absl/flags/_validators.py | 156 | ||||
-rw-r--r-- | absl/flags/_validators_classes.py | 176 | ||||
-rw-r--r-- | absl/flags/argparse_flags.py | 1 | ||||
-rw-r--r-- | absl/flags/tests/_validators_test.py | 142 |
8 files changed, 311 insertions, 199 deletions
diff --git a/absl/flags/BUILD b/absl/flags/BUILD index 50bdb00..340ae93 100644 --- a/absl/flags/BUILD +++ b/absl/flags/BUILD @@ -83,6 +83,7 @@ py_library( ":_exceptions", ":_flag", ":_helpers", + ":_validators_classes", "@six_archive//:six", ], ) @@ -103,6 +104,18 @@ py_library( deps = [ ":_exceptions", ":_flagvalues", + ":_validators_classes", + ], +) + +py_library( + name = "_validators_classes", + srcs = [ + "_validators_classes.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":_exceptions", ], ) diff --git a/absl/flags/_argument_parser.pyi b/absl/flags/_argument_parser.pyi index db5f41b..7e78d7d 100644 --- a/absl/flags/_argument_parser.pyi +++ b/absl/flags/_argument_parser.pyi @@ -30,6 +30,9 @@ class ArgumentSerializer(Generic[_T]): # The metaclass of ArgumentParser is not reflected here, because it does not # affect the provided API. class ArgumentParser(Generic[_T]): + + syntactic_help: Text + def parse(self, argument: Text) -> Optional[_T]: ... def flag_type(self) -> Text: ... diff --git a/absl/flags/_flag.pyi b/absl/flags/_flag.pyi index 2e7400b..f28bbf3 100644 --- a/absl/flags/_flag.pyi +++ b/absl/flags/_flag.pyi @@ -32,6 +32,8 @@ class Flag(Generic[_T]): name = ... # type: Text default = ... # type: Any + default_unparsed = ... # type: Any + default_as_str = ... # type: Optional[Text] help = ... # type: Text short_name = ... # type: Text boolean = ... # type: bool diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py index 2a13927..8b9d662 100644 --- a/absl/flags/_flagvalues.py +++ b/absl/flags/_flagvalues.py @@ -31,6 +31,7 @@ from xml.dom import minidom from absl.flags import _exceptions from absl.flags import _flag from absl.flags import _helpers +from absl.flags import _validators_classes import six # pylint: disable=unused-import @@ -544,13 +545,27 @@ class FlagValues(object): IllegalFlagValueError: Raised if validation fails for at least one validator. """ + messages = [] + bad_flags = set() for validator in sorted( validators, key=lambda validator: validator.insertion_index): try: + if isinstance(validator, _validators_classes.SingleFlagValidator): + if validator.flag_name in bad_flags: + continue + elif isinstance(validator, _validators_classes.MultiFlagsValidator): + if bad_flags & set(validator.flag_names): + continue validator.verify(self) except _exceptions.ValidationError as e: + if isinstance(validator, _validators_classes.SingleFlagValidator): + bad_flags.add(validator.flag_name) + elif isinstance(validator, _validators_classes.MultiFlagsValidator): + bad_flags.update(set(validator.flag_names)) message = validator.print_flags_with_values(self) - raise _exceptions.IllegalFlagValueError('%s: %s' % (message, str(e))) + messages.append('%s: %s' % (message, str(e))) + if messages: + raise _exceptions.IllegalFlagValueError('\n'.join(messages)) def __delattr__(self, flag_name): """Deletes a previously-defined flag from a flag object. diff --git a/absl/flags/_validators.py b/absl/flags/_validators.py index fdb6ea2..47cc430 100644 --- a/absl/flags/_validators.py +++ b/absl/flags/_validators.py @@ -40,157 +40,7 @@ import warnings from absl.flags import _exceptions from absl.flags import _flagvalues - - -class Validator(object): - """Base class for flags validators. - - Users should NOT overload these classes, and use flags.Register... - methods instead. - """ - - # Used to assign each validator an unique insertion_index - validators_count = 0 - - def __init__(self, checker, message): - """Constructor to create all validators. - - Args: - checker: function to verify the constraint. - Input of this method varies, see SingleFlagValidator and - multi_flags_validator for a detailed description. - message: str, error message to be shown to the user. - """ - self.checker = checker - self.message = message - Validator.validators_count += 1 - # Used to assert validators in the order they were registered. - self.insertion_index = Validator.validators_count - - def verify(self, flag_values): - """Verifies that constraint is satisfied. - - flags library calls this method to verify Validator's constraint. - - Args: - flag_values: flags.FlagValues, the FlagValues instance to get flags from. - Raises: - Error: Raised if constraint is not satisfied. - """ - param = self._get_input_to_checker_function(flag_values) - if not self.checker(param): - raise _exceptions.ValidationError(self.message) - - def get_flags_names(self): - """Returns the names of the flags checked by this validator. - - Returns: - [string], names of the flags. - """ - raise NotImplementedError('This method should be overloaded') - - def print_flags_with_values(self, flag_values): - raise NotImplementedError('This method should be overloaded') - - def _get_input_to_checker_function(self, flag_values): - """Given flag values, returns the input to be given to checker. - - Args: - flag_values: flags.FlagValues, containing all flags. - Returns: - The input to be given to checker. The return type depends on the specific - validator. - """ - raise NotImplementedError('This method should be overloaded') - - -class SingleFlagValidator(Validator): - """Validator behind register_validator() method. - - Validates that a single flag passes its checker function. The checker function - takes the flag value and returns True (if value looks fine) or, if flag value - is not valid, either returns False or raises an Exception. - """ - - def __init__(self, flag_name, checker, message): - """Constructor. - - Args: - flag_name: string, name of the flag. - checker: function to verify the validator. - input - value of the corresponding flag (string, boolean, etc). - output - bool, True if validator constraint is satisfied. - If constraint is not satisfied, it should either return False or - raise flags.ValidationError(desired_error_message). - message: str, error message to be shown to the user if validator's - condition is not satisfied. - """ - super(SingleFlagValidator, self).__init__(checker, message) - self.flag_name = flag_name - - def get_flags_names(self): - return [self.flag_name] - - def print_flags_with_values(self, flag_values): - return 'flag --%s=%s' % (self.flag_name, flag_values[self.flag_name].value) - - def _get_input_to_checker_function(self, flag_values): - """Given flag values, returns the input to be given to checker. - - Args: - flag_values: flags.FlagValues, the FlagValues instance to get flags from. - Returns: - object, the input to be given to checker. - """ - return flag_values[self.flag_name].value - - -class MultiFlagsValidator(Validator): - """Validator behind register_multi_flags_validator method. - - Validates that flag values pass their common checker function. The checker - function takes flag values and returns True (if values look fine) or, - if values are not valid, either returns False or raises an Exception. - """ - - def __init__(self, flag_names, checker, message): - """Constructor. - - Args: - flag_names: [str], containing names of the flags used by checker. - checker: function to verify the validator. - input - dict, with keys() being flag_names, and value for each - key being the value of the corresponding flag (string, boolean, - etc). - output - bool, True if validator constraint is satisfied. - If constraint is not satisfied, it should either return False or - raise flags.ValidationError(desired_error_message). - message: str, error message to be shown to the user if validator's - condition is not satisfied - """ - super(MultiFlagsValidator, self).__init__(checker, message) - self.flag_names = flag_names - - def _get_input_to_checker_function(self, flag_values): - """Given flag values, returns the input to be given to checker. - - Args: - flag_values: flags.FlagValues, the FlagValues instance to get flags from. - Returns: - dict, with keys() being self.lag_names, and value for each key - being the value of the corresponding flag (string, boolean, etc). - """ - return dict([key, flag_values[key].value] for key in self.flag_names) - - def print_flags_with_values(self, flag_values): - prefix = 'flags ' - flags_with_values = [] - for key in self.flag_names: - flags_with_values.append('%s=%s' % (key, flag_values[key].value)) - return prefix + ', '.join(flags_with_values) - - def get_flags_names(self): - return self.flag_names +from absl.flags import _validators_classes def register_validator(flag_name, @@ -219,7 +69,7 @@ def register_validator(flag_name, AttributeError: Raised when flag_name is not registered as a valid flag name. """ - v = SingleFlagValidator(flag_name, checker, message) + v = _validators_classes.SingleFlagValidator(flag_name, checker, message) _add_validator(flag_values, v) @@ -283,7 +133,7 @@ def register_multi_flags_validator(flag_names, Raises: AttributeError: Raised when a flag is not registered as a valid flag name. """ - v = MultiFlagsValidator( + v = _validators_classes.MultiFlagsValidator( flag_names, multi_flags_checker, message) _add_validator(flag_values, v) diff --git a/absl/flags/_validators_classes.py b/absl/flags/_validators_classes.py new file mode 100644 index 0000000..d8996e0 --- /dev/null +++ b/absl/flags/_validators_classes.py @@ -0,0 +1,176 @@ +# Copyright 2021 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines *private* classes used for flag validators. + +Do NOT import this module. DO NOT use anything from this module. They are +private APIs. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.flags import _exceptions + + +class Validator(object): + """Base class for flags validators. + + Users should NOT overload these classes, and use flags.Register... + methods instead. + """ + + # Used to assign each validator an unique insertion_index + validators_count = 0 + + def __init__(self, checker, message): + """Constructor to create all validators. + + Args: + checker: function to verify the constraint. + Input of this method varies, see SingleFlagValidator and + multi_flags_validator for a detailed description. + message: str, error message to be shown to the user. + """ + self.checker = checker + self.message = message + Validator.validators_count += 1 + # Used to assert validators in the order they were registered. + self.insertion_index = Validator.validators_count + + def verify(self, flag_values): + """Verifies that constraint is satisfied. + + flags library calls this method to verify Validator's constraint. + + Args: + flag_values: flags.FlagValues, the FlagValues instance to get flags from. + Raises: + Error: Raised if constraint is not satisfied. + """ + param = self._get_input_to_checker_function(flag_values) + if not self.checker(param): + raise _exceptions.ValidationError(self.message) + + def get_flags_names(self): + """Returns the names of the flags checked by this validator. + + Returns: + [string], names of the flags. + """ + raise NotImplementedError('This method should be overloaded') + + def print_flags_with_values(self, flag_values): + raise NotImplementedError('This method should be overloaded') + + def _get_input_to_checker_function(self, flag_values): + """Given flag values, returns the input to be given to checker. + + Args: + flag_values: flags.FlagValues, containing all flags. + Returns: + The input to be given to checker. The return type depends on the specific + validator. + """ + raise NotImplementedError('This method should be overloaded') + + +class SingleFlagValidator(Validator): + """Validator behind register_validator() method. + + Validates that a single flag passes its checker function. The checker function + takes the flag value and returns True (if value looks fine) or, if flag value + is not valid, either returns False or raises an Exception. + """ + + def __init__(self, flag_name, checker, message): + """Constructor. + + Args: + flag_name: string, name of the flag. + checker: function to verify the validator. + input - value of the corresponding flag (string, boolean, etc). + output - bool, True if validator constraint is satisfied. + If constraint is not satisfied, it should either return False or + raise flags.ValidationError(desired_error_message). + message: str, error message to be shown to the user if validator's + condition is not satisfied. + """ + super(SingleFlagValidator, self).__init__(checker, message) + self.flag_name = flag_name + + def get_flags_names(self): + return [self.flag_name] + + def print_flags_with_values(self, flag_values): + return 'flag --%s=%s' % (self.flag_name, flag_values[self.flag_name].value) + + def _get_input_to_checker_function(self, flag_values): + """Given flag values, returns the input to be given to checker. + + Args: + flag_values: flags.FlagValues, the FlagValues instance to get flags from. + Returns: + object, the input to be given to checker. + """ + return flag_values[self.flag_name].value + + +class MultiFlagsValidator(Validator): + """Validator behind register_multi_flags_validator method. + + Validates that flag values pass their common checker function. The checker + function takes flag values and returns True (if values look fine) or, + if values are not valid, either returns False or raises an Exception. + """ + + def __init__(self, flag_names, checker, message): + """Constructor. + + Args: + flag_names: [str], containing names of the flags used by checker. + checker: function to verify the validator. + input - dict, with keys() being flag_names, and value for each + key being the value of the corresponding flag (string, boolean, + etc). + output - bool, True if validator constraint is satisfied. + If constraint is not satisfied, it should either return False or + raise flags.ValidationError(desired_error_message). + message: str, error message to be shown to the user if validator's + condition is not satisfied + """ + super(MultiFlagsValidator, self).__init__(checker, message) + self.flag_names = flag_names + + def _get_input_to_checker_function(self, flag_values): + """Given flag values, returns the input to be given to checker. + + Args: + flag_values: flags.FlagValues, the FlagValues instance to get flags from. + Returns: + dict, with keys() being self.lag_names, and value for each key + being the value of the corresponding flag (string, boolean, etc). + """ + return dict([key, flag_values[key].value] for key in self.flag_names) + + def print_flags_with_values(self, flag_values): + prefix = 'flags ' + flags_with_values = [] + for key in self.flag_names: + flags_with_values.append('%s=%s' % (key, flag_values[key].value)) + return prefix + ', '.join(flags_with_values) + + def get_flags_names(self): + return self.flag_names diff --git a/absl/flags/argparse_flags.py b/absl/flags/argparse_flags.py index bcd40f5..ea4cdfa 100644 --- a/absl/flags/argparse_flags.py +++ b/absl/flags/argparse_flags.py @@ -18,7 +18,6 @@ argparse_flags.ArgumentParser is a drop-in replacement for argparse.ArgumentParser. It takes care of collecting and defining absl flags in argparse. - Here is a simple example: # Assume the following absl.flags is defined in another module: diff --git a/absl/flags/tests/_validators_test.py b/absl/flags/tests/_validators_test.py index a5dec45..f724813 100644 --- a/absl/flags/tests/_validators_test.py +++ b/absl/flags/tests/_validators_test.py @@ -54,7 +54,7 @@ class SingleFlagValidatorTest(absltest.TestCase): argv = ('./program',) self.flag_values(argv) - self.assertEqual(None, self.flag_values.test_flag) + self.assertIsNone(self.flag_values.test_flag) self.flag_values.test_flag = 2 self.assertEqual(2, self.flag_values.test_flag) self.assertEqual([None, 2], self.call_args) @@ -110,11 +110,9 @@ class SingleFlagValidatorTest(absltest.TestCase): argv = ('./program', '--test_flag=1') self.flag_values(argv) - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.test_flag = 2 - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flag --test_flag=2: Errors happen', str(e)) + self.assertEqual('flag --test_flag=2: Errors happen', str(cm.exception)) self.assertEqual([1, 2], self.call_args) def test_exception_raised_if_checker_raises_exception(self): @@ -134,11 +132,9 @@ class SingleFlagValidatorTest(absltest.TestCase): argv = ('./program', '--test_flag=1') self.flag_values(argv) - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.test_flag = 2 - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flag --test_flag=2: Specific message', str(e)) + self.assertEqual('flag --test_flag=2: Specific message', str(cm.exception)) self.assertEqual([1, 2], self.call_args) def test_error_message_when_checker_returns_false_on_start(self): @@ -154,11 +150,9 @@ class SingleFlagValidatorTest(absltest.TestCase): flag_values=self.flag_values) argv = ('./program', '--test_flag=1') - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values(argv) - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flag --test_flag=1: Errors happen', str(e)) + self.assertEqual('flag --test_flag=1: Errors happen', str(cm.exception)) self.assertEqual([1], self.call_args) def test_error_message_when_checker_raises_exception_on_start(self): @@ -175,11 +169,9 @@ class SingleFlagValidatorTest(absltest.TestCase): flag_values=self.flag_values) argv = ('./program', '--test_flag=1') - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values(argv) - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flag --test_flag=1: Specific message', str(e)) + self.assertEqual('flag --test_flag=1: Specific message', str(cm.exception)) self.assertEqual([1], self.call_args) def test_validators_checked_in_order(self): @@ -222,7 +214,7 @@ class SingleFlagValidatorTest(absltest.TestCase): argv = ('./program',) self.flag_values(argv) - self.assertEqual(None, self.flag_values.test_flag) + self.assertIsNone(self.flag_values.test_flag) self.flag_values.test_flag = 2 self.assertEqual(2, self.flag_values.test_flag) self.assertEqual([None, 2], self.call_args) @@ -288,11 +280,9 @@ class MultiFlagsValidatorTest(absltest.TestCase): argv = ('./program',) self.flag_values(argv) - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.bar = 1 - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flags foo=1, bar=1: Errors happen', str(e)) + self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception)) self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}], self.call_args) @@ -313,11 +303,9 @@ class MultiFlagsValidatorTest(absltest.TestCase): argv = ('./program',) self.flag_values(argv) - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.bar = 1 - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flags foo=1, bar=1: Specific message', str(e)) + self.assertEqual('flags foo=1, bar=1: Specific message', str(cm.exception)) self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}], self.call_args) @@ -332,11 +320,9 @@ class MultiFlagsValidatorTest(absltest.TestCase): argv = ('./program',) self.flag_values(argv) - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.bar = 1 - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flags foo=1, bar=1: Errors happen', str(e)) + self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception)) self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}], self.call_args) @@ -373,8 +359,8 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase): argv = ('./program',) self.flag_values(argv) - self.assertEqual(None, self.flag_values.flag_one) - self.assertEqual(None, self.flag_values.flag_two) + self.assertIsNone(self.flag_values.flag_one) + self.assertIsNone(self.flag_values.flag_two) def test_no_flags_present_required(self): self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True) @@ -452,8 +438,8 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase): ['multi_flag_one', 'multi_flag_two'], False) argv = ('./program',) self.flag_values(argv) - self.assertEqual(None, self.flag_values.multi_flag_one) - self.assertEqual(None, self.flag_values.multi_flag_two) + self.assertIsNone(self.flag_values.multi_flag_one) + self.assertIsNone(self.flag_values.multi_flag_two) def test_no_multistring_flags_present_required(self): self._mark_flags_as_mutually_exclusive( @@ -558,8 +544,8 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase): self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], False) argv = ('./program', '--false_1', '--false_2') expected = ( - 'flags false_1=True, false_2=True: At most one of (false_1, false_2) ' - 'must be True.') + 'flags false_1=True, false_2=True: At most one of (false_1, ' + 'false_2) must be True.') self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError, expected, self.flag_values, argv) @@ -610,11 +596,9 @@ class MarkFlagAsRequiredTest(absltest.TestCase): self.assertEqual('value', self.flag_values.string_flag) expected = ('flag --string_flag=None: Flag --string_flag must have a value ' 'other than None.') - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.string_flag = None - raise AssertionError('Failed to detect non-set required flag.') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual(expected, str(e)) + self.assertEqual(expected, str(cm.exception)) def test_flag_default_not_none_warning(self): _defines.DEFINE_string( @@ -680,11 +664,81 @@ class MarkFlagsAsRequiredTest(absltest.TestCase): expected = ( 'flag --string_flag_1=None: Flag --string_flag_1 must have a value ' 'other than None.') - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.string_flag_1 = None - raise AssertionError('Failed to detect non-set required flag.') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual(expected, str(e)) + self.assertEqual(expected, str(cm.exception)) + + def test_catch_multiple_flags_as_none_at_program_start(self): + _defines.DEFINE_float( + 'float_flag_1', + None, + 'string flag 1', + flag_values=self.flag_values) + _defines.DEFINE_float( + 'float_flag_2', + None, + 'string flag 2', + flag_values=self.flag_values) + _validators.mark_flags_as_required( + ['float_flag_1', 'float_flag_2'], flag_values=self.flag_values) + argv = ('./program', '') + expected = ( + 'flag --float_flag_1=None: Flag --float_flag_1 must have a value ' + 'other than None.\n' + 'flag --float_flag_2=None: Flag --float_flag_2 must have a value ' + 'other than None.') + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values(argv) + self.assertEqual(expected, str(cm.exception)) + + def test_fail_fast_single_flag_and_skip_remaining_validators(self): + def raise_unexpected_error(x): + del x + raise _exceptions.ValidationError('Should not be raised.') + _defines.DEFINE_float( + 'flag_1', None, 'flag 1', flag_values=self.flag_values) + _defines.DEFINE_float( + 'flag_2', 4.2, 'flag 2', flag_values=self.flag_values) + _validators.mark_flag_as_required('flag_1', flag_values=self.flag_values) + _validators.register_validator( + 'flag_1', raise_unexpected_error, flag_values=self.flag_values) + _validators.register_multi_flags_validator(['flag_2', 'flag_1'], + raise_unexpected_error, + flag_values=self.flag_values) + argv = ('./program', '') + expected = ( + 'flag --flag_1=None: Flag --flag_1 must have a value other than None.') + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values(argv) + self.assertEqual(expected, str(cm.exception)) + + def test_fail_fast_multi_flag_and_skip_remaining_validators(self): + def raise_expected_error(x): + del x + raise _exceptions.ValidationError('Expected error.') + def raise_unexpected_error(x): + del x + raise _exceptions.ValidationError('Got unexpected error.') + _defines.DEFINE_float( + 'flag_1', 5.1, 'flag 1', flag_values=self.flag_values) + _defines.DEFINE_float( + 'flag_2', 10.0, 'flag 2', flag_values=self.flag_values) + _validators.register_multi_flags_validator(['flag_1', 'flag_2'], + raise_expected_error, + flag_values=self.flag_values) + _validators.register_multi_flags_validator(['flag_2', 'flag_1'], + raise_unexpected_error, + flag_values=self.flag_values) + _validators.register_validator( + 'flag_1', raise_unexpected_error, flag_values=self.flag_values) + _validators.register_validator( + 'flag_2', raise_unexpected_error, flag_values=self.flag_values) + argv = ('./program', '') + expected = ('flags flag_1=5.1, flag_2=10.0: Expected error.') + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values(argv) + self.assertEqual(expected, str(cm.exception)) + if __name__ == '__main__': absltest.main() |