diff options
author | Siavash Khodadadeh <siavash.khodadadeh@gmail.com> | 2018-12-19 14:05:09 -0800 |
---|---|---|
committer | Copybara-Service <copybara-piper@google.com> | 2018-12-19 14:05:22 -0800 |
commit | de91d0a265b8c067b1d01d63731684819f62e4bc (patch) | |
tree | 135ed474dc654e39249c05e707bcc866459dabfc | |
parent | 8575d76290eada1a3bcc6da4288cc263f7b3e84b (diff) | |
download | absl-py-de91d0a265b8c067b1d01d63731684819f62e4bc.tar.gz |
Allow defining multi fields with any Iterable, not just lists.
This was motivated to allow tuples as input, but was expanded to allow any
iterable. Strings and non-iterables are still special cased and converted
to a single-element list. A copy of the input iterable is made so that
the flag has sole ownership over the set of values.
NOTE: Custom MultiFlag behavior change: The items of the iterable, rather than
the iterable itself, are now passed onto the underlying self.parser object.
Custom flags using DEFINE_multi or MultiFlag should update their custom flag
code as appropriate.
NOTE: Heavily modified from original PR to adjust for various lint/style
errors and from the auto-formatter.
Resolves #78
Closes #80
PiperOrigin-RevId: 226230348
-rw-r--r-- | absl/flags/BUILD | 1 | ||||
-rw-r--r-- | absl/flags/_defines.py | 26 | ||||
-rw-r--r-- | absl/flags/_flag.py | 10 | ||||
-rw-r--r-- | absl/flags/tests/flags_test.py | 20 |
4 files changed, 48 insertions, 9 deletions
diff --git a/absl/flags/BUILD b/absl/flags/BUILD index 111cf64..f23bb76 100644 --- a/absl/flags/BUILD +++ b/absl/flags/BUILD @@ -68,6 +68,7 @@ py_library( ":_argument_parser", ":_exceptions", ":_helpers", + "@six_archive//:six", ], ) diff --git a/absl/flags/_defines.py b/absl/flags/_defines.py index 73ce8c7..b95c419 100644 --- a/absl/flags/_defines.py +++ b/absl/flags/_defines.py @@ -420,7 +420,11 @@ def DEFINE_multi( # pylint: disable=invalid-name,redefined-builtin parser: ArgumentParser, used to parse the flag arguments. serializer: ArgumentSerializer, the flag serializer instance. name: str, the flag name. - default: list|str|None, the default value of the flag. + default: Union[Iterable[T], Text, None], the default value of the flag. + If the value is text, it will be parsed as if it was provided from + the command line. If the value is a non-string iterable, it will be + iterated over to create a shallow copy of the values. If it is None, + it is left as-is. help: str, the help message. flag_values: FlagValues, the FlagValues instance with which the flag will be registered. This should almost never need to be overridden. @@ -445,7 +449,8 @@ def DEFINE_multi_string( # pylint: disable=invalid-name,redefined-builtin Args: name: str, the flag name. - default: [str]|str|None, the default value of the flag. + default: Union[Iterable[Text], Text, None], the default value of the flag; + see `DEFINE_multi`. help: str, the help message. flag_values: FlagValues, the FlagValues instance with which the flag will be registered. This should almost never need to be overridden. @@ -469,7 +474,8 @@ def DEFINE_multi_integer( # pylint: disable=invalid-name,redefined-builtin Args: name: str, the flag name. - default: [int]|str|None, the default value of the flag. + default: Union[Iterable[int], Text, None], the default value of the flag; + see `DEFINE_multi`. help: str, the help message. lower_bound: int, min values of the flag. upper_bound: int, max values of the flag. @@ -495,7 +501,8 @@ def DEFINE_multi_float( # pylint: disable=invalid-name,redefined-builtin Args: name: str, the flag name. - default: [float]|str|None, the default value of the flag. + default: Union[Iterable[float], Text, None], the default value of the flag; + see `DEFINE_multi`. help: str, the help message. lower_bound: float, min values of the flag. upper_bound: float, max values of the flag. @@ -521,7 +528,8 @@ def DEFINE_multi_enum( # pylint: disable=invalid-name,redefined-builtin Args: name: str, the flag name. - default: [str]|str|None, the default value of the flag. + default: Union[Iterable[Text], Text, None], the default value of the flag; + see `DEFINE_multi`. enum_values: [str], a non-empty list of strings with the possible values for the flag. help: str, the help message. @@ -551,8 +559,12 @@ def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin Args: name: str, the flag name. - default: [Enum]|Enum|[str]|str|None, the default value of the flag. A string - or a list of strings will be converted to the equivalent Enum objects. + default: Union[Iterable[Enum], Iterable[Text], Enum, Text, None], the + default value of the flag; see + `DEFINE_multi`; only differences are documented here. If the value is + a single Enum, it is treated as a single-item list of that Enum value. + If it is an iterable, text values within the iterable will be converted + to the equivalent Enum objects. enum_class: class, the Enum class with all the possible values for the flag. help: str, the help message. flag_values: FlagValues, the FlagValues instance with which the flag will be diff --git a/absl/flags/_flag.py b/absl/flags/_flag.py index 7568ed7..8e720ec 100644 --- a/absl/flags/_flag.py +++ b/absl/flags/_flag.py @@ -22,12 +22,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import copy import functools from absl.flags import _argument_parser from absl.flags import _exceptions from absl.flags import _helpers +import six @functools.total_ordering @@ -358,8 +360,8 @@ class MultiFlag(Flag): See the __doc__ for Flag for most behavior of this class. Only differences in behavior are described here: - * The default value may be either a single value or a list of values. - A single value is interpreted as the [value] singleton list. + * The default value may be either a single value or an iterable of values. + A single value is transformed into a single-item list of that value. * The value of the flag is always a list, even if the option was only supplied once, and even if the default value is a single @@ -386,6 +388,10 @@ class MultiFlag(Flag): self.present += len(new_values) def _parse(self, arguments): + if (isinstance(arguments, collections.Iterable) and + not isinstance(arguments, six.string_types)): + arguments = list(arguments) + if not isinstance(arguments, list): # Default value may be a list of values. Most other arguments # will not be, so convert them into a single-item list to make diff --git a/absl/flags/tests/flags_test.py b/absl/flags/tests/flags_test.py index 90094d3..ac77707 100644 --- a/absl/flags/tests/flags_test.py +++ b/absl/flags/tests/flags_test.py @@ -1184,6 +1184,26 @@ class MultiNumericalFlagsTest(absltest.TestCase): expected_floats, FLAGS.get_flag_value('m_float', None)): self.assertAlmostEqual(expected, actual) + def test_multi_numerical_with_tuples(self): + """Verify multi_int/float accept tuples as default values.""" + flags.DEFINE_multi_integer( + 'm_int_tuple', + (77, 88), + 'integer option that can occur multiple times', + short_name='mi_tuple') + self.assertListEqual(FLAGS.get_flag_value('m_int_tuple', None), [77, 88]) + + dict_with_float_keys = {2.2: 'hello', 3: 'happy'} + float_defaults = dict_with_float_keys.keys() + flags.DEFINE_multi_float( + 'm_float_tuple', + float_defaults, + 'float option that can occur multiple times', + short_name='mf_tuple') + for (expected, actual) in zip(float_defaults, + FLAGS.get_flag_value('m_float_tuple', None)): + self.assertAlmostEqual(expected, actual) + def test_single_value_default(self): """Test multi_int and multi_float flags with a single default value.""" int_default = 77 |