aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAbseil Team <absl-team@google.com>2020-12-23 11:44:14 -0800
committerCopybara-Service <copybara-worker@google.com>2020-12-23 11:44:36 -0800
commitb03026ac859ed10f6b380d49308213cd2c436e0f (patch)
tree4b47e8eaf213acd0ae47d304c6e5b64d0a575f1b
parentbcff9304a1be4d92b7b251b8ffdb7dd011b9951e (diff)
downloadabsl-py-b03026ac859ed10f6b380d49308213cd2c436e0f.tar.gz
Add a required argument to DEFINE_* methods.
Setting it to true, is functionally equivalent to calling `flags.mark_flag_as_required(flag_name)`, but changes the type of returned flagholder. ``` _A : FlagHolder[Optional[str]] = flags.DEFINE_string( name='a', default=None, help='help') flags.mark_flag_as_required('a') ``` v/s ``` _A : FlagHolder[str] = flags.DEFINE_string( name='a', default=None, help='help', required=True) ``` PiperOrigin-RevId: 348825600 Change-Id: Ia7610af1b5c4649c20aba5cbb620eae1359ad592
-rw-r--r--absl/flags/__init__.pyi102
-rw-r--r--absl/flags/_defines.py154
-rw-r--r--absl/flags/_defines.pyi257
-rw-r--r--absl/flags/tests/flags_test.py1147
4 files changed, 1139 insertions, 521 deletions
diff --git a/absl/flags/__init__.pyi b/absl/flags/__init__.pyi
new file mode 100644
index 0000000..e016d04
--- /dev/null
+++ b/absl/flags/__init__.pyi
@@ -0,0 +1,102 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from absl.flags import _argument_parser
+from absl.flags import _defines
+from absl.flags import _exceptions
+from absl.flags import _flag
+from absl.flags import _flagvalues
+from absl.flags import _helpers
+from absl.flags import _validators
+
+# DEFINE functions. They are explained in more details in the module doc string.
+# pylint: disable=invalid-name
+DEFINE = _defines.DEFINE
+DEFINE_flag = _defines.DEFINE_flag
+DEFINE_string = _defines.DEFINE_string
+DEFINE_boolean = _defines.DEFINE_boolean
+DEFINE_bool = DEFINE_boolean # Match C++ API.
+DEFINE_float = _defines.DEFINE_float
+DEFINE_integer = _defines.DEFINE_integer
+DEFINE_enum = _defines.DEFINE_enum
+DEFINE_enum_class = _defines.DEFINE_enum_class
+DEFINE_list = _defines.DEFINE_list
+DEFINE_spaceseplist = _defines.DEFINE_spaceseplist
+DEFINE_multi = _defines.DEFINE_multi
+DEFINE_multi_string = _defines.DEFINE_multi_string
+DEFINE_multi_integer = _defines.DEFINE_multi_integer
+DEFINE_multi_float = _defines.DEFINE_multi_float
+DEFINE_multi_enum = _defines.DEFINE_multi_enum
+DEFINE_multi_enum_class = _defines.DEFINE_multi_enum_class
+DEFINE_alias = _defines.DEFINE_alias
+# pylint: enable=invalid-name
+
+# Flag validators.
+register_validator = _validators.register_validator
+validator = _validators.validator
+register_multi_flags_validator = _validators.register_multi_flags_validator
+multi_flags_validator = _validators.multi_flags_validator
+mark_flag_as_required = _validators.mark_flag_as_required
+mark_flags_as_required = _validators.mark_flags_as_required
+mark_flags_as_mutual_exclusive = _validators.mark_flags_as_mutual_exclusive
+mark_bool_flags_as_mutual_exclusive = _validators.mark_bool_flags_as_mutual_exclusive
+
+# Key flag related functions.
+declare_key_flag = _defines.declare_key_flag
+adopt_module_key_flags = _defines.adopt_module_key_flags
+disclaim_key_flags = _defines.disclaim_key_flags
+
+# Module exceptions.
+# pylint: disable=invalid-name
+Error = _exceptions.Error
+CantOpenFlagFileError = _exceptions.CantOpenFlagFileError
+DuplicateFlagError = _exceptions.DuplicateFlagError
+IllegalFlagValueError = _exceptions.IllegalFlagValueError
+UnrecognizedFlagError = _exceptions.UnrecognizedFlagError
+UnparsedFlagAccessError = _exceptions.UnparsedFlagAccessError
+ValidationError = _exceptions.ValidationError
+FlagNameConflictsWithMethodError = _exceptions.FlagNameConflictsWithMethodError
+
+# Public classes.
+Flag = _flag.Flag
+BooleanFlag = _flag.BooleanFlag
+EnumFlag = _flag.EnumFlag
+EnumClassFlag = _flag.EnumClassFlag
+MultiFlag = _flag.MultiFlag
+MultiEnumClassFlag = _flag.MultiEnumClassFlag
+FlagHolder = _flagvalues.FlagHolder
+FlagValues = _flagvalues.FlagValues
+ArgumentParser = _argument_parser.ArgumentParser
+BooleanParser = _argument_parser.BooleanParser
+EnumParser = _argument_parser.EnumParser
+EnumClassParser = _argument_parser.EnumClassParser
+ArgumentSerializer = _argument_parser.ArgumentSerializer
+FloatParser = _argument_parser.FloatParser
+IntegerParser = _argument_parser.IntegerParser
+BaseListParser = _argument_parser.BaseListParser
+ListParser = _argument_parser.ListParser
+ListSerializer = _argument_parser.ListSerializer
+CsvListSerializer = _argument_parser.CsvListSerializer
+WhitespaceSeparatedListParser = _argument_parser.WhitespaceSeparatedListParser
+# pylint: enable=invalid-name
+
+# Helper functions.
+get_help_width = _helpers.get_help_width
+text_wrap = _helpers.text_wrap
+flag_dict_to_args = _helpers.flag_dict_to_args
+doc_to_help = _helpers.doc_to_help
+
+# The global FlagValues instance.
+FLAGS = _flagvalues.FLAGS
+
diff --git a/absl/flags/_defines.py b/absl/flags/_defines.py
index a8d5470..10abb76 100644
--- a/absl/flags/_defines.py
+++ b/absl/flags/_defines.py
@@ -74,6 +74,7 @@ def DEFINE( # pylint: disable=invalid-name
flag_values=_flagvalues.FLAGS,
serializer=None,
module_name=None,
+ required=False,
**args):
"""Registers a generic Flag object.
@@ -93,6 +94,7 @@ def DEFINE( # pylint: disable=invalid-name
serializer: ArgumentSerializer, the flag serializer instance.
module_name: str, the name of the Python module declaring this flag. If not
provided, it will be computed using the stack trace of this call.
+ required: bool, is this a required flag.
**args: dict, the extra keyword args that are passed to Flag __init__.
Returns:
@@ -100,10 +102,14 @@ def DEFINE( # pylint: disable=invalid-name
"""
return DEFINE_flag(
_flag.Flag(parser, serializer, name, default, help, **args), flag_values,
- module_name)
+ module_name, required)
-def DEFINE_flag(flag, flag_values=_flagvalues.FLAGS, module_name=None): # pylint: disable=invalid-name
+def DEFINE_flag( # pylint: disable=invalid-name
+ flag,
+ flag_values=_flagvalues.FLAGS,
+ module_name=None,
+ required=False):
"""Registers a 'Flag' object with a 'FlagValues' object.
By default, the global FLAGS 'FlagValue' object is used.
@@ -119,10 +125,14 @@ def DEFINE_flag(flag, flag_values=_flagvalues.FLAGS, module_name=None): # pylin
registered. This should almost never need to be overridden.
module_name: str, the name of the Python module declaring this flag. If not
provided, it will be computed using the stack trace of this call.
+ required: bool, is this a required flag.
Returns:
a handle to defined flag.
"""
+ if required and flag.default is not None:
+ raise ValueError('Required flag --%s cannot have a non-None default' %
+ flag.name)
# Copying the reference to flag_values prevents pychecker warnings.
fv = flag_values
fv[flag.name] = flag
@@ -133,8 +143,11 @@ def DEFINE_flag(flag, flag_values=_flagvalues.FLAGS, module_name=None): # pylin
module, module_name = _helpers.get_calling_module_object_and_name()
flag_values.register_flag_by_module(module_name, flag)
flag_values.register_flag_by_module_id(id(module), flag)
+ if required:
+ _validators.mark_flag_as_required(flag.name, fv)
+ ensure_non_none_value = (flag.default is not None) or required
return _flagvalues.FlagHolder(
- fv, flag, ensure_non_none_value=flag.default is not None)
+ fv, flag, ensure_non_none_value=ensure_non_none_value)
def _internal_declare_key_flags(flag_names,
@@ -263,11 +276,20 @@ def DEFINE_string( # pylint: disable=invalid-name,redefined-builtin
default,
help,
flag_values=_flagvalues.FLAGS,
+ required=False,
**args):
"""Registers a flag whose value can be any string."""
parser = _argument_parser.ArgumentParser()
serializer = _argument_parser.ArgumentSerializer()
- return DEFINE(parser, name, default, help, flag_values, serializer, **args)
+ return DEFINE(
+ parser,
+ name,
+ default,
+ help,
+ flag_values,
+ serializer,
+ required=required,
+ **args)
def DEFINE_boolean( # pylint: disable=invalid-name,redefined-builtin
@@ -276,6 +298,7 @@ def DEFINE_boolean( # pylint: disable=invalid-name,redefined-builtin
help,
flag_values=_flagvalues.FLAGS,
module_name=None,
+ required=False,
**args):
"""Registers a boolean flag.
@@ -295,13 +318,15 @@ def DEFINE_boolean( # pylint: disable=invalid-name,redefined-builtin
registered. This should almost never need to be overridden.
module_name: str, the name of the Python module declaring this flag. If not
provided, it will be computed using the stack trace of this call.
+ required: bool, is this a required flag.
**args: dict, the extra keyword args that are passed to Flag __init__.
Returns:
a handle to defined flag.
"""
return DEFINE_flag(
- _flag.BooleanFlag(name, default, help, **args), flag_values, module_name)
+ _flag.BooleanFlag(name, default, help, **args), flag_values, module_name,
+ required)
def DEFINE_float( # pylint: disable=invalid-name,redefined-builtin
@@ -311,6 +336,7 @@ def DEFINE_float( # pylint: disable=invalid-name,redefined-builtin
lower_bound=None,
upper_bound=None,
flag_values=_flagvalues.FLAGS,
+ required=False,
**args):
"""Registers a flag whose value must be a float.
@@ -325,6 +351,7 @@ def DEFINE_float( # pylint: disable=invalid-name,redefined-builtin
upper_bound: float, max value of the flag.
flag_values: FlagValues, the FlagValues instance with which the flag will be
registered. This should almost never need to be overridden.
+ required: bool, is this a required flag.
**args: dict, the extra keyword args that are passed to DEFINE.
Returns:
@@ -332,7 +359,15 @@ def DEFINE_float( # pylint: disable=invalid-name,redefined-builtin
"""
parser = _argument_parser.FloatParser(lower_bound, upper_bound)
serializer = _argument_parser.ArgumentSerializer()
- result = DEFINE(parser, name, default, help, flag_values, serializer, **args)
+ result = DEFINE(
+ parser,
+ name,
+ default,
+ help,
+ flag_values,
+ serializer,
+ required=required,
+ **args)
_register_bounds_validator_if_needed(parser, name, flag_values=flag_values)
return result
@@ -344,6 +379,7 @@ def DEFINE_integer( # pylint: disable=invalid-name,redefined-builtin
lower_bound=None,
upper_bound=None,
flag_values=_flagvalues.FLAGS,
+ required=False,
**args):
"""Registers a flag whose value must be an integer.
@@ -358,6 +394,7 @@ def DEFINE_integer( # pylint: disable=invalid-name,redefined-builtin
upper_bound: int, max value of the flag.
flag_values: FlagValues, the FlagValues instance with which the flag will be
registered. This should almost never need to be overridden.
+ required: bool, is this a required flag.
**args: dict, the extra keyword args that are passed to DEFINE.
Returns:
@@ -365,7 +402,15 @@ def DEFINE_integer( # pylint: disable=invalid-name,redefined-builtin
"""
parser = _argument_parser.IntegerParser(lower_bound, upper_bound)
serializer = _argument_parser.ArgumentSerializer()
- result = DEFINE(parser, name, default, help, flag_values, serializer, **args)
+ result = DEFINE(
+ parser,
+ name,
+ default,
+ help,
+ flag_values,
+ serializer,
+ required=required,
+ **args)
_register_bounds_validator_if_needed(parser, name, flag_values=flag_values)
return result
@@ -377,6 +422,7 @@ def DEFINE_enum( # pylint: disable=invalid-name,redefined-builtin
help,
flag_values=_flagvalues.FLAGS,
module_name=None,
+ required=False,
**args):
"""Registers a flag whose value can be any string from enum_values.
@@ -393,6 +439,7 @@ def DEFINE_enum( # pylint: disable=invalid-name,redefined-builtin
registered. This should almost never need to be overridden.
module_name: str, the name of the Python module declaring this flag. If not
provided, it will be computed using the stack trace of this call.
+ required: bool, is this a required flag.
**args: dict, the extra keyword args that are passed to Flag __init__.
Returns:
@@ -400,7 +447,7 @@ def DEFINE_enum( # pylint: disable=invalid-name,redefined-builtin
"""
return DEFINE_flag(
_flag.EnumFlag(name, default, help, enum_values, **args), flag_values,
- module_name)
+ module_name, required)
def DEFINE_enum_class( # pylint: disable=invalid-name,redefined-builtin
@@ -411,6 +458,7 @@ def DEFINE_enum_class( # pylint: disable=invalid-name,redefined-builtin
flag_values=_flagvalues.FLAGS,
module_name=None,
case_sensitive=False,
+ required=False,
**args):
"""Registers a flag whose value can be the name of enum members.
@@ -425,6 +473,7 @@ def DEFINE_enum_class( # pylint: disable=invalid-name,redefined-builtin
provided, it will be computed using the stack trace of this call.
case_sensitive: bool, whether to map strings to members of the enum_class
without considering case.
+ required: bool, is this a required flag.
**args: dict, the extra keyword args that are passed to Flag __init__.
Returns:
@@ -437,7 +486,7 @@ def DEFINE_enum_class( # pylint: disable=invalid-name,redefined-builtin
help,
enum_class,
case_sensitive=case_sensitive,
- **args), flag_values, module_name)
+ **args), flag_values, module_name, required)
def DEFINE_list( # pylint: disable=invalid-name,redefined-builtin
@@ -445,6 +494,7 @@ def DEFINE_list( # pylint: disable=invalid-name,redefined-builtin
default,
help,
flag_values=_flagvalues.FLAGS,
+ required=False,
**args):
"""Registers a flag whose value is a comma-separated list of strings.
@@ -456,6 +506,7 @@ def DEFINE_list( # pylint: disable=invalid-name,redefined-builtin
help: str, the help message.
flag_values: FlagValues, the FlagValues instance with which the flag will be
registered. This should almost never need to be overridden.
+ required: bool, is this a required flag.
**args: Dictionary with extra keyword args that are passed to the Flag
__init__.
@@ -464,7 +515,15 @@ def DEFINE_list( # pylint: disable=invalid-name,redefined-builtin
"""
parser = _argument_parser.ListParser()
serializer = _argument_parser.CsvListSerializer(',')
- return DEFINE(parser, name, default, help, flag_values, serializer, **args)
+ return DEFINE(
+ parser,
+ name,
+ default,
+ help,
+ flag_values,
+ serializer,
+ required=required,
+ **args)
def DEFINE_spaceseplist( # pylint: disable=invalid-name,redefined-builtin
@@ -473,6 +532,7 @@ def DEFINE_spaceseplist( # pylint: disable=invalid-name,redefined-builtin
help,
comma_compat=False,
flag_values=_flagvalues.FLAGS,
+ required=False,
**args):
"""Registers a flag whose value is a whitespace-separated list of strings.
@@ -487,6 +547,7 @@ def DEFINE_spaceseplist( # pylint: disable=invalid-name,redefined-builtin
backwards compatibility with flags that used to be comma-separated.
flag_values: FlagValues, the FlagValues instance with which the flag will be
registered. This should almost never need to be overridden.
+ required: bool, is this a required flag.
**args: Dictionary with extra keyword args that are passed to the Flag
__init__.
@@ -496,7 +557,15 @@ def DEFINE_spaceseplist( # pylint: disable=invalid-name,redefined-builtin
parser = _argument_parser.WhitespaceSeparatedListParser(
comma_compat=comma_compat)
serializer = _argument_parser.ListSerializer(' ')
- return DEFINE(parser, name, default, help, flag_values, serializer, **args)
+ return DEFINE(
+ parser,
+ name,
+ default,
+ help,
+ flag_values,
+ serializer,
+ required=required,
+ **args)
def DEFINE_multi( # pylint: disable=invalid-name,redefined-builtin
@@ -507,6 +576,7 @@ def DEFINE_multi( # pylint: disable=invalid-name,redefined-builtin
help,
flag_values=_flagvalues.FLAGS,
module_name=None,
+ required=False,
**args):
"""Registers a generic MultiFlag that parses its args with a given parser.
@@ -530,6 +600,7 @@ def DEFINE_multi( # pylint: disable=invalid-name,redefined-builtin
registered. This should almost never need to be overridden.
module_name: A string, the name of the Python module declaring this flag. If
not provided, it will be computed using the stack trace of this call.
+ required: bool, is this a required flag.
**args: Dictionary with extra keyword args that are passed to the Flag
__init__.
@@ -538,7 +609,7 @@ def DEFINE_multi( # pylint: disable=invalid-name,redefined-builtin
"""
return DEFINE_flag(
_flag.MultiFlag(parser, serializer, name, default, help, **args),
- flag_values, module_name)
+ flag_values, module_name, required)
def DEFINE_multi_string( # pylint: disable=invalid-name,redefined-builtin
@@ -546,6 +617,7 @@ def DEFINE_multi_string( # pylint: disable=invalid-name,redefined-builtin
default,
help,
flag_values=_flagvalues.FLAGS,
+ required=False,
**args):
"""Registers a flag whose value can be a list of any strings.
@@ -562,6 +634,7 @@ def DEFINE_multi_string( # pylint: disable=invalid-name,redefined-builtin
help: str, the help message.
flag_values: FlagValues, the FlagValues instance with which the flag will be
registered. This should almost never need to be overridden.
+ required: bool, is this a required flag.
**args: Dictionary with extra keyword args that are passed to the Flag
__init__.
@@ -570,8 +643,15 @@ def DEFINE_multi_string( # pylint: disable=invalid-name,redefined-builtin
"""
parser = _argument_parser.ArgumentParser()
serializer = _argument_parser.ArgumentSerializer()
- return DEFINE_multi(parser, serializer, name, default, help, flag_values,
- **args)
+ return DEFINE_multi(
+ parser,
+ serializer,
+ name,
+ default,
+ help,
+ flag_values,
+ required=required,
+ **args)
def DEFINE_multi_integer( # pylint: disable=invalid-name,redefined-builtin
@@ -581,6 +661,7 @@ def DEFINE_multi_integer( # pylint: disable=invalid-name,redefined-builtin
lower_bound=None,
upper_bound=None,
flag_values=_flagvalues.FLAGS,
+ required=False,
**args):
"""Registers a flag whose value can be a list of arbitrary integers.
@@ -598,6 +679,7 @@ def DEFINE_multi_integer( # pylint: disable=invalid-name,redefined-builtin
upper_bound: int, max values of the flag.
flag_values: FlagValues, the FlagValues instance with which the flag will be
registered. This should almost never need to be overridden.
+ required: bool, is this a required flag.
**args: Dictionary with extra keyword args that are passed to the Flag
__init__.
@@ -606,8 +688,15 @@ def DEFINE_multi_integer( # pylint: disable=invalid-name,redefined-builtin
"""
parser = _argument_parser.IntegerParser(lower_bound, upper_bound)
serializer = _argument_parser.ArgumentSerializer()
- return DEFINE_multi(parser, serializer, name, default, help, flag_values,
- **args)
+ return DEFINE_multi(
+ parser,
+ serializer,
+ name,
+ default,
+ help,
+ flag_values,
+ required=required,
+ **args)
def DEFINE_multi_float( # pylint: disable=invalid-name,redefined-builtin
@@ -617,6 +706,7 @@ def DEFINE_multi_float( # pylint: disable=invalid-name,redefined-builtin
lower_bound=None,
upper_bound=None,
flag_values=_flagvalues.FLAGS,
+ required=False,
**args):
"""Registers a flag whose value can be a list of arbitrary floats.
@@ -634,6 +724,7 @@ def DEFINE_multi_float( # pylint: disable=invalid-name,redefined-builtin
upper_bound: float, max values of the flag.
flag_values: FlagValues, the FlagValues instance with which the flag will be
registered. This should almost never need to be overridden.
+ required: bool, is this a required flag.
**args: Dictionary with extra keyword args that are passed to the Flag
__init__.
@@ -642,8 +733,15 @@ def DEFINE_multi_float( # pylint: disable=invalid-name,redefined-builtin
"""
parser = _argument_parser.FloatParser(lower_bound, upper_bound)
serializer = _argument_parser.ArgumentSerializer()
- return DEFINE_multi(parser, serializer, name, default, help, flag_values,
- **args)
+ return DEFINE_multi(
+ parser,
+ serializer,
+ name,
+ default,
+ help,
+ flag_values,
+ required=required,
+ **args)
def DEFINE_multi_enum( # pylint: disable=invalid-name,redefined-builtin
@@ -653,6 +751,7 @@ def DEFINE_multi_enum( # pylint: disable=invalid-name,redefined-builtin
help,
flag_values=_flagvalues.FLAGS,
case_sensitive=True,
+ required=False,
**args):
"""Registers a flag whose value can be a list strings from enum_values.
@@ -671,6 +770,7 @@ def DEFINE_multi_enum( # pylint: disable=invalid-name,redefined-builtin
flag_values: FlagValues, the FlagValues instance with which the flag will be
registered. This should almost never need to be overridden.
case_sensitive: Whether or not the enum is to be case-sensitive.
+ required: bool, is this a required flag.
**args: Dictionary with extra keyword args that are passed to the Flag
__init__.
@@ -679,8 +779,15 @@ def DEFINE_multi_enum( # pylint: disable=invalid-name,redefined-builtin
"""
parser = _argument_parser.EnumParser(enum_values, case_sensitive)
serializer = _argument_parser.ArgumentSerializer()
- return DEFINE_multi(parser, serializer, name, default, help, flag_values,
- **args)
+ return DEFINE_multi(
+ parser,
+ serializer,
+ name,
+ default,
+ help,
+ flag_values,
+ required=required,
+ **args)
def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin
@@ -691,6 +798,7 @@ def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin
flag_values=_flagvalues.FLAGS,
module_name=None,
case_sensitive=False,
+ required=False,
**args):
"""Registers a flag whose value can be a list of enum members.
@@ -712,6 +820,7 @@ def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin
not provided, it will be computed using the stack trace of this call.
case_sensitive: bool, whether to map strings to members of the enum_class
without considering case.
+ required: bool, is this a required flag.
**args: Dictionary with extra keyword args that are passed to the Flag
__init__.
@@ -721,7 +830,10 @@ def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin
return DEFINE_flag(
_flag.MultiEnumClassFlag(
name, default, help, enum_class, case_sensitive=case_sensitive),
- flag_values, module_name, **args)
+ flag_values,
+ module_name,
+ required=required,
+ **args)
def DEFINE_alias( # pylint: disable=invalid-name
diff --git a/absl/flags/_defines.pyi b/absl/flags/_defines.pyi
index 5949f0a..2d8a201 100644
--- a/absl/flags/_defines.pyi
+++ b/absl/flags/_defines.pyi
@@ -20,12 +20,13 @@ from absl.flags import _flagvalues
import enum
-from typing import Text, List, Any, TypeVar, Optional, Union, Type, Iterable, overload
+from typing import Text, List, Any, TypeVar, Optional, Union, Type, Iterable, overload, Literal
_T = TypeVar('_T')
_ET = TypeVar('_ET', bound=enum.Enum)
+@overload
def DEFINE(
parser: _argument_parser.ArgumentParser[_T],
name: Text,
@@ -36,15 +37,88 @@ def DEFINE(
flag_values : _flagvalues.FlagValues = ...,
serializer: Optional[_argument_parser.ArgumentSerializer[_T]] = None,
module_name: Optional[Text] = None,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[_T]:
+ ...
+
+
+@overload
+def DEFINE(
+ parser: _argument_parser.ArgumentParser[_T],
+ name: Text,
+ default: Any,
+ help: Optional[Text],
+ # Explicitly replacing ... with _flagvalues.FLAGS causes pytype to
+ # not like the syntax.
+ flag_values : _flagvalues.FlagValues = ...,
+ serializer: Optional[_argument_parser.ArgumentSerializer[_T]] = None,
+ module_name: Optional[Text] = None,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[Optional[_T]]:
...
+@overload
+def DEFINE_flag(
+ flag: _flag.Flag[_T],
+ flag_values: _flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = None,
+ required: Literal[True]) -> _flagvalues.FlagHolder[_T]:
+ ...
+@overload
def DEFINE_flag(
flag: _flag.Flag[_T],
flag_values: _flagvalues.FlagValues = ...,
- module_name: Optional[Text] = None) -> _flagvalues.FlagHolder[Optional[_T]]:
+ module_name: Optional[Text] = None,
+ required: bool = False) -> _flagvalues.FlagHolder[Optional[_T]]:
+ ...
+
+# typing overloads for DEFINE_* methods...
+#
+# - DEFINE_* method return FlagHolder[Optional[T]] or FlagHolder[T] depending
+# on the arguments.
+# - If the flag value is guaranteed to be not None, the return type is
+# FlagHolder[T].
+# - If the flag is required OR has a non-None default, the flag value i
+# guaranteed to be not None after flag parsing has finished.
+# The information above is captured with three overloads as follows.
+#
+# (if required=True, return type is FlagHolder[Y])
+# @overload
+# def DEFINE_xxx(
+# ... arguments...
+# default : Union[None, X],
+# required: Literal[True]) -> _flagvalues.FlagHolder[Y]:
+# ...
+#
+# (if default=None, return type is FlagHolder[Optional[Y]])
+# @overload
+# def DEFINE_xxx(
+# ... arguments...
+# default : None,
+# required: bool = False) -> _flagvalues.FlagHolder[Optional[Y]]:
+# ...
+#
+# (if default!=None, return type is FlagHolder[Y])#
+# @overload
+# def DEFINE_xxx(
+# ... arguments...
+# default: X,
+# required: bool = False) -> _flagvalues.FlagHolder[Y]:
+# ...
+#
+# where X = type of non-None default values for the flag
+# and Y = non-None type for flag value
+
+@overload
+def DEFINE_string(
+ name: Text,
+ default: Optional[Text],
+ help: Optional[Text],
+ flag_values: _flagvalues.FlagValues = ...,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[Text]:
...
@overload
@@ -53,6 +127,7 @@ def DEFINE_string(
default: None,
help: Optional[Text],
flag_values: _flagvalues.FlagValues = ...,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[Optional[Text]]:
...
@@ -62,16 +137,29 @@ def DEFINE_string(
default: Text,
help: Optional[Text],
flag_values: _flagvalues.FlagValues = ...,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[Text]:
...
@overload
def DEFINE_boolean(
name : Text,
+ default: Union[None, Text, bool, int],
+ help: Optional[Text],
+ flag_values: _flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = None,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[bool]:
+ ...
+
+@overload
+def DEFINE_boolean(
+ name : Text,
default: None,
help: Optional[Text],
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = None,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[Optional[bool]]:
...
@@ -82,17 +170,31 @@ def DEFINE_boolean(
help: Optional[Text],
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = None,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[bool]:
...
@overload
def DEFINE_float(
name: Text,
+ default: Union[None, float, Text],
+ help: Optional[Text],
+ lower_bound: Optional[float] = None,
+ upper_bound: Optional[float] = None,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[float]:
+ ...
+
+@overload
+def DEFINE_float(
+ name: Text,
default: None,
help: Optional[Text],
lower_bound: Optional[float] = None,
upper_bound: Optional[float] = None,
flag_values: _flagvalues.FlagValues = ...,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[Optional[float]]:
...
@@ -104,6 +206,7 @@ def DEFINE_float(
lower_bound: Optional[float] = None,
upper_bound: Optional[float] = None,
flag_values: _flagvalues.FlagValues = ...,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[float]:
...
@@ -111,11 +214,24 @@ def DEFINE_float(
@overload
def DEFINE_integer(
name: Text,
+ default: Union[None, int, Text],
+ help: Optional[Text],
+ lower_bound: Optional[int] = None,
+ upper_bound: Optional[int] = None,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[int]:
+ ...
+
+@overload
+def DEFINE_integer(
+ name: Text,
default: None,
help: Optional[Text],
lower_bound: Optional[int] = None,
upper_bound: Optional[int] = None,
flag_values: _flagvalues.FlagValues = ...,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[Optional[int]]:
...
@@ -127,17 +243,31 @@ def DEFINE_integer(
lower_bound: Optional[int] = None,
upper_bound: Optional[int] = None,
flag_values: _flagvalues.FlagValues = ...,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[int]:
...
@overload
def DEFINE_enum(
name : Text,
+ default: Optional[Text],
+ enum_values: Iterable[Text],
+ help: Optional[Text],
+ flag_values: _flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = None,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[Text]:
+ ...
+
+@overload
+def DEFINE_enum(
+ name : Text,
default: None,
enum_values: Iterable[Text],
help: Optional[Text],
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = None,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[Optional[Text]]:
...
@@ -149,18 +279,33 @@ def DEFINE_enum(
help: Optional[Text],
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = None,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[Text]:
...
@overload
def DEFINE_enum_class(
name: Text,
+ default: Union[None, _ET, Text],
+ enum_class: Type[_ET],
+ help: Optional[Text],
+ flag_values: _flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = None,
+ case_sensitive: bool = ...,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[_ET]:
+ ...
+
+@overload
+def DEFINE_enum_class(
+ name: Text,
default: None,
enum_class: Type[_ET],
help: Optional[Text],
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = None,
case_sensitive: bool = ...,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[Optional[_ET]]:
...
@@ -173,6 +318,7 @@ def DEFINE_enum_class(
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = None,
case_sensitive: bool = ...,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[_ET]:
...
@@ -180,9 +326,20 @@ def DEFINE_enum_class(
@overload
def DEFINE_list(
name: Text,
+ default: Union[None, Iterable[Text], Text],
+ help: Text,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[List[Text]]:
+ ...
+
+@overload
+def DEFINE_list(
+ name: Text,
default: None,
help: Text,
flag_values: _flagvalues.FlagValues = ...,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[Optional[List[Text]]]:
...
@@ -192,6 +349,18 @@ def DEFINE_list(
default: Union[Iterable[Text], Text],
help: Text,
flag_values: _flagvalues.FlagValues = ...,
+ required: bool = False,
+ **args: Any) -> _flagvalues.FlagHolder[List[Text]]:
+ ...
+
+@overload
+def DEFINE_spaceseplist(
+ name: Text,
+ default: Union[None, Iterable[Text], Text],
+ help: Text,
+ comma_compat: bool = False,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: Literal[True],
**args: Any) -> _flagvalues.FlagHolder[List[Text]]:
...
@@ -202,6 +371,7 @@ def DEFINE_spaceseplist(
help: Text,
comma_compat: bool = False,
flag_values: _flagvalues.FlagValues = ...,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[Optional[List[Text]]]:
...
@@ -212,6 +382,7 @@ def DEFINE_spaceseplist(
help: Text,
comma_compat: bool = False,
flag_values: _flagvalues.FlagValues = ...,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[List[Text]]:
...
@@ -220,10 +391,24 @@ def DEFINE_multi(
parser : _argument_parser.ArgumentParser[_T],
serializer: _argument_parser.ArgumentSerializer[_T],
name: Text,
+ default: Union[None, Iterable[_T], _T, Text],
+ help: Text,
+ flag_values:_flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = None,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[List[_T]]:
+ ...
+
+@overload
+def DEFINE_multi(
+ parser : _argument_parser.ArgumentParser[_T],
+ serializer: _argument_parser.ArgumentSerializer[_T],
+ name: Text,
default: None,
help: Text,
flag_values:_flagvalues.FlagValues = ...,
module_name: Optional[Text] = None,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[Optional[List[_T]]]:
...
@@ -236,15 +421,27 @@ def DEFINE_multi(
help: Text,
flag_values:_flagvalues.FlagValues = ...,
module_name: Optional[Text] = None,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[List[_T]]:
...
@overload
def DEFINE_multi_string(
name: Text,
+ default: Union[None, Iterable[Text], Text],
+ help: Text,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[List[Text]]:
+ ...
+
+@overload
+def DEFINE_multi_string(
+ name: Text,
default: None,
help: Text,
flag_values: _flagvalues.FlagValues = ...,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[Optional[List[Text]]]:
...
@@ -254,17 +451,31 @@ def DEFINE_multi_string(
default: Union[Iterable[Text], Text],
help: Text,
flag_values: _flagvalues.FlagValues = ...,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[List[Text]]:
...
@overload
def DEFINE_multi_integer(
name: Text,
+ default: Union[None, Iterable[int], int, Text],
+ help: Text,
+ lower_bound: Optional[int] = None,
+ upper_bound: Optional[int] = None,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[List[int]]:
+ ...
+
+@overload
+def DEFINE_multi_integer(
+ name: Text,
default: None,
help: Text,
lower_bound: Optional[int] = None,
upper_bound: Optional[int] = None,
flag_values: _flagvalues.FlagValues = ...,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[Optional[List[int]]]:
...
@@ -276,17 +487,31 @@ def DEFINE_multi_integer(
lower_bound: Optional[int] = None,
upper_bound: Optional[int] = None,
flag_values: _flagvalues.FlagValues = ...,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[List[int]]:
...
@overload
def DEFINE_multi_float(
name: Text,
+ default: Union[None, Iterable[float], float, Text],
+ help: Text,
+ lower_bound: Optional[float] = None,
+ upper_bound: Optional[float] = None,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[List[float]]:
+ ...
+
+@overload
+def DEFINE_multi_float(
+ name: Text,
default: None,
help: Text,
lower_bound: Optional[float] = None,
upper_bound: Optional[float] = None,
flag_values: _flagvalues.FlagValues = ...,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[Optional[List[float]]]:
...
@@ -298,6 +523,7 @@ def DEFINE_multi_float(
lower_bound: Optional[float] = None,
upper_bound: Optional[float] = None,
flag_values: _flagvalues.FlagValues = ...,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[List[float]]:
...
@@ -305,10 +531,22 @@ def DEFINE_multi_float(
@overload
def DEFINE_multi_enum(
name: Text,
+ default: Union[None, Iterable[Text], Text],
+ enum_values: Iterable[Text],
+ help: Text,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[List[Text]]:
+ ...
+
+@overload
+def DEFINE_multi_enum(
+ name: Text,
default: None,
enum_values: Iterable[Text],
help: Text,
flag_values: _flagvalues.FlagValues = ...,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[Optional[List[Text]]]:
...
@@ -319,17 +557,31 @@ def DEFINE_multi_enum(
enum_values: Iterable[Text],
help: Text,
flag_values: _flagvalues.FlagValues = ...,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[List[Text]]:
...
@overload
def DEFINE_multi_enum_class(
name: Text,
+ default: Union[None, Iterable[_ET], _ET, Text],
+ enum_class: Type[_ET],
+ help: Text,
+ flag_values: _flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = None,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[List[_ET]]:
+ ...
+
+@overload
+def DEFINE_multi_enum_class(
+ name: Text,
default: None,
enum_class: Type[_ET],
help: Text,
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = None,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[Optional[List[_ET]]]:
...
@@ -341,6 +593,7 @@ def DEFINE_multi_enum_class(
help: Text,
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = None,
+ required: bool = False,
**args: Any) -> _flagvalues.FlagHolder[List[_ET]]:
...
diff --git a/absl/flags/tests/flags_test.py b/absl/flags/tests/flags_test.py
index 8608f2f..2d63c40 100644
--- a/absl/flags/tests/flags_test.py
+++ b/absl/flags/tests/flags_test.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""Tests for absl.flags used as a package."""
from __future__ import absolute_import
@@ -36,7 +35,6 @@ from absl.flags.tests import module_foo
from absl.testing import absltest
import six
-
FLAGS = flags.FLAGS
@@ -61,10 +59,8 @@ class FlagDictToArgsTest(absltest.TestCase):
'loadthatstuff': [42, 'hello', 'goodbye'],
}
self.assertSameElements(
- (
- '--week-end', '--noestudia', '--notrabaja',
- '--party', '--monday=party', '--score=42',
- '--loadthatstuff=42,hello,goodbye'),
+ ('--week-end', '--noestudia', '--notrabaja', '--party',
+ '--monday=party', '--score=42', '--loadthatstuff=42,hello,goodbye'),
flags.flag_dict_to_args(arg_dict))
def test_flatten_google_flag_map_with_multi_flag(self):
@@ -73,9 +69,8 @@ class FlagDictToArgsTest(absltest.TestCase):
'some_multi_string': ['value3', 'value4'],
}
self.assertSameElements(
- (
- '--some_list=value1,value2', '--some_multi_string=value3',
- '--some_multi_string=value4'),
+ ('--some_list=value1,value2', '--some_multi_string=value3',
+ '--some_multi_string=value4'),
flags.flag_dict_to_args(arg_dict, multi_flags={'some_multi_string'}))
@@ -162,8 +157,7 @@ class AliasFlagsTest(absltest.TestCase):
self.assertEqual('--alias=[0, 1]', actual)
def test_allow_overwrite_false(self):
- self.define_integer('aliased', None, 'help',
- allow_overwrite=False)
+ self.define_integer('aliased', None, 'help', allow_overwrite=False)
self.define_alias('alias', 'aliased')
with self.assertRaisesRegex(flags.IllegalFlagValueError, 'already defined'):
@@ -173,6 +167,7 @@ class AliasFlagsTest(absltest.TestCase):
self.assertEqual(1, self.aliased.value)
def test_aliasing_multi_no_default(self):
+
def define_flags():
self.flags = flags.FlagValues()
self.define_multi_integer('aliased', None, 'help')
@@ -202,6 +197,7 @@ class AliasFlagsTest(absltest.TestCase):
self.assert_alias_mirrors_aliased(self.alias, self.aliased)
def test_aliasing_multi_with_default(self):
+
def define_flags():
self.flags = flags.FlagValues()
self.define_multi_integer('aliased', [0], 'help')
@@ -243,6 +239,7 @@ class AliasFlagsTest(absltest.TestCase):
self.assertEqual(0, self.aliased.present)
def test_aliasing_regular(self):
+
def define_flags():
self.flags = flags.FlagValues()
self.define_string('aliased', '', 'help')
@@ -296,8 +293,8 @@ class FlagsUnitTest(absltest.TestCase):
# Define flags
number_test_framework_flags = len(FLAGS)
repeat_help = 'how many times to repeat (0-5)'
- flags.DEFINE_integer('repeat', 4, repeat_help,
- lower_bound=0, short_name='r')
+ flags.DEFINE_integer(
+ 'repeat', 4, repeat_help, lower_bound=0, short_name='r')
flags.DEFINE_string('name', 'Bob', 'namehelp')
flags.DEFINE_boolean('debug', 0, 'debughelp')
flags.DEFINE_boolean('q', 1, 'quiet mode')
@@ -314,24 +311,35 @@ class FlagsUnitTest(absltest.TestCase):
flags.DEFINE_list('numbers', [1, 2, 3], 'a list of numbers')
flags.DEFINE_enum('kwery', None, ['who', 'what', 'Why', 'where', 'when'],
'?')
- flags.DEFINE_enum('sense', None, ['Case', 'case', 'CASE'],
- '?', case_sensitive=True)
- flags.DEFINE_enum('cases', None, ['UPPER', 'lower', 'Initial', 'Ot_HeR'],
- '?', case_sensitive=False)
- flags.DEFINE_enum('funny', None, ['Joke', 'ha', 'ha', 'ha', 'ha'],
- '?', case_sensitive=True)
- flags.DEFINE_enum('blah', None, ['bla', 'Blah', 'BLAH', 'blah'],
- '?', case_sensitive=False)
- flags.DEFINE_string('only_once', None, 'test only sets this once',
- allow_overwrite=False)
- flags.DEFINE_string('universe', None, 'test tries to set this three times',
- allow_overwrite=False)
+ flags.DEFINE_enum(
+ 'sense', None, ['Case', 'case', 'CASE'], '?', case_sensitive=True)
+ flags.DEFINE_enum(
+ 'cases',
+ None, ['UPPER', 'lower', 'Initial', 'Ot_HeR'],
+ '?',
+ case_sensitive=False)
+ flags.DEFINE_enum(
+ 'funny',
+ None, ['Joke', 'ha', 'ha', 'ha', 'ha'],
+ '?',
+ case_sensitive=True)
+ flags.DEFINE_enum(
+ 'blah',
+ None, ['bla', 'Blah', 'BLAH', 'blah'],
+ '?',
+ case_sensitive=False)
+ flags.DEFINE_string(
+ 'only_once', None, 'test only sets this once', allow_overwrite=False)
+ flags.DEFINE_string(
+ 'universe',
+ None,
+ 'test tries to set this three times',
+ allow_overwrite=False)
# Specify number of flags defined above. The short_name defined
# for 'repeat' counts as an extra flag.
number_defined_flags = 22 + 1
- self.assertEqual(len(FLAGS),
- number_defined_flags + number_test_framework_flags)
+ self.assertLen(FLAGS, number_defined_flags + number_test_framework_flags)
self.assertEqual(FLAGS.repeat, 4)
self.assertEqual(FLAGS.name, 'Bob')
@@ -393,13 +401,13 @@ class FlagsUnitTest(absltest.TestCase):
# .. empty command line
argv = ('./program',)
argv = FLAGS(argv)
- self.assertEqual(len(argv), 1, 'wrong number of arguments pulled')
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
# .. non-empty command line
argv = ('./program', '--debug', '--name=Bob', '-q', '--x=8')
argv = FLAGS(argv)
- self.assertEqual(len(argv), 1, 'wrong number of arguments pulled')
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(FLAGS['debug'].present, 1)
FLAGS['debug'].present = 0 # Reset
@@ -411,8 +419,7 @@ class FlagsUnitTest(absltest.TestCase):
FLAGS['x'].present = 0 # Reset
# Flags list.
- self.assertEqual(len(FLAGS),
- number_defined_flags + number_test_framework_flags)
+ self.assertLen(FLAGS, number_defined_flags + number_test_framework_flags)
self.assertIn('name', FLAGS)
self.assertIn('debug', FLAGS)
self.assertIn('repeat', FLAGS)
@@ -431,14 +438,14 @@ class FlagsUnitTest(absltest.TestCase):
# try deleting a flag
del FLAGS.r
- self.assertEqual(len(FLAGS),
- number_defined_flags - 1 + number_test_framework_flags)
+ self.assertLen(FLAGS,
+ number_defined_flags - 1 + number_test_framework_flags)
self.assertNotIn('r', FLAGS)
# .. command line with extra stuff
argv = ('./program', '--debug', '--name=Bob', 'extra')
argv = FLAGS(argv)
- self.assertEqual(len(argv), 2, 'wrong number of arguments pulled')
+ self.assertLen(argv, 2, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(argv[1], 'extra', 'extra argument not preserved')
self.assertEqual(FLAGS['debug'].present, 1)
@@ -449,7 +456,7 @@ class FlagsUnitTest(absltest.TestCase):
# Test reset
argv = ('./program', '--debug')
argv = FLAGS(argv)
- self.assertEqual(len(argv), 1, 'wrong number of arguments pulled')
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(FLAGS['debug'].present, 1)
self.assertTrue(FLAGS['debug'].value)
@@ -460,14 +467,14 @@ class FlagsUnitTest(absltest.TestCase):
# Test that reset restores default value when default value is None.
argv = ('./program', '--kwery=who')
argv = FLAGS(argv)
- self.assertEqual(len(argv), 1, 'wrong number of arguments pulled')
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(FLAGS['kwery'].present, 1)
self.assertEqual(FLAGS['kwery'].value, 'who')
FLAGS.unparse_flags()
argv = ('./program', '--kwery=Why')
argv = FLAGS(argv)
- self.assertEqual(len(argv), 1, 'wrong number of arguments pulled')
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(FLAGS['kwery'].present, 1)
self.assertEqual(FLAGS['kwery'].value, 'Why')
@@ -478,14 +485,14 @@ class FlagsUnitTest(absltest.TestCase):
# Test case sensitive enum.
argv = ('./program', '--sense=CASE')
argv = FLAGS(argv)
- self.assertEqual(len(argv), 1, 'wrong number of arguments pulled')
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(FLAGS['sense'].present, 1)
self.assertEqual(FLAGS['sense'].value, 'CASE')
FLAGS.unparse_flags()
argv = ('./program', '--sense=Case')
argv = FLAGS(argv)
- self.assertEqual(len(argv), 1, 'wrong number of arguments pulled')
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(FLAGS['sense'].present, 1)
self.assertEqual(FLAGS['sense'].value, 'Case')
@@ -494,7 +501,7 @@ class FlagsUnitTest(absltest.TestCase):
# Test case insensitive enum.
argv = ('./program', '--cases=upper')
argv = FLAGS(argv)
- self.assertEqual(len(argv), 1, 'wrong number of arguments pulled')
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(FLAGS['cases'].present, 1)
self.assertEqual(FLAGS['cases'].value, 'UPPER')
@@ -503,7 +510,7 @@ class FlagsUnitTest(absltest.TestCase):
# Test case sensitive enum with duplicates.
argv = ('./program', '--funny=ha')
argv = FLAGS(argv)
- self.assertEqual(len(argv), 1, 'wrong number of arguments pulled')
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(FLAGS['funny'].present, 1)
self.assertEqual(FLAGS['funny'].value, 'ha')
@@ -512,14 +519,14 @@ class FlagsUnitTest(absltest.TestCase):
# Test case insensitive enum with duplicates.
argv = ('./program', '--blah=bLah')
argv = FLAGS(argv)
- self.assertEqual(len(argv), 1, 'wrong number of arguments pulled')
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(FLAGS['blah'].present, 1)
self.assertEqual(FLAGS['blah'].value, 'Blah')
FLAGS.unparse_flags()
argv = ('./program', '--blah=BLAH')
argv = FLAGS(argv)
- self.assertEqual(len(argv), 1, 'wrong number of arguments pulled')
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(FLAGS['blah'].present, 1)
self.assertEqual(FLAGS['blah'].value, 'Blah')
@@ -617,22 +624,21 @@ class FlagsUnitTest(absltest.TestCase):
self.assertEqual(FLAGS.get_flag_value('testget4', 'foo'), 'foo')
# test list code
- lists = [['hello', 'moo', 'boo', '1'],
- []]
+ lists = [['hello', 'moo', 'boo', '1'], []]
flags.DEFINE_list('testcomma_list', '', 'test comma list parsing')
flags.DEFINE_spaceseplist('testspace_list', '', 'tests space list parsing')
flags.DEFINE_spaceseplist(
- 'testspace_or_comma_list', '',
- 'tests space list parsing with comma compatibility', comma_compat=True)
-
- for name, sep in (
- ('testcomma_list', ','),
- ('testspace_list', ' '),
- ('testspace_list', '\n'),
- ('testspace_or_comma_list', ' '),
- ('testspace_or_comma_list', '\n'),
- ('testspace_or_comma_list', ',')):
+ 'testspace_or_comma_list',
+ '',
+ 'tests space list parsing with comma compatibility',
+ comma_compat=True)
+
+ for name, sep in (('testcomma_list', ','), ('testspace_list',
+ ' '), ('testspace_list', '\n'),
+ ('testspace_or_comma_list',
+ ' '), ('testspace_or_comma_list',
+ '\n'), ('testspace_or_comma_list', ',')):
for lst in lists:
argv = ('./program', '--%s=%s' % (name, sep.join(lst)))
argv = FLAGS(argv)
@@ -640,10 +646,10 @@ class FlagsUnitTest(absltest.TestCase):
# Test help text
flags_help = str(FLAGS)
- self.assertNotEqual(flags_help.find('repeat'), -1,
- 'cannot find flag in help')
- self.assertNotEqual(flags_help.find(repeat_help), -1,
- 'cannot find help string in help')
+ self.assertNotEqual(
+ flags_help.find('repeat'), -1, 'cannot find flag in help')
+ self.assertNotEqual(
+ flags_help.find(repeat_help), -1, 'cannot find help string in help')
# Test flag specified twice
argv = ('./program', '--repeat=4', '--repeat=2', '--debug', '--nodebug')
@@ -652,16 +658,20 @@ class FlagsUnitTest(absltest.TestCase):
self.assertEqual(FLAGS.get_flag_value('debug', None), 0)
# Test MultiFlag with single default value
- flags.DEFINE_multi_string('s_str', 'sing1',
- 'string option that can occur multiple times',
- short_name='s')
+ flags.DEFINE_multi_string(
+ 's_str',
+ 'sing1',
+ 'string option that can occur multiple times',
+ short_name='s')
self.assertEqual(FLAGS.get_flag_value('s_str', None), ['sing1'])
# Test MultiFlag with list of default values
multi_string_defs = ['def1', 'def2']
- flags.DEFINE_multi_string('m_str', multi_string_defs,
- 'string option that can occur multiple times',
- short_name='m')
+ flags.DEFINE_multi_string(
+ 'm_str',
+ multi_string_defs,
+ 'string option that can occur multiple times',
+ short_name='m')
self.assertEqual(FLAGS.get_flag_value('m_str', None), multi_string_defs)
# Test flag specified multiple times with a MultiFlag
@@ -677,12 +687,11 @@ class FlagsUnitTest(absltest.TestCase):
# A flag with allow_overwrite set to False should complain when it is
# specified more than once
- argv = ('./program', '--universe=ptolemaic',
- '--universe=copernicean', '--universe=euclidean')
+ argv = ('./program', '--universe=ptolemaic', '--universe=copernicean',
+ '--universe=euclidean')
self.assertRaisesWithLiteralMatch(
flags.IllegalFlagValueError,
- 'flag --universe=copernicean: already defined as ptolemaic',
- FLAGS,
+ 'flag --universe=copernicean: already defined as ptolemaic', FLAGS,
argv)
# Test single-letter flags; should support both single and double dash
@@ -709,9 +718,7 @@ class FlagsUnitTest(absltest.TestCase):
old_testspace_list = FLAGS.testspace_list
old_testspace_or_comma_list = FLAGS.testspace_or_comma_list
- argv = ('./program',
- FLAGS['test0'].serialize(),
- FLAGS['test1'].serialize(),
+ argv = ('./program', FLAGS['test0'].serialize(), FLAGS['test1'].serialize(),
FLAGS['s_str'].serialize())
argv = FLAGS(argv)
@@ -727,8 +734,7 @@ class FlagsUnitTest(absltest.TestCase):
FLAGS.testcomma_list = list(testcomma_list1)
FLAGS.testspace_list = list(testspace_list1)
FLAGS.testspace_or_comma_list = list(testspace_or_comma_list1)
- argv = ('./program',
- FLAGS['testcomma_list'].serialize(),
+ argv = ('./program', FLAGS['testcomma_list'].serialize(),
FLAGS['testspace_list'].serialize(),
FLAGS['testspace_or_comma_list'].serialize())
argv = FLAGS(argv)
@@ -742,8 +748,7 @@ class FlagsUnitTest(absltest.TestCase):
FLAGS.testcomma_list = list(testcomma_list1)
FLAGS.testspace_list = list(testspace_list1)
FLAGS.testspace_or_comma_list = list(testspace_or_comma_list1)
- argv = ('./program',
- FLAGS['testcomma_list'].serialize(),
+ argv = ('./program', FLAGS['testcomma_list'].serialize(),
FLAGS['testspace_list'].serialize(),
FLAGS['testspace_or_comma_list'].serialize())
argv = FLAGS(argv)
@@ -767,18 +772,17 @@ class FlagsUnitTest(absltest.TestCase):
flags_to_exclude = {'log_dir', 'test_srcdir', 'test_tmpdir'}
flagnames = set(FLAGS) - flags_to_exclude
- nonbool_flags = ['--%s %s' % (name, FLAGS.get_flag_value(name, None))
- for name in flagnames
- if not isinstance(FLAGS[name], flags.BooleanFlag)]
-
- truebool_flags = ['--%s' % (name)
- for name in flagnames
- if isinstance(FLAGS[name], flags.BooleanFlag) and
- FLAGS.get_flag_value(name, None)]
- falsebool_flags = ['--no%s' % (name)
- for name in flagnames
- if isinstance(FLAGS[name], flags.BooleanFlag) and
- not FLAGS.get_flag_value(name, None)]
+ nonbool_flags = []
+ truebool_flags = []
+ falsebool_flags = []
+ for name in flagnames:
+ flag_value = FLAGS.get_flag_value(name, None)
+ if not isinstance(FLAGS[name], flags.BooleanFlag):
+ nonbool_flags.append('--%s %s' % (name, flag_value))
+ elif flag_value:
+ truebool_flags.append('--%s' % name)
+ else:
+ falsebool_flags.append('--no%s' % name)
all_flags = nonbool_flags + truebool_flags + falsebool_flags
all_flags.sort()
return all_flags
@@ -953,8 +957,11 @@ class FlagsUnitTest(absltest.TestCase):
# correctly.
flagnames = ['repeated']
original_flags = flags.FlagValues()
- flags.DEFINE_boolean(flagnames[0], False, 'Flag about to be repeated.',
- flag_values=original_flags)
+ flags.DEFINE_boolean(
+ flagnames[0],
+ False,
+ 'Flag about to be repeated.',
+ flag_values=original_flags)
duplicate_flags = module_foo.duplicate_flags(flagnames)
with self.assertRaisesRegex(flags.DuplicateFlagError,
'flags_test.*module_foo'):
@@ -962,13 +969,13 @@ class FlagsUnitTest(absltest.TestCase):
# Make sure allow_override works
try:
- flags.DEFINE_boolean('dup1', 0, 'runhelp d11', short_name='u',
- allow_override=0)
+ flags.DEFINE_boolean(
+ 'dup1', 0, 'runhelp d11', short_name='u', allow_override=0)
flag = FLAGS._flags()['dup1']
self.assertEqual(flag.default, 0)
- flags.DEFINE_boolean('dup1', 1, 'runhelp d12', short_name='u',
- allow_override=1)
+ flags.DEFINE_boolean(
+ 'dup1', 1, 'runhelp d12', short_name='u', allow_override=1)
flag = FLAGS._flags()['dup1']
self.assertEqual(flag.default, 1)
except flags.DuplicateFlagError:
@@ -976,13 +983,13 @@ class FlagsUnitTest(absltest.TestCase):
# Make sure allow_override works
try:
- flags.DEFINE_boolean('dup2', 0, 'runhelp d21', short_name='u',
- allow_override=1)
+ flags.DEFINE_boolean(
+ 'dup2', 0, 'runhelp d21', short_name='u', allow_override=1)
flag = FLAGS._flags()['dup2']
self.assertEqual(flag.default, 0)
- flags.DEFINE_boolean('dup2', 1, 'runhelp d22', short_name='u',
- allow_override=0)
+ flags.DEFINE_boolean(
+ 'dup2', 1, 'runhelp d22', short_name='u', allow_override=0)
flag = FLAGS._flags()['dup2']
self.assertEqual(flag.default, 1)
except flags.DuplicateFlagError:
@@ -991,18 +998,17 @@ class FlagsUnitTest(absltest.TestCase):
# Make sure that re-importing a module does not cause a DuplicateFlagError
# to be raised.
try:
- sys.modules.pop(
- 'absl.flags.tests.module_baz')
+ sys.modules.pop('absl.flags.tests.module_baz')
import absl.flags.tests.module_baz
del absl
except flags.DuplicateFlagError:
raise AssertionError('Module reimport caused flag duplication error')
# Make sure that when we override, the help string gets updated correctly
- flags.DEFINE_boolean('dup3', 0, 'runhelp d31', short_name='u',
- allow_override=1)
- flags.DEFINE_boolean('dup3', 1, 'runhelp d32', short_name='u',
- allow_override=1)
+ flags.DEFINE_boolean(
+ 'dup3', 0, 'runhelp d31', short_name='u', allow_override=1)
+ flags.DEFINE_boolean(
+ 'dup3', 1, 'runhelp d32', short_name='u', allow_override=1)
self.assertEqual(str(FLAGS).find('runhelp d31'), -1)
self.assertNotEqual(str(FLAGS).find('runhelp d32'), -1)
@@ -1026,8 +1032,8 @@ class FlagsUnitTest(absltest.TestCase):
# Make sure append_flag_values works with flags with shortnames.
new_flags = flags.FlagValues()
flags.DEFINE_boolean('new3', 0, 'runhelp n3', flag_values=new_flags)
- flags.DEFINE_boolean('new4', 0, 'runhelp n4', flag_values=new_flags,
- short_name='n4')
+ flags.DEFINE_boolean(
+ 'new4', 0, 'runhelp n4', flag_values=new_flags, short_name='n4')
self.assertEqual(len(new_flags._flags()), 3)
old_len = len(FLAGS._flags())
FLAGS.append_flag_values(new_flags)
@@ -1101,18 +1107,12 @@ class FlagsUnitTest(absltest.TestCase):
flags.DEFINE_alias('alias_letters', 'letters')
self.assertEqual(FLAGS['name'].default, FLAGS.alias_name)
self.assertEqual(FLAGS['debug'].default, FLAGS.alias_debug)
- self.assertEqual(
- int(FLAGS['decimal'].default), FLAGS.alias_decimal)
- self.assertEqual(
- float(FLAGS['float'].default), FLAGS.alias_float)
- self.assertSameElements(
- FLAGS['letters'].default, FLAGS.alias_letters)
+ self.assertEqual(int(FLAGS['decimal'].default), FLAGS.alias_decimal)
+ self.assertEqual(float(FLAGS['float'].default), FLAGS.alias_float)
+ self.assertSameElements(FLAGS['letters'].default, FLAGS.alias_letters)
# Original flags set on comand line
- argv = ('./program',
- '--name=Martin',
- '--debug=True',
- '--decimal=777',
+ argv = ('./program', '--name=Martin', '--debug=True', '--decimal=777',
'--letters=x,y,z')
FLAGS(argv)
self.assertEqual('Martin', FLAGS.name)
@@ -1125,11 +1125,8 @@ class FlagsUnitTest(absltest.TestCase):
self.assertSameElements(['x', 'y', 'z'], FLAGS.alias_letters)
# Alias flags set on command line
- argv = ('./program',
- '--alias_name=Auston',
- '--alias_debug=False',
- '--alias_decimal=888',
- '--alias_letters=l,m,n')
+ argv = ('./program', '--alias_name=Auston', '--alias_debug=False',
+ '--alias_decimal=888', '--alias_letters=l,m,n')
FLAGS(argv)
self.assertEqual('Auston', FLAGS.name)
self.assertEqual('Auston', FLAGS.alias_name)
@@ -1142,36 +1139,36 @@ class FlagsUnitTest(absltest.TestCase):
# Make sure importing a module does not change flag value parsed
# from commandline.
- flags.DEFINE_integer('dup5', 1, 'runhelp d51', short_name='u5',
- allow_override=0)
+ flags.DEFINE_integer(
+ 'dup5', 1, 'runhelp d51', short_name='u5', allow_override=0)
self.assertEqual(FLAGS.dup5, 1)
self.assertEqual(FLAGS.dup5, 1)
argv = ('./program', '--dup5=3')
FLAGS(argv)
self.assertEqual(FLAGS.dup5, 3)
- flags.DEFINE_integer('dup5', 2, 'runhelp d52', short_name='u5',
- allow_override=1)
+ flags.DEFINE_integer(
+ 'dup5', 2, 'runhelp d52', short_name='u5', allow_override=1)
self.assertEqual(FLAGS.dup5, 3)
# Make sure importing a module does not change user defined flag value.
- flags.DEFINE_integer('dup6', 1, 'runhelp d61', short_name='u6',
- allow_override=0)
+ flags.DEFINE_integer(
+ 'dup6', 1, 'runhelp d61', short_name='u6', allow_override=0)
self.assertEqual(FLAGS.dup6, 1)
FLAGS.dup6 = 3
self.assertEqual(FLAGS.dup6, 3)
- flags.DEFINE_integer('dup6', 2, 'runhelp d62', short_name='u6',
- allow_override=1)
+ flags.DEFINE_integer(
+ 'dup6', 2, 'runhelp d62', short_name='u6', allow_override=1)
self.assertEqual(FLAGS.dup6, 3)
# Make sure importing a module does not change user defined flag value
# even if it is the 'default' value.
- flags.DEFINE_integer('dup7', 1, 'runhelp d71', short_name='u7',
- allow_override=0)
+ flags.DEFINE_integer(
+ 'dup7', 1, 'runhelp d71', short_name='u7', allow_override=0)
self.assertEqual(FLAGS.dup7, 1)
FLAGS.dup7 = 1
self.assertEqual(FLAGS.dup7, 1)
- flags.DEFINE_integer('dup7', 2, 'runhelp d72', short_name='u7',
- allow_override=1)
+ flags.DEFINE_integer(
+ 'dup7', 2, 'runhelp d72', short_name='u7', allow_override=1)
self.assertEqual(FLAGS.dup7, 1)
# Test module_help().
@@ -1360,55 +1357,80 @@ class FlagsUnitTest(absltest.TestCase):
def test_enum_class_flag_with_wrong_default_value_type(self):
fv = flags.FlagValues()
with self.assertRaises(_exceptions.IllegalFlagValueError):
- flags.DEFINE_enum_class('fruit', 1, Fruit, 'help',
- flag_values=fv)
+ flags.DEFINE_enum_class('fruit', 1, Fruit, 'help', flag_values=fv)
def test_enum_class_flag_requires_enum_class(self):
fv = flags.FlagValues()
with self.assertRaises(TypeError):
- flags.DEFINE_enum_class('fruit', None, ['apple', 'orange'], 'help',
- flag_values=fv)
+ flags.DEFINE_enum_class(
+ 'fruit', None, ['apple', 'orange'], 'help', flag_values=fv)
def test_enum_class_flag_requires_non_empty_enum_class(self):
fv = flags.FlagValues()
with self.assertRaises(ValueError):
- flags.DEFINE_enum_class('empty', None, EmptyEnum, 'help',
- flag_values=fv)
+ flags.DEFINE_enum_class('empty', None, EmptyEnum, 'help', flag_values=fv)
+
+ def test_required_flag(self):
+ fv = flags.FlagValues()
+ fl = flags.DEFINE_integer(
+ name='int_flag',
+ default=None,
+ help='help',
+ required=True,
+ flag_values=fv)
+ # Since the flag is required, the FlagHolder should ensure value returned
+ # is not None.
+ self.assertTrue(fl._ensure_non_none_value)
+
+ def test_illegal_required_flag(self):
+ fv = flags.FlagValues()
+ with self.assertRaises(ValueError):
+ flags.DEFINE_integer(
+ name='int_flag',
+ default=3,
+ help='help',
+ required=True,
+ flag_values=fv)
class MultiNumericalFlagsTest(absltest.TestCase):
def test_multi_numerical_flags(self):
"""Test multi_int and multi_float flags."""
-
+ fv = flags.FlagValues()
int_defaults = [77, 88]
- flags.DEFINE_multi_integer('m_int', int_defaults,
- 'integer option that can occur multiple times',
- short_name='mi')
- self.assertListEqual(FLAGS.get_flag_value('m_int', None), int_defaults)
+ flags.DEFINE_multi_integer(
+ 'm_int',
+ int_defaults,
+ 'integer option that can occur multiple times',
+ short_name='mi',
+ flag_values=fv)
+ self.assertListEqual(fv['m_int'].default, int_defaults)
argv = ('./program', '--m_int=-99', '--mi=101')
- FLAGS(argv)
- self.assertListEqual(FLAGS.get_flag_value('m_int', None), [-99, 101])
+ fv(argv)
+ self.assertListEqual(fv.get_flag_value('m_int', None), [-99, 101])
float_defaults = [2.2, 3]
- flags.DEFINE_multi_float('m_float', float_defaults,
- 'float option that can occur multiple times',
- short_name='mf')
- for (expected, actual) in zip(
- float_defaults, FLAGS.get_flag_value('m_float', None)):
+ flags.DEFINE_multi_float(
+ 'm_float',
+ float_defaults,
+ 'float option that can occur multiple times',
+ short_name='mf',
+ flag_values=fv)
+ for (expected, actual) in zip(float_defaults,
+ fv.get_flag_value('m_float', None)):
self.assertAlmostEqual(expected, actual)
argv = ('./program', '--m_float=-17', '--mf=2.78e9')
- FLAGS(argv)
+ fv(argv)
expected_floats = [-17.0, 2.78e9]
- for (expected, actual) in zip(
- expected_floats, FLAGS.get_flag_value('m_float', None)):
+ for (expected, actual) in zip(expected_floats,
+ fv.get_flag_value('m_float', None)):
self.assertAlmostEqual(expected, actual)
def test_multi_numerical_with_tuples(self):
"""Verify multi_int/float accept tuples as default values."""
flags.DEFINE_multi_integer(
- 'm_int_tuple',
- (77, 88),
+ 'm_int_tuple', (77, 88),
'integer option that can occur multiple times',
short_name='mi_tuple')
self.assertListEqual(FLAGS.get_flag_value('m_int_tuple', None), [77, 88])
@@ -1448,79 +1470,90 @@ class MultiNumericalFlagsTest(absltest.TestCase):
flags.DEFINE_multi_integer, 'm_int2', ['abc'], 'desc')
self.assertRaisesRegex(
- flags.IllegalFlagValueError,
- r'flag --m_float2=abc: '
+ flags.IllegalFlagValueError, r'flag --m_float2=abc: '
r'(invalid literal for float\(\)|could not convert string to float): '
- r"'?abc'?",
- flags.DEFINE_multi_float, 'm_float2', ['abc'], 'desc')
+ r"'?abc'?", flags.DEFINE_multi_float, 'm_float2', ['abc'], 'desc')
# Test non-parseable command line values.
- flags.DEFINE_multi_integer('m_int2', '77',
- 'integer option that can occur multiple times')
+ fv = flags.FlagValues()
+ flags.DEFINE_multi_integer(
+ 'm_int2',
+ '77',
+ 'integer option that can occur multiple times',
+ flag_values=fv)
argv = ('./program', '--m_int2=def')
self.assertRaisesRegex(
flags.IllegalFlagValueError,
r"flag --m_int2=def: invalid literal for int\(\) with base 10: 'def'",
- FLAGS, argv)
+ fv, argv)
- flags.DEFINE_multi_float('m_float2', 2.2,
- 'float option that can occur multiple times')
+ flags.DEFINE_multi_float(
+ 'm_float2',
+ 2.2,
+ 'float option that can occur multiple times',
+ flag_values=fv)
argv = ('./program', '--m_float2=def')
self.assertRaisesRegex(
- flags.IllegalFlagValueError,
- r'flag --m_float2=def: '
+ flags.IllegalFlagValueError, r'flag --m_float2=def: '
r'(invalid literal for float\(\)|could not convert string to float): '
- r"'?def'?",
- FLAGS, argv)
+ r"'?def'?", fv, argv)
class MultiEnumFlagsTest(absltest.TestCase):
def test_multi_enum_flags(self):
"""Test multi_enum flags."""
+ fv = flags.FlagValues()
enum_defaults = ['FOO', 'BAZ']
- flags.DEFINE_multi_enum('m_enum', enum_defaults,
- ['FOO', 'BAR', 'BAZ', 'WHOOSH'],
- 'Enum option that can occur multiple times',
- short_name='me')
- self.assertListEqual(FLAGS.get_flag_value('m_enum', None), enum_defaults)
+ flags.DEFINE_multi_enum(
+ 'm_enum',
+ enum_defaults, ['FOO', 'BAR', 'BAZ', 'WHOOSH'],
+ 'Enum option that can occur multiple times',
+ short_name='me',
+ flag_values=fv)
+ self.assertListEqual(fv['m_enum'].default, enum_defaults)
argv = ('./program', '--m_enum=WHOOSH', '--me=FOO')
- FLAGS(argv)
- self.assertListEqual(
- FLAGS.get_flag_value('m_enum', None), ['WHOOSH', 'FOO'])
+ fv(argv)
+ self.assertListEqual(fv.get_flag_value('m_enum', None), ['WHOOSH', 'FOO'])
def test_single_value_default(self):
"""Test multi_enum flags with a single default value."""
+ fv = flags.FlagValues()
enum_default = 'FOO'
- flags.DEFINE_multi_enum('m_enum1', enum_default,
- ['FOO', 'BAR', 'BAZ', 'WHOOSH'],
- 'enum option that can occur multiple times')
- self.assertListEqual(FLAGS.get_flag_value('m_enum1', None), [enum_default])
+ flags.DEFINE_multi_enum(
+ 'm_enum1',
+ enum_default, ['FOO', 'BAR', 'BAZ', 'WHOOSH'],
+ 'enum option that can occur multiple times',
+ flag_values=fv)
+ self.assertListEqual(fv['m_enum1'].default, [enum_default])
def test_case_sensitivity(self):
"""Test case sensitivity of multi_enum flag."""
+ fv = flags.FlagValues()
# Test case insensitive enum.
- flags.DEFINE_multi_enum('m_enum2', ['whoosh'],
- ['FOO', 'BAR', 'BAZ', 'WHOOSH'],
- 'Enum option that can occur multiple times',
- short_name='me2',
- case_sensitive=False)
+ flags.DEFINE_multi_enum(
+ 'm_enum2', ['whoosh'], ['FOO', 'BAR', 'BAZ', 'WHOOSH'],
+ 'Enum option that can occur multiple times',
+ short_name='me2',
+ case_sensitive=False,
+ flag_values=fv)
argv = ('./program', '--m_enum2=bar', '--me2=fOo')
- FLAGS(argv)
- self.assertListEqual(FLAGS.get_flag_value('m_enum2', None), ['BAR', 'FOO'])
+ fv(argv)
+ self.assertListEqual(fv.get_flag_value('m_enum2', None), ['BAR', 'FOO'])
# Test case sensitive enum.
- flags.DEFINE_multi_enum('m_enum3', ['BAR'],
- ['FOO', 'BAR', 'BAZ', 'WHOOSH'],
- 'Enum option that can occur multiple times',
- short_name='me3',
- case_sensitive=True)
+ flags.DEFINE_multi_enum(
+ 'm_enum3', ['BAR'], ['FOO', 'BAR', 'BAZ', 'WHOOSH'],
+ 'Enum option that can occur multiple times',
+ short_name='me3',
+ case_sensitive=True,
+ flag_values=fv)
argv = ('./program', '--m_enum3=bar', '--me3=fOo')
self.assertRaisesRegex(
flags.IllegalFlagValueError,
r'flag --m_enum3=invalid: value should be one of <FOO|BAR|BAZ|WHOOSH>',
- FLAGS, argv)
+ fv, argv)
def test_bad_multi_enum_flags(self):
"""Test multi_enum with invalid values."""
@@ -1529,24 +1562,23 @@ class MultiEnumFlagsTest(absltest.TestCase):
self.assertRaisesRegex(
flags.IllegalFlagValueError,
r'flag --m_enum=INVALID: value should be one of <FOO|BAR|BAZ>',
- flags.DEFINE_multi_enum, 'm_enum', ['INVALID'],
- ['FOO', 'BAR', 'BAZ'], 'desc')
+ flags.DEFINE_multi_enum, 'm_enum', ['INVALID'], ['FOO', 'BAR', 'BAZ'],
+ 'desc')
self.assertRaisesRegex(
flags.IllegalFlagValueError,
r'flag --m_enum=1234: value should be one of <FOO|BAR|BAZ>',
- flags.DEFINE_multi_enum, 'm_enum2', [1234],
- ['FOO', 'BAR', 'BAZ'], 'desc')
+ flags.DEFINE_multi_enum, 'm_enum2', [1234], ['FOO', 'BAR', 'BAZ'],
+ 'desc')
# Test command-line values that are not in the permitted list of enums.
- flags.DEFINE_multi_enum('m_enum4', 'FOO',
- ['FOO', 'BAR', 'BAZ'],
+ flags.DEFINE_multi_enum('m_enum4', 'FOO', ['FOO', 'BAR', 'BAZ'],
'enum option that can occur multiple times')
argv = ('./program', '--m_enum4=INVALID')
self.assertRaisesRegex(
flags.IllegalFlagValueError,
- r'flag --m_enum4=invalid: value should be one of <FOO|BAR|BAZ>',
- FLAGS, argv)
+ r'flag --m_enum4=invalid: value should be one of <FOO|BAR|BAZ>', FLAGS,
+ argv)
class MultiEnumClassFlagsTest(absltest.TestCase):
@@ -1554,10 +1586,12 @@ class MultiEnumClassFlagsTest(absltest.TestCase):
def test_define_results_in_registered_flag_with_none(self):
fv = flags.FlagValues()
enum_defaults = None
- flags.DEFINE_multi_enum_class('fruit',
- enum_defaults, Fruit,
- 'Enum option that can occur multiple times',
- flag_values=fv)
+ flags.DEFINE_multi_enum_class(
+ 'fruit',
+ enum_defaults,
+ Fruit,
+ 'Enum option that can occur multiple times',
+ flag_values=fv)
fv.mark_as_parsed()
self.assertIsNone(fv.fruit)
@@ -1565,10 +1599,12 @@ class MultiEnumClassFlagsTest(absltest.TestCase):
def test_define_results_in_registered_flag_with_string(self):
fv = flags.FlagValues()
enum_defaults = 'apple'
- flags.DEFINE_multi_enum_class('fruit',
- enum_defaults, Fruit,
- 'Enum option that can occur multiple times',
- flag_values=fv)
+ flags.DEFINE_multi_enum_class(
+ 'fruit',
+ enum_defaults,
+ Fruit,
+ 'Enum option that can occur multiple times',
+ flag_values=fv)
fv.mark_as_parsed()
self.assertListEqual(fv.fruit, [Fruit.APPLE])
@@ -1576,10 +1612,12 @@ class MultiEnumClassFlagsTest(absltest.TestCase):
def test_define_results_in_registered_flag_with_enum(self):
fv = flags.FlagValues()
enum_defaults = Fruit.APPLE
- flags.DEFINE_multi_enum_class('fruit',
- enum_defaults, Fruit,
- 'Enum option that can occur multiple times',
- flag_values=fv)
+ flags.DEFINE_multi_enum_class(
+ 'fruit',
+ enum_defaults,
+ Fruit,
+ 'Enum option that can occur multiple times',
+ flag_values=fv)
fv.mark_as_parsed()
self.assertListEqual(fv.fruit, [Fruit.APPLE])
@@ -1602,10 +1640,12 @@ class MultiEnumClassFlagsTest(absltest.TestCase):
def test_define_results_in_registered_flag_with_enum_list(self):
fv = flags.FlagValues()
enum_defaults = [Fruit.APPLE, Fruit.ORANGE]
- flags.DEFINE_multi_enum_class('fruit',
- enum_defaults, Fruit,
- 'Enum option that can occur multiple times',
- flag_values=fv)
+ flags.DEFINE_multi_enum_class(
+ 'fruit',
+ enum_defaults,
+ Fruit,
+ 'Enum option that can occur multiple times',
+ flag_values=fv)
fv.mark_as_parsed()
self.assertListEqual(fv.fruit, [Fruit.APPLE, Fruit.ORANGE])
@@ -1613,10 +1653,12 @@ class MultiEnumClassFlagsTest(absltest.TestCase):
def test_from_command_line_returns_multiple(self):
fv = flags.FlagValues()
enum_defaults = [Fruit.APPLE]
- flags.DEFINE_multi_enum_class('fruit',
- enum_defaults, Fruit,
- 'Enum option that can occur multiple times',
- flag_values=fv)
+ flags.DEFINE_multi_enum_class(
+ 'fruit',
+ enum_defaults,
+ Fruit,
+ 'Enum option that can occur multiple times',
+ flag_values=fv)
argv = ('./program', '--fruit=Apple', '--fruit=orange')
fv(argv)
self.assertListEqual(fv.fruit, [Fruit.APPLE, Fruit.ORANGE])
@@ -1630,8 +1672,8 @@ class MultiEnumClassFlagsTest(absltest.TestCase):
def test_bad_multi_enum_class_flags_from_commandline(self):
fv = flags.FlagValues()
enum_defaults = [Fruit.APPLE]
- flags.DEFINE_multi_enum_class('fruit', enum_defaults, Fruit, 'desc',
- flag_values=fv)
+ flags.DEFINE_multi_enum_class(
+ 'fruit', enum_defaults, Fruit, 'desc', flag_values=fv)
argv = ('./program', '--fruit=INVALID')
with self.assertRaisesRegex(
flags.IllegalFlagValueError,
@@ -1643,65 +1685,81 @@ class UnicodeFlagsTest(absltest.TestCase):
"""Testing proper unicode support for flags."""
def test_unicode_default_and_helpstring(self):
- flags.DEFINE_string('unicode_str', b'\xC3\x80\xC3\xBD'.decode('utf-8'),
- b'help:\xC3\xAA'.decode('utf-8'))
+ fv = flags.FlagValues()
+ flags.DEFINE_string(
+ 'unicode_str',
+ b'\xC3\x80\xC3\xBD'.decode('utf-8'),
+ b'help:\xC3\xAA'.decode('utf-8'),
+ flag_values=fv)
argv = ('./program',)
- FLAGS(argv) # should not raise any exceptions
+ fv(argv) # should not raise any exceptions
argv = ('./program', '--unicode_str=foo')
- FLAGS(argv) # should not raise any exceptions
+ fv(argv) # should not raise any exceptions
def test_unicode_in_list(self):
- flags.DEFINE_list('unicode_list', ['abc', b'\xC3\x80'.decode('utf-8'),
- b'\xC3\xBD'.decode('utf-8')],
- b'help:\xC3\xAB'.decode('utf-8'))
+ fv = flags.FlagValues()
+ flags.DEFINE_list(
+ 'unicode_list',
+ ['abc', b'\xC3\x80'.decode('utf-8'), b'\xC3\xBD'.decode('utf-8')],
+ b'help:\xC3\xAB'.decode('utf-8'),
+ flag_values=fv)
argv = ('./program',)
- FLAGS(argv) # should not raise any exceptions
+ fv(argv) # should not raise any exceptions
argv = ('./program', '--unicode_list=hello,there')
- FLAGS(argv) # should not raise any exceptions
+ fv(argv) # should not raise any exceptions
def test_xmloutput(self):
- flags.DEFINE_string('unicode1', b'\xC3\x80\xC3\xBD'.decode('utf-8'),
- b'help:\xC3\xAC'.decode('utf-8'))
- flags.DEFINE_list('unicode2', ['abc', b'\xC3\x80'.decode('utf-8'),
- b'\xC3\xBD'.decode('utf-8')],
- b'help:\xC3\xAD'.decode('utf-8'))
- flags.DEFINE_list('non_unicode', ['abc', 'def', 'ghi'],
- b'help:\xC3\xAD'.decode('utf-8'))
+ fv = flags.FlagValues()
+ flags.DEFINE_string(
+ 'unicode1',
+ b'\xC3\x80\xC3\xBD'.decode('utf-8'),
+ b'help:\xC3\xAC'.decode('utf-8'),
+ flag_values=fv)
+ flags.DEFINE_list(
+ 'unicode2',
+ ['abc', b'\xC3\x80'.decode('utf-8'), b'\xC3\xBD'.decode('utf-8')],
+ b'help:\xC3\xAD'.decode('utf-8'),
+ flag_values=fv)
+ flags.DEFINE_list(
+ 'non_unicode', ['abc', 'def', 'ghi'],
+ b'help:\xC3\xAD'.decode('utf-8'),
+ flag_values=fv)
outfile = io.StringIO() if six.PY3 else io.BytesIO()
- FLAGS.write_help_in_xml_format(outfile)
+ fv.write_help_in_xml_format(outfile)
actual_output = outfile.getvalue()
if six.PY2:
actual_output = actual_output.decode('utf-8')
# The xml output is large, so we just check parts of it.
- self.assertIn(b'<name>unicode1</name>\n'
- b' <meaning>help:\xc3\xac</meaning>\n'
- b' <default>\xc3\x80\xc3\xbd</default>\n'
- b' <current>\xc3\x80\xc3\xbd</current>'.decode('utf-8'),
- actual_output)
+ self.assertIn(
+ b'<name>unicode1</name>\n'
+ b' <meaning>help:\xc3\xac</meaning>\n'
+ b' <default>\xc3\x80\xc3\xbd</default>\n'
+ b' <current>\xc3\x80\xc3\xbd</current>'.decode('utf-8'),
+ actual_output)
if six.PY2:
- self.assertIn(b'<name>unicode2</name>\n'
- b' <meaning>help:\xc3\xad</meaning>\n'
- b' <default>abc,\xc3\x80,\xc3\xbd</default>\n'
- b" <current>['abc', u'\\xc0', u'\\xfd']"
- b'</current>'.decode('utf-8'),
- actual_output)
+ self.assertIn(
+ b'<name>unicode2</name>\n'
+ b' <meaning>help:\xc3\xad</meaning>\n'
+ b' <default>abc,\xc3\x80,\xc3\xbd</default>\n'
+ b" <current>['abc', u'\\xc0', u'\\xfd']"
+ b'</current>'.decode('utf-8'), actual_output)
else:
- self.assertIn(b'<name>unicode2</name>\n'
- b' <meaning>help:\xc3\xad</meaning>\n'
- b' <default>abc,\xc3\x80,\xc3\xbd</default>\n'
- b" <current>['abc', '\xc3\x80', '\xc3\xbd']"
- b'</current>'.decode('utf-8'),
- actual_output)
- self.assertIn(b'<name>non_unicode</name>\n'
- b' <meaning>help:\xc3\xad</meaning>\n'
- b' <default>abc,def,ghi</default>\n'
- b" <current>['abc', 'def', 'ghi']"
- b'</current>'.decode('utf-8'),
- actual_output)
+ self.assertIn(
+ b'<name>unicode2</name>\n'
+ b' <meaning>help:\xc3\xad</meaning>\n'
+ b' <default>abc,\xc3\x80,\xc3\xbd</default>\n'
+ b" <current>['abc', '\xc3\x80', '\xc3\xbd']"
+ b'</current>'.decode('utf-8'), actual_output)
+ self.assertIn(
+ b'<name>non_unicode</name>\n'
+ b' <meaning>help:\xc3\xad</meaning>\n'
+ b' <default>abc,def,ghi</default>\n'
+ b" <current>['abc', 'def', 'ghi']"
+ b'</current>'.decode('utf-8'), actual_output)
class LoadFromFlagFileTest(absltest.TestCase):
@@ -1709,16 +1767,29 @@ class LoadFromFlagFileTest(absltest.TestCase):
def setUp(self):
self.flag_values = flags.FlagValues()
- flags.DEFINE_string('unittest_message1', 'Foo!', 'You Add Here.',
- flag_values=self.flag_values)
- flags.DEFINE_string('unittest_message2', 'Bar!', 'Hello, Sailor!',
- flag_values=self.flag_values)
- flags.DEFINE_boolean('unittest_boolflag', 0, 'Some Boolean thing',
- flag_values=self.flag_values)
- flags.DEFINE_integer('unittest_number', 12345, 'Some integer',
- lower_bound=0, flag_values=self.flag_values)
- flags.DEFINE_list('UnitTestList', '1,2,3', 'Some list',
- flag_values=self.flag_values)
+ flags.DEFINE_string(
+ 'unittest_message1',
+ 'Foo!',
+ 'You Add Here.',
+ flag_values=self.flag_values)
+ flags.DEFINE_string(
+ 'unittest_message2',
+ 'Bar!',
+ 'Hello, Sailor!',
+ flag_values=self.flag_values)
+ flags.DEFINE_boolean(
+ 'unittest_boolflag',
+ 0,
+ 'Some Boolean thing',
+ flag_values=self.flag_values)
+ flags.DEFINE_integer(
+ 'unittest_number',
+ 12345,
+ 'Some integer',
+ lower_bound=0,
+ flag_values=self.flag_values)
+ flags.DEFINE_list(
+ 'UnitTestList', '1,2,3', 'Some list', flag_values=self.flag_values)
self.tmp_path = None
self.flag_values.mark_as_parsed()
@@ -1788,8 +1859,8 @@ class LoadFromFlagFileTest(absltest.TestCase):
fake_argv = fake_cmd_line.split(' ')
self.flag_values(fake_argv)
self.assertEqual(self.flag_values.unittest_boolflag, 1)
- self.assertListEqual(
- fake_argv, self._read_flags_from_files(fake_argv, False))
+ self.assertListEqual(fake_argv,
+ self._read_flags_from_files(fake_argv, False))
def test_method_flagfiles_2(self):
"""Tests parsing one file + arguments off simulated argv."""
@@ -1801,32 +1872,31 @@ class LoadFromFlagFileTest(absltest.TestCase):
# We should see the original cmd line with the file's contents spliced in.
# Flags from the file will appear in the order order they are sepcified
# in the file, in the same position as the flagfile argument.
- expected_results = ['fooScript',
- '--q',
- '--unittest_message1=tempFile1!',
- '--unittest_number=54321',
- '--nounittest_boolflag']
+ expected_results = [
+ 'fooScript', '--q', '--unittest_message1=tempFile1!',
+ '--unittest_number=54321', '--nounittest_boolflag'
+ ]
test_results = self._read_flags_from_files(fake_argv, False)
self.assertListEqual(expected_results, test_results)
+
# end testTwo def
def test_method_flagfiles_3(self):
"""Tests parsing nested files + arguments of simulated argv."""
tmp_files = self._setup_test_files()
# specify our temp file on the fake cmd line
- fake_cmd_line = ('fooScript --unittest_number=77 --flagfile=%s'
- % tmp_files[1])
+ fake_cmd_line = ('fooScript --unittest_number=77 --flagfile=%s' %
+ tmp_files[1])
fake_argv = fake_cmd_line.split(' ')
- expected_results = ['fooScript',
- '--unittest_number=77',
- '--unittest_message1=tempFile1!',
- '--unittest_number=54321',
- '--nounittest_boolflag',
- '--unittest_message2=setFromTempFile2',
- '--unittest_number=6789a']
+ expected_results = [
+ 'fooScript', '--unittest_number=77', '--unittest_message1=tempFile1!',
+ '--unittest_number=54321', '--nounittest_boolflag',
+ '--unittest_message2=setFromTempFile2', '--unittest_number=6789a'
+ ]
test_results = self._read_flags_from_files(fake_argv, False)
self.assertListEqual(expected_results, test_results)
+
# end testThree def
def test_method_flagfiles_3_spaces(self):
@@ -1837,18 +1907,16 @@ class LoadFromFlagFileTest(absltest.TestCase):
"""
tmp_files = self._setup_test_files()
# specify our temp file on the fake cmd line
- fake_cmd_line = ('fooScript --unittest_number 77 --flagfile=%s'
- % tmp_files[1])
+ fake_cmd_line = ('fooScript --unittest_number 77 --flagfile=%s' %
+ tmp_files[1])
fake_argv = fake_cmd_line.split(' ')
- expected_results = ['fooScript',
- '--unittest_number',
- '77',
- '--unittest_message1=tempFile1!',
- '--unittest_number=54321',
- '--nounittest_boolflag',
- '--unittest_message2=setFromTempFile2',
- '--unittest_number=6789a']
+ expected_results = [
+ 'fooScript', '--unittest_number', '77',
+ '--unittest_message1=tempFile1!', '--unittest_number=54321',
+ '--nounittest_boolflag', '--unittest_message2=setFromTempFile2',
+ '--unittest_number=6789a'
+ ]
test_results = self._read_flags_from_files(fake_argv, False)
self.assertListEqual(expected_results, test_results)
@@ -1860,14 +1928,14 @@ class LoadFromFlagFileTest(absltest.TestCase):
"""
tmp_files = self._setup_test_files()
# specify our temp file on the fake cmd line
- fake_cmd_line = ('fooScript --unittest_boolflag 77 --flagfile=%s'
- % tmp_files[1])
+ fake_cmd_line = ('fooScript --unittest_boolflag 77 --flagfile=%s' %
+ tmp_files[1])
fake_argv = fake_cmd_line.split(' ')
- expected_results = ['fooScript',
- '--unittest_boolflag',
- '77',
- '--flagfile=%s' % tmp_files[1]]
+ expected_results = [
+ 'fooScript', '--unittest_boolflag', '77',
+ '--flagfile=%s' % tmp_files[1]
+ ]
with _use_gnu_getopt(self.flag_values, False):
test_results = self._read_flags_from_files(fake_argv, False)
self.assertListEqual(expected_results, test_results)
@@ -1879,13 +1947,13 @@ class LoadFromFlagFileTest(absltest.TestCase):
"""
tmp_files = self._setup_test_files()
# specify our temp file on the fake cmd line
- fake_cmd_line = ('fooScript --flagfile=%s --nounittest_boolflag'
- % tmp_files[2])
+ fake_cmd_line = ('fooScript --flagfile=%s --nounittest_boolflag' %
+ tmp_files[2])
fake_argv = fake_cmd_line.split(' ')
- expected_results = ['fooScript',
- '--unittest_message1=setFromTempFile3',
- '--unittest_boolflag',
- '--nounittest_boolflag']
+ expected_results = [
+ 'fooScript', '--unittest_message1=setFromTempFile3',
+ '--unittest_boolflag', '--nounittest_boolflag'
+ ]
test_results = self._read_flags_from_files(fake_argv, False)
self.assertListEqual(expected_results, test_results)
@@ -1896,10 +1964,10 @@ class LoadFromFlagFileTest(absltest.TestCase):
# specify our temp file on the fake cmd line
fake_cmd_line = 'fooScript --some_flag -- --flagfile=%s' % tmp_files[0]
fake_argv = fake_cmd_line.split(' ')
- expected_results = ['fooScript',
- '--some_flag',
- '--',
- '--flagfile=%s' % tmp_files[0]]
+ expected_results = [
+ 'fooScript', '--some_flag', '--',
+ '--flagfile=%s' % tmp_files[0]
+ ]
test_results = self._read_flags_from_files(fake_argv, False)
self.assertListEqual(expected_results, test_results)
@@ -1908,13 +1976,13 @@ class LoadFromFlagFileTest(absltest.TestCase):
"""Test that --flagfile parsing stops at non-options (non-GNU behavior)."""
tmp_files = self._setup_test_files()
# specify our temp file on the fake cmd line
- fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s'
- % tmp_files[0])
+ fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' %
+ tmp_files[0])
fake_argv = fake_cmd_line.split(' ')
- expected_results = ['fooScript',
- '--some_flag',
- 'some_arg',
- '--flagfile=%s' % tmp_files[0]]
+ expected_results = [
+ 'fooScript', '--some_flag', 'some_arg',
+ '--flagfile=%s' % tmp_files[0]
+ ]
with _use_gnu_getopt(self.flag_values, False):
test_results = self._read_flags_from_files(fake_argv, False)
@@ -1925,15 +1993,14 @@ class LoadFromFlagFileTest(absltest.TestCase):
self.flag_values.set_gnu_getopt()
tmp_files = self._setup_test_files()
# specify our temp file on the fake cmd line
- fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s'
- % tmp_files[0])
+ fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' %
+ tmp_files[0])
fake_argv = fake_cmd_line.split(' ')
- expected_results = ['fooScript',
- '--some_flag',
- 'some_arg',
- '--unittest_message1=tempFile1!',
- '--unittest_number=54321',
- '--nounittest_boolflag']
+ expected_results = [
+ 'fooScript', '--some_flag', 'some_arg',
+ '--unittest_message1=tempFile1!', '--unittest_number=54321',
+ '--nounittest_boolflag'
+ ]
test_results = self._read_flags_from_files(fake_argv, False)
self.assertListEqual(expected_results, test_results)
@@ -1942,15 +2009,14 @@ class LoadFromFlagFileTest(absltest.TestCase):
"""Test that --flagfile parsing respects force_gnu=True."""
tmp_files = self._setup_test_files()
# specify our temp file on the fake cmd line
- fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s'
- % tmp_files[0])
+ fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' %
+ tmp_files[0])
fake_argv = fake_cmd_line.split(' ')
- expected_results = ['fooScript',
- '--some_flag',
- 'some_arg',
- '--unittest_message1=tempFile1!',
- '--unittest_number=54321',
- '--nounittest_boolflag']
+ expected_results = [
+ 'fooScript', '--some_flag', 'some_arg',
+ '--unittest_message1=tempFile1!', '--unittest_number=54321',
+ '--nounittest_boolflag'
+ ]
test_results = self._read_flags_from_files(fake_argv, True)
self.assertListEqual(expected_results, test_results)
@@ -1959,18 +2025,16 @@ class LoadFromFlagFileTest(absltest.TestCase):
"""Tests that parsing repeated non-circular flagfiles works."""
tmp_files = self._setup_test_files()
# specify our temp files on the fake cmd line
- fake_cmd_line = ('fooScript --flagfile=%s --flagfile=%s'
- % (tmp_files[1], tmp_files[0]))
+ fake_cmd_line = ('fooScript --flagfile=%s --flagfile=%s' %
+ (tmp_files[1], tmp_files[0]))
fake_argv = fake_cmd_line.split(' ')
- expected_results = ['fooScript',
- '--unittest_message1=tempFile1!',
- '--unittest_number=54321',
- '--nounittest_boolflag',
- '--unittest_message2=setFromTempFile2',
- '--unittest_number=6789a',
- '--unittest_message1=tempFile1!',
- '--unittest_number=54321',
- '--nounittest_boolflag']
+ expected_results = [
+ 'fooScript', '--unittest_message1=tempFile1!',
+ '--unittest_number=54321', '--nounittest_boolflag',
+ '--unittest_message2=setFromTempFile2', '--unittest_number=6789a',
+ '--unittest_message1=tempFile1!', '--unittest_number=54321',
+ '--nounittest_boolflag'
+ ]
test_results = self._read_flags_from_files(fake_argv, False)
self.assertListEqual(expected_results, test_results)
@@ -1982,21 +2046,21 @@ class LoadFromFlagFileTest(absltest.TestCase):
"""Test that --flagfile raises except on file that is unreadable."""
tmp_files = self._setup_test_files()
# specify our temp file on the fake cmd line
- fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s'
- % tmp_files[3])
+ fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' %
+ tmp_files[3])
fake_argv = fake_cmd_line.split(' ')
- self.assertRaises(flags.CantOpenFlagFileError,
- self._read_flags_from_files, fake_argv, True)
+ self.assertRaises(flags.CantOpenFlagFileError, self._read_flags_from_files,
+ fake_argv, True)
def test_method_flagfiles_not_found(self):
"""Test that --flagfile raises except on file that does not exist."""
tmp_files = self._setup_test_files()
# specify our temp file on the fake cmd line
- fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%sNOTEXIST'
- % tmp_files[3])
+ fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%sNOTEXIST' %
+ tmp_files[3])
fake_argv = fake_cmd_line.split(' ')
- self.assertRaises(flags.CantOpenFlagFileError,
- self._read_flags_from_files, fake_argv, True)
+ self.assertRaises(flags.CantOpenFlagFileError, self._read_flags_from_files,
+ fake_argv, True)
def test_flagfiles_user_path_expansion(self):
"""Test that user directory referenced paths are correctly expanded.
@@ -2023,8 +2087,10 @@ class LoadFromFlagFileTest(absltest.TestCase):
The argumants are not supposed to be flags
"""
- fake_argv = ['fooScript', '--unittest_boolflag',
- 'command', '--command_arg1', '--UnitTestBoom', '--UnitTestB']
+ fake_argv = [
+ 'fooScript', '--unittest_boolflag', 'command', '--command_arg1',
+ '--UnitTestBoom', '--UnitTestB'
+ ]
with _use_gnu_getopt(self.flag_values, False):
argv = self.flag_values(fake_argv)
self.assertListEqual(argv, fake_argv[:1] + fake_argv[2:])
@@ -2032,8 +2098,9 @@ class LoadFromFlagFileTest(absltest.TestCase):
def test_parse_flags_after_args_if_using_gnugetopt(self):
"""Test that flags given after arguments are parsed if using gnu_getopt."""
self.flag_values.set_gnu_getopt()
- fake_argv = ['fooScript', '--unittest_boolflag',
- 'command', '--unittest_number=54321']
+ fake_argv = [
+ 'fooScript', '--unittest_boolflag', 'command', '--unittest_number=54321'
+ ]
argv = self.flag_values(fake_argv)
self.assertListEqual(argv, ['fooScript', 'command'])
@@ -2059,8 +2126,7 @@ class LoadFromFlagFileTest(absltest.TestCase):
self.flag_values.set_default('unittest_number', 0)
self.assertEqual(self.flag_values['unittest_number'].default, 0)
self.assertEqual(self.flag_values.unittest_number, 56)
- self.assertEqual(
- self.flag_values['unittest_number'].default_as_str, "'0'")
+ self.assertEqual(self.flag_values['unittest_number'].default_as_str, "'0'")
self.flag_values(['dummyscript', '--unittest_number=56'])
self.assertEqual(self.flag_values.unittest_number, 56)
@@ -2083,8 +2149,7 @@ class LoadFromFlagFileTest(absltest.TestCase):
# Test that setting a list default works correctly.
self.flag_values.set_default('UnitTestList', '4,5,6')
self.assertListEqual(self.flag_values.UnitTestList, ['4', '5', '6'])
- self.assertEqual(self.flag_values['UnitTestList'].default_as_str,
- "'4,5,6'")
+ self.assertEqual(self.flag_values['UnitTestList'].default_as_str, "'4,5,6'")
self.flag_values(['dummyscript', '--UnitTestList=7,8,9'])
self.assertListEqual(self.flag_values.UnitTestList, ['7', '8', '9'])
@@ -2102,25 +2167,21 @@ class FlagsParsingTest(absltest.TestCase):
self.flag_values = flags.FlagValues()
def test_two_dash_arg_first(self):
- flags.DEFINE_string('twodash_name', 'Bob', 'namehelp',
- flag_values=self.flag_values)
- flags.DEFINE_string('twodash_blame', 'Rob', 'blamehelp',
- flag_values=self.flag_values)
- argv = ('./program',
- '--',
- '--twodash_name=Harry')
+ flags.DEFINE_string(
+ 'twodash_name', 'Bob', 'namehelp', flag_values=self.flag_values)
+ flags.DEFINE_string(
+ 'twodash_blame', 'Rob', 'blamehelp', flag_values=self.flag_values)
+ argv = ('./program', '--', '--twodash_name=Harry')
argv = self.flag_values(argv)
self.assertEqual('Bob', self.flag_values.twodash_name)
self.assertEqual(argv[1], '--twodash_name=Harry')
def test_two_dash_arg_middle(self):
- flags.DEFINE_string('twodash2_name', 'Bob', 'namehelp',
- flag_values=self.flag_values)
- flags.DEFINE_string('twodash2_blame', 'Rob', 'blamehelp',
- flag_values=self.flag_values)
- argv = ('./program',
- '--twodash2_blame=Larry',
- '--',
+ flags.DEFINE_string(
+ 'twodash2_name', 'Bob', 'namehelp', flag_values=self.flag_values)
+ flags.DEFINE_string(
+ 'twodash2_blame', 'Rob', 'blamehelp', flag_values=self.flag_values)
+ argv = ('./program', '--twodash2_blame=Larry', '--',
'--twodash2_name=Harry')
argv = self.flag_values(argv)
self.assertEqual('Bob', self.flag_values.twodash2_name)
@@ -2128,19 +2189,42 @@ class FlagsParsingTest(absltest.TestCase):
self.assertEqual(argv[1], '--twodash2_name=Harry')
def test_one_dash_arg_first(self):
- flags.DEFINE_string('onedash_name', 'Bob', 'namehelp',
- flag_values=self.flag_values)
- flags.DEFINE_string('onedash_blame', 'Rob', 'blamehelp',
- flag_values=self.flag_values)
- argv = ('./program',
- '-',
- '--onedash_name=Harry')
+ flags.DEFINE_string(
+ 'onedash_name', 'Bob', 'namehelp', flag_values=self.flag_values)
+ flags.DEFINE_string(
+ 'onedash_blame', 'Rob', 'blamehelp', flag_values=self.flag_values)
+ argv = ('./program', '-', '--onedash_name=Harry')
with _use_gnu_getopt(self.flag_values, False):
argv = self.flag_values(argv)
self.assertEqual(len(argv), 3)
self.assertEqual(argv[1], '-')
self.assertEqual(argv[2], '--onedash_name=Harry')
+ def test_required_flag_not_specified(self):
+ flags.DEFINE_string(
+ 'str_flag',
+ default=None,
+ help='help',
+ required=True,
+ flag_values=self.flag_values)
+ argv = ('./program',)
+ with _use_gnu_getopt(self.flag_values, False):
+ with self.assertRaises(flags.IllegalFlagValueError):
+ self.flag_values(argv)
+
+ def test_required_arg_works_with_other_validators(self):
+ flags.DEFINE_integer(
+ 'int_flag',
+ default=None,
+ help='help',
+ required=True,
+ lower_bound=4,
+ flag_values=self.flag_values)
+ argv = ('./program', '--int_flag=2')
+ with _use_gnu_getopt(self.flag_values, False):
+ with self.assertRaises(flags.IllegalFlagValueError):
+ self.flag_values(argv)
+
def test_unrecognized_flags(self):
flags.DEFINE_string('name', 'Bob', 'namehelp', flag_values=self.flag_values)
# Unknown flag --nosuchflag
@@ -2171,16 +2255,16 @@ class FlagsParsingTest(absltest.TestCase):
self.assertEqual(e.flagvalue, '--nosuchflagwithparam=foo')
# Allow unknown flag --nosuchflag if specified with undefok
- argv = ('./program', '--nosuchflag', '--name=Bob',
- '--undefok=nosuchflag', 'extra')
+ argv = ('./program', '--nosuchflag', '--name=Bob', '--undefok=nosuchflag',
+ 'extra')
argv = self.flag_values(argv)
self.assertEqual(len(argv), 2, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(argv[1], 'extra', 'extra argument not preserved')
# Allow unknown flag --noboolflag if undefok=boolflag is specified
- argv = ('./program', '--noboolflag', '--name=Bob',
- '--undefok=boolflag', 'extra')
+ argv = ('./program', '--noboolflag', '--name=Bob', '--undefok=boolflag',
+ 'extra')
argv = self.flag_values(argv)
self.assertEqual(len(argv), 2, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
@@ -2188,8 +2272,8 @@ class FlagsParsingTest(absltest.TestCase):
# But not if the flagname is misspelled:
try:
- argv = ('./program', '--nosuchflag', '--name=Bob',
- '--undefok=nosuchfla', 'extra')
+ argv = ('./program', '--nosuchflag', '--name=Bob', '--undefok=nosuchfla',
+ 'extra')
self.flag_values(argv)
raise AssertionError('Unknown flag exception not raised')
except flags.UnrecognizedFlagError as e:
@@ -2221,9 +2305,7 @@ class FlagsParsingTest(absltest.TestCase):
# Even if undefok specifies multiple flags
argv = ('./program', '--nosuchflag', '-w', '--nosuchflagwithparam=foo',
- '--name=Bob',
- '--undefok=nosuchflag,w,nosuchflagwithparam',
- 'extra')
+ '--name=Bob', '--undefok=nosuchflag,w,nosuchflagwithparam', 'extra')
argv = self.flag_values(argv)
self.assertEqual(len(argv), 2, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
@@ -2251,9 +2333,7 @@ class FlagsParsingTest(absltest.TestCase):
# Test --undefok <list>
argv = ('./program', '--nosuchflag', '-w', '--nosuchflagwithparam=foo',
- '--name=Bob',
- '--undefok',
- 'nosuchflag,w,nosuchflagwithparam',
+ '--name=Bob', '--undefok', 'nosuchflag,w,nosuchflagwithparam',
'extra')
argv = self.flag_values(argv)
self.assertEqual(len(argv), 2, 'wrong number of arguments pulled')
@@ -2272,9 +2352,7 @@ class NonGlobalFlagsTest(absltest.TestCase):
"""Test use of non-global FlagValues."""
nonglobal_flags = flags.FlagValues()
flags.DEFINE_string('nonglobal_flag', 'Bob', 'flaghelp', nonglobal_flags)
- argv = ('./program',
- '--nonglobal_flag=Mary',
- 'extra')
+ argv = ('./program', '--nonglobal_flag=Mary', 'extra')
argv = nonglobal_flags(argv)
self.assertEqual(len(argv), 2, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
@@ -2284,17 +2362,14 @@ class NonGlobalFlagsTest(absltest.TestCase):
def test_unrecognized_nonglobal_flags(self):
"""Test unrecognized non-global flags."""
nonglobal_flags = flags.FlagValues()
- argv = ('./program',
- '--nosuchflag')
+ argv = ('./program', '--nosuchflag')
try:
argv = nonglobal_flags(argv)
raise AssertionError('Unknown flag exception not raised')
except flags.UnrecognizedFlagError as e:
self.assertEqual(e.flagname, 'nosuchflag')
- argv = ('./program',
- '--nosuchflag',
- '--undefok=nosuchflag')
+ argv = ('./program', '--nosuchflag', '--undefok=nosuchflag')
argv = nonglobal_flags(argv)
self.assertEqual(len(argv), 1, 'wrong number of arguments pulled')
@@ -2317,8 +2392,8 @@ class NonGlobalFlagsTest(absltest.TestCase):
default_value = 'default value for test_flag_values_del_attr'
# 1. Declare and delete a flag with no short name.
flag_values = flags.FlagValues()
- flags.DEFINE_string('delattr_foo', default_value, 'A simple flag.',
- flag_values=flag_values)
+ flags.DEFINE_string(
+ 'delattr_foo', default_value, 'A simple flag.', flag_values=flag_values)
flag_values.mark_as_parsed()
self.assertEqual(flag_values.delattr_foo, default_value)
@@ -2330,15 +2405,19 @@ class NonGlobalFlagsTest(absltest.TestCase):
self.assertFalse(flag_values._flag_is_registered(flag_obj))
# If the previous del FLAGS.delattr_foo did not work properly, the
# next definition will trigger a redefinition error.
- flags.DEFINE_integer('delattr_foo', 3, 'A simple flag.',
- flag_values=flag_values)
+ flags.DEFINE_integer(
+ 'delattr_foo', 3, 'A simple flag.', flag_values=flag_values)
del flag_values.delattr_foo
self.assertFalse('delattr_foo' in flag_values)
# 2. Declare and delete a flag with a short name.
- flags.DEFINE_string('delattr_bar', default_value, 'flag with short name',
- short_name='x5', flag_values=flag_values)
+ flags.DEFINE_string(
+ 'delattr_bar',
+ default_value,
+ 'flag with short name',
+ short_name='x5',
+ flag_values=flag_values)
flag_obj = flag_values['delattr_bar']
self.assertTrue(flag_values._flag_is_registered(flag_obj))
del flag_values.x5
@@ -2347,8 +2426,12 @@ class NonGlobalFlagsTest(absltest.TestCase):
self.assertFalse(flag_values._flag_is_registered(flag_obj))
# 3. Just like 2, but del flag_values.name last
- flags.DEFINE_string('delattr_bar', default_value, 'flag with short name',
- short_name='x5', flag_values=flag_values)
+ flags.DEFINE_string(
+ 'delattr_bar',
+ default_value,
+ 'flag with short name',
+ short_name='x5',
+ flag_values=flag_values)
flag_obj = flag_values['delattr_bar']
self.assertTrue(flag_values._flag_is_registered(flag_obj))
del flag_values.delattr_bar
@@ -2361,24 +2444,25 @@ class NonGlobalFlagsTest(absltest.TestCase):
def test_list_flag_format(self):
"""Tests for correctly-formatted list flags."""
- flags.DEFINE_list('listflag', '', 'A list of arguments')
+ fv = flags.FlagValues()
+ flags.DEFINE_list('listflag', '', 'A list of arguments', flag_values=fv)
def _check_parsing(listval):
"""Parse a particular value for our test flag, --listflag."""
- argv = FLAGS(['./program', '--listflag=' + listval, 'plain-arg'])
+ argv = fv(['./program', '--listflag=' + listval, 'plain-arg'])
self.assertEqual(['./program', 'plain-arg'], argv)
- return FLAGS.listflag
+ return fv.listflag
# Basic success case
self.assertEqual(_check_parsing('foo,bar'), ['foo', 'bar'])
# Success case: newline in argument is quoted.
self.assertEqual(_check_parsing('"foo","bar\nbar"'), ['foo', 'bar\nbar'])
# Failure case: newline in argument is unquoted.
- self.assertRaises(
- flags.IllegalFlagValueError, _check_parsing, '"foo",bar\nbar')
+ self.assertRaises(flags.IllegalFlagValueError, _check_parsing,
+ '"foo",bar\nbar')
# Failure case: unmatched ".
- self.assertRaises(
- flags.IllegalFlagValueError, _check_parsing, '"foo,barbar')
+ self.assertRaises(flags.IllegalFlagValueError, _check_parsing,
+ '"foo,barbar')
def test_flag_definition_via_setitem(self):
with self.assertRaises(flags.IllegalFlagValueError):
@@ -2432,16 +2516,14 @@ class KeyFlagsTest(absltest.TestCase):
flag_values = flags.FlagValues()
# Before starting any testing, make sure no flags are already
# defined for module_foo and module_bar.
- self.assertListEqual(self._get_names_of_key_flags(module_foo, flag_values),
- [])
- self.assertListEqual(self._get_names_of_key_flags(module_bar, flag_values),
- [])
- self.assertListEqual(self._get_names_of_defined_flags(module_foo,
- flag_values),
- [])
- self.assertListEqual(self._get_names_of_defined_flags(module_bar,
- flag_values),
- [])
+ self.assertListEqual(
+ self._get_names_of_key_flags(module_foo, flag_values), [])
+ self.assertListEqual(
+ self._get_names_of_key_flags(module_bar, flag_values), [])
+ self.assertListEqual(
+ self._get_names_of_defined_flags(module_foo, flag_values), [])
+ self.assertListEqual(
+ self._get_names_of_defined_flags(module_bar, flag_values), [])
# Defines a few flags in module_foo and module_bar.
module_foo.define_flags(flag_values=flag_values)
@@ -2501,12 +2583,8 @@ class KeyFlagsTest(absltest.TestCase):
# Before starting any testing, make sure no flags are already
# defined for module_foo and module_bar.
- self.assertListEqual(
- self._get_names_of_key_flags(module_bar, fv),
- [])
- self.assertListEqual(
- self._get_names_of_defined_flags(module_bar, fv),
- [])
+ self.assertListEqual(self._get_names_of_key_flags(module_bar, fv), [])
+ self.assertListEqual(self._get_names_of_defined_flags(module_bar, fv), [])
module_bar.define_flags(flag_values=fv)
@@ -2531,8 +2609,7 @@ class KeyFlagsTest(absltest.TestCase):
flags.declare_key_flag(flag_name_0, flag_values=fv)
self._assert_lists_have_same_elements(
- self._get_names_of_key_flags(main_module, fv),
- [flag_name_0])
+ self._get_names_of_key_flags(main_module, fv), [flag_name_0])
flags.declare_key_flag(flag_name_2, flag_values=fv)
self._assert_lists_have_same_elements(
@@ -2569,9 +2646,11 @@ class KeyFlagsTest(absltest.TestCase):
# Define one flag in this main module and some flags in modules
# a and b. Also declare one flag from module a and one flag
# from module b as key flags for the main module.
- flags.DEFINE_integer('main_module_int_fg', 1,
- 'Integer flag in the main module.',
- flag_values=self.flag_values)
+ flags.DEFINE_integer(
+ 'main_module_int_fg',
+ 1,
+ 'Integer flag in the main module.',
+ flag_values=self.flag_values)
try:
main_module_int_fg_help = (
@@ -2611,9 +2690,7 @@ class KeyFlagsTest(absltest.TestCase):
# main_module_help, so we can't keep incrementally extending
# the expected_help string ...
expected_help = ('\n%s:\n%s\n%s\n%s' %
- (sys.argv[0],
- main_module_int_fg_help,
- tmod_bar_z_help,
+ (sys.argv[0], main_module_int_fg_help, tmod_bar_z_help,
tmod_foo_bool_help))
self.assertMultiLineEqual(self.flag_values.main_module_help(),
expected_help)
@@ -2626,9 +2703,7 @@ class KeyFlagsTest(absltest.TestCase):
def test_adoptmodule_key_flags(self):
# Check that adopt_module_key_flags raises an exception when
# called with a module name (as opposed to a module object).
- self.assertRaises(flags.Error,
- flags.adopt_module_key_flags,
- 'pyglib.app')
+ self.assertRaises(flags.Error, flags.adopt_module_key_flags, 'pyglib.app')
def test_disclaimkey_flags(self):
original_disclaim_module_ids = _helpers.disclaim_module_ids
@@ -2646,36 +2721,42 @@ class FindModuleTest(absltest.TestCase):
"""Testing methods that find a module that defines a given flag."""
def test_find_module_defining_flag(self):
- self.assertEqual('default', FLAGS.find_module_defining_flag(
- '__NON_EXISTENT_FLAG__', 'default'))
self.assertEqual(
- module_baz.__name__, FLAGS.find_module_defining_flag('tmod_baz_x'))
+ 'default',
+ FLAGS.find_module_defining_flag('__NON_EXISTENT_FLAG__', 'default'))
+ self.assertEqual(module_baz.__name__,
+ FLAGS.find_module_defining_flag('tmod_baz_x'))
def test_find_module_id_defining_flag(self):
- self.assertEqual('default', FLAGS.find_module_id_defining_flag(
- '__NON_EXISTENT_FLAG__', 'default'))
+ self.assertEqual(
+ 'default',
+ FLAGS.find_module_id_defining_flag('__NON_EXISTENT_FLAG__', 'default'))
self.assertEqual(
id(module_baz), FLAGS.find_module_id_defining_flag('tmod_baz_x'))
def test_find_module_defining_flag_passing_module_name(self):
my_flags = flags.FlagValues()
module_name = sys.__name__ # Must use an existing module.
- flags.DEFINE_boolean('flag_name', True,
- 'Flag with a different module name.',
- flag_values=my_flags,
- module_name=module_name)
+ flags.DEFINE_boolean(
+ 'flag_name',
+ True,
+ 'Flag with a different module name.',
+ flag_values=my_flags,
+ module_name=module_name)
self.assertEqual(module_name,
my_flags.find_module_defining_flag('flag_name'))
def test_find_module_id_defining_flag_passing_module_name(self):
my_flags = flags.FlagValues()
module_name = sys.__name__ # Must use an existing module.
- flags.DEFINE_boolean('flag_name', True,
- 'Flag with a different module name.',
- flag_values=my_flags,
- module_name=module_name)
- self.assertEqual(id(sys),
- my_flags.find_module_id_defining_flag('flag_name'))
+ flags.DEFINE_boolean(
+ 'flag_name',
+ True,
+ 'Flag with a different module name.',
+ flag_values=my_flags,
+ module_name=module_name)
+ self.assertEqual(
+ id(sys), my_flags.find_module_id_defining_flag('flag_name'))
class FlagsErrorMessagesTest(absltest.TestCase):
@@ -2686,22 +2767,56 @@ class FlagsErrorMessagesTest(absltest.TestCase):
def test_integer_error_text(self):
# Make sure we get proper error text
- flags.DEFINE_integer('positive', 4, 'non-negative flag', lower_bound=1,
- flag_values=self.flag_values)
- flags.DEFINE_integer('non_negative', 4, 'positive flag', lower_bound=0,
- flag_values=self.flag_values)
- flags.DEFINE_integer('negative', -4, 'negative flag', upper_bound=-1,
- flag_values=self.flag_values)
- flags.DEFINE_integer('non_positive', -4, 'non-positive flag', upper_bound=0,
- flag_values=self.flag_values)
- flags.DEFINE_integer('greater', 19, 'greater-than flag', lower_bound=4,
- flag_values=self.flag_values)
- flags.DEFINE_integer('smaller', -19, 'smaller-than flag', upper_bound=4,
- flag_values=self.flag_values)
- flags.DEFINE_integer('usual', 4, 'usual flag', lower_bound=0,
- upper_bound=10000, flag_values=self.flag_values)
- flags.DEFINE_integer('another_usual', 0, 'usual flag', lower_bound=-1,
- upper_bound=1, flag_values=self.flag_values)
+ flags.DEFINE_integer(
+ 'positive',
+ 4,
+ 'non-negative flag',
+ lower_bound=1,
+ flag_values=self.flag_values)
+ flags.DEFINE_integer(
+ 'non_negative',
+ 4,
+ 'positive flag',
+ lower_bound=0,
+ flag_values=self.flag_values)
+ flags.DEFINE_integer(
+ 'negative',
+ -4,
+ 'negative flag',
+ upper_bound=-1,
+ flag_values=self.flag_values)
+ flags.DEFINE_integer(
+ 'non_positive',
+ -4,
+ 'non-positive flag',
+ upper_bound=0,
+ flag_values=self.flag_values)
+ flags.DEFINE_integer(
+ 'greater',
+ 19,
+ 'greater-than flag',
+ lower_bound=4,
+ flag_values=self.flag_values)
+ flags.DEFINE_integer(
+ 'smaller',
+ -19,
+ 'smaller-than flag',
+ upper_bound=4,
+ flag_values=self.flag_values)
+ flags.DEFINE_integer(
+ 'usual',
+ 4,
+ 'usual flag',
+ lower_bound=0,
+ upper_bound=10000,
+ flag_values=self.flag_values)
+ flags.DEFINE_integer(
+ 'another_usual',
+ 0,
+ 'usual flag',
+ lower_bound=-1,
+ upper_bound=1,
+ flag_values=self.flag_values)
self._check_error_message('positive', -4, 'a positive integer')
self._check_error_message('non_negative', -4, 'a non-negative integer')
@@ -2714,22 +2829,56 @@ class FlagsErrorMessagesTest(absltest.TestCase):
self._check_error_message('smaller', 5, 'integer <= 4')
def test_float_error_text(self):
- flags.DEFINE_float('positive', 4, 'non-negative flag', lower_bound=1,
- flag_values=self.flag_values)
- flags.DEFINE_float('non_negative', 4, 'positive flag', lower_bound=0,
- flag_values=self.flag_values)
- flags.DEFINE_float('negative', -4, 'negative flag', upper_bound=-1,
- flag_values=self.flag_values)
- flags.DEFINE_float('non_positive', -4, 'non-positive flag', upper_bound=0,
- flag_values=self.flag_values)
- flags.DEFINE_float('greater', 19, 'greater-than flag', lower_bound=4,
- flag_values=self.flag_values)
- flags.DEFINE_float('smaller', -19, 'smaller-than flag', upper_bound=4,
- flag_values=self.flag_values)
- flags.DEFINE_float('usual', 4, 'usual flag', lower_bound=0,
- upper_bound=10000, flag_values=self.flag_values)
- flags.DEFINE_float('another_usual', 0, 'usual flag', lower_bound=-1,
- upper_bound=1, flag_values=self.flag_values)
+ flags.DEFINE_float(
+ 'positive',
+ 4,
+ 'non-negative flag',
+ lower_bound=1,
+ flag_values=self.flag_values)
+ flags.DEFINE_float(
+ 'non_negative',
+ 4,
+ 'positive flag',
+ lower_bound=0,
+ flag_values=self.flag_values)
+ flags.DEFINE_float(
+ 'negative',
+ -4,
+ 'negative flag',
+ upper_bound=-1,
+ flag_values=self.flag_values)
+ flags.DEFINE_float(
+ 'non_positive',
+ -4,
+ 'non-positive flag',
+ upper_bound=0,
+ flag_values=self.flag_values)
+ flags.DEFINE_float(
+ 'greater',
+ 19,
+ 'greater-than flag',
+ lower_bound=4,
+ flag_values=self.flag_values)
+ flags.DEFINE_float(
+ 'smaller',
+ -19,
+ 'smaller-than flag',
+ upper_bound=4,
+ flag_values=self.flag_values)
+ flags.DEFINE_float(
+ 'usual',
+ 4,
+ 'usual flag',
+ lower_bound=0,
+ upper_bound=10000,
+ flag_values=self.flag_values)
+ flags.DEFINE_float(
+ 'another_usual',
+ 0,
+ 'usual flag',
+ lower_bound=-1,
+ upper_bound=1,
+ flag_values=self.flag_values)
self._check_error_message('positive', 0.5, 'number >= 1')
self._check_error_message('non_negative', -4.0, 'a non-negative number')
@@ -2748,9 +2897,11 @@ class FlagsErrorMessagesTest(absltest.TestCase):
self.flag_values.__setattr__(flag_name, flag_value)
raise AssertionError('Bounds exception not raised!')
except flags.IllegalFlagValueError as e:
- expected = ('flag --%(name)s=%(value)s: %(value)s is not %(suffix)s' %
- {'name': flag_name, 'value': flag_value,
- 'suffix': expected_message_suffix})
+ expected = ('flag --%(name)s=%(value)s: %(value)s is not %(suffix)s' % {
+ 'name': flag_name,
+ 'value': flag_value,
+ 'suffix': expected_message_suffix
+ })
self.assertEqual(str(e), expected)