aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAbseil Team <absl-team@google.com>2020-08-03 06:46:32 -0700
committerCopybara-Service <copybara-worker@google.com>2020-08-03 06:46:51 -0700
commiteb94d9587c6f2eade9617237fb6bba1364226a3b (patch)
tree9f3d6257449e6f72c51cdd8c8607376723db7ada
parent0a1d25251489871c782d92208ca0c224fb31cddb (diff)
downloadabsl-py-eb94d9587c6f2eade9617237fb6bba1364226a3b.tar.gz
Make `DEFINE_enum_class` values case-insensitive by default.
This is to bridge the gap between lower case being idiomatic for command line values, and upper case for values defined in an `Enum`. It also makes it more consistent with `DEFINE_enum` usage, which usually uses lower case values. Case sensitivity can be restored by passing `case_sensitive=True`. When `case_sensitive=False` (the default), the associated flag serializers will lowercase enum member names. This seemed reasonable since help text that SCREAMS at you is inconsistent with the fact that member names will be provided most often in lowercase format. Internally, EnumParser is re-used for simplicity. PiperOrigin-RevId: 324594299 Change-Id: I2e15448ad00a095212756c5277b08219a9e84d55
-rw-r--r--absl/CHANGELOG.md6
-rw-r--r--absl/flags/BUILD1
-rw-r--r--absl/flags/_argument_parser.py74
-rw-r--r--absl/flags/_argument_parser.pyi9
-rw-r--r--absl/flags/_defines.py20
-rw-r--r--absl/flags/_defines.pyi2
-rw-r--r--absl/flags/_flag.py35
-rw-r--r--absl/flags/tests/_argument_parser_test.py41
-rw-r--r--absl/flags/tests/_flag_test.py50
-rw-r--r--absl/flags/tests/flags_helpxml_test.py35
-rw-r--r--absl/flags/tests/flags_test.py35
11 files changed, 235 insertions, 73 deletions
diff --git a/absl/CHANGELOG.md b/absl/CHANGELOG.md
index 46ed882..c708e6f 100644
--- a/absl/CHANGELOG.md
+++ b/absl/CHANGELOG.md
@@ -28,6 +28,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com).
`abc.Mapping` is treated as if only its keys are passed as a list of test
cases. If you were relying on the old inconsistent behavior, explicitly
convert the `abc.Mapping` to a `list`.
+* (flags) `DEFINE_enum_class` and `DEFINE_mutlti_enum_class` accept a
+ `case_sensitive` argument. When `False` (the default), strings are mapped to
+ enum member names without case sensitivity, and member names are serialized
+ in lowercase form. Flag definitions for enums whose members include
+ duplicates when case is ignored must now explicitly pass
+ `case_sensitive=True`.
### Fixed
diff --git a/absl/flags/BUILD b/absl/flags/BUILD
index 0cd93c4..50bdb00 100644
--- a/absl/flags/BUILD
+++ b/absl/flags/BUILD
@@ -114,6 +114,7 @@ py2and3_test(
":_argument_parser",
"//absl:_enum_module",
"//absl/testing:absltest",
+ "//absl/testing:parameterized",
"@six_archive//:six",
],
)
diff --git a/absl/flags/_argument_parser.py b/absl/flags/_argument_parser.py
index a1f0daf..a706191 100644
--- a/absl/flags/_argument_parser.py
+++ b/absl/flags/_argument_parser.py
@@ -22,6 +22,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import csv
import io
import string
@@ -372,11 +373,13 @@ class EnumParser(ArgumentParser):
class EnumClassParser(ArgumentParser):
"""Parser of an Enum class member."""
- def __init__(self, enum_class):
+ def __init__(self, enum_class, case_sensitive=True):
"""Initializes EnumParser.
Args:
enum_class: class, the Enum class with all possible flag values.
+ case_sensitive: bool, whether or not the enum is to be case-sensitive. If
+ False, all member names must be unique when case is ignored.
Raises:
TypeError: When enum_class is not a subclass of Enum.
@@ -391,9 +394,30 @@ class EnumClassParser(ArgumentParser):
if not enum_class.__members__:
raise ValueError('enum_class cannot be empty, but "{}" is empty.'
.format(enum_class))
+ if not case_sensitive:
+ members = collections.Counter(
+ name.lower() for name in enum_class.__members__)
+ duplicate_keys = {
+ member for member, count in members.items() if count > 1
+ }
+ if duplicate_keys:
+ raise ValueError(
+ 'Duplicate enum values for {} using case_sensitive=False'.format(
+ duplicate_keys))
super(EnumClassParser, self).__init__()
self.enum_class = enum_class
+ self._case_sensitive = case_sensitive
+ if case_sensitive:
+ self._member_names = tuple(enum_class.__members__)
+ else:
+ self._member_names = tuple(
+ name.lower() for name in enum_class.__members__)
+
+ @property
+ def member_names(self):
+ """The accepted enum names, in lowercase if not case sensitive."""
+ return self._member_names
def parse(self, argument):
"""Determines validity of argument and returns the correct element of enum.
@@ -409,11 +433,19 @@ class EnumClassParser(ArgumentParser):
"""
if isinstance(argument, self.enum_class):
return argument
- if argument not in self.enum_class.__members__:
- raise ValueError('value should be one of <%s>' %
- '|'.join(self.enum_class.__members__.keys()))
+ elif not isinstance(argument, six.string_types):
+ raise ValueError(
+ '{} is not an enum member or a name of a member in {}'.format(
+ argument, self.enum_class))
+ key = EnumParser(
+ self._member_names, case_sensitive=self._case_sensitive).parse(argument)
+ if self._case_sensitive:
+ return self.enum_class[key]
else:
- return self.enum_class[argument]
+ # If EnumParser.parse() return a value, we're guaranteed to find it
+ # as a member of the class
+ return next(value for name, value in self.enum_class.__members__.items()
+ if name.lower() == key.lower())
def flag_type(self):
"""See base class."""
@@ -431,13 +463,30 @@ class ListSerializer(ArgumentSerializer):
class EnumClassListSerializer(ListSerializer):
+ """A serializer for MultiEnumClass flags.
+
+ This serializer simply joins the output of `EnumClassSerializer` using a
+ provided seperator.
+ """
+
+ def __init__(self, list_sep, **kwargs):
+ """Initializes EnumClassListSerializer.
+
+ Args:
+ list_sep: String to be used as a separator when serializing
+ **kwargs: Keyword arguments to the `EnumClassSerializer` used to serialize
+ individual values.
+ """
+ super(EnumClassListSerializer, self).__init__(list_sep)
+ self._element_serializer = EnumClassSerializer(**kwargs)
def serialize(self, value):
"""See base class."""
if isinstance(value, list):
- return self.list_sep.join(_helpers.str_or_unicode(x.name) for x in value)
+ return self.list_sep.join(
+ self._element_serializer.serialize(x) for x in value)
else:
- return _helpers.str_or_unicode(value.name)
+ return self._element_serializer.serialize(value)
class CsvListSerializer(ArgumentSerializer):
@@ -466,9 +515,18 @@ class CsvListSerializer(ArgumentSerializer):
class EnumClassSerializer(ArgumentSerializer):
"""Class for generating string representations of an enum class flag value."""
+ def __init__(self, lowercase):
+ """Initializes EnumClassSerializer.
+
+ Args:
+ lowercase: If True, enum member names are lowercased during serialization.
+ """
+ self._lowercase = lowercase
+
def serialize(self, value):
"""Returns a serialized string of the Enum class value."""
- return _helpers.str_or_unicode(value.name)
+ as_string = _helpers.str_or_unicode(value.name)
+ return as_string.lower() if self._lowercase else as_string
class BaseListParser(ArgumentParser):
diff --git a/absl/flags/_argument_parser.pyi b/absl/flags/_argument_parser.pyi
index e237002..62f6738 100644
--- a/absl/flags/_argument_parser.pyi
+++ b/absl/flags/_argument_parser.pyi
@@ -67,16 +67,18 @@ class BooleanParser(ArgumentParser[bool]):
class EnumParser(ArgumentParser[Text]):
- def __init__(self, enum_values: Sequence[Text], case_sensitive: bool=True) -> None:
+ def __init__(self, enum_values: Sequence[Text], case_sensitive: bool=...) -> None:
...
class EnumClassParser(ArgumentParser[_ET]):
- def __init__(self, enum_class: Type[_ET]) -> None:
+ def __init__(self, enum_class: Type[_ET], case_sensitive: bool=...) -> None:
...
+ @property
+ def member_names(self) -> Sequence[Text]: ...
class BaseListParser(ArgumentParser[List[Text]]):
@@ -107,7 +109,8 @@ class ListSerializer(ArgumentSerializer[List[Text]]):
class EnumClassListSerializer(ArgumentSerializer[List[Text]]):
- ...
+ def __init__(self, list_sep: Text, **kwargs: Any) -> None:
+ ...
class CsvListSerializer(ArgumentSerializer[List[Any]]):
diff --git a/absl/flags/_defines.py b/absl/flags/_defines.py
index bb0c170..a8d5470 100644
--- a/absl/flags/_defines.py
+++ b/absl/flags/_defines.py
@@ -410,6 +410,7 @@ def DEFINE_enum_class( # pylint: disable=invalid-name,redefined-builtin
help,
flag_values=_flagvalues.FLAGS,
module_name=None,
+ case_sensitive=False,
**args):
"""Registers a flag whose value can be the name of enum members.
@@ -422,14 +423,21 @@ def DEFINE_enum_class( # pylint: disable=invalid-name,redefined-builtin
registered. This should almost never need to be overridden.
module_name: str, the name of the Python module declaring this flag. If not
provided, it will be computed using the stack trace of this call.
+ case_sensitive: bool, whether to map strings to members of the enum_class
+ without considering case.
**args: dict, the extra keyword args that are passed to Flag __init__.
Returns:
a handle to defined flag.
"""
return DEFINE_flag(
- _flag.EnumClassFlag(name, default, help, enum_class, **args), flag_values,
- module_name)
+ _flag.EnumClassFlag(
+ name,
+ default,
+ help,
+ enum_class,
+ case_sensitive=case_sensitive,
+ **args), flag_values, module_name)
def DEFINE_list( # pylint: disable=invalid-name,redefined-builtin
@@ -682,6 +690,7 @@ def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin
help,
flag_values=_flagvalues.FLAGS,
module_name=None,
+ case_sensitive=False,
**args):
"""Registers a flag whose value can be a list of enum members.
@@ -701,6 +710,8 @@ def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin
registered. This should almost never need to be overridden.
module_name: A string, the name of the Python module declaring this flag. If
not provided, it will be computed using the stack trace of this call.
+ case_sensitive: bool, whether to map strings to members of the enum_class
+ without considering case.
**args: Dictionary with extra keyword args that are passed to the Flag
__init__.
@@ -708,8 +719,9 @@ def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin
a handle to defined flag.
"""
return DEFINE_flag(
- _flag.MultiEnumClassFlag(name, default, help, enum_class), flag_values,
- module_name, **args)
+ _flag.MultiEnumClassFlag(
+ name, default, help, enum_class, case_sensitive=case_sensitive),
+ flag_values, module_name, **args)
def DEFINE_alias( # pylint: disable=invalid-name
diff --git a/absl/flags/_defines.pyi b/absl/flags/_defines.pyi
index 8607835..5949f0a 100644
--- a/absl/flags/_defines.pyi
+++ b/absl/flags/_defines.pyi
@@ -160,6 +160,7 @@ def DEFINE_enum_class(
help: Optional[Text],
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = None,
+ case_sensitive: bool = ...,
**args: Any) -> _flagvalues.FlagHolder[Optional[_ET]]:
...
@@ -171,6 +172,7 @@ def DEFINE_enum_class(
help: Optional[Text],
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = None,
+ case_sensitive: bool = ...,
**args: Any) -> _flagvalues.FlagHolder[_ET]:
...
diff --git a/absl/flags/_flag.py b/absl/flags/_flag.py
index 74b5d65..a893b22 100644
--- a/absl/flags/_flag.py
+++ b/absl/flags/_flag.py
@@ -345,13 +345,21 @@ class EnumFlag(Flag):
class EnumClassFlag(Flag):
"""Basic enum flag; its value is an enum class's member."""
- def __init__(self, name, default, help, enum_class, # pylint: disable=redefined-builtin
- short_name=None, **args):
- p = _argument_parser.EnumClassParser(enum_class)
- g = _argument_parser.EnumClassSerializer()
+ def __init__(
+ self,
+ name,
+ default,
+ help, # pylint: disable=redefined-builtin
+ enum_class,
+ short_name=None,
+ case_sensitive=False,
+ **args):
+ p = _argument_parser.EnumClassParser(
+ enum_class, case_sensitive=case_sensitive)
+ g = _argument_parser.EnumClassSerializer(lowercase=not case_sensitive)
super(EnumClassFlag, self).__init__(
p, g, name, default, help, short_name, **args)
- self.help = '<%s>: %s' % ('|'.join(enum_class.__members__), self.help)
+ self.help = '<%s>: %s' % ('|'.join(p.member_names), self.help)
def _extra_xml_dom_elements(self, doc):
elements = []
@@ -445,15 +453,22 @@ class MultiEnumClassFlag(MultiFlag):
type.
"""
- def __init__(self, name, default, help_string, enum_class, **args):
- p = _argument_parser.EnumClassParser(enum_class)
- g = _argument_parser.EnumClassListSerializer(list_sep=',')
+ def __init__(self,
+ name,
+ default,
+ help_string,
+ enum_class,
+ case_sensitive=False,
+ **args):
+ p = _argument_parser.EnumClassParser(
+ enum_class, case_sensitive=case_sensitive)
+ g = _argument_parser.EnumClassListSerializer(
+ list_sep=',', lowercase=not case_sensitive)
super(MultiEnumClassFlag, self).__init__(
p, g, name, default, help_string, **args)
self.help = (
'<%s>: %s;\n repeat this option to specify a list of values' %
- ('|'.join(enum_class.__members__),
- help_string or '(no help available)'))
+ ('|'.join(p.member_names), help_string or '(no help available)'))
def _extra_xml_dom_elements(self, doc):
elements = []
diff --git a/absl/flags/tests/_argument_parser_test.py b/absl/flags/tests/_argument_parser_test.py
index 7131a04..e62c903 100644
--- a/absl/flags/tests/_argument_parser_test.py
+++ b/absl/flags/tests/_argument_parser_test.py
@@ -23,6 +23,7 @@ from __future__ import print_function
from absl._enum_module import enum
from absl.flags import _argument_parser
from absl.testing import absltest
+from absl.testing import parameterized
import six
@@ -134,7 +135,13 @@ class EmptyEnum(enum.Enum):
pass
-class EnumClassParserTest(absltest.TestCase):
+class MixedCaseEnum(enum.Enum):
+ APPLE = 1
+ BANANA = 2
+ apple = 3
+
+
+class EnumClassParserTest(parameterized.TestCase):
def test_requires_enum(self):
with self.assertRaises(TypeError):
@@ -144,10 +151,25 @@ class EnumClassParserTest(absltest.TestCase):
with self.assertRaises(ValueError):
_argument_parser.EnumClassParser(EmptyEnum)
+ def test_case_sensitive_rejects_duplicates(self):
+ unused_normal_parser = _argument_parser.EnumClassParser(MixedCaseEnum)
+ with self.assertRaisesRegex(ValueError, 'Duplicate.+apple'):
+ _argument_parser.EnumClassParser(MixedCaseEnum, case_sensitive=False)
+
def test_parse_string(self):
parser = _argument_parser.EnumClassParser(Fruit)
self.assertEqual(Fruit.APPLE, parser.parse('APPLE'))
+ def test_parse_string_case_sensitive(self):
+ parser = _argument_parser.EnumClassParser(Fruit)
+ with self.assertRaises(ValueError):
+ parser.parse('apple')
+
+ @parameterized.parameters('APPLE', 'apple', 'Apple')
+ def test_parse_string_case_insensitive(self, value):
+ parser = _argument_parser.EnumClassParser(Fruit, case_sensitive=False)
+ self.assertIs(Fruit.APPLE, parser.parse(value))
+
def test_parse_literal(self):
parser = _argument_parser.EnumClassParser(Fruit)
self.assertEqual(Fruit.APPLE, parser.parse(Fruit.APPLE))
@@ -157,14 +179,15 @@ class EnumClassParserTest(absltest.TestCase):
with self.assertRaises(ValueError):
parser.parse('ORANGE')
- def test_serialize_parse(self):
- serializer = _argument_parser.EnumClassSerializer()
- val1 = Fruit.BANANA
- parser = _argument_parser.EnumClassParser(Fruit)
- serialized = serializer.serialize(val1)
- self.assertEqual(serialized, 'BANANA')
- val2 = parser.parse(serialized)
- self.assertEqual(val1, val2)
+ @parameterized.parameters((Fruit.BANANA, False, 'BANANA'),
+ (Fruit.BANANA, True, 'banana'))
+ def test_serialize_parse(self, value, lowercase, expected):
+ serializer = _argument_parser.EnumClassSerializer(lowercase=lowercase)
+ parser = _argument_parser.EnumClassParser(
+ Fruit, case_sensitive=not lowercase)
+ serialized = serializer.serialize(value)
+ self.assertEqual(serialized, expected)
+ self.assertEqual(value, parser.parse(expected))
class HelperFunctionsTest(absltest.TestCase):
diff --git a/absl/flags/tests/_flag_test.py b/absl/flags/tests/_flag_test.py
index 084b37b..fb08b51 100644
--- a/absl/flags/tests/_flag_test.py
+++ b/absl/flags/tests/_flag_test.py
@@ -124,10 +124,18 @@ class EmptyEnum(enum.Enum):
class EnumClassFlagTest(parameterized.TestCase):
@parameterized.parameters(
+ ('', '<apple|orange>: (no help available)'),
+ ('Type of fruit.', '<apple|orange>: Type of fruit.'))
+ def test_help_text_case_insensitive(self, helptext_input, helptext_output):
+ f = _flag.EnumClassFlag('fruit', None, helptext_input, Fruit)
+ self.assertEqual(helptext_output, f.help)
+
+ @parameterized.parameters(
('', '<APPLE|ORANGE>: (no help available)'),
('Type of fruit.', '<APPLE|ORANGE>: Type of fruit.'))
- def test_help_text(self, helptext_input, helptext_output):
- f = _flag.EnumClassFlag('fruit', None, helptext_input, Fruit)
+ def test_help_text_case_sensitive(self, helptext_input, helptext_output):
+ f = _flag.EnumClassFlag(
+ 'fruit', None, helptext_input, Fruit, case_sensitive=True)
self.assertEqual(helptext_output, f.help)
def test_requires_enum(self):
@@ -146,6 +154,16 @@ class EnumClassFlagTest(parameterized.TestCase):
f = _flag.EnumClassFlag('fruit', 'ORANGE', 'A sample enum flag.', Fruit)
self.assertEqual(Fruit.ORANGE, f.value)
+ def test_case_sensitive_rejects_default_with_wrong_case(self):
+ with self.assertRaises(_exceptions.IllegalFlagValueError):
+ _flag.EnumClassFlag(
+ 'fruit', 'oranGe', 'A sample enum flag.', Fruit, case_sensitive=True)
+
+ def test_case_insensitive_accepts_string_default(self):
+ f = _flag.EnumClassFlag(
+ 'fruit', 'oranGe', 'A sample enum flag.', Fruit, case_sensitive=False)
+ self.assertEqual(Fruit.ORANGE, f.value)
+
def test_default_value_does_not_exist(self):
with self.assertRaises(_exceptions.IllegalFlagValueError):
_flag.EnumClassFlag('fruit', 'BANANA', 'help', Fruit)
@@ -154,13 +172,14 @@ class EnumClassFlagTest(parameterized.TestCase):
class MultiEnumClassFlagTest(parameterized.TestCase):
@parameterized.named_parameters(
- ('NoHelpSupplied', '', '<APPLE|ORANGE>: (no help available);\n '
- 'repeat this option to specify a list of values'),
+ ('NoHelpSupplied', '', '<apple|orange>: (no help available);\n '
+ 'repeat this option to specify a list of values', False),
('WithHelpSupplied', 'Type of fruit.',
'<APPLE|ORANGE>: Type of fruit.;\n '
- 'repeat this option to specify a list of values'))
- def test_help_text(self, helptext_input, helptext_output):
- f = _flag.MultiEnumClassFlag('fruit', None, helptext_input, Fruit)
+ 'repeat this option to specify a list of values', True))
+ def test_help_text(self, helptext_input, helptext_output, case_sensitive):
+ f = _flag.MultiEnumClassFlag(
+ 'fruit', None, helptext_input, Fruit, case_sensitive=case_sensitive)
self.assertEqual(helptext_output, f.help)
def test_requires_enum(self):
@@ -171,6 +190,20 @@ class MultiEnumClassFlagTest(parameterized.TestCase):
with self.assertRaises(ValueError):
_flag.MultiEnumClassFlag('empty', None, 'help', EmptyEnum)
+ def test_rejects_wrong_case_when_case_sensitive(self):
+ with self.assertRaisesRegex(_exceptions.IllegalFlagValueError,
+ '<APPLE|ORANGE>'):
+ _flag.MultiEnumClassFlag(
+ 'fruit', ['APPLE', 'Orange'],
+ 'A sample enum flag.',
+ Fruit,
+ case_sensitive=True)
+
+ def test_accepts_case_insensitive(self):
+ f = _flag.MultiEnumClassFlag('fruit', ['apple', 'APPLE'],
+ 'A sample enum flag.', Fruit)
+ self.assertListEqual([Fruit.APPLE, Fruit.APPLE], f.value)
+
def test_accepts_literal_default(self):
f = _flag.MultiEnumClassFlag('fruit', Fruit.APPLE, 'A sample enum flag.',
Fruit)
@@ -192,7 +225,8 @@ class MultiEnumClassFlagTest(parameterized.TestCase):
self.assertListEqual([Fruit.ORANGE, Fruit.APPLE], f.value)
def test_default_value_does_not_exist(self):
- with self.assertRaises(_exceptions.IllegalFlagValueError):
+ with self.assertRaisesRegex(_exceptions.IllegalFlagValueError,
+ '<apple|banana>'):
_flag.MultiEnumClassFlag('fruit', 'BANANA', 'help', Fruit)
diff --git a/absl/flags/tests/flags_helpxml_test.py b/absl/flags/tests/flags_helpxml_test.py
index 33b1e84..92b59bb 100644
--- a/absl/flags/tests/flags_helpxml_test.py
+++ b/absl/flags/tests/flags_helpxml_test.py
@@ -217,18 +217,17 @@ class FlagCreateXMLDOMElement(absltest.TestCase):
flags.DEFINE_enum_class('cc_version', 'STABLE', Version,
'Compiler version to use.', flag_values=self.fv)
- expected_output = (
- '<flag>\n'
- ' <file>tool</file>\n'
- ' <name>cc_version</name>\n'
- ' <meaning>&lt;STABLE|EXPERIMENTAL&gt;: '
- 'Compiler version to use.</meaning>\n'
- ' <default>STABLE</default>\n'
- ' <current>Version.STABLE</current>\n'
- ' <type>enum class</type>\n'
- ' <enum_value>STABLE</enum_value>\n'
- ' <enum_value>EXPERIMENTAL</enum_value>\n'
- '</flag>\n')
+ expected_output = ('<flag>\n'
+ ' <file>tool</file>\n'
+ ' <name>cc_version</name>\n'
+ ' <meaning>&lt;stable|experimental&gt;: '
+ 'Compiler version to use.</meaning>\n'
+ ' <default>stable</default>\n'
+ ' <current>Version.STABLE</current>\n'
+ ' <type>enum class</type>\n'
+ ' <enum_value>STABLE</enum_value>\n'
+ ' <enum_value>EXPERIMENTAL</enum_value>\n'
+ '</flag>\n')
self._check_flag_help_in_xml('cc_version', 'tool', expected_output)
def test_flag_help_in_xml_comma_separated_list(self):
@@ -374,10 +373,10 @@ class FlagCreateXMLDOMElement(absltest.TestCase):
'<flag>\n'
' <file>tool</file>\n'
' <name>fruit</name>\n'
- ' <meaning>&lt;ORANGE|BANANA&gt;: The fruit flag.;\n'
+ ' <meaning>&lt;orange|banana&gt;: The fruit flag.;\n'
' repeat this option to specify a list of values</meaning>\n'
- ' <default>ORANGE</default>\n'
- ' <current>ORANGE</current>\n'
+ ' <default>orange</default>\n'
+ ' <current>orange</current>\n'
' <type>multi enum class</type>\n'
' <enum_value>ORANGE</enum_value>\n'
' <enum_value>BANANA</enum_value>\n'
@@ -396,10 +395,10 @@ class FlagCreateXMLDOMElement(absltest.TestCase):
'<flag>\n'
' <file>tool</file>\n'
' <name>fruit</name>\n'
- ' <meaning>&lt;ORANGE|BANANA&gt;: The fruit flag.;\n'
+ ' <meaning>&lt;orange|banana&gt;: The fruit flag.;\n'
' repeat this option to specify a list of values</meaning>\n'
- ' <default>ORANGE,BANANA</default>\n'
- ' <current>ORANGE,BANANA</current>\n'
+ ' <default>orange,banana</default>\n'
+ ' <current>orange,banana</current>\n'
' <type>multi enum class</type>\n'
' <enum_value>ORANGE</enum_value>\n'
' <enum_value>BANANA</enum_value>\n'
diff --git a/absl/flags/tests/flags_test.py b/absl/flags/tests/flags_test.py
index af5454b..8608f2f 100644
--- a/absl/flags/tests/flags_test.py
+++ b/absl/flags/tests/flags_test.py
@@ -80,6 +80,11 @@ class FlagDictToArgsTest(absltest.TestCase):
class Fruit(enum.Enum):
+ APPLE = object()
+ ORANGE = object()
+
+
+class CaseSensitiveFruit(enum.Enum):
apple = 1
orange = 2
APPLE = 3
@@ -1328,12 +1333,12 @@ class FlagsUnitTest(absltest.TestCase):
fv = flags.FlagValues()
flags.DEFINE_enum_class('fruit', None, Fruit, '?', flag_values=fv)
- argv = ('./program', '--fruit=apple')
+ argv = ('./program', '--fruit=orange')
argv = fv(argv)
self.assertEqual(len(argv), 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(fv['fruit'].present, 1)
- self.assertEqual(fv['fruit'].value, Fruit.apple)
+ self.assertEqual(fv['fruit'].value, Fruit.ORANGE)
fv.unparse_flags()
argv = ('./program', '--fruit=APPLE')
argv = fv(argv)
@@ -1348,7 +1353,7 @@ class FlagsUnitTest(absltest.TestCase):
flags.DEFINE_enum_class('fruit', None, Fruit, '?', flag_values=fv)
helpstr = fv.main_module_help()
- expected_help = '\n%s:\n --fruit: <apple|orange|APPLE>: ?' % sys.argv[0]
+ expected_help = '\n%s:\n --fruit: <apple|orange>: ?' % sys.argv[0]
self.assertEqual(helpstr, expected_help)
@@ -1566,7 +1571,7 @@ class MultiEnumClassFlagsTest(absltest.TestCase):
flag_values=fv)
fv.mark_as_parsed()
- self.assertListEqual(fv.fruit, [Fruit.apple])
+ self.assertListEqual(fv.fruit, [Fruit.APPLE])
def test_define_results_in_registered_flag_with_enum(self):
fv = flags.FlagValues()
@@ -1582,24 +1587,28 @@ class MultiEnumClassFlagsTest(absltest.TestCase):
def test_define_results_in_registered_flag_with_string_list(self):
fv = flags.FlagValues()
enum_defaults = ['apple', 'APPLE']
- flags.DEFINE_multi_enum_class('fruit',
- enum_defaults, Fruit,
- 'Enum option that can occur multiple times',
- flag_values=fv)
+ flags.DEFINE_multi_enum_class(
+ 'fruit',
+ enum_defaults,
+ CaseSensitiveFruit,
+ 'Enum option that can occur multiple times',
+ flag_values=fv,
+ case_sensitive=True)
fv.mark_as_parsed()
- self.assertListEqual(fv.fruit, [Fruit.apple, Fruit.APPLE])
+ self.assertListEqual(fv.fruit,
+ [CaseSensitiveFruit.apple, CaseSensitiveFruit.APPLE])
def test_define_results_in_registered_flag_with_enum_list(self):
fv = flags.FlagValues()
- enum_defaults = [Fruit.APPLE, Fruit.orange]
+ enum_defaults = [Fruit.APPLE, Fruit.ORANGE]
flags.DEFINE_multi_enum_class('fruit',
enum_defaults, Fruit,
'Enum option that can occur multiple times',
flag_values=fv)
fv.mark_as_parsed()
- self.assertListEqual(fv.fruit, [Fruit.APPLE, Fruit.orange])
+ self.assertListEqual(fv.fruit, [Fruit.APPLE, Fruit.ORANGE])
def test_from_command_line_returns_multiple(self):
fv = flags.FlagValues()
@@ -1608,9 +1617,9 @@ class MultiEnumClassFlagsTest(absltest.TestCase):
enum_defaults, Fruit,
'Enum option that can occur multiple times',
flag_values=fv)
- argv = ('./program', '--fruit=apple', '--fruit=orange')
+ argv = ('./program', '--fruit=Apple', '--fruit=orange')
fv(argv)
- self.assertListEqual(fv.fruit, [Fruit.apple, Fruit.orange])
+ self.assertListEqual(fv.fruit, [Fruit.APPLE, Fruit.ORANGE])
def test_bad_multi_enum_class_flags_from_definition(self):
with self.assertRaisesRegex(