aboutsummaryrefslogtreecommitdiff
path: root/absl/flags
diff options
context:
space:
mode:
authorAbseil Team <absl-team@google.com>2021-04-30 15:12:40 -0700
committerCopybara-Service <copybara-worker@google.com>2021-04-30 15:13:07 -0700
commit1b71112b8aa319e2a09dc90d51eefe26e921d0ce (patch)
treef64f29c63ff9d6a9a604b170dc0edf9702856548 /absl/flags
parent5206702470497dd371ae75427ab5e8091652ecbb (diff)
downloadabsl-py-1b71112b8aa319e2a09dc90d51eefe26e921d0ce.tar.gz
Shows all missing required flags in error message, instead of just the first missing required flag.
Given these required flags of float values with default value of None: flags.mark_flags_as_required(["x", "y", "z"]) If none of the required flags are passed as arguments, the resulting error will only show the first missing required flag as an issue: FATAL Flags parsing error: flag --x=None: Flag --x must have a value other than None. The user will then add this required flag "x" and try again, but then get the error for "y" as a missing requirement. This loops until the user finally passes all the required flags as arguments. Since we already know which flags are required, this changes the error message shows all missing required flags at once, so the user can make the necessary changes in a single pass with the error message showing this: FATAL Flags parsing error: flag --x=None: Flag --x must have a value other than None. flag --y=None: Flag --y must have a value other than None. flag --z=None: Flag --z must have a value other than None. To achieve the formatting change, `app.parse_flags_with_usage` now puts the error message on a new line and adds indentation if the message is multi-line string. PiperOrigin-RevId: 371414012 Change-Id: Id7456e9c293fb95d7b4551fd28441610af9b3030
Diffstat (limited to 'absl/flags')
-rw-r--r--absl/flags/BUILD13
-rw-r--r--absl/flags/_flagvalues.py17
-rw-r--r--absl/flags/_validators.py156
-rw-r--r--absl/flags/_validators_classes.py176
-rw-r--r--absl/flags/tests/_validators_test.py142
5 files changed, 306 insertions, 198 deletions
diff --git a/absl/flags/BUILD b/absl/flags/BUILD
index 50bdb00..340ae93 100644
--- a/absl/flags/BUILD
+++ b/absl/flags/BUILD
@@ -83,6 +83,7 @@ py_library(
":_exceptions",
":_flag",
":_helpers",
+ ":_validators_classes",
"@six_archive//:six",
],
)
@@ -103,6 +104,18 @@ py_library(
deps = [
":_exceptions",
":_flagvalues",
+ ":_validators_classes",
+ ],
+)
+
+py_library(
+ name = "_validators_classes",
+ srcs = [
+ "_validators_classes.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":_exceptions",
],
)
diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py
index 2a13927..8b9d662 100644
--- a/absl/flags/_flagvalues.py
+++ b/absl/flags/_flagvalues.py
@@ -31,6 +31,7 @@ from xml.dom import minidom
from absl.flags import _exceptions
from absl.flags import _flag
from absl.flags import _helpers
+from absl.flags import _validators_classes
import six
# pylint: disable=unused-import
@@ -544,13 +545,27 @@ class FlagValues(object):
IllegalFlagValueError: Raised if validation fails for at least one
validator.
"""
+ messages = []
+ bad_flags = set()
for validator in sorted(
validators, key=lambda validator: validator.insertion_index):
try:
+ if isinstance(validator, _validators_classes.SingleFlagValidator):
+ if validator.flag_name in bad_flags:
+ continue
+ elif isinstance(validator, _validators_classes.MultiFlagsValidator):
+ if bad_flags & set(validator.flag_names):
+ continue
validator.verify(self)
except _exceptions.ValidationError as e:
+ if isinstance(validator, _validators_classes.SingleFlagValidator):
+ bad_flags.add(validator.flag_name)
+ elif isinstance(validator, _validators_classes.MultiFlagsValidator):
+ bad_flags.update(set(validator.flag_names))
message = validator.print_flags_with_values(self)
- raise _exceptions.IllegalFlagValueError('%s: %s' % (message, str(e)))
+ messages.append('%s: %s' % (message, str(e)))
+ if messages:
+ raise _exceptions.IllegalFlagValueError('\n'.join(messages))
def __delattr__(self, flag_name):
"""Deletes a previously-defined flag from a flag object.
diff --git a/absl/flags/_validators.py b/absl/flags/_validators.py
index fdb6ea2..47cc430 100644
--- a/absl/flags/_validators.py
+++ b/absl/flags/_validators.py
@@ -40,157 +40,7 @@ import warnings
from absl.flags import _exceptions
from absl.flags import _flagvalues
-
-
-class Validator(object):
- """Base class for flags validators.
-
- Users should NOT overload these classes, and use flags.Register...
- methods instead.
- """
-
- # Used to assign each validator an unique insertion_index
- validators_count = 0
-
- def __init__(self, checker, message):
- """Constructor to create all validators.
-
- Args:
- checker: function to verify the constraint.
- Input of this method varies, see SingleFlagValidator and
- multi_flags_validator for a detailed description.
- message: str, error message to be shown to the user.
- """
- self.checker = checker
- self.message = message
- Validator.validators_count += 1
- # Used to assert validators in the order they were registered.
- self.insertion_index = Validator.validators_count
-
- def verify(self, flag_values):
- """Verifies that constraint is satisfied.
-
- flags library calls this method to verify Validator's constraint.
-
- Args:
- flag_values: flags.FlagValues, the FlagValues instance to get flags from.
- Raises:
- Error: Raised if constraint is not satisfied.
- """
- param = self._get_input_to_checker_function(flag_values)
- if not self.checker(param):
- raise _exceptions.ValidationError(self.message)
-
- def get_flags_names(self):
- """Returns the names of the flags checked by this validator.
-
- Returns:
- [string], names of the flags.
- """
- raise NotImplementedError('This method should be overloaded')
-
- def print_flags_with_values(self, flag_values):
- raise NotImplementedError('This method should be overloaded')
-
- def _get_input_to_checker_function(self, flag_values):
- """Given flag values, returns the input to be given to checker.
-
- Args:
- flag_values: flags.FlagValues, containing all flags.
- Returns:
- The input to be given to checker. The return type depends on the specific
- validator.
- """
- raise NotImplementedError('This method should be overloaded')
-
-
-class SingleFlagValidator(Validator):
- """Validator behind register_validator() method.
-
- Validates that a single flag passes its checker function. The checker function
- takes the flag value and returns True (if value looks fine) or, if flag value
- is not valid, either returns False or raises an Exception.
- """
-
- def __init__(self, flag_name, checker, message):
- """Constructor.
-
- Args:
- flag_name: string, name of the flag.
- checker: function to verify the validator.
- input - value of the corresponding flag (string, boolean, etc).
- output - bool, True if validator constraint is satisfied.
- If constraint is not satisfied, it should either return False or
- raise flags.ValidationError(desired_error_message).
- message: str, error message to be shown to the user if validator's
- condition is not satisfied.
- """
- super(SingleFlagValidator, self).__init__(checker, message)
- self.flag_name = flag_name
-
- def get_flags_names(self):
- return [self.flag_name]
-
- def print_flags_with_values(self, flag_values):
- return 'flag --%s=%s' % (self.flag_name, flag_values[self.flag_name].value)
-
- def _get_input_to_checker_function(self, flag_values):
- """Given flag values, returns the input to be given to checker.
-
- Args:
- flag_values: flags.FlagValues, the FlagValues instance to get flags from.
- Returns:
- object, the input to be given to checker.
- """
- return flag_values[self.flag_name].value
-
-
-class MultiFlagsValidator(Validator):
- """Validator behind register_multi_flags_validator method.
-
- Validates that flag values pass their common checker function. The checker
- function takes flag values and returns True (if values look fine) or,
- if values are not valid, either returns False or raises an Exception.
- """
-
- def __init__(self, flag_names, checker, message):
- """Constructor.
-
- Args:
- flag_names: [str], containing names of the flags used by checker.
- checker: function to verify the validator.
- input - dict, with keys() being flag_names, and value for each
- key being the value of the corresponding flag (string, boolean,
- etc).
- output - bool, True if validator constraint is satisfied.
- If constraint is not satisfied, it should either return False or
- raise flags.ValidationError(desired_error_message).
- message: str, error message to be shown to the user if validator's
- condition is not satisfied
- """
- super(MultiFlagsValidator, self).__init__(checker, message)
- self.flag_names = flag_names
-
- def _get_input_to_checker_function(self, flag_values):
- """Given flag values, returns the input to be given to checker.
-
- Args:
- flag_values: flags.FlagValues, the FlagValues instance to get flags from.
- Returns:
- dict, with keys() being self.lag_names, and value for each key
- being the value of the corresponding flag (string, boolean, etc).
- """
- return dict([key, flag_values[key].value] for key in self.flag_names)
-
- def print_flags_with_values(self, flag_values):
- prefix = 'flags '
- flags_with_values = []
- for key in self.flag_names:
- flags_with_values.append('%s=%s' % (key, flag_values[key].value))
- return prefix + ', '.join(flags_with_values)
-
- def get_flags_names(self):
- return self.flag_names
+from absl.flags import _validators_classes
def register_validator(flag_name,
@@ -219,7 +69,7 @@ def register_validator(flag_name,
AttributeError: Raised when flag_name is not registered as a valid flag
name.
"""
- v = SingleFlagValidator(flag_name, checker, message)
+ v = _validators_classes.SingleFlagValidator(flag_name, checker, message)
_add_validator(flag_values, v)
@@ -283,7 +133,7 @@ def register_multi_flags_validator(flag_names,
Raises:
AttributeError: Raised when a flag is not registered as a valid flag name.
"""
- v = MultiFlagsValidator(
+ v = _validators_classes.MultiFlagsValidator(
flag_names, multi_flags_checker, message)
_add_validator(flag_values, v)
diff --git a/absl/flags/_validators_classes.py b/absl/flags/_validators_classes.py
new file mode 100644
index 0000000..d8996e0
--- /dev/null
+++ b/absl/flags/_validators_classes.py
@@ -0,0 +1,176 @@
+# Copyright 2021 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Defines *private* classes used for flag validators.
+
+Do NOT import this module. DO NOT use anything from this module. They are
+private APIs.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.flags import _exceptions
+
+
+class Validator(object):
+ """Base class for flags validators.
+
+ Users should NOT overload these classes, and use flags.Register...
+ methods instead.
+ """
+
+ # Used to assign each validator an unique insertion_index
+ validators_count = 0
+
+ def __init__(self, checker, message):
+ """Constructor to create all validators.
+
+ Args:
+ checker: function to verify the constraint.
+ Input of this method varies, see SingleFlagValidator and
+ multi_flags_validator for a detailed description.
+ message: str, error message to be shown to the user.
+ """
+ self.checker = checker
+ self.message = message
+ Validator.validators_count += 1
+ # Used to assert validators in the order they were registered.
+ self.insertion_index = Validator.validators_count
+
+ def verify(self, flag_values):
+ """Verifies that constraint is satisfied.
+
+ flags library calls this method to verify Validator's constraint.
+
+ Args:
+ flag_values: flags.FlagValues, the FlagValues instance to get flags from.
+ Raises:
+ Error: Raised if constraint is not satisfied.
+ """
+ param = self._get_input_to_checker_function(flag_values)
+ if not self.checker(param):
+ raise _exceptions.ValidationError(self.message)
+
+ def get_flags_names(self):
+ """Returns the names of the flags checked by this validator.
+
+ Returns:
+ [string], names of the flags.
+ """
+ raise NotImplementedError('This method should be overloaded')
+
+ def print_flags_with_values(self, flag_values):
+ raise NotImplementedError('This method should be overloaded')
+
+ def _get_input_to_checker_function(self, flag_values):
+ """Given flag values, returns the input to be given to checker.
+
+ Args:
+ flag_values: flags.FlagValues, containing all flags.
+ Returns:
+ The input to be given to checker. The return type depends on the specific
+ validator.
+ """
+ raise NotImplementedError('This method should be overloaded')
+
+
+class SingleFlagValidator(Validator):
+ """Validator behind register_validator() method.
+
+ Validates that a single flag passes its checker function. The checker function
+ takes the flag value and returns True (if value looks fine) or, if flag value
+ is not valid, either returns False or raises an Exception.
+ """
+
+ def __init__(self, flag_name, checker, message):
+ """Constructor.
+
+ Args:
+ flag_name: string, name of the flag.
+ checker: function to verify the validator.
+ input - value of the corresponding flag (string, boolean, etc).
+ output - bool, True if validator constraint is satisfied.
+ If constraint is not satisfied, it should either return False or
+ raise flags.ValidationError(desired_error_message).
+ message: str, error message to be shown to the user if validator's
+ condition is not satisfied.
+ """
+ super(SingleFlagValidator, self).__init__(checker, message)
+ self.flag_name = flag_name
+
+ def get_flags_names(self):
+ return [self.flag_name]
+
+ def print_flags_with_values(self, flag_values):
+ return 'flag --%s=%s' % (self.flag_name, flag_values[self.flag_name].value)
+
+ def _get_input_to_checker_function(self, flag_values):
+ """Given flag values, returns the input to be given to checker.
+
+ Args:
+ flag_values: flags.FlagValues, the FlagValues instance to get flags from.
+ Returns:
+ object, the input to be given to checker.
+ """
+ return flag_values[self.flag_name].value
+
+
+class MultiFlagsValidator(Validator):
+ """Validator behind register_multi_flags_validator method.
+
+ Validates that flag values pass their common checker function. The checker
+ function takes flag values and returns True (if values look fine) or,
+ if values are not valid, either returns False or raises an Exception.
+ """
+
+ def __init__(self, flag_names, checker, message):
+ """Constructor.
+
+ Args:
+ flag_names: [str], containing names of the flags used by checker.
+ checker: function to verify the validator.
+ input - dict, with keys() being flag_names, and value for each
+ key being the value of the corresponding flag (string, boolean,
+ etc).
+ output - bool, True if validator constraint is satisfied.
+ If constraint is not satisfied, it should either return False or
+ raise flags.ValidationError(desired_error_message).
+ message: str, error message to be shown to the user if validator's
+ condition is not satisfied
+ """
+ super(MultiFlagsValidator, self).__init__(checker, message)
+ self.flag_names = flag_names
+
+ def _get_input_to_checker_function(self, flag_values):
+ """Given flag values, returns the input to be given to checker.
+
+ Args:
+ flag_values: flags.FlagValues, the FlagValues instance to get flags from.
+ Returns:
+ dict, with keys() being self.lag_names, and value for each key
+ being the value of the corresponding flag (string, boolean, etc).
+ """
+ return dict([key, flag_values[key].value] for key in self.flag_names)
+
+ def print_flags_with_values(self, flag_values):
+ prefix = 'flags '
+ flags_with_values = []
+ for key in self.flag_names:
+ flags_with_values.append('%s=%s' % (key, flag_values[key].value))
+ return prefix + ', '.join(flags_with_values)
+
+ def get_flags_names(self):
+ return self.flag_names
diff --git a/absl/flags/tests/_validators_test.py b/absl/flags/tests/_validators_test.py
index a5dec45..f724813 100644
--- a/absl/flags/tests/_validators_test.py
+++ b/absl/flags/tests/_validators_test.py
@@ -54,7 +54,7 @@ class SingleFlagValidatorTest(absltest.TestCase):
argv = ('./program',)
self.flag_values(argv)
- self.assertEqual(None, self.flag_values.test_flag)
+ self.assertIsNone(self.flag_values.test_flag)
self.flag_values.test_flag = 2
self.assertEqual(2, self.flag_values.test_flag)
self.assertEqual([None, 2], self.call_args)
@@ -110,11 +110,9 @@ class SingleFlagValidatorTest(absltest.TestCase):
argv = ('./program', '--test_flag=1')
self.flag_values(argv)
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.test_flag = 2
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flag --test_flag=2: Errors happen', str(e))
+ self.assertEqual('flag --test_flag=2: Errors happen', str(cm.exception))
self.assertEqual([1, 2], self.call_args)
def test_exception_raised_if_checker_raises_exception(self):
@@ -134,11 +132,9 @@ class SingleFlagValidatorTest(absltest.TestCase):
argv = ('./program', '--test_flag=1')
self.flag_values(argv)
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.test_flag = 2
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flag --test_flag=2: Specific message', str(e))
+ self.assertEqual('flag --test_flag=2: Specific message', str(cm.exception))
self.assertEqual([1, 2], self.call_args)
def test_error_message_when_checker_returns_false_on_start(self):
@@ -154,11 +150,9 @@ class SingleFlagValidatorTest(absltest.TestCase):
flag_values=self.flag_values)
argv = ('./program', '--test_flag=1')
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values(argv)
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flag --test_flag=1: Errors happen', str(e))
+ self.assertEqual('flag --test_flag=1: Errors happen', str(cm.exception))
self.assertEqual([1], self.call_args)
def test_error_message_when_checker_raises_exception_on_start(self):
@@ -175,11 +169,9 @@ class SingleFlagValidatorTest(absltest.TestCase):
flag_values=self.flag_values)
argv = ('./program', '--test_flag=1')
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values(argv)
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flag --test_flag=1: Specific message', str(e))
+ self.assertEqual('flag --test_flag=1: Specific message', str(cm.exception))
self.assertEqual([1], self.call_args)
def test_validators_checked_in_order(self):
@@ -222,7 +214,7 @@ class SingleFlagValidatorTest(absltest.TestCase):
argv = ('./program',)
self.flag_values(argv)
- self.assertEqual(None, self.flag_values.test_flag)
+ self.assertIsNone(self.flag_values.test_flag)
self.flag_values.test_flag = 2
self.assertEqual(2, self.flag_values.test_flag)
self.assertEqual([None, 2], self.call_args)
@@ -288,11 +280,9 @@ class MultiFlagsValidatorTest(absltest.TestCase):
argv = ('./program',)
self.flag_values(argv)
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.bar = 1
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flags foo=1, bar=1: Errors happen', str(e))
+ self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception))
self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
self.call_args)
@@ -313,11 +303,9 @@ class MultiFlagsValidatorTest(absltest.TestCase):
argv = ('./program',)
self.flag_values(argv)
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.bar = 1
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flags foo=1, bar=1: Specific message', str(e))
+ self.assertEqual('flags foo=1, bar=1: Specific message', str(cm.exception))
self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
self.call_args)
@@ -332,11 +320,9 @@ class MultiFlagsValidatorTest(absltest.TestCase):
argv = ('./program',)
self.flag_values(argv)
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.bar = 1
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flags foo=1, bar=1: Errors happen', str(e))
+ self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception))
self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
self.call_args)
@@ -373,8 +359,8 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
argv = ('./program',)
self.flag_values(argv)
- self.assertEqual(None, self.flag_values.flag_one)
- self.assertEqual(None, self.flag_values.flag_two)
+ self.assertIsNone(self.flag_values.flag_one)
+ self.assertIsNone(self.flag_values.flag_two)
def test_no_flags_present_required(self):
self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
@@ -452,8 +438,8 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
['multi_flag_one', 'multi_flag_two'], False)
argv = ('./program',)
self.flag_values(argv)
- self.assertEqual(None, self.flag_values.multi_flag_one)
- self.assertEqual(None, self.flag_values.multi_flag_two)
+ self.assertIsNone(self.flag_values.multi_flag_one)
+ self.assertIsNone(self.flag_values.multi_flag_two)
def test_no_multistring_flags_present_required(self):
self._mark_flags_as_mutually_exclusive(
@@ -558,8 +544,8 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase):
self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], False)
argv = ('./program', '--false_1', '--false_2')
expected = (
- 'flags false_1=True, false_2=True: At most one of (false_1, false_2) '
- 'must be True.')
+ 'flags false_1=True, false_2=True: At most one of (false_1, '
+ 'false_2) must be True.')
self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
expected, self.flag_values, argv)
@@ -610,11 +596,9 @@ class MarkFlagAsRequiredTest(absltest.TestCase):
self.assertEqual('value', self.flag_values.string_flag)
expected = ('flag --string_flag=None: Flag --string_flag must have a value '
'other than None.')
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.string_flag = None
- raise AssertionError('Failed to detect non-set required flag.')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual(expected, str(e))
+ self.assertEqual(expected, str(cm.exception))
def test_flag_default_not_none_warning(self):
_defines.DEFINE_string(
@@ -680,11 +664,81 @@ class MarkFlagsAsRequiredTest(absltest.TestCase):
expected = (
'flag --string_flag_1=None: Flag --string_flag_1 must have a value '
'other than None.')
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.string_flag_1 = None
- raise AssertionError('Failed to detect non-set required flag.')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual(expected, str(e))
+ self.assertEqual(expected, str(cm.exception))
+
+ def test_catch_multiple_flags_as_none_at_program_start(self):
+ _defines.DEFINE_float(
+ 'float_flag_1',
+ None,
+ 'string flag 1',
+ flag_values=self.flag_values)
+ _defines.DEFINE_float(
+ 'float_flag_2',
+ None,
+ 'string flag 2',
+ flag_values=self.flag_values)
+ _validators.mark_flags_as_required(
+ ['float_flag_1', 'float_flag_2'], flag_values=self.flag_values)
+ argv = ('./program', '')
+ expected = (
+ 'flag --float_flag_1=None: Flag --float_flag_1 must have a value '
+ 'other than None.\n'
+ 'flag --float_flag_2=None: Flag --float_flag_2 must have a value '
+ 'other than None.')
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values(argv)
+ self.assertEqual(expected, str(cm.exception))
+
+ def test_fail_fast_single_flag_and_skip_remaining_validators(self):
+ def raise_unexpected_error(x):
+ del x
+ raise _exceptions.ValidationError('Should not be raised.')
+ _defines.DEFINE_float(
+ 'flag_1', None, 'flag 1', flag_values=self.flag_values)
+ _defines.DEFINE_float(
+ 'flag_2', 4.2, 'flag 2', flag_values=self.flag_values)
+ _validators.mark_flag_as_required('flag_1', flag_values=self.flag_values)
+ _validators.register_validator(
+ 'flag_1', raise_unexpected_error, flag_values=self.flag_values)
+ _validators.register_multi_flags_validator(['flag_2', 'flag_1'],
+ raise_unexpected_error,
+ flag_values=self.flag_values)
+ argv = ('./program', '')
+ expected = (
+ 'flag --flag_1=None: Flag --flag_1 must have a value other than None.')
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values(argv)
+ self.assertEqual(expected, str(cm.exception))
+
+ def test_fail_fast_multi_flag_and_skip_remaining_validators(self):
+ def raise_expected_error(x):
+ del x
+ raise _exceptions.ValidationError('Expected error.')
+ def raise_unexpected_error(x):
+ del x
+ raise _exceptions.ValidationError('Got unexpected error.')
+ _defines.DEFINE_float(
+ 'flag_1', 5.1, 'flag 1', flag_values=self.flag_values)
+ _defines.DEFINE_float(
+ 'flag_2', 10.0, 'flag 2', flag_values=self.flag_values)
+ _validators.register_multi_flags_validator(['flag_1', 'flag_2'],
+ raise_expected_error,
+ flag_values=self.flag_values)
+ _validators.register_multi_flags_validator(['flag_2', 'flag_1'],
+ raise_unexpected_error,
+ flag_values=self.flag_values)
+ _validators.register_validator(
+ 'flag_1', raise_unexpected_error, flag_values=self.flag_values)
+ _validators.register_validator(
+ 'flag_2', raise_unexpected_error, flag_values=self.flag_values)
+ argv = ('./program', '')
+ expected = ('flags flag_1=5.1, flag_2=10.0: Expected error.')
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values(argv)
+ self.assertEqual(expected, str(cm.exception))
+
if __name__ == '__main__':
absltest.main()