aboutsummaryrefslogtreecommitdiff
path: root/absl/flags
diff options
context:
space:
mode:
authorAbseil Team <absl-team@google.com>2019-01-09 06:58:34 -0800
committerCopybara-Service <copybara-piper@google.com>2019-01-09 06:58:48 -0800
commit252d140a8d76b4d79548e3f6d617f33444c86935 (patch)
tree063d44062acd87253c27812a94d71c72dc47c899 /absl/flags
parente7fbca22a32ce1da8df6a774505c6358099a2edb (diff)
downloadabsl-py-252d140a8d76b4d79548e3f6d617f33444c86935.tar.gz
Change "should be specified" errors to "value other than None"
Previously, the error message for mark_flags_as_mutual_exclusive said "specified", but "specified" sounds like "included in the command-line invocation. That's not necessarily equivalent, in particular because flags can have default values other than None. mark_flags_required used a similar wording in the error message, but generated a warning when used with a flag with a default value other than None. Instead, make both generate a warning and clarify the error message for both to "should have a value other than None", since neither validator checks that a flag value was specified in the command-line invocation. The error message should still be clear in cases where the warning is suppressed or overlooked. Also adds a test for the existing warning behavior of mark_flag_as_required. PiperOrigin-RevId: 228510310
Diffstat (limited to 'absl/flags')
-rw-r--r--absl/flags/_validators.py21
-rw-r--r--absl/flags/tests/_validators_test.py59
2 files changed, 61 insertions, 19 deletions
diff --git a/absl/flags/_validators.py b/absl/flags/_validators.py
index 8ec4ef6..f52e74b 100644
--- a/absl/flags/_validators.py
+++ b/absl/flags/_validators.py
@@ -356,10 +356,11 @@ def mark_flag_as_required(flag_name, flag_values=_flagvalues.FLAGS):
'Flag --%s has a non-None default value; therefore, '
'mark_flag_as_required will pass even if flag is not specified in the '
'command line!' % flag_name)
- register_validator(flag_name,
- lambda value: value is not None,
- message='Flag --%s must be specified.' % flag_name,
- flag_values=flag_values)
+ register_validator(
+ flag_name,
+ lambda value: value is not None,
+ message='Flag --{} must have a value other than None.'.format(flag_name),
+ flag_values=flag_values)
def mark_flags_as_required(flag_names, flag_values=_flagvalues.FLAGS):
@@ -400,14 +401,20 @@ def mark_flags_as_mutual_exclusive(flag_names, required=False,
flag_values: flags.FlagValues, optional FlagValues instance where the flags
are defined.
"""
+ for flag_name in flag_names:
+ if flag_values[flag_name].default is not None:
+ warnings.warn(
+ 'Flag --{} has a non-None default value. That does not make sense '
+ 'with mark_flags_as_mutual_exclusive, which checks whether the '
+ 'listed flags have a value other than None.'.format(flag_name))
def validate_mutual_exclusion(flags_dict):
flag_count = sum(1 for val in flags_dict.values() if val is not None)
if flag_count == 1 or (not required and flag_count == 0):
return True
- message = ('%s one of (%s) must be specified.' %
- ('Exactly' if required else 'At most', ', '.join(flag_names)))
- raise _exceptions.ValidationError(message)
+ raise _exceptions.ValidationError(
+ '{} one of ({}) must have a value other than None.'.format(
+ 'Exactly' if required else 'At most', ', '.join(flag_names)))
register_multi_flags_validator(
flag_names, validate_mutual_exclusion, flag_values=flag_values)
diff --git a/absl/flags/tests/_validators_test.py b/absl/flags/tests/_validators_test.py
index f01d257..e2fb9fa 100644
--- a/absl/flags/tests/_validators_test.py
+++ b/absl/flags/tests/_validators_test.py
@@ -22,6 +22,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import warnings
+
+
from absl.flags import _defines
from absl.flags import _exceptions
from absl.flags import _flagvalues
@@ -358,6 +361,8 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
'multi_flag_one', None, 'multi flag one', flag_values=self.flag_values)
_defines.DEFINE_multi_string(
'multi_flag_two', None, 'multi flag two', flag_values=self.flag_values)
+ _defines.DEFINE_boolean(
+ 'flag_not_none', False, 'false default', flag_values=self.flag_values)
def _mark_flags_as_mutually_exclusive(self, flag_names, required):
_validators.mark_flags_as_mutual_exclusive(
@@ -376,7 +381,8 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
argv = ('./program',)
expected = (
'flags flag_one=None, flag_two=None: '
- 'Exactly one of (flag_one, flag_two) must be specified.')
+ 'Exactly one of (flag_one, flag_two) must have a value other than '
+ 'None.')
self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
expected, self.flag_values, argv)
@@ -411,7 +417,8 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
argv = ('./program', '--int_flag_one=0', '--int_flag_two=0')
expected = (
'flags int_flag_one=0, int_flag_two=0: '
- 'At most one of (int_flag_one, int_flag_two) must be specified.')
+ 'At most one of (int_flag_one, int_flag_two) must have a value other '
+ 'than None.')
self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
expected, self.flag_values, argv)
@@ -422,7 +429,8 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
argv = ('./program', '--flag_one=1', '--flag_two=2', '--flag_three=3')
expected = (
'flags flag_one=1, flag_two=2, flag_three=3: '
- 'At most one of (flag_one, flag_two, flag_three) must be specified.')
+ 'At most one of (flag_one, flag_two, flag_three) must have a value '
+ 'other than None.')
self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
expected, self.flag_values, argv)
@@ -433,7 +441,8 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
argv = ('./program', '--flag_one=1', '--flag_two=2', '--flag_three=3')
expected = (
'flags flag_one=1, flag_two=2, flag_three=3: '
- 'Exactly one of (flag_one, flag_two, flag_three) must be specified.')
+ 'Exactly one of (flag_one, flag_two, flag_three) must have a value '
+ 'other than None.')
self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
expected, self.flag_values, argv)
@@ -452,7 +461,8 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
argv = ('./program',)
expected = (
'flags multi_flag_one=None, multi_flag_two=None: '
- 'Exactly one of (multi_flag_one, multi_flag_two) must be specified.')
+ 'Exactly one of (multi_flag_one, multi_flag_two) must have a value '
+ 'other than None.')
self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
expected, self.flag_values, argv)
@@ -475,7 +485,8 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
argv = ('./program', '--multi_flag_one=1', '--multi_flag_two=2')
expected = (
"flags multi_flag_one=['1'], multi_flag_two=['2']: "
- "At most one of (multi_flag_one, multi_flag_two) must be specified.")
+ 'At most one of (multi_flag_one, multi_flag_two) must have a value '
+ 'other than None.')
self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
expected, self.flag_values, argv)
@@ -486,11 +497,21 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
argv = ('./program', '--multi_flag_one=1', '--multi_flag_two=2')
expected = (
"flags multi_flag_one=['1'], multi_flag_two=['2']: "
- "Exactly one of (multi_flag_one, multi_flag_two) must be specified.")
+ 'Exactly one of (multi_flag_one, multi_flag_two) must have a value '
+ 'other than None.')
self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
expected, self.flag_values, argv)
+ def test_flag_default_not_none_warning(self):
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter('always')
+ self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_not_none'],
+ False)
+ self.assertLen(w, 1)
+ self.assertIn('--flag_not_none has a non-None default value',
+ str(w[0].message))
+
class MarkFlagAsRequiredTest(absltest.TestCase):
@@ -514,7 +535,8 @@ class MarkFlagAsRequiredTest(absltest.TestCase):
'string_flag', flag_values=self.flag_values)
argv = ('./program',)
expected = (
- r'flag --string_flag=None: Flag --string_flag must be specified\.')
+ r'flag --string_flag=None: Flag --string_flag must have a value other '
+ r'than None\.')
with self.assertRaisesRegex(_exceptions.IllegalFlagValueError, expected):
self.flag_values(argv)
@@ -526,13 +548,25 @@ class MarkFlagAsRequiredTest(absltest.TestCase):
argv = ('./program',)
self.flag_values(argv)
self.assertEqual('value', self.flag_values.string_flag)
- expected = 'flag --string_flag=None: Flag --string_flag must be specified.'
+ expected = ('flag --string_flag=None: Flag --string_flag must have a value '
+ 'other than None.')
try:
self.flag_values.string_flag = None
raise AssertionError('Failed to detect non-set required flag.')
except _exceptions.IllegalFlagValueError as e:
self.assertEqual(expected, str(e))
+ def test_flag_default_not_none_warning(self):
+ _defines.DEFINE_string(
+ 'flag_not_none', '', 'empty default', flag_values=self.flag_values)
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter('always')
+ _validators.mark_flag_as_required(
+ 'flag_not_none', flag_values=self.flag_values)
+ self.assertLen(w, 1)
+ self.assertIn('--flag_not_none has a non-None default value',
+ str(w[0].message))
+
class MarkFlagsAsRequiredTest(absltest.TestCase):
@@ -561,7 +595,8 @@ class MarkFlagsAsRequiredTest(absltest.TestCase):
['string_flag_1', 'string_flag_2'], flag_values=self.flag_values)
argv = ('./program', '--string_flag_1=value_1')
expected = (
- r'flag --string_flag_2=None: Flag --string_flag_2 must be specified\.')
+ r'flag --string_flag_2=None: Flag --string_flag_2 must have a value '
+ r'other than None\.')
with self.assertRaisesRegex(_exceptions.IllegalFlagValueError, expected):
self.flag_values(argv)
@@ -582,13 +617,13 @@ class MarkFlagsAsRequiredTest(absltest.TestCase):
self.flag_values(argv)
self.assertEqual('value_1', self.flag_values.string_flag_1)
expected = (
- 'flag --string_flag_1=None: Flag --string_flag_1 must be specified.')
+ 'flag --string_flag_1=None: Flag --string_flag_1 must have a value '
+ 'other than None.')
try:
self.flag_values.string_flag_1 = None
raise AssertionError('Failed to detect non-set required flag.')
except _exceptions.IllegalFlagValueError as e:
self.assertEqual(expected, str(e))
-
if __name__ == '__main__':
absltest.main()