aboutsummaryrefslogtreecommitdiff
path: root/absl/flags
diff options
context:
space:
mode:
authorYilei "Dolee" Yang <yileiyang@google.com>2021-06-14 11:55:03 -1000
committerGitHub <noreply@github.com>2021-06-14 11:55:03 -1000
commit4d8ea27b446c943a10a3e9ec60c9a7a2980508ba (patch)
treeac44eb1e2afa7ec7fc47c0673e7185fb2239cf64 /absl/flags
parent9954557f9df0b346a57ff82688438c55202d2188 (diff)
parentdd66afb1011d54814d7b527a28592fb09fa8566e (diff)
downloadabsl-py-4d8ea27b446c943a10a3e9ec60c9a7a2980508ba.tar.gz
Merge pull request #168 from yilei/push_up_to_379349640upstream/pypi-v0.13.0
Push up to 379349640
Diffstat (limited to 'absl/flags')
-rw-r--r--absl/flags/BUILD13
-rw-r--r--absl/flags/_argument_parser.pyi3
-rw-r--r--absl/flags/_flag.pyi2
-rw-r--r--absl/flags/_flagvalues.py17
-rw-r--r--absl/flags/_validators.py156
-rw-r--r--absl/flags/_validators_classes.py176
-rw-r--r--absl/flags/argparse_flags.py1
-rw-r--r--absl/flags/tests/_validators_test.py142
8 files changed, 311 insertions, 199 deletions
diff --git a/absl/flags/BUILD b/absl/flags/BUILD
index 50bdb00..340ae93 100644
--- a/absl/flags/BUILD
+++ b/absl/flags/BUILD
@@ -83,6 +83,7 @@ py_library(
":_exceptions",
":_flag",
":_helpers",
+ ":_validators_classes",
"@six_archive//:six",
],
)
@@ -103,6 +104,18 @@ py_library(
deps = [
":_exceptions",
":_flagvalues",
+ ":_validators_classes",
+ ],
+)
+
+py_library(
+ name = "_validators_classes",
+ srcs = [
+ "_validators_classes.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":_exceptions",
],
)
diff --git a/absl/flags/_argument_parser.pyi b/absl/flags/_argument_parser.pyi
index db5f41b..7e78d7d 100644
--- a/absl/flags/_argument_parser.pyi
+++ b/absl/flags/_argument_parser.pyi
@@ -30,6 +30,9 @@ class ArgumentSerializer(Generic[_T]):
# The metaclass of ArgumentParser is not reflected here, because it does not
# affect the provided API.
class ArgumentParser(Generic[_T]):
+
+ syntactic_help: Text
+
def parse(self, argument: Text) -> Optional[_T]: ...
def flag_type(self) -> Text: ...
diff --git a/absl/flags/_flag.pyi b/absl/flags/_flag.pyi
index 2e7400b..f28bbf3 100644
--- a/absl/flags/_flag.pyi
+++ b/absl/flags/_flag.pyi
@@ -32,6 +32,8 @@ class Flag(Generic[_T]):
name = ... # type: Text
default = ... # type: Any
+ default_unparsed = ... # type: Any
+ default_as_str = ... # type: Optional[Text]
help = ... # type: Text
short_name = ... # type: Text
boolean = ... # type: bool
diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py
index 2a13927..8b9d662 100644
--- a/absl/flags/_flagvalues.py
+++ b/absl/flags/_flagvalues.py
@@ -31,6 +31,7 @@ from xml.dom import minidom
from absl.flags import _exceptions
from absl.flags import _flag
from absl.flags import _helpers
+from absl.flags import _validators_classes
import six
# pylint: disable=unused-import
@@ -544,13 +545,27 @@ class FlagValues(object):
IllegalFlagValueError: Raised if validation fails for at least one
validator.
"""
+ messages = []
+ bad_flags = set()
for validator in sorted(
validators, key=lambda validator: validator.insertion_index):
try:
+ if isinstance(validator, _validators_classes.SingleFlagValidator):
+ if validator.flag_name in bad_flags:
+ continue
+ elif isinstance(validator, _validators_classes.MultiFlagsValidator):
+ if bad_flags & set(validator.flag_names):
+ continue
validator.verify(self)
except _exceptions.ValidationError as e:
+ if isinstance(validator, _validators_classes.SingleFlagValidator):
+ bad_flags.add(validator.flag_name)
+ elif isinstance(validator, _validators_classes.MultiFlagsValidator):
+ bad_flags.update(set(validator.flag_names))
message = validator.print_flags_with_values(self)
- raise _exceptions.IllegalFlagValueError('%s: %s' % (message, str(e)))
+ messages.append('%s: %s' % (message, str(e)))
+ if messages:
+ raise _exceptions.IllegalFlagValueError('\n'.join(messages))
def __delattr__(self, flag_name):
"""Deletes a previously-defined flag from a flag object.
diff --git a/absl/flags/_validators.py b/absl/flags/_validators.py
index fdb6ea2..47cc430 100644
--- a/absl/flags/_validators.py
+++ b/absl/flags/_validators.py
@@ -40,157 +40,7 @@ import warnings
from absl.flags import _exceptions
from absl.flags import _flagvalues
-
-
-class Validator(object):
- """Base class for flags validators.
-
- Users should NOT overload these classes, and use flags.Register...
- methods instead.
- """
-
- # Used to assign each validator an unique insertion_index
- validators_count = 0
-
- def __init__(self, checker, message):
- """Constructor to create all validators.
-
- Args:
- checker: function to verify the constraint.
- Input of this method varies, see SingleFlagValidator and
- multi_flags_validator for a detailed description.
- message: str, error message to be shown to the user.
- """
- self.checker = checker
- self.message = message
- Validator.validators_count += 1
- # Used to assert validators in the order they were registered.
- self.insertion_index = Validator.validators_count
-
- def verify(self, flag_values):
- """Verifies that constraint is satisfied.
-
- flags library calls this method to verify Validator's constraint.
-
- Args:
- flag_values: flags.FlagValues, the FlagValues instance to get flags from.
- Raises:
- Error: Raised if constraint is not satisfied.
- """
- param = self._get_input_to_checker_function(flag_values)
- if not self.checker(param):
- raise _exceptions.ValidationError(self.message)
-
- def get_flags_names(self):
- """Returns the names of the flags checked by this validator.
-
- Returns:
- [string], names of the flags.
- """
- raise NotImplementedError('This method should be overloaded')
-
- def print_flags_with_values(self, flag_values):
- raise NotImplementedError('This method should be overloaded')
-
- def _get_input_to_checker_function(self, flag_values):
- """Given flag values, returns the input to be given to checker.
-
- Args:
- flag_values: flags.FlagValues, containing all flags.
- Returns:
- The input to be given to checker. The return type depends on the specific
- validator.
- """
- raise NotImplementedError('This method should be overloaded')
-
-
-class SingleFlagValidator(Validator):
- """Validator behind register_validator() method.
-
- Validates that a single flag passes its checker function. The checker function
- takes the flag value and returns True (if value looks fine) or, if flag value
- is not valid, either returns False or raises an Exception.
- """
-
- def __init__(self, flag_name, checker, message):
- """Constructor.
-
- Args:
- flag_name: string, name of the flag.
- checker: function to verify the validator.
- input - value of the corresponding flag (string, boolean, etc).
- output - bool, True if validator constraint is satisfied.
- If constraint is not satisfied, it should either return False or
- raise flags.ValidationError(desired_error_message).
- message: str, error message to be shown to the user if validator's
- condition is not satisfied.
- """
- super(SingleFlagValidator, self).__init__(checker, message)
- self.flag_name = flag_name
-
- def get_flags_names(self):
- return [self.flag_name]
-
- def print_flags_with_values(self, flag_values):
- return 'flag --%s=%s' % (self.flag_name, flag_values[self.flag_name].value)
-
- def _get_input_to_checker_function(self, flag_values):
- """Given flag values, returns the input to be given to checker.
-
- Args:
- flag_values: flags.FlagValues, the FlagValues instance to get flags from.
- Returns:
- object, the input to be given to checker.
- """
- return flag_values[self.flag_name].value
-
-
-class MultiFlagsValidator(Validator):
- """Validator behind register_multi_flags_validator method.
-
- Validates that flag values pass their common checker function. The checker
- function takes flag values and returns True (if values look fine) or,
- if values are not valid, either returns False or raises an Exception.
- """
-
- def __init__(self, flag_names, checker, message):
- """Constructor.
-
- Args:
- flag_names: [str], containing names of the flags used by checker.
- checker: function to verify the validator.
- input - dict, with keys() being flag_names, and value for each
- key being the value of the corresponding flag (string, boolean,
- etc).
- output - bool, True if validator constraint is satisfied.
- If constraint is not satisfied, it should either return False or
- raise flags.ValidationError(desired_error_message).
- message: str, error message to be shown to the user if validator's
- condition is not satisfied
- """
- super(MultiFlagsValidator, self).__init__(checker, message)
- self.flag_names = flag_names
-
- def _get_input_to_checker_function(self, flag_values):
- """Given flag values, returns the input to be given to checker.
-
- Args:
- flag_values: flags.FlagValues, the FlagValues instance to get flags from.
- Returns:
- dict, with keys() being self.lag_names, and value for each key
- being the value of the corresponding flag (string, boolean, etc).
- """
- return dict([key, flag_values[key].value] for key in self.flag_names)
-
- def print_flags_with_values(self, flag_values):
- prefix = 'flags '
- flags_with_values = []
- for key in self.flag_names:
- flags_with_values.append('%s=%s' % (key, flag_values[key].value))
- return prefix + ', '.join(flags_with_values)
-
- def get_flags_names(self):
- return self.flag_names
+from absl.flags import _validators_classes
def register_validator(flag_name,
@@ -219,7 +69,7 @@ def register_validator(flag_name,
AttributeError: Raised when flag_name is not registered as a valid flag
name.
"""
- v = SingleFlagValidator(flag_name, checker, message)
+ v = _validators_classes.SingleFlagValidator(flag_name, checker, message)
_add_validator(flag_values, v)
@@ -283,7 +133,7 @@ def register_multi_flags_validator(flag_names,
Raises:
AttributeError: Raised when a flag is not registered as a valid flag name.
"""
- v = MultiFlagsValidator(
+ v = _validators_classes.MultiFlagsValidator(
flag_names, multi_flags_checker, message)
_add_validator(flag_values, v)
diff --git a/absl/flags/_validators_classes.py b/absl/flags/_validators_classes.py
new file mode 100644
index 0000000..d8996e0
--- /dev/null
+++ b/absl/flags/_validators_classes.py
@@ -0,0 +1,176 @@
+# Copyright 2021 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Defines *private* classes used for flag validators.
+
+Do NOT import this module. DO NOT use anything from this module. They are
+private APIs.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.flags import _exceptions
+
+
+class Validator(object):
+ """Base class for flags validators.
+
+ Users should NOT overload these classes, and use flags.Register...
+ methods instead.
+ """
+
+ # Used to assign each validator an unique insertion_index
+ validators_count = 0
+
+ def __init__(self, checker, message):
+ """Constructor to create all validators.
+
+ Args:
+ checker: function to verify the constraint.
+ Input of this method varies, see SingleFlagValidator and
+ multi_flags_validator for a detailed description.
+ message: str, error message to be shown to the user.
+ """
+ self.checker = checker
+ self.message = message
+ Validator.validators_count += 1
+ # Used to assert validators in the order they were registered.
+ self.insertion_index = Validator.validators_count
+
+ def verify(self, flag_values):
+ """Verifies that constraint is satisfied.
+
+ flags library calls this method to verify Validator's constraint.
+
+ Args:
+ flag_values: flags.FlagValues, the FlagValues instance to get flags from.
+ Raises:
+ Error: Raised if constraint is not satisfied.
+ """
+ param = self._get_input_to_checker_function(flag_values)
+ if not self.checker(param):
+ raise _exceptions.ValidationError(self.message)
+
+ def get_flags_names(self):
+ """Returns the names of the flags checked by this validator.
+
+ Returns:
+ [string], names of the flags.
+ """
+ raise NotImplementedError('This method should be overloaded')
+
+ def print_flags_with_values(self, flag_values):
+ raise NotImplementedError('This method should be overloaded')
+
+ def _get_input_to_checker_function(self, flag_values):
+ """Given flag values, returns the input to be given to checker.
+
+ Args:
+ flag_values: flags.FlagValues, containing all flags.
+ Returns:
+ The input to be given to checker. The return type depends on the specific
+ validator.
+ """
+ raise NotImplementedError('This method should be overloaded')
+
+
+class SingleFlagValidator(Validator):
+ """Validator behind register_validator() method.
+
+ Validates that a single flag passes its checker function. The checker function
+ takes the flag value and returns True (if value looks fine) or, if flag value
+ is not valid, either returns False or raises an Exception.
+ """
+
+ def __init__(self, flag_name, checker, message):
+ """Constructor.
+
+ Args:
+ flag_name: string, name of the flag.
+ checker: function to verify the validator.
+ input - value of the corresponding flag (string, boolean, etc).
+ output - bool, True if validator constraint is satisfied.
+ If constraint is not satisfied, it should either return False or
+ raise flags.ValidationError(desired_error_message).
+ message: str, error message to be shown to the user if validator's
+ condition is not satisfied.
+ """
+ super(SingleFlagValidator, self).__init__(checker, message)
+ self.flag_name = flag_name
+
+ def get_flags_names(self):
+ return [self.flag_name]
+
+ def print_flags_with_values(self, flag_values):
+ return 'flag --%s=%s' % (self.flag_name, flag_values[self.flag_name].value)
+
+ def _get_input_to_checker_function(self, flag_values):
+ """Given flag values, returns the input to be given to checker.
+
+ Args:
+ flag_values: flags.FlagValues, the FlagValues instance to get flags from.
+ Returns:
+ object, the input to be given to checker.
+ """
+ return flag_values[self.flag_name].value
+
+
+class MultiFlagsValidator(Validator):
+ """Validator behind register_multi_flags_validator method.
+
+ Validates that flag values pass their common checker function. The checker
+ function takes flag values and returns True (if values look fine) or,
+ if values are not valid, either returns False or raises an Exception.
+ """
+
+ def __init__(self, flag_names, checker, message):
+ """Constructor.
+
+ Args:
+ flag_names: [str], containing names of the flags used by checker.
+ checker: function to verify the validator.
+ input - dict, with keys() being flag_names, and value for each
+ key being the value of the corresponding flag (string, boolean,
+ etc).
+ output - bool, True if validator constraint is satisfied.
+ If constraint is not satisfied, it should either return False or
+ raise flags.ValidationError(desired_error_message).
+ message: str, error message to be shown to the user if validator's
+ condition is not satisfied
+ """
+ super(MultiFlagsValidator, self).__init__(checker, message)
+ self.flag_names = flag_names
+
+ def _get_input_to_checker_function(self, flag_values):
+ """Given flag values, returns the input to be given to checker.
+
+ Args:
+ flag_values: flags.FlagValues, the FlagValues instance to get flags from.
+ Returns:
+ dict, with keys() being self.lag_names, and value for each key
+ being the value of the corresponding flag (string, boolean, etc).
+ """
+ return dict([key, flag_values[key].value] for key in self.flag_names)
+
+ def print_flags_with_values(self, flag_values):
+ prefix = 'flags '
+ flags_with_values = []
+ for key in self.flag_names:
+ flags_with_values.append('%s=%s' % (key, flag_values[key].value))
+ return prefix + ', '.join(flags_with_values)
+
+ def get_flags_names(self):
+ return self.flag_names
diff --git a/absl/flags/argparse_flags.py b/absl/flags/argparse_flags.py
index bcd40f5..ea4cdfa 100644
--- a/absl/flags/argparse_flags.py
+++ b/absl/flags/argparse_flags.py
@@ -18,7 +18,6 @@ argparse_flags.ArgumentParser is a drop-in replacement for
argparse.ArgumentParser. It takes care of collecting and defining absl flags
in argparse.
-
Here is a simple example:
# Assume the following absl.flags is defined in another module:
diff --git a/absl/flags/tests/_validators_test.py b/absl/flags/tests/_validators_test.py
index a5dec45..f724813 100644
--- a/absl/flags/tests/_validators_test.py
+++ b/absl/flags/tests/_validators_test.py
@@ -54,7 +54,7 @@ class SingleFlagValidatorTest(absltest.TestCase):
argv = ('./program',)
self.flag_values(argv)
- self.assertEqual(None, self.flag_values.test_flag)
+ self.assertIsNone(self.flag_values.test_flag)
self.flag_values.test_flag = 2
self.assertEqual(2, self.flag_values.test_flag)
self.assertEqual([None, 2], self.call_args)
@@ -110,11 +110,9 @@ class SingleFlagValidatorTest(absltest.TestCase):
argv = ('./program', '--test_flag=1')
self.flag_values(argv)
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.test_flag = 2
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flag --test_flag=2: Errors happen', str(e))
+ self.assertEqual('flag --test_flag=2: Errors happen', str(cm.exception))
self.assertEqual([1, 2], self.call_args)
def test_exception_raised_if_checker_raises_exception(self):
@@ -134,11 +132,9 @@ class SingleFlagValidatorTest(absltest.TestCase):
argv = ('./program', '--test_flag=1')
self.flag_values(argv)
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.test_flag = 2
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flag --test_flag=2: Specific message', str(e))
+ self.assertEqual('flag --test_flag=2: Specific message', str(cm.exception))
self.assertEqual([1, 2], self.call_args)
def test_error_message_when_checker_returns_false_on_start(self):
@@ -154,11 +150,9 @@ class SingleFlagValidatorTest(absltest.TestCase):
flag_values=self.flag_values)
argv = ('./program', '--test_flag=1')
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values(argv)
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flag --test_flag=1: Errors happen', str(e))
+ self.assertEqual('flag --test_flag=1: Errors happen', str(cm.exception))
self.assertEqual([1], self.call_args)
def test_error_message_when_checker_raises_exception_on_start(self):
@@ -175,11 +169,9 @@ class SingleFlagValidatorTest(absltest.TestCase):
flag_values=self.flag_values)
argv = ('./program', '--test_flag=1')
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values(argv)
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flag --test_flag=1: Specific message', str(e))
+ self.assertEqual('flag --test_flag=1: Specific message', str(cm.exception))
self.assertEqual([1], self.call_args)
def test_validators_checked_in_order(self):
@@ -222,7 +214,7 @@ class SingleFlagValidatorTest(absltest.TestCase):
argv = ('./program',)
self.flag_values(argv)
- self.assertEqual(None, self.flag_values.test_flag)
+ self.assertIsNone(self.flag_values.test_flag)
self.flag_values.test_flag = 2
self.assertEqual(2, self.flag_values.test_flag)
self.assertEqual([None, 2], self.call_args)
@@ -288,11 +280,9 @@ class MultiFlagsValidatorTest(absltest.TestCase):
argv = ('./program',)
self.flag_values(argv)
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.bar = 1
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flags foo=1, bar=1: Errors happen', str(e))
+ self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception))
self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
self.call_args)
@@ -313,11 +303,9 @@ class MultiFlagsValidatorTest(absltest.TestCase):
argv = ('./program',)
self.flag_values(argv)
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.bar = 1
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flags foo=1, bar=1: Specific message', str(e))
+ self.assertEqual('flags foo=1, bar=1: Specific message', str(cm.exception))
self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
self.call_args)
@@ -332,11 +320,9 @@ class MultiFlagsValidatorTest(absltest.TestCase):
argv = ('./program',)
self.flag_values(argv)
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.bar = 1
- raise AssertionError('IllegalFlagValueError expected')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual('flags foo=1, bar=1: Errors happen', str(e))
+ self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception))
self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
self.call_args)
@@ -373,8 +359,8 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
argv = ('./program',)
self.flag_values(argv)
- self.assertEqual(None, self.flag_values.flag_one)
- self.assertEqual(None, self.flag_values.flag_two)
+ self.assertIsNone(self.flag_values.flag_one)
+ self.assertIsNone(self.flag_values.flag_two)
def test_no_flags_present_required(self):
self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
@@ -452,8 +438,8 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
['multi_flag_one', 'multi_flag_two'], False)
argv = ('./program',)
self.flag_values(argv)
- self.assertEqual(None, self.flag_values.multi_flag_one)
- self.assertEqual(None, self.flag_values.multi_flag_two)
+ self.assertIsNone(self.flag_values.multi_flag_one)
+ self.assertIsNone(self.flag_values.multi_flag_two)
def test_no_multistring_flags_present_required(self):
self._mark_flags_as_mutually_exclusive(
@@ -558,8 +544,8 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase):
self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], False)
argv = ('./program', '--false_1', '--false_2')
expected = (
- 'flags false_1=True, false_2=True: At most one of (false_1, false_2) '
- 'must be True.')
+ 'flags false_1=True, false_2=True: At most one of (false_1, '
+ 'false_2) must be True.')
self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
expected, self.flag_values, argv)
@@ -610,11 +596,9 @@ class MarkFlagAsRequiredTest(absltest.TestCase):
self.assertEqual('value', self.flag_values.string_flag)
expected = ('flag --string_flag=None: Flag --string_flag must have a value '
'other than None.')
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.string_flag = None
- raise AssertionError('Failed to detect non-set required flag.')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual(expected, str(e))
+ self.assertEqual(expected, str(cm.exception))
def test_flag_default_not_none_warning(self):
_defines.DEFINE_string(
@@ -680,11 +664,81 @@ class MarkFlagsAsRequiredTest(absltest.TestCase):
expected = (
'flag --string_flag_1=None: Flag --string_flag_1 must have a value '
'other than None.')
- try:
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.string_flag_1 = None
- raise AssertionError('Failed to detect non-set required flag.')
- except _exceptions.IllegalFlagValueError as e:
- self.assertEqual(expected, str(e))
+ self.assertEqual(expected, str(cm.exception))
+
+ def test_catch_multiple_flags_as_none_at_program_start(self):
+ _defines.DEFINE_float(
+ 'float_flag_1',
+ None,
+ 'string flag 1',
+ flag_values=self.flag_values)
+ _defines.DEFINE_float(
+ 'float_flag_2',
+ None,
+ 'string flag 2',
+ flag_values=self.flag_values)
+ _validators.mark_flags_as_required(
+ ['float_flag_1', 'float_flag_2'], flag_values=self.flag_values)
+ argv = ('./program', '')
+ expected = (
+ 'flag --float_flag_1=None: Flag --float_flag_1 must have a value '
+ 'other than None.\n'
+ 'flag --float_flag_2=None: Flag --float_flag_2 must have a value '
+ 'other than None.')
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values(argv)
+ self.assertEqual(expected, str(cm.exception))
+
+ def test_fail_fast_single_flag_and_skip_remaining_validators(self):
+ def raise_unexpected_error(x):
+ del x
+ raise _exceptions.ValidationError('Should not be raised.')
+ _defines.DEFINE_float(
+ 'flag_1', None, 'flag 1', flag_values=self.flag_values)
+ _defines.DEFINE_float(
+ 'flag_2', 4.2, 'flag 2', flag_values=self.flag_values)
+ _validators.mark_flag_as_required('flag_1', flag_values=self.flag_values)
+ _validators.register_validator(
+ 'flag_1', raise_unexpected_error, flag_values=self.flag_values)
+ _validators.register_multi_flags_validator(['flag_2', 'flag_1'],
+ raise_unexpected_error,
+ flag_values=self.flag_values)
+ argv = ('./program', '')
+ expected = (
+ 'flag --flag_1=None: Flag --flag_1 must have a value other than None.')
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values(argv)
+ self.assertEqual(expected, str(cm.exception))
+
+ def test_fail_fast_multi_flag_and_skip_remaining_validators(self):
+ def raise_expected_error(x):
+ del x
+ raise _exceptions.ValidationError('Expected error.')
+ def raise_unexpected_error(x):
+ del x
+ raise _exceptions.ValidationError('Got unexpected error.')
+ _defines.DEFINE_float(
+ 'flag_1', 5.1, 'flag 1', flag_values=self.flag_values)
+ _defines.DEFINE_float(
+ 'flag_2', 10.0, 'flag 2', flag_values=self.flag_values)
+ _validators.register_multi_flags_validator(['flag_1', 'flag_2'],
+ raise_expected_error,
+ flag_values=self.flag_values)
+ _validators.register_multi_flags_validator(['flag_2', 'flag_1'],
+ raise_unexpected_error,
+ flag_values=self.flag_values)
+ _validators.register_validator(
+ 'flag_1', raise_unexpected_error, flag_values=self.flag_values)
+ _validators.register_validator(
+ 'flag_2', raise_unexpected_error, flag_values=self.flag_values)
+ argv = ('./program', '')
+ expected = ('flags flag_1=5.1, flag_2=10.0: Expected error.')
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values(argv)
+ self.assertEqual(expected, str(cm.exception))
+
if __name__ == '__main__':
absltest.main()