diff options
author | Abseil Team <absl-team@google.com> | 2020-12-23 11:44:14 -0800 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2020-12-23 11:44:36 -0800 |
commit | b03026ac859ed10f6b380d49308213cd2c436e0f (patch) | |
tree | 4b47e8eaf213acd0ae47d304c6e5b64d0a575f1b | |
parent | bcff9304a1be4d92b7b251b8ffdb7dd011b9951e (diff) | |
download | absl-py-b03026ac859ed10f6b380d49308213cd2c436e0f.tar.gz |
Add a required argument to DEFINE_* methods.
Setting it to true, is functionally equivalent to calling `flags.mark_flag_as_required(flag_name)`, but changes the type of returned flagholder.
```
_A : FlagHolder[Optional[str]] = flags.DEFINE_string(
name='a', default=None, help='help')
flags.mark_flag_as_required('a')
```
v/s
```
_A : FlagHolder[str] = flags.DEFINE_string(
name='a', default=None, help='help', required=True)
```
PiperOrigin-RevId: 348825600
Change-Id: Ia7610af1b5c4649c20aba5cbb620eae1359ad592
-rw-r--r-- | absl/flags/__init__.pyi | 102 | ||||
-rw-r--r-- | absl/flags/_defines.py | 154 | ||||
-rw-r--r-- | absl/flags/_defines.pyi | 257 | ||||
-rw-r--r-- | absl/flags/tests/flags_test.py | 1147 |
4 files changed, 1139 insertions, 521 deletions
diff --git a/absl/flags/__init__.pyi b/absl/flags/__init__.pyi new file mode 100644 index 0000000..e016d04 --- /dev/null +++ b/absl/flags/__init__.pyi @@ -0,0 +1,102 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.flags import _argument_parser +from absl.flags import _defines +from absl.flags import _exceptions +from absl.flags import _flag +from absl.flags import _flagvalues +from absl.flags import _helpers +from absl.flags import _validators + +# DEFINE functions. They are explained in more details in the module doc string. +# pylint: disable=invalid-name +DEFINE = _defines.DEFINE +DEFINE_flag = _defines.DEFINE_flag +DEFINE_string = _defines.DEFINE_string +DEFINE_boolean = _defines.DEFINE_boolean +DEFINE_bool = DEFINE_boolean # Match C++ API. +DEFINE_float = _defines.DEFINE_float +DEFINE_integer = _defines.DEFINE_integer +DEFINE_enum = _defines.DEFINE_enum +DEFINE_enum_class = _defines.DEFINE_enum_class +DEFINE_list = _defines.DEFINE_list +DEFINE_spaceseplist = _defines.DEFINE_spaceseplist +DEFINE_multi = _defines.DEFINE_multi +DEFINE_multi_string = _defines.DEFINE_multi_string +DEFINE_multi_integer = _defines.DEFINE_multi_integer +DEFINE_multi_float = _defines.DEFINE_multi_float +DEFINE_multi_enum = _defines.DEFINE_multi_enum +DEFINE_multi_enum_class = _defines.DEFINE_multi_enum_class +DEFINE_alias = _defines.DEFINE_alias +# pylint: enable=invalid-name + +# Flag validators. +register_validator = _validators.register_validator +validator = _validators.validator +register_multi_flags_validator = _validators.register_multi_flags_validator +multi_flags_validator = _validators.multi_flags_validator +mark_flag_as_required = _validators.mark_flag_as_required +mark_flags_as_required = _validators.mark_flags_as_required +mark_flags_as_mutual_exclusive = _validators.mark_flags_as_mutual_exclusive +mark_bool_flags_as_mutual_exclusive = _validators.mark_bool_flags_as_mutual_exclusive + +# Key flag related functions. +declare_key_flag = _defines.declare_key_flag +adopt_module_key_flags = _defines.adopt_module_key_flags +disclaim_key_flags = _defines.disclaim_key_flags + +# Module exceptions. +# pylint: disable=invalid-name +Error = _exceptions.Error +CantOpenFlagFileError = _exceptions.CantOpenFlagFileError +DuplicateFlagError = _exceptions.DuplicateFlagError +IllegalFlagValueError = _exceptions.IllegalFlagValueError +UnrecognizedFlagError = _exceptions.UnrecognizedFlagError +UnparsedFlagAccessError = _exceptions.UnparsedFlagAccessError +ValidationError = _exceptions.ValidationError +FlagNameConflictsWithMethodError = _exceptions.FlagNameConflictsWithMethodError + +# Public classes. +Flag = _flag.Flag +BooleanFlag = _flag.BooleanFlag +EnumFlag = _flag.EnumFlag +EnumClassFlag = _flag.EnumClassFlag +MultiFlag = _flag.MultiFlag +MultiEnumClassFlag = _flag.MultiEnumClassFlag +FlagHolder = _flagvalues.FlagHolder +FlagValues = _flagvalues.FlagValues +ArgumentParser = _argument_parser.ArgumentParser +BooleanParser = _argument_parser.BooleanParser +EnumParser = _argument_parser.EnumParser +EnumClassParser = _argument_parser.EnumClassParser +ArgumentSerializer = _argument_parser.ArgumentSerializer +FloatParser = _argument_parser.FloatParser +IntegerParser = _argument_parser.IntegerParser +BaseListParser = _argument_parser.BaseListParser +ListParser = _argument_parser.ListParser +ListSerializer = _argument_parser.ListSerializer +CsvListSerializer = _argument_parser.CsvListSerializer +WhitespaceSeparatedListParser = _argument_parser.WhitespaceSeparatedListParser +# pylint: enable=invalid-name + +# Helper functions. +get_help_width = _helpers.get_help_width +text_wrap = _helpers.text_wrap +flag_dict_to_args = _helpers.flag_dict_to_args +doc_to_help = _helpers.doc_to_help + +# The global FlagValues instance. +FLAGS = _flagvalues.FLAGS + diff --git a/absl/flags/_defines.py b/absl/flags/_defines.py index a8d5470..10abb76 100644 --- a/absl/flags/_defines.py +++ b/absl/flags/_defines.py @@ -74,6 +74,7 @@ def DEFINE( # pylint: disable=invalid-name flag_values=_flagvalues.FLAGS, serializer=None, module_name=None, + required=False, **args): """Registers a generic Flag object. @@ -93,6 +94,7 @@ def DEFINE( # pylint: disable=invalid-name serializer: ArgumentSerializer, the flag serializer instance. module_name: str, the name of the Python module declaring this flag. If not provided, it will be computed using the stack trace of this call. + required: bool, is this a required flag. **args: dict, the extra keyword args that are passed to Flag __init__. Returns: @@ -100,10 +102,14 @@ def DEFINE( # pylint: disable=invalid-name """ return DEFINE_flag( _flag.Flag(parser, serializer, name, default, help, **args), flag_values, - module_name) + module_name, required) -def DEFINE_flag(flag, flag_values=_flagvalues.FLAGS, module_name=None): # pylint: disable=invalid-name +def DEFINE_flag( # pylint: disable=invalid-name + flag, + flag_values=_flagvalues.FLAGS, + module_name=None, + required=False): """Registers a 'Flag' object with a 'FlagValues' object. By default, the global FLAGS 'FlagValue' object is used. @@ -119,10 +125,14 @@ def DEFINE_flag(flag, flag_values=_flagvalues.FLAGS, module_name=None): # pylin registered. This should almost never need to be overridden. module_name: str, the name of the Python module declaring this flag. If not provided, it will be computed using the stack trace of this call. + required: bool, is this a required flag. Returns: a handle to defined flag. """ + if required and flag.default is not None: + raise ValueError('Required flag --%s cannot have a non-None default' % + flag.name) # Copying the reference to flag_values prevents pychecker warnings. fv = flag_values fv[flag.name] = flag @@ -133,8 +143,11 @@ def DEFINE_flag(flag, flag_values=_flagvalues.FLAGS, module_name=None): # pylin module, module_name = _helpers.get_calling_module_object_and_name() flag_values.register_flag_by_module(module_name, flag) flag_values.register_flag_by_module_id(id(module), flag) + if required: + _validators.mark_flag_as_required(flag.name, fv) + ensure_non_none_value = (flag.default is not None) or required return _flagvalues.FlagHolder( - fv, flag, ensure_non_none_value=flag.default is not None) + fv, flag, ensure_non_none_value=ensure_non_none_value) def _internal_declare_key_flags(flag_names, @@ -263,11 +276,20 @@ def DEFINE_string( # pylint: disable=invalid-name,redefined-builtin default, help, flag_values=_flagvalues.FLAGS, + required=False, **args): """Registers a flag whose value can be any string.""" parser = _argument_parser.ArgumentParser() serializer = _argument_parser.ArgumentSerializer() - return DEFINE(parser, name, default, help, flag_values, serializer, **args) + return DEFINE( + parser, + name, + default, + help, + flag_values, + serializer, + required=required, + **args) def DEFINE_boolean( # pylint: disable=invalid-name,redefined-builtin @@ -276,6 +298,7 @@ def DEFINE_boolean( # pylint: disable=invalid-name,redefined-builtin help, flag_values=_flagvalues.FLAGS, module_name=None, + required=False, **args): """Registers a boolean flag. @@ -295,13 +318,15 @@ def DEFINE_boolean( # pylint: disable=invalid-name,redefined-builtin registered. This should almost never need to be overridden. module_name: str, the name of the Python module declaring this flag. If not provided, it will be computed using the stack trace of this call. + required: bool, is this a required flag. **args: dict, the extra keyword args that are passed to Flag __init__. Returns: a handle to defined flag. """ return DEFINE_flag( - _flag.BooleanFlag(name, default, help, **args), flag_values, module_name) + _flag.BooleanFlag(name, default, help, **args), flag_values, module_name, + required) def DEFINE_float( # pylint: disable=invalid-name,redefined-builtin @@ -311,6 +336,7 @@ def DEFINE_float( # pylint: disable=invalid-name,redefined-builtin lower_bound=None, upper_bound=None, flag_values=_flagvalues.FLAGS, + required=False, **args): """Registers a flag whose value must be a float. @@ -325,6 +351,7 @@ def DEFINE_float( # pylint: disable=invalid-name,redefined-builtin upper_bound: float, max value of the flag. flag_values: FlagValues, the FlagValues instance with which the flag will be registered. This should almost never need to be overridden. + required: bool, is this a required flag. **args: dict, the extra keyword args that are passed to DEFINE. Returns: @@ -332,7 +359,15 @@ def DEFINE_float( # pylint: disable=invalid-name,redefined-builtin """ parser = _argument_parser.FloatParser(lower_bound, upper_bound) serializer = _argument_parser.ArgumentSerializer() - result = DEFINE(parser, name, default, help, flag_values, serializer, **args) + result = DEFINE( + parser, + name, + default, + help, + flag_values, + serializer, + required=required, + **args) _register_bounds_validator_if_needed(parser, name, flag_values=flag_values) return result @@ -344,6 +379,7 @@ def DEFINE_integer( # pylint: disable=invalid-name,redefined-builtin lower_bound=None, upper_bound=None, flag_values=_flagvalues.FLAGS, + required=False, **args): """Registers a flag whose value must be an integer. @@ -358,6 +394,7 @@ def DEFINE_integer( # pylint: disable=invalid-name,redefined-builtin upper_bound: int, max value of the flag. flag_values: FlagValues, the FlagValues instance with which the flag will be registered. This should almost never need to be overridden. + required: bool, is this a required flag. **args: dict, the extra keyword args that are passed to DEFINE. Returns: @@ -365,7 +402,15 @@ def DEFINE_integer( # pylint: disable=invalid-name,redefined-builtin """ parser = _argument_parser.IntegerParser(lower_bound, upper_bound) serializer = _argument_parser.ArgumentSerializer() - result = DEFINE(parser, name, default, help, flag_values, serializer, **args) + result = DEFINE( + parser, + name, + default, + help, + flag_values, + serializer, + required=required, + **args) _register_bounds_validator_if_needed(parser, name, flag_values=flag_values) return result @@ -377,6 +422,7 @@ def DEFINE_enum( # pylint: disable=invalid-name,redefined-builtin help, flag_values=_flagvalues.FLAGS, module_name=None, + required=False, **args): """Registers a flag whose value can be any string from enum_values. @@ -393,6 +439,7 @@ def DEFINE_enum( # pylint: disable=invalid-name,redefined-builtin registered. This should almost never need to be overridden. module_name: str, the name of the Python module declaring this flag. If not provided, it will be computed using the stack trace of this call. + required: bool, is this a required flag. **args: dict, the extra keyword args that are passed to Flag __init__. Returns: @@ -400,7 +447,7 @@ def DEFINE_enum( # pylint: disable=invalid-name,redefined-builtin """ return DEFINE_flag( _flag.EnumFlag(name, default, help, enum_values, **args), flag_values, - module_name) + module_name, required) def DEFINE_enum_class( # pylint: disable=invalid-name,redefined-builtin @@ -411,6 +458,7 @@ def DEFINE_enum_class( # pylint: disable=invalid-name,redefined-builtin flag_values=_flagvalues.FLAGS, module_name=None, case_sensitive=False, + required=False, **args): """Registers a flag whose value can be the name of enum members. @@ -425,6 +473,7 @@ def DEFINE_enum_class( # pylint: disable=invalid-name,redefined-builtin provided, it will be computed using the stack trace of this call. case_sensitive: bool, whether to map strings to members of the enum_class without considering case. + required: bool, is this a required flag. **args: dict, the extra keyword args that are passed to Flag __init__. Returns: @@ -437,7 +486,7 @@ def DEFINE_enum_class( # pylint: disable=invalid-name,redefined-builtin help, enum_class, case_sensitive=case_sensitive, - **args), flag_values, module_name) + **args), flag_values, module_name, required) def DEFINE_list( # pylint: disable=invalid-name,redefined-builtin @@ -445,6 +494,7 @@ def DEFINE_list( # pylint: disable=invalid-name,redefined-builtin default, help, flag_values=_flagvalues.FLAGS, + required=False, **args): """Registers a flag whose value is a comma-separated list of strings. @@ -456,6 +506,7 @@ def DEFINE_list( # pylint: disable=invalid-name,redefined-builtin help: str, the help message. flag_values: FlagValues, the FlagValues instance with which the flag will be registered. This should almost never need to be overridden. + required: bool, is this a required flag. **args: Dictionary with extra keyword args that are passed to the Flag __init__. @@ -464,7 +515,15 @@ def DEFINE_list( # pylint: disable=invalid-name,redefined-builtin """ parser = _argument_parser.ListParser() serializer = _argument_parser.CsvListSerializer(',') - return DEFINE(parser, name, default, help, flag_values, serializer, **args) + return DEFINE( + parser, + name, + default, + help, + flag_values, + serializer, + required=required, + **args) def DEFINE_spaceseplist( # pylint: disable=invalid-name,redefined-builtin @@ -473,6 +532,7 @@ def DEFINE_spaceseplist( # pylint: disable=invalid-name,redefined-builtin help, comma_compat=False, flag_values=_flagvalues.FLAGS, + required=False, **args): """Registers a flag whose value is a whitespace-separated list of strings. @@ -487,6 +547,7 @@ def DEFINE_spaceseplist( # pylint: disable=invalid-name,redefined-builtin backwards compatibility with flags that used to be comma-separated. flag_values: FlagValues, the FlagValues instance with which the flag will be registered. This should almost never need to be overridden. + required: bool, is this a required flag. **args: Dictionary with extra keyword args that are passed to the Flag __init__. @@ -496,7 +557,15 @@ def DEFINE_spaceseplist( # pylint: disable=invalid-name,redefined-builtin parser = _argument_parser.WhitespaceSeparatedListParser( comma_compat=comma_compat) serializer = _argument_parser.ListSerializer(' ') - return DEFINE(parser, name, default, help, flag_values, serializer, **args) + return DEFINE( + parser, + name, + default, + help, + flag_values, + serializer, + required=required, + **args) def DEFINE_multi( # pylint: disable=invalid-name,redefined-builtin @@ -507,6 +576,7 @@ def DEFINE_multi( # pylint: disable=invalid-name,redefined-builtin help, flag_values=_flagvalues.FLAGS, module_name=None, + required=False, **args): """Registers a generic MultiFlag that parses its args with a given parser. @@ -530,6 +600,7 @@ def DEFINE_multi( # pylint: disable=invalid-name,redefined-builtin registered. This should almost never need to be overridden. module_name: A string, the name of the Python module declaring this flag. If not provided, it will be computed using the stack trace of this call. + required: bool, is this a required flag. **args: Dictionary with extra keyword args that are passed to the Flag __init__. @@ -538,7 +609,7 @@ def DEFINE_multi( # pylint: disable=invalid-name,redefined-builtin """ return DEFINE_flag( _flag.MultiFlag(parser, serializer, name, default, help, **args), - flag_values, module_name) + flag_values, module_name, required) def DEFINE_multi_string( # pylint: disable=invalid-name,redefined-builtin @@ -546,6 +617,7 @@ def DEFINE_multi_string( # pylint: disable=invalid-name,redefined-builtin default, help, flag_values=_flagvalues.FLAGS, + required=False, **args): """Registers a flag whose value can be a list of any strings. @@ -562,6 +634,7 @@ def DEFINE_multi_string( # pylint: disable=invalid-name,redefined-builtin help: str, the help message. flag_values: FlagValues, the FlagValues instance with which the flag will be registered. This should almost never need to be overridden. + required: bool, is this a required flag. **args: Dictionary with extra keyword args that are passed to the Flag __init__. @@ -570,8 +643,15 @@ def DEFINE_multi_string( # pylint: disable=invalid-name,redefined-builtin """ parser = _argument_parser.ArgumentParser() serializer = _argument_parser.ArgumentSerializer() - return DEFINE_multi(parser, serializer, name, default, help, flag_values, - **args) + return DEFINE_multi( + parser, + serializer, + name, + default, + help, + flag_values, + required=required, + **args) def DEFINE_multi_integer( # pylint: disable=invalid-name,redefined-builtin @@ -581,6 +661,7 @@ def DEFINE_multi_integer( # pylint: disable=invalid-name,redefined-builtin lower_bound=None, upper_bound=None, flag_values=_flagvalues.FLAGS, + required=False, **args): """Registers a flag whose value can be a list of arbitrary integers. @@ -598,6 +679,7 @@ def DEFINE_multi_integer( # pylint: disable=invalid-name,redefined-builtin upper_bound: int, max values of the flag. flag_values: FlagValues, the FlagValues instance with which the flag will be registered. This should almost never need to be overridden. + required: bool, is this a required flag. **args: Dictionary with extra keyword args that are passed to the Flag __init__. @@ -606,8 +688,15 @@ def DEFINE_multi_integer( # pylint: disable=invalid-name,redefined-builtin """ parser = _argument_parser.IntegerParser(lower_bound, upper_bound) serializer = _argument_parser.ArgumentSerializer() - return DEFINE_multi(parser, serializer, name, default, help, flag_values, - **args) + return DEFINE_multi( + parser, + serializer, + name, + default, + help, + flag_values, + required=required, + **args) def DEFINE_multi_float( # pylint: disable=invalid-name,redefined-builtin @@ -617,6 +706,7 @@ def DEFINE_multi_float( # pylint: disable=invalid-name,redefined-builtin lower_bound=None, upper_bound=None, flag_values=_flagvalues.FLAGS, + required=False, **args): """Registers a flag whose value can be a list of arbitrary floats. @@ -634,6 +724,7 @@ def DEFINE_multi_float( # pylint: disable=invalid-name,redefined-builtin upper_bound: float, max values of the flag. flag_values: FlagValues, the FlagValues instance with which the flag will be registered. This should almost never need to be overridden. + required: bool, is this a required flag. **args: Dictionary with extra keyword args that are passed to the Flag __init__. @@ -642,8 +733,15 @@ def DEFINE_multi_float( # pylint: disable=invalid-name,redefined-builtin """ parser = _argument_parser.FloatParser(lower_bound, upper_bound) serializer = _argument_parser.ArgumentSerializer() - return DEFINE_multi(parser, serializer, name, default, help, flag_values, - **args) + return DEFINE_multi( + parser, + serializer, + name, + default, + help, + flag_values, + required=required, + **args) def DEFINE_multi_enum( # pylint: disable=invalid-name,redefined-builtin @@ -653,6 +751,7 @@ def DEFINE_multi_enum( # pylint: disable=invalid-name,redefined-builtin help, flag_values=_flagvalues.FLAGS, case_sensitive=True, + required=False, **args): """Registers a flag whose value can be a list strings from enum_values. @@ -671,6 +770,7 @@ def DEFINE_multi_enum( # pylint: disable=invalid-name,redefined-builtin flag_values: FlagValues, the FlagValues instance with which the flag will be registered. This should almost never need to be overridden. case_sensitive: Whether or not the enum is to be case-sensitive. + required: bool, is this a required flag. **args: Dictionary with extra keyword args that are passed to the Flag __init__. @@ -679,8 +779,15 @@ def DEFINE_multi_enum( # pylint: disable=invalid-name,redefined-builtin """ parser = _argument_parser.EnumParser(enum_values, case_sensitive) serializer = _argument_parser.ArgumentSerializer() - return DEFINE_multi(parser, serializer, name, default, help, flag_values, - **args) + return DEFINE_multi( + parser, + serializer, + name, + default, + help, + flag_values, + required=required, + **args) def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin @@ -691,6 +798,7 @@ def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin flag_values=_flagvalues.FLAGS, module_name=None, case_sensitive=False, + required=False, **args): """Registers a flag whose value can be a list of enum members. @@ -712,6 +820,7 @@ def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin not provided, it will be computed using the stack trace of this call. case_sensitive: bool, whether to map strings to members of the enum_class without considering case. + required: bool, is this a required flag. **args: Dictionary with extra keyword args that are passed to the Flag __init__. @@ -721,7 +830,10 @@ def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin return DEFINE_flag( _flag.MultiEnumClassFlag( name, default, help, enum_class, case_sensitive=case_sensitive), - flag_values, module_name, **args) + flag_values, + module_name, + required=required, + **args) def DEFINE_alias( # pylint: disable=invalid-name diff --git a/absl/flags/_defines.pyi b/absl/flags/_defines.pyi index 5949f0a..2d8a201 100644 --- a/absl/flags/_defines.pyi +++ b/absl/flags/_defines.pyi @@ -20,12 +20,13 @@ from absl.flags import _flagvalues import enum -from typing import Text, List, Any, TypeVar, Optional, Union, Type, Iterable, overload +from typing import Text, List, Any, TypeVar, Optional, Union, Type, Iterable, overload, Literal _T = TypeVar('_T') _ET = TypeVar('_ET', bound=enum.Enum) +@overload def DEFINE( parser: _argument_parser.ArgumentParser[_T], name: Text, @@ -36,15 +37,88 @@ def DEFINE( flag_values : _flagvalues.FlagValues = ..., serializer: Optional[_argument_parser.ArgumentSerializer[_T]] = None, module_name: Optional[Text] = None, + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[_T]: + ... + + +@overload +def DEFINE( + parser: _argument_parser.ArgumentParser[_T], + name: Text, + default: Any, + help: Optional[Text], + # Explicitly replacing ... with _flagvalues.FLAGS causes pytype to + # not like the syntax. + flag_values : _flagvalues.FlagValues = ..., + serializer: Optional[_argument_parser.ArgumentSerializer[_T]] = None, + module_name: Optional[Text] = None, + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[Optional[_T]]: ... +@overload +def DEFINE_flag( + flag: _flag.Flag[_T], + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = None, + required: Literal[True]) -> _flagvalues.FlagHolder[_T]: + ... +@overload def DEFINE_flag( flag: _flag.Flag[_T], flag_values: _flagvalues.FlagValues = ..., - module_name: Optional[Text] = None) -> _flagvalues.FlagHolder[Optional[_T]]: + module_name: Optional[Text] = None, + required: bool = False) -> _flagvalues.FlagHolder[Optional[_T]]: + ... + +# typing overloads for DEFINE_* methods... +# +# - DEFINE_* method return FlagHolder[Optional[T]] or FlagHolder[T] depending +# on the arguments. +# - If the flag value is guaranteed to be not None, the return type is +# FlagHolder[T]. +# - If the flag is required OR has a non-None default, the flag value i +# guaranteed to be not None after flag parsing has finished. +# The information above is captured with three overloads as follows. +# +# (if required=True, return type is FlagHolder[Y]) +# @overload +# def DEFINE_xxx( +# ... arguments... +# default : Union[None, X], +# required: Literal[True]) -> _flagvalues.FlagHolder[Y]: +# ... +# +# (if default=None, return type is FlagHolder[Optional[Y]]) +# @overload +# def DEFINE_xxx( +# ... arguments... +# default : None, +# required: bool = False) -> _flagvalues.FlagHolder[Optional[Y]]: +# ... +# +# (if default!=None, return type is FlagHolder[Y])# +# @overload +# def DEFINE_xxx( +# ... arguments... +# default: X, +# required: bool = False) -> _flagvalues.FlagHolder[Y]: +# ... +# +# where X = type of non-None default values for the flag +# and Y = non-None type for flag value + +@overload +def DEFINE_string( + name: Text, + default: Optional[Text], + help: Optional[Text], + flag_values: _flagvalues.FlagValues = ..., + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[Text]: ... @overload @@ -53,6 +127,7 @@ def DEFINE_string( default: None, help: Optional[Text], flag_values: _flagvalues.FlagValues = ..., + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[Optional[Text]]: ... @@ -62,16 +137,29 @@ def DEFINE_string( default: Text, help: Optional[Text], flag_values: _flagvalues.FlagValues = ..., + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[Text]: ... @overload def DEFINE_boolean( name : Text, + default: Union[None, Text, bool, int], + help: Optional[Text], + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = None, + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[bool]: + ... + +@overload +def DEFINE_boolean( + name : Text, default: None, help: Optional[Text], flag_values: _flagvalues.FlagValues = ..., module_name: Optional[Text] = None, + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[Optional[bool]]: ... @@ -82,17 +170,31 @@ def DEFINE_boolean( help: Optional[Text], flag_values: _flagvalues.FlagValues = ..., module_name: Optional[Text] = None, + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[bool]: ... @overload def DEFINE_float( name: Text, + default: Union[None, float, Text], + help: Optional[Text], + lower_bound: Optional[float] = None, + upper_bound: Optional[float] = None, + flag_values: _flagvalues.FlagValues = ..., + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[float]: + ... + +@overload +def DEFINE_float( + name: Text, default: None, help: Optional[Text], lower_bound: Optional[float] = None, upper_bound: Optional[float] = None, flag_values: _flagvalues.FlagValues = ..., + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[Optional[float]]: ... @@ -104,6 +206,7 @@ def DEFINE_float( lower_bound: Optional[float] = None, upper_bound: Optional[float] = None, flag_values: _flagvalues.FlagValues = ..., + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[float]: ... @@ -111,11 +214,24 @@ def DEFINE_float( @overload def DEFINE_integer( name: Text, + default: Union[None, int, Text], + help: Optional[Text], + lower_bound: Optional[int] = None, + upper_bound: Optional[int] = None, + flag_values: _flagvalues.FlagValues = ..., + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[int]: + ... + +@overload +def DEFINE_integer( + name: Text, default: None, help: Optional[Text], lower_bound: Optional[int] = None, upper_bound: Optional[int] = None, flag_values: _flagvalues.FlagValues = ..., + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[Optional[int]]: ... @@ -127,17 +243,31 @@ def DEFINE_integer( lower_bound: Optional[int] = None, upper_bound: Optional[int] = None, flag_values: _flagvalues.FlagValues = ..., + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[int]: ... @overload def DEFINE_enum( name : Text, + default: Optional[Text], + enum_values: Iterable[Text], + help: Optional[Text], + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = None, + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[Text]: + ... + +@overload +def DEFINE_enum( + name : Text, default: None, enum_values: Iterable[Text], help: Optional[Text], flag_values: _flagvalues.FlagValues = ..., module_name: Optional[Text] = None, + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[Optional[Text]]: ... @@ -149,18 +279,33 @@ def DEFINE_enum( help: Optional[Text], flag_values: _flagvalues.FlagValues = ..., module_name: Optional[Text] = None, + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[Text]: ... @overload def DEFINE_enum_class( name: Text, + default: Union[None, _ET, Text], + enum_class: Type[_ET], + help: Optional[Text], + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = None, + case_sensitive: bool = ..., + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[_ET]: + ... + +@overload +def DEFINE_enum_class( + name: Text, default: None, enum_class: Type[_ET], help: Optional[Text], flag_values: _flagvalues.FlagValues = ..., module_name: Optional[Text] = None, case_sensitive: bool = ..., + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[Optional[_ET]]: ... @@ -173,6 +318,7 @@ def DEFINE_enum_class( flag_values: _flagvalues.FlagValues = ..., module_name: Optional[Text] = None, case_sensitive: bool = ..., + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[_ET]: ... @@ -180,9 +326,20 @@ def DEFINE_enum_class( @overload def DEFINE_list( name: Text, + default: Union[None, Iterable[Text], Text], + help: Text, + flag_values: _flagvalues.FlagValues = ..., + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[List[Text]]: + ... + +@overload +def DEFINE_list( + name: Text, default: None, help: Text, flag_values: _flagvalues.FlagValues = ..., + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[Optional[List[Text]]]: ... @@ -192,6 +349,18 @@ def DEFINE_list( default: Union[Iterable[Text], Text], help: Text, flag_values: _flagvalues.FlagValues = ..., + required: bool = False, + **args: Any) -> _flagvalues.FlagHolder[List[Text]]: + ... + +@overload +def DEFINE_spaceseplist( + name: Text, + default: Union[None, Iterable[Text], Text], + help: Text, + comma_compat: bool = False, + flag_values: _flagvalues.FlagValues = ..., + required: Literal[True], **args: Any) -> _flagvalues.FlagHolder[List[Text]]: ... @@ -202,6 +371,7 @@ def DEFINE_spaceseplist( help: Text, comma_compat: bool = False, flag_values: _flagvalues.FlagValues = ..., + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[Optional[List[Text]]]: ... @@ -212,6 +382,7 @@ def DEFINE_spaceseplist( help: Text, comma_compat: bool = False, flag_values: _flagvalues.FlagValues = ..., + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[List[Text]]: ... @@ -220,10 +391,24 @@ def DEFINE_multi( parser : _argument_parser.ArgumentParser[_T], serializer: _argument_parser.ArgumentSerializer[_T], name: Text, + default: Union[None, Iterable[_T], _T, Text], + help: Text, + flag_values:_flagvalues.FlagValues = ..., + module_name: Optional[Text] = None, + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[List[_T]]: + ... + +@overload +def DEFINE_multi( + parser : _argument_parser.ArgumentParser[_T], + serializer: _argument_parser.ArgumentSerializer[_T], + name: Text, default: None, help: Text, flag_values:_flagvalues.FlagValues = ..., module_name: Optional[Text] = None, + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[Optional[List[_T]]]: ... @@ -236,15 +421,27 @@ def DEFINE_multi( help: Text, flag_values:_flagvalues.FlagValues = ..., module_name: Optional[Text] = None, + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[List[_T]]: ... @overload def DEFINE_multi_string( name: Text, + default: Union[None, Iterable[Text], Text], + help: Text, + flag_values: _flagvalues.FlagValues = ..., + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[List[Text]]: + ... + +@overload +def DEFINE_multi_string( + name: Text, default: None, help: Text, flag_values: _flagvalues.FlagValues = ..., + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[Optional[List[Text]]]: ... @@ -254,17 +451,31 @@ def DEFINE_multi_string( default: Union[Iterable[Text], Text], help: Text, flag_values: _flagvalues.FlagValues = ..., + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[List[Text]]: ... @overload def DEFINE_multi_integer( name: Text, + default: Union[None, Iterable[int], int, Text], + help: Text, + lower_bound: Optional[int] = None, + upper_bound: Optional[int] = None, + flag_values: _flagvalues.FlagValues = ..., + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[List[int]]: + ... + +@overload +def DEFINE_multi_integer( + name: Text, default: None, help: Text, lower_bound: Optional[int] = None, upper_bound: Optional[int] = None, flag_values: _flagvalues.FlagValues = ..., + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[Optional[List[int]]]: ... @@ -276,17 +487,31 @@ def DEFINE_multi_integer( lower_bound: Optional[int] = None, upper_bound: Optional[int] = None, flag_values: _flagvalues.FlagValues = ..., + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[List[int]]: ... @overload def DEFINE_multi_float( name: Text, + default: Union[None, Iterable[float], float, Text], + help: Text, + lower_bound: Optional[float] = None, + upper_bound: Optional[float] = None, + flag_values: _flagvalues.FlagValues = ..., + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[List[float]]: + ... + +@overload +def DEFINE_multi_float( + name: Text, default: None, help: Text, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None, flag_values: _flagvalues.FlagValues = ..., + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[Optional[List[float]]]: ... @@ -298,6 +523,7 @@ def DEFINE_multi_float( lower_bound: Optional[float] = None, upper_bound: Optional[float] = None, flag_values: _flagvalues.FlagValues = ..., + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[List[float]]: ... @@ -305,10 +531,22 @@ def DEFINE_multi_float( @overload def DEFINE_multi_enum( name: Text, + default: Union[None, Iterable[Text], Text], + enum_values: Iterable[Text], + help: Text, + flag_values: _flagvalues.FlagValues = ..., + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[List[Text]]: + ... + +@overload +def DEFINE_multi_enum( + name: Text, default: None, enum_values: Iterable[Text], help: Text, flag_values: _flagvalues.FlagValues = ..., + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[Optional[List[Text]]]: ... @@ -319,17 +557,31 @@ def DEFINE_multi_enum( enum_values: Iterable[Text], help: Text, flag_values: _flagvalues.FlagValues = ..., + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[List[Text]]: ... @overload def DEFINE_multi_enum_class( name: Text, + default: Union[None, Iterable[_ET], _ET, Text], + enum_class: Type[_ET], + help: Text, + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = None, + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[List[_ET]]: + ... + +@overload +def DEFINE_multi_enum_class( + name: Text, default: None, enum_class: Type[_ET], help: Text, flag_values: _flagvalues.FlagValues = ..., module_name: Optional[Text] = None, + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[Optional[List[_ET]]]: ... @@ -341,6 +593,7 @@ def DEFINE_multi_enum_class( help: Text, flag_values: _flagvalues.FlagValues = ..., module_name: Optional[Text] = None, + required: bool = False, **args: Any) -> _flagvalues.FlagHolder[List[_ET]]: ... diff --git a/absl/flags/tests/flags_test.py b/absl/flags/tests/flags_test.py index 8608f2f..2d63c40 100644 --- a/absl/flags/tests/flags_test.py +++ b/absl/flags/tests/flags_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tests for absl.flags used as a package.""" from __future__ import absolute_import @@ -36,7 +35,6 @@ from absl.flags.tests import module_foo from absl.testing import absltest import six - FLAGS = flags.FLAGS @@ -61,10 +59,8 @@ class FlagDictToArgsTest(absltest.TestCase): 'loadthatstuff': [42, 'hello', 'goodbye'], } self.assertSameElements( - ( - '--week-end', '--noestudia', '--notrabaja', - '--party', '--monday=party', '--score=42', - '--loadthatstuff=42,hello,goodbye'), + ('--week-end', '--noestudia', '--notrabaja', '--party', + '--monday=party', '--score=42', '--loadthatstuff=42,hello,goodbye'), flags.flag_dict_to_args(arg_dict)) def test_flatten_google_flag_map_with_multi_flag(self): @@ -73,9 +69,8 @@ class FlagDictToArgsTest(absltest.TestCase): 'some_multi_string': ['value3', 'value4'], } self.assertSameElements( - ( - '--some_list=value1,value2', '--some_multi_string=value3', - '--some_multi_string=value4'), + ('--some_list=value1,value2', '--some_multi_string=value3', + '--some_multi_string=value4'), flags.flag_dict_to_args(arg_dict, multi_flags={'some_multi_string'})) @@ -162,8 +157,7 @@ class AliasFlagsTest(absltest.TestCase): self.assertEqual('--alias=[0, 1]', actual) def test_allow_overwrite_false(self): - self.define_integer('aliased', None, 'help', - allow_overwrite=False) + self.define_integer('aliased', None, 'help', allow_overwrite=False) self.define_alias('alias', 'aliased') with self.assertRaisesRegex(flags.IllegalFlagValueError, 'already defined'): @@ -173,6 +167,7 @@ class AliasFlagsTest(absltest.TestCase): self.assertEqual(1, self.aliased.value) def test_aliasing_multi_no_default(self): + def define_flags(): self.flags = flags.FlagValues() self.define_multi_integer('aliased', None, 'help') @@ -202,6 +197,7 @@ class AliasFlagsTest(absltest.TestCase): self.assert_alias_mirrors_aliased(self.alias, self.aliased) def test_aliasing_multi_with_default(self): + def define_flags(): self.flags = flags.FlagValues() self.define_multi_integer('aliased', [0], 'help') @@ -243,6 +239,7 @@ class AliasFlagsTest(absltest.TestCase): self.assertEqual(0, self.aliased.present) def test_aliasing_regular(self): + def define_flags(): self.flags = flags.FlagValues() self.define_string('aliased', '', 'help') @@ -296,8 +293,8 @@ class FlagsUnitTest(absltest.TestCase): # Define flags number_test_framework_flags = len(FLAGS) repeat_help = 'how many times to repeat (0-5)' - flags.DEFINE_integer('repeat', 4, repeat_help, - lower_bound=0, short_name='r') + flags.DEFINE_integer( + 'repeat', 4, repeat_help, lower_bound=0, short_name='r') flags.DEFINE_string('name', 'Bob', 'namehelp') flags.DEFINE_boolean('debug', 0, 'debughelp') flags.DEFINE_boolean('q', 1, 'quiet mode') @@ -314,24 +311,35 @@ class FlagsUnitTest(absltest.TestCase): flags.DEFINE_list('numbers', [1, 2, 3], 'a list of numbers') flags.DEFINE_enum('kwery', None, ['who', 'what', 'Why', 'where', 'when'], '?') - flags.DEFINE_enum('sense', None, ['Case', 'case', 'CASE'], - '?', case_sensitive=True) - flags.DEFINE_enum('cases', None, ['UPPER', 'lower', 'Initial', 'Ot_HeR'], - '?', case_sensitive=False) - flags.DEFINE_enum('funny', None, ['Joke', 'ha', 'ha', 'ha', 'ha'], - '?', case_sensitive=True) - flags.DEFINE_enum('blah', None, ['bla', 'Blah', 'BLAH', 'blah'], - '?', case_sensitive=False) - flags.DEFINE_string('only_once', None, 'test only sets this once', - allow_overwrite=False) - flags.DEFINE_string('universe', None, 'test tries to set this three times', - allow_overwrite=False) + flags.DEFINE_enum( + 'sense', None, ['Case', 'case', 'CASE'], '?', case_sensitive=True) + flags.DEFINE_enum( + 'cases', + None, ['UPPER', 'lower', 'Initial', 'Ot_HeR'], + '?', + case_sensitive=False) + flags.DEFINE_enum( + 'funny', + None, ['Joke', 'ha', 'ha', 'ha', 'ha'], + '?', + case_sensitive=True) + flags.DEFINE_enum( + 'blah', + None, ['bla', 'Blah', 'BLAH', 'blah'], + '?', + case_sensitive=False) + flags.DEFINE_string( + 'only_once', None, 'test only sets this once', allow_overwrite=False) + flags.DEFINE_string( + 'universe', + None, + 'test tries to set this three times', + allow_overwrite=False) # Specify number of flags defined above. The short_name defined # for 'repeat' counts as an extra flag. number_defined_flags = 22 + 1 - self.assertEqual(len(FLAGS), - number_defined_flags + number_test_framework_flags) + self.assertLen(FLAGS, number_defined_flags + number_test_framework_flags) self.assertEqual(FLAGS.repeat, 4) self.assertEqual(FLAGS.name, 'Bob') @@ -393,13 +401,13 @@ class FlagsUnitTest(absltest.TestCase): # .. empty command line argv = ('./program',) argv = FLAGS(argv) - self.assertEqual(len(argv), 1, 'wrong number of arguments pulled') + self.assertLen(argv, 1, 'wrong number of arguments pulled') self.assertEqual(argv[0], './program', 'program name not preserved') # .. non-empty command line argv = ('./program', '--debug', '--name=Bob', '-q', '--x=8') argv = FLAGS(argv) - self.assertEqual(len(argv), 1, 'wrong number of arguments pulled') + self.assertLen(argv, 1, 'wrong number of arguments pulled') self.assertEqual(argv[0], './program', 'program name not preserved') self.assertEqual(FLAGS['debug'].present, 1) FLAGS['debug'].present = 0 # Reset @@ -411,8 +419,7 @@ class FlagsUnitTest(absltest.TestCase): FLAGS['x'].present = 0 # Reset # Flags list. - self.assertEqual(len(FLAGS), - number_defined_flags + number_test_framework_flags) + self.assertLen(FLAGS, number_defined_flags + number_test_framework_flags) self.assertIn('name', FLAGS) self.assertIn('debug', FLAGS) self.assertIn('repeat', FLAGS) @@ -431,14 +438,14 @@ class FlagsUnitTest(absltest.TestCase): # try deleting a flag del FLAGS.r - self.assertEqual(len(FLAGS), - number_defined_flags - 1 + number_test_framework_flags) + self.assertLen(FLAGS, + number_defined_flags - 1 + number_test_framework_flags) self.assertNotIn('r', FLAGS) # .. command line with extra stuff argv = ('./program', '--debug', '--name=Bob', 'extra') argv = FLAGS(argv) - self.assertEqual(len(argv), 2, 'wrong number of arguments pulled') + self.assertLen(argv, 2, 'wrong number of arguments pulled') self.assertEqual(argv[0], './program', 'program name not preserved') self.assertEqual(argv[1], 'extra', 'extra argument not preserved') self.assertEqual(FLAGS['debug'].present, 1) @@ -449,7 +456,7 @@ class FlagsUnitTest(absltest.TestCase): # Test reset argv = ('./program', '--debug') argv = FLAGS(argv) - self.assertEqual(len(argv), 1, 'wrong number of arguments pulled') + self.assertLen(argv, 1, 'wrong number of arguments pulled') self.assertEqual(argv[0], './program', 'program name not preserved') self.assertEqual(FLAGS['debug'].present, 1) self.assertTrue(FLAGS['debug'].value) @@ -460,14 +467,14 @@ class FlagsUnitTest(absltest.TestCase): # Test that reset restores default value when default value is None. argv = ('./program', '--kwery=who') argv = FLAGS(argv) - self.assertEqual(len(argv), 1, 'wrong number of arguments pulled') + self.assertLen(argv, 1, 'wrong number of arguments pulled') self.assertEqual(argv[0], './program', 'program name not preserved') self.assertEqual(FLAGS['kwery'].present, 1) self.assertEqual(FLAGS['kwery'].value, 'who') FLAGS.unparse_flags() argv = ('./program', '--kwery=Why') argv = FLAGS(argv) - self.assertEqual(len(argv), 1, 'wrong number of arguments pulled') + self.assertLen(argv, 1, 'wrong number of arguments pulled') self.assertEqual(argv[0], './program', 'program name not preserved') self.assertEqual(FLAGS['kwery'].present, 1) self.assertEqual(FLAGS['kwery'].value, 'Why') @@ -478,14 +485,14 @@ class FlagsUnitTest(absltest.TestCase): # Test case sensitive enum. argv = ('./program', '--sense=CASE') argv = FLAGS(argv) - self.assertEqual(len(argv), 1, 'wrong number of arguments pulled') + self.assertLen(argv, 1, 'wrong number of arguments pulled') self.assertEqual(argv[0], './program', 'program name not preserved') self.assertEqual(FLAGS['sense'].present, 1) self.assertEqual(FLAGS['sense'].value, 'CASE') FLAGS.unparse_flags() argv = ('./program', '--sense=Case') argv = FLAGS(argv) - self.assertEqual(len(argv), 1, 'wrong number of arguments pulled') + self.assertLen(argv, 1, 'wrong number of arguments pulled') self.assertEqual(argv[0], './program', 'program name not preserved') self.assertEqual(FLAGS['sense'].present, 1) self.assertEqual(FLAGS['sense'].value, 'Case') @@ -494,7 +501,7 @@ class FlagsUnitTest(absltest.TestCase): # Test case insensitive enum. argv = ('./program', '--cases=upper') argv = FLAGS(argv) - self.assertEqual(len(argv), 1, 'wrong number of arguments pulled') + self.assertLen(argv, 1, 'wrong number of arguments pulled') self.assertEqual(argv[0], './program', 'program name not preserved') self.assertEqual(FLAGS['cases'].present, 1) self.assertEqual(FLAGS['cases'].value, 'UPPER') @@ -503,7 +510,7 @@ class FlagsUnitTest(absltest.TestCase): # Test case sensitive enum with duplicates. argv = ('./program', '--funny=ha') argv = FLAGS(argv) - self.assertEqual(len(argv), 1, 'wrong number of arguments pulled') + self.assertLen(argv, 1, 'wrong number of arguments pulled') self.assertEqual(argv[0], './program', 'program name not preserved') self.assertEqual(FLAGS['funny'].present, 1) self.assertEqual(FLAGS['funny'].value, 'ha') @@ -512,14 +519,14 @@ class FlagsUnitTest(absltest.TestCase): # Test case insensitive enum with duplicates. argv = ('./program', '--blah=bLah') argv = FLAGS(argv) - self.assertEqual(len(argv), 1, 'wrong number of arguments pulled') + self.assertLen(argv, 1, 'wrong number of arguments pulled') self.assertEqual(argv[0], './program', 'program name not preserved') self.assertEqual(FLAGS['blah'].present, 1) self.assertEqual(FLAGS['blah'].value, 'Blah') FLAGS.unparse_flags() argv = ('./program', '--blah=BLAH') argv = FLAGS(argv) - self.assertEqual(len(argv), 1, 'wrong number of arguments pulled') + self.assertLen(argv, 1, 'wrong number of arguments pulled') self.assertEqual(argv[0], './program', 'program name not preserved') self.assertEqual(FLAGS['blah'].present, 1) self.assertEqual(FLAGS['blah'].value, 'Blah') @@ -617,22 +624,21 @@ class FlagsUnitTest(absltest.TestCase): self.assertEqual(FLAGS.get_flag_value('testget4', 'foo'), 'foo') # test list code - lists = [['hello', 'moo', 'boo', '1'], - []] + lists = [['hello', 'moo', 'boo', '1'], []] flags.DEFINE_list('testcomma_list', '', 'test comma list parsing') flags.DEFINE_spaceseplist('testspace_list', '', 'tests space list parsing') flags.DEFINE_spaceseplist( - 'testspace_or_comma_list', '', - 'tests space list parsing with comma compatibility', comma_compat=True) - - for name, sep in ( - ('testcomma_list', ','), - ('testspace_list', ' '), - ('testspace_list', '\n'), - ('testspace_or_comma_list', ' '), - ('testspace_or_comma_list', '\n'), - ('testspace_or_comma_list', ',')): + 'testspace_or_comma_list', + '', + 'tests space list parsing with comma compatibility', + comma_compat=True) + + for name, sep in (('testcomma_list', ','), ('testspace_list', + ' '), ('testspace_list', '\n'), + ('testspace_or_comma_list', + ' '), ('testspace_or_comma_list', + '\n'), ('testspace_or_comma_list', ',')): for lst in lists: argv = ('./program', '--%s=%s' % (name, sep.join(lst))) argv = FLAGS(argv) @@ -640,10 +646,10 @@ class FlagsUnitTest(absltest.TestCase): # Test help text flags_help = str(FLAGS) - self.assertNotEqual(flags_help.find('repeat'), -1, - 'cannot find flag in help') - self.assertNotEqual(flags_help.find(repeat_help), -1, - 'cannot find help string in help') + self.assertNotEqual( + flags_help.find('repeat'), -1, 'cannot find flag in help') + self.assertNotEqual( + flags_help.find(repeat_help), -1, 'cannot find help string in help') # Test flag specified twice argv = ('./program', '--repeat=4', '--repeat=2', '--debug', '--nodebug') @@ -652,16 +658,20 @@ class FlagsUnitTest(absltest.TestCase): self.assertEqual(FLAGS.get_flag_value('debug', None), 0) # Test MultiFlag with single default value - flags.DEFINE_multi_string('s_str', 'sing1', - 'string option that can occur multiple times', - short_name='s') + flags.DEFINE_multi_string( + 's_str', + 'sing1', + 'string option that can occur multiple times', + short_name='s') self.assertEqual(FLAGS.get_flag_value('s_str', None), ['sing1']) # Test MultiFlag with list of default values multi_string_defs = ['def1', 'def2'] - flags.DEFINE_multi_string('m_str', multi_string_defs, - 'string option that can occur multiple times', - short_name='m') + flags.DEFINE_multi_string( + 'm_str', + multi_string_defs, + 'string option that can occur multiple times', + short_name='m') self.assertEqual(FLAGS.get_flag_value('m_str', None), multi_string_defs) # Test flag specified multiple times with a MultiFlag @@ -677,12 +687,11 @@ class FlagsUnitTest(absltest.TestCase): # A flag with allow_overwrite set to False should complain when it is # specified more than once - argv = ('./program', '--universe=ptolemaic', - '--universe=copernicean', '--universe=euclidean') + argv = ('./program', '--universe=ptolemaic', '--universe=copernicean', + '--universe=euclidean') self.assertRaisesWithLiteralMatch( flags.IllegalFlagValueError, - 'flag --universe=copernicean: already defined as ptolemaic', - FLAGS, + 'flag --universe=copernicean: already defined as ptolemaic', FLAGS, argv) # Test single-letter flags; should support both single and double dash @@ -709,9 +718,7 @@ class FlagsUnitTest(absltest.TestCase): old_testspace_list = FLAGS.testspace_list old_testspace_or_comma_list = FLAGS.testspace_or_comma_list - argv = ('./program', - FLAGS['test0'].serialize(), - FLAGS['test1'].serialize(), + argv = ('./program', FLAGS['test0'].serialize(), FLAGS['test1'].serialize(), FLAGS['s_str'].serialize()) argv = FLAGS(argv) @@ -727,8 +734,7 @@ class FlagsUnitTest(absltest.TestCase): FLAGS.testcomma_list = list(testcomma_list1) FLAGS.testspace_list = list(testspace_list1) FLAGS.testspace_or_comma_list = list(testspace_or_comma_list1) - argv = ('./program', - FLAGS['testcomma_list'].serialize(), + argv = ('./program', FLAGS['testcomma_list'].serialize(), FLAGS['testspace_list'].serialize(), FLAGS['testspace_or_comma_list'].serialize()) argv = FLAGS(argv) @@ -742,8 +748,7 @@ class FlagsUnitTest(absltest.TestCase): FLAGS.testcomma_list = list(testcomma_list1) FLAGS.testspace_list = list(testspace_list1) FLAGS.testspace_or_comma_list = list(testspace_or_comma_list1) - argv = ('./program', - FLAGS['testcomma_list'].serialize(), + argv = ('./program', FLAGS['testcomma_list'].serialize(), FLAGS['testspace_list'].serialize(), FLAGS['testspace_or_comma_list'].serialize()) argv = FLAGS(argv) @@ -767,18 +772,17 @@ class FlagsUnitTest(absltest.TestCase): flags_to_exclude = {'log_dir', 'test_srcdir', 'test_tmpdir'} flagnames = set(FLAGS) - flags_to_exclude - nonbool_flags = ['--%s %s' % (name, FLAGS.get_flag_value(name, None)) - for name in flagnames - if not isinstance(FLAGS[name], flags.BooleanFlag)] - - truebool_flags = ['--%s' % (name) - for name in flagnames - if isinstance(FLAGS[name], flags.BooleanFlag) and - FLAGS.get_flag_value(name, None)] - falsebool_flags = ['--no%s' % (name) - for name in flagnames - if isinstance(FLAGS[name], flags.BooleanFlag) and - not FLAGS.get_flag_value(name, None)] + nonbool_flags = [] + truebool_flags = [] + falsebool_flags = [] + for name in flagnames: + flag_value = FLAGS.get_flag_value(name, None) + if not isinstance(FLAGS[name], flags.BooleanFlag): + nonbool_flags.append('--%s %s' % (name, flag_value)) + elif flag_value: + truebool_flags.append('--%s' % name) + else: + falsebool_flags.append('--no%s' % name) all_flags = nonbool_flags + truebool_flags + falsebool_flags all_flags.sort() return all_flags @@ -953,8 +957,11 @@ class FlagsUnitTest(absltest.TestCase): # correctly. flagnames = ['repeated'] original_flags = flags.FlagValues() - flags.DEFINE_boolean(flagnames[0], False, 'Flag about to be repeated.', - flag_values=original_flags) + flags.DEFINE_boolean( + flagnames[0], + False, + 'Flag about to be repeated.', + flag_values=original_flags) duplicate_flags = module_foo.duplicate_flags(flagnames) with self.assertRaisesRegex(flags.DuplicateFlagError, 'flags_test.*module_foo'): @@ -962,13 +969,13 @@ class FlagsUnitTest(absltest.TestCase): # Make sure allow_override works try: - flags.DEFINE_boolean('dup1', 0, 'runhelp d11', short_name='u', - allow_override=0) + flags.DEFINE_boolean( + 'dup1', 0, 'runhelp d11', short_name='u', allow_override=0) flag = FLAGS._flags()['dup1'] self.assertEqual(flag.default, 0) - flags.DEFINE_boolean('dup1', 1, 'runhelp d12', short_name='u', - allow_override=1) + flags.DEFINE_boolean( + 'dup1', 1, 'runhelp d12', short_name='u', allow_override=1) flag = FLAGS._flags()['dup1'] self.assertEqual(flag.default, 1) except flags.DuplicateFlagError: @@ -976,13 +983,13 @@ class FlagsUnitTest(absltest.TestCase): # Make sure allow_override works try: - flags.DEFINE_boolean('dup2', 0, 'runhelp d21', short_name='u', - allow_override=1) + flags.DEFINE_boolean( + 'dup2', 0, 'runhelp d21', short_name='u', allow_override=1) flag = FLAGS._flags()['dup2'] self.assertEqual(flag.default, 0) - flags.DEFINE_boolean('dup2', 1, 'runhelp d22', short_name='u', - allow_override=0) + flags.DEFINE_boolean( + 'dup2', 1, 'runhelp d22', short_name='u', allow_override=0) flag = FLAGS._flags()['dup2'] self.assertEqual(flag.default, 1) except flags.DuplicateFlagError: @@ -991,18 +998,17 @@ class FlagsUnitTest(absltest.TestCase): # Make sure that re-importing a module does not cause a DuplicateFlagError # to be raised. try: - sys.modules.pop( - 'absl.flags.tests.module_baz') + sys.modules.pop('absl.flags.tests.module_baz') import absl.flags.tests.module_baz del absl except flags.DuplicateFlagError: raise AssertionError('Module reimport caused flag duplication error') # Make sure that when we override, the help string gets updated correctly - flags.DEFINE_boolean('dup3', 0, 'runhelp d31', short_name='u', - allow_override=1) - flags.DEFINE_boolean('dup3', 1, 'runhelp d32', short_name='u', - allow_override=1) + flags.DEFINE_boolean( + 'dup3', 0, 'runhelp d31', short_name='u', allow_override=1) + flags.DEFINE_boolean( + 'dup3', 1, 'runhelp d32', short_name='u', allow_override=1) self.assertEqual(str(FLAGS).find('runhelp d31'), -1) self.assertNotEqual(str(FLAGS).find('runhelp d32'), -1) @@ -1026,8 +1032,8 @@ class FlagsUnitTest(absltest.TestCase): # Make sure append_flag_values works with flags with shortnames. new_flags = flags.FlagValues() flags.DEFINE_boolean('new3', 0, 'runhelp n3', flag_values=new_flags) - flags.DEFINE_boolean('new4', 0, 'runhelp n4', flag_values=new_flags, - short_name='n4') + flags.DEFINE_boolean( + 'new4', 0, 'runhelp n4', flag_values=new_flags, short_name='n4') self.assertEqual(len(new_flags._flags()), 3) old_len = len(FLAGS._flags()) FLAGS.append_flag_values(new_flags) @@ -1101,18 +1107,12 @@ class FlagsUnitTest(absltest.TestCase): flags.DEFINE_alias('alias_letters', 'letters') self.assertEqual(FLAGS['name'].default, FLAGS.alias_name) self.assertEqual(FLAGS['debug'].default, FLAGS.alias_debug) - self.assertEqual( - int(FLAGS['decimal'].default), FLAGS.alias_decimal) - self.assertEqual( - float(FLAGS['float'].default), FLAGS.alias_float) - self.assertSameElements( - FLAGS['letters'].default, FLAGS.alias_letters) + self.assertEqual(int(FLAGS['decimal'].default), FLAGS.alias_decimal) + self.assertEqual(float(FLAGS['float'].default), FLAGS.alias_float) + self.assertSameElements(FLAGS['letters'].default, FLAGS.alias_letters) # Original flags set on comand line - argv = ('./program', - '--name=Martin', - '--debug=True', - '--decimal=777', + argv = ('./program', '--name=Martin', '--debug=True', '--decimal=777', '--letters=x,y,z') FLAGS(argv) self.assertEqual('Martin', FLAGS.name) @@ -1125,11 +1125,8 @@ class FlagsUnitTest(absltest.TestCase): self.assertSameElements(['x', 'y', 'z'], FLAGS.alias_letters) # Alias flags set on command line - argv = ('./program', - '--alias_name=Auston', - '--alias_debug=False', - '--alias_decimal=888', - '--alias_letters=l,m,n') + argv = ('./program', '--alias_name=Auston', '--alias_debug=False', + '--alias_decimal=888', '--alias_letters=l,m,n') FLAGS(argv) self.assertEqual('Auston', FLAGS.name) self.assertEqual('Auston', FLAGS.alias_name) @@ -1142,36 +1139,36 @@ class FlagsUnitTest(absltest.TestCase): # Make sure importing a module does not change flag value parsed # from commandline. - flags.DEFINE_integer('dup5', 1, 'runhelp d51', short_name='u5', - allow_override=0) + flags.DEFINE_integer( + 'dup5', 1, 'runhelp d51', short_name='u5', allow_override=0) self.assertEqual(FLAGS.dup5, 1) self.assertEqual(FLAGS.dup5, 1) argv = ('./program', '--dup5=3') FLAGS(argv) self.assertEqual(FLAGS.dup5, 3) - flags.DEFINE_integer('dup5', 2, 'runhelp d52', short_name='u5', - allow_override=1) + flags.DEFINE_integer( + 'dup5', 2, 'runhelp d52', short_name='u5', allow_override=1) self.assertEqual(FLAGS.dup5, 3) # Make sure importing a module does not change user defined flag value. - flags.DEFINE_integer('dup6', 1, 'runhelp d61', short_name='u6', - allow_override=0) + flags.DEFINE_integer( + 'dup6', 1, 'runhelp d61', short_name='u6', allow_override=0) self.assertEqual(FLAGS.dup6, 1) FLAGS.dup6 = 3 self.assertEqual(FLAGS.dup6, 3) - flags.DEFINE_integer('dup6', 2, 'runhelp d62', short_name='u6', - allow_override=1) + flags.DEFINE_integer( + 'dup6', 2, 'runhelp d62', short_name='u6', allow_override=1) self.assertEqual(FLAGS.dup6, 3) # Make sure importing a module does not change user defined flag value # even if it is the 'default' value. - flags.DEFINE_integer('dup7', 1, 'runhelp d71', short_name='u7', - allow_override=0) + flags.DEFINE_integer( + 'dup7', 1, 'runhelp d71', short_name='u7', allow_override=0) self.assertEqual(FLAGS.dup7, 1) FLAGS.dup7 = 1 self.assertEqual(FLAGS.dup7, 1) - flags.DEFINE_integer('dup7', 2, 'runhelp d72', short_name='u7', - allow_override=1) + flags.DEFINE_integer( + 'dup7', 2, 'runhelp d72', short_name='u7', allow_override=1) self.assertEqual(FLAGS.dup7, 1) # Test module_help(). @@ -1360,55 +1357,80 @@ class FlagsUnitTest(absltest.TestCase): def test_enum_class_flag_with_wrong_default_value_type(self): fv = flags.FlagValues() with self.assertRaises(_exceptions.IllegalFlagValueError): - flags.DEFINE_enum_class('fruit', 1, Fruit, 'help', - flag_values=fv) + flags.DEFINE_enum_class('fruit', 1, Fruit, 'help', flag_values=fv) def test_enum_class_flag_requires_enum_class(self): fv = flags.FlagValues() with self.assertRaises(TypeError): - flags.DEFINE_enum_class('fruit', None, ['apple', 'orange'], 'help', - flag_values=fv) + flags.DEFINE_enum_class( + 'fruit', None, ['apple', 'orange'], 'help', flag_values=fv) def test_enum_class_flag_requires_non_empty_enum_class(self): fv = flags.FlagValues() with self.assertRaises(ValueError): - flags.DEFINE_enum_class('empty', None, EmptyEnum, 'help', - flag_values=fv) + flags.DEFINE_enum_class('empty', None, EmptyEnum, 'help', flag_values=fv) + + def test_required_flag(self): + fv = flags.FlagValues() + fl = flags.DEFINE_integer( + name='int_flag', + default=None, + help='help', + required=True, + flag_values=fv) + # Since the flag is required, the FlagHolder should ensure value returned + # is not None. + self.assertTrue(fl._ensure_non_none_value) + + def test_illegal_required_flag(self): + fv = flags.FlagValues() + with self.assertRaises(ValueError): + flags.DEFINE_integer( + name='int_flag', + default=3, + help='help', + required=True, + flag_values=fv) class MultiNumericalFlagsTest(absltest.TestCase): def test_multi_numerical_flags(self): """Test multi_int and multi_float flags.""" - + fv = flags.FlagValues() int_defaults = [77, 88] - flags.DEFINE_multi_integer('m_int', int_defaults, - 'integer option that can occur multiple times', - short_name='mi') - self.assertListEqual(FLAGS.get_flag_value('m_int', None), int_defaults) + flags.DEFINE_multi_integer( + 'm_int', + int_defaults, + 'integer option that can occur multiple times', + short_name='mi', + flag_values=fv) + self.assertListEqual(fv['m_int'].default, int_defaults) argv = ('./program', '--m_int=-99', '--mi=101') - FLAGS(argv) - self.assertListEqual(FLAGS.get_flag_value('m_int', None), [-99, 101]) + fv(argv) + self.assertListEqual(fv.get_flag_value('m_int', None), [-99, 101]) float_defaults = [2.2, 3] - flags.DEFINE_multi_float('m_float', float_defaults, - 'float option that can occur multiple times', - short_name='mf') - for (expected, actual) in zip( - float_defaults, FLAGS.get_flag_value('m_float', None)): + flags.DEFINE_multi_float( + 'm_float', + float_defaults, + 'float option that can occur multiple times', + short_name='mf', + flag_values=fv) + for (expected, actual) in zip(float_defaults, + fv.get_flag_value('m_float', None)): self.assertAlmostEqual(expected, actual) argv = ('./program', '--m_float=-17', '--mf=2.78e9') - FLAGS(argv) + fv(argv) expected_floats = [-17.0, 2.78e9] - for (expected, actual) in zip( - expected_floats, FLAGS.get_flag_value('m_float', None)): + for (expected, actual) in zip(expected_floats, + fv.get_flag_value('m_float', None)): self.assertAlmostEqual(expected, actual) def test_multi_numerical_with_tuples(self): """Verify multi_int/float accept tuples as default values.""" flags.DEFINE_multi_integer( - 'm_int_tuple', - (77, 88), + 'm_int_tuple', (77, 88), 'integer option that can occur multiple times', short_name='mi_tuple') self.assertListEqual(FLAGS.get_flag_value('m_int_tuple', None), [77, 88]) @@ -1448,79 +1470,90 @@ class MultiNumericalFlagsTest(absltest.TestCase): flags.DEFINE_multi_integer, 'm_int2', ['abc'], 'desc') self.assertRaisesRegex( - flags.IllegalFlagValueError, - r'flag --m_float2=abc: ' + flags.IllegalFlagValueError, r'flag --m_float2=abc: ' r'(invalid literal for float\(\)|could not convert string to float): ' - r"'?abc'?", - flags.DEFINE_multi_float, 'm_float2', ['abc'], 'desc') + r"'?abc'?", flags.DEFINE_multi_float, 'm_float2', ['abc'], 'desc') # Test non-parseable command line values. - flags.DEFINE_multi_integer('m_int2', '77', - 'integer option that can occur multiple times') + fv = flags.FlagValues() + flags.DEFINE_multi_integer( + 'm_int2', + '77', + 'integer option that can occur multiple times', + flag_values=fv) argv = ('./program', '--m_int2=def') self.assertRaisesRegex( flags.IllegalFlagValueError, r"flag --m_int2=def: invalid literal for int\(\) with base 10: 'def'", - FLAGS, argv) + fv, argv) - flags.DEFINE_multi_float('m_float2', 2.2, - 'float option that can occur multiple times') + flags.DEFINE_multi_float( + 'm_float2', + 2.2, + 'float option that can occur multiple times', + flag_values=fv) argv = ('./program', '--m_float2=def') self.assertRaisesRegex( - flags.IllegalFlagValueError, - r'flag --m_float2=def: ' + flags.IllegalFlagValueError, r'flag --m_float2=def: ' r'(invalid literal for float\(\)|could not convert string to float): ' - r"'?def'?", - FLAGS, argv) + r"'?def'?", fv, argv) class MultiEnumFlagsTest(absltest.TestCase): def test_multi_enum_flags(self): """Test multi_enum flags.""" + fv = flags.FlagValues() enum_defaults = ['FOO', 'BAZ'] - flags.DEFINE_multi_enum('m_enum', enum_defaults, - ['FOO', 'BAR', 'BAZ', 'WHOOSH'], - 'Enum option that can occur multiple times', - short_name='me') - self.assertListEqual(FLAGS.get_flag_value('m_enum', None), enum_defaults) + flags.DEFINE_multi_enum( + 'm_enum', + enum_defaults, ['FOO', 'BAR', 'BAZ', 'WHOOSH'], + 'Enum option that can occur multiple times', + short_name='me', + flag_values=fv) + self.assertListEqual(fv['m_enum'].default, enum_defaults) argv = ('./program', '--m_enum=WHOOSH', '--me=FOO') - FLAGS(argv) - self.assertListEqual( - FLAGS.get_flag_value('m_enum', None), ['WHOOSH', 'FOO']) + fv(argv) + self.assertListEqual(fv.get_flag_value('m_enum', None), ['WHOOSH', 'FOO']) def test_single_value_default(self): """Test multi_enum flags with a single default value.""" + fv = flags.FlagValues() enum_default = 'FOO' - flags.DEFINE_multi_enum('m_enum1', enum_default, - ['FOO', 'BAR', 'BAZ', 'WHOOSH'], - 'enum option that can occur multiple times') - self.assertListEqual(FLAGS.get_flag_value('m_enum1', None), [enum_default]) + flags.DEFINE_multi_enum( + 'm_enum1', + enum_default, ['FOO', 'BAR', 'BAZ', 'WHOOSH'], + 'enum option that can occur multiple times', + flag_values=fv) + self.assertListEqual(fv['m_enum1'].default, [enum_default]) def test_case_sensitivity(self): """Test case sensitivity of multi_enum flag.""" + fv = flags.FlagValues() # Test case insensitive enum. - flags.DEFINE_multi_enum('m_enum2', ['whoosh'], - ['FOO', 'BAR', 'BAZ', 'WHOOSH'], - 'Enum option that can occur multiple times', - short_name='me2', - case_sensitive=False) + flags.DEFINE_multi_enum( + 'm_enum2', ['whoosh'], ['FOO', 'BAR', 'BAZ', 'WHOOSH'], + 'Enum option that can occur multiple times', + short_name='me2', + case_sensitive=False, + flag_values=fv) argv = ('./program', '--m_enum2=bar', '--me2=fOo') - FLAGS(argv) - self.assertListEqual(FLAGS.get_flag_value('m_enum2', None), ['BAR', 'FOO']) + fv(argv) + self.assertListEqual(fv.get_flag_value('m_enum2', None), ['BAR', 'FOO']) # Test case sensitive enum. - flags.DEFINE_multi_enum('m_enum3', ['BAR'], - ['FOO', 'BAR', 'BAZ', 'WHOOSH'], - 'Enum option that can occur multiple times', - short_name='me3', - case_sensitive=True) + flags.DEFINE_multi_enum( + 'm_enum3', ['BAR'], ['FOO', 'BAR', 'BAZ', 'WHOOSH'], + 'Enum option that can occur multiple times', + short_name='me3', + case_sensitive=True, + flag_values=fv) argv = ('./program', '--m_enum3=bar', '--me3=fOo') self.assertRaisesRegex( flags.IllegalFlagValueError, r'flag --m_enum3=invalid: value should be one of <FOO|BAR|BAZ|WHOOSH>', - FLAGS, argv) + fv, argv) def test_bad_multi_enum_flags(self): """Test multi_enum with invalid values.""" @@ -1529,24 +1562,23 @@ class MultiEnumFlagsTest(absltest.TestCase): self.assertRaisesRegex( flags.IllegalFlagValueError, r'flag --m_enum=INVALID: value should be one of <FOO|BAR|BAZ>', - flags.DEFINE_multi_enum, 'm_enum', ['INVALID'], - ['FOO', 'BAR', 'BAZ'], 'desc') + flags.DEFINE_multi_enum, 'm_enum', ['INVALID'], ['FOO', 'BAR', 'BAZ'], + 'desc') self.assertRaisesRegex( flags.IllegalFlagValueError, r'flag --m_enum=1234: value should be one of <FOO|BAR|BAZ>', - flags.DEFINE_multi_enum, 'm_enum2', [1234], - ['FOO', 'BAR', 'BAZ'], 'desc') + flags.DEFINE_multi_enum, 'm_enum2', [1234], ['FOO', 'BAR', 'BAZ'], + 'desc') # Test command-line values that are not in the permitted list of enums. - flags.DEFINE_multi_enum('m_enum4', 'FOO', - ['FOO', 'BAR', 'BAZ'], + flags.DEFINE_multi_enum('m_enum4', 'FOO', ['FOO', 'BAR', 'BAZ'], 'enum option that can occur multiple times') argv = ('./program', '--m_enum4=INVALID') self.assertRaisesRegex( flags.IllegalFlagValueError, - r'flag --m_enum4=invalid: value should be one of <FOO|BAR|BAZ>', - FLAGS, argv) + r'flag --m_enum4=invalid: value should be one of <FOO|BAR|BAZ>', FLAGS, + argv) class MultiEnumClassFlagsTest(absltest.TestCase): @@ -1554,10 +1586,12 @@ class MultiEnumClassFlagsTest(absltest.TestCase): def test_define_results_in_registered_flag_with_none(self): fv = flags.FlagValues() enum_defaults = None - flags.DEFINE_multi_enum_class('fruit', - enum_defaults, Fruit, - 'Enum option that can occur multiple times', - flag_values=fv) + flags.DEFINE_multi_enum_class( + 'fruit', + enum_defaults, + Fruit, + 'Enum option that can occur multiple times', + flag_values=fv) fv.mark_as_parsed() self.assertIsNone(fv.fruit) @@ -1565,10 +1599,12 @@ class MultiEnumClassFlagsTest(absltest.TestCase): def test_define_results_in_registered_flag_with_string(self): fv = flags.FlagValues() enum_defaults = 'apple' - flags.DEFINE_multi_enum_class('fruit', - enum_defaults, Fruit, - 'Enum option that can occur multiple times', - flag_values=fv) + flags.DEFINE_multi_enum_class( + 'fruit', + enum_defaults, + Fruit, + 'Enum option that can occur multiple times', + flag_values=fv) fv.mark_as_parsed() self.assertListEqual(fv.fruit, [Fruit.APPLE]) @@ -1576,10 +1612,12 @@ class MultiEnumClassFlagsTest(absltest.TestCase): def test_define_results_in_registered_flag_with_enum(self): fv = flags.FlagValues() enum_defaults = Fruit.APPLE - flags.DEFINE_multi_enum_class('fruit', - enum_defaults, Fruit, - 'Enum option that can occur multiple times', - flag_values=fv) + flags.DEFINE_multi_enum_class( + 'fruit', + enum_defaults, + Fruit, + 'Enum option that can occur multiple times', + flag_values=fv) fv.mark_as_parsed() self.assertListEqual(fv.fruit, [Fruit.APPLE]) @@ -1602,10 +1640,12 @@ class MultiEnumClassFlagsTest(absltest.TestCase): def test_define_results_in_registered_flag_with_enum_list(self): fv = flags.FlagValues() enum_defaults = [Fruit.APPLE, Fruit.ORANGE] - flags.DEFINE_multi_enum_class('fruit', - enum_defaults, Fruit, - 'Enum option that can occur multiple times', - flag_values=fv) + flags.DEFINE_multi_enum_class( + 'fruit', + enum_defaults, + Fruit, + 'Enum option that can occur multiple times', + flag_values=fv) fv.mark_as_parsed() self.assertListEqual(fv.fruit, [Fruit.APPLE, Fruit.ORANGE]) @@ -1613,10 +1653,12 @@ class MultiEnumClassFlagsTest(absltest.TestCase): def test_from_command_line_returns_multiple(self): fv = flags.FlagValues() enum_defaults = [Fruit.APPLE] - flags.DEFINE_multi_enum_class('fruit', - enum_defaults, Fruit, - 'Enum option that can occur multiple times', - flag_values=fv) + flags.DEFINE_multi_enum_class( + 'fruit', + enum_defaults, + Fruit, + 'Enum option that can occur multiple times', + flag_values=fv) argv = ('./program', '--fruit=Apple', '--fruit=orange') fv(argv) self.assertListEqual(fv.fruit, [Fruit.APPLE, Fruit.ORANGE]) @@ -1630,8 +1672,8 @@ class MultiEnumClassFlagsTest(absltest.TestCase): def test_bad_multi_enum_class_flags_from_commandline(self): fv = flags.FlagValues() enum_defaults = [Fruit.APPLE] - flags.DEFINE_multi_enum_class('fruit', enum_defaults, Fruit, 'desc', - flag_values=fv) + flags.DEFINE_multi_enum_class( + 'fruit', enum_defaults, Fruit, 'desc', flag_values=fv) argv = ('./program', '--fruit=INVALID') with self.assertRaisesRegex( flags.IllegalFlagValueError, @@ -1643,65 +1685,81 @@ class UnicodeFlagsTest(absltest.TestCase): """Testing proper unicode support for flags.""" def test_unicode_default_and_helpstring(self): - flags.DEFINE_string('unicode_str', b'\xC3\x80\xC3\xBD'.decode('utf-8'), - b'help:\xC3\xAA'.decode('utf-8')) + fv = flags.FlagValues() + flags.DEFINE_string( + 'unicode_str', + b'\xC3\x80\xC3\xBD'.decode('utf-8'), + b'help:\xC3\xAA'.decode('utf-8'), + flag_values=fv) argv = ('./program',) - FLAGS(argv) # should not raise any exceptions + fv(argv) # should not raise any exceptions argv = ('./program', '--unicode_str=foo') - FLAGS(argv) # should not raise any exceptions + fv(argv) # should not raise any exceptions def test_unicode_in_list(self): - flags.DEFINE_list('unicode_list', ['abc', b'\xC3\x80'.decode('utf-8'), - b'\xC3\xBD'.decode('utf-8')], - b'help:\xC3\xAB'.decode('utf-8')) + fv = flags.FlagValues() + flags.DEFINE_list( + 'unicode_list', + ['abc', b'\xC3\x80'.decode('utf-8'), b'\xC3\xBD'.decode('utf-8')], + b'help:\xC3\xAB'.decode('utf-8'), + flag_values=fv) argv = ('./program',) - FLAGS(argv) # should not raise any exceptions + fv(argv) # should not raise any exceptions argv = ('./program', '--unicode_list=hello,there') - FLAGS(argv) # should not raise any exceptions + fv(argv) # should not raise any exceptions def test_xmloutput(self): - flags.DEFINE_string('unicode1', b'\xC3\x80\xC3\xBD'.decode('utf-8'), - b'help:\xC3\xAC'.decode('utf-8')) - flags.DEFINE_list('unicode2', ['abc', b'\xC3\x80'.decode('utf-8'), - b'\xC3\xBD'.decode('utf-8')], - b'help:\xC3\xAD'.decode('utf-8')) - flags.DEFINE_list('non_unicode', ['abc', 'def', 'ghi'], - b'help:\xC3\xAD'.decode('utf-8')) + fv = flags.FlagValues() + flags.DEFINE_string( + 'unicode1', + b'\xC3\x80\xC3\xBD'.decode('utf-8'), + b'help:\xC3\xAC'.decode('utf-8'), + flag_values=fv) + flags.DEFINE_list( + 'unicode2', + ['abc', b'\xC3\x80'.decode('utf-8'), b'\xC3\xBD'.decode('utf-8')], + b'help:\xC3\xAD'.decode('utf-8'), + flag_values=fv) + flags.DEFINE_list( + 'non_unicode', ['abc', 'def', 'ghi'], + b'help:\xC3\xAD'.decode('utf-8'), + flag_values=fv) outfile = io.StringIO() if six.PY3 else io.BytesIO() - FLAGS.write_help_in_xml_format(outfile) + fv.write_help_in_xml_format(outfile) actual_output = outfile.getvalue() if six.PY2: actual_output = actual_output.decode('utf-8') # The xml output is large, so we just check parts of it. - self.assertIn(b'<name>unicode1</name>\n' - b' <meaning>help:\xc3\xac</meaning>\n' - b' <default>\xc3\x80\xc3\xbd</default>\n' - b' <current>\xc3\x80\xc3\xbd</current>'.decode('utf-8'), - actual_output) + self.assertIn( + b'<name>unicode1</name>\n' + b' <meaning>help:\xc3\xac</meaning>\n' + b' <default>\xc3\x80\xc3\xbd</default>\n' + b' <current>\xc3\x80\xc3\xbd</current>'.decode('utf-8'), + actual_output) if six.PY2: - self.assertIn(b'<name>unicode2</name>\n' - b' <meaning>help:\xc3\xad</meaning>\n' - b' <default>abc,\xc3\x80,\xc3\xbd</default>\n' - b" <current>['abc', u'\\xc0', u'\\xfd']" - b'</current>'.decode('utf-8'), - actual_output) + self.assertIn( + b'<name>unicode2</name>\n' + b' <meaning>help:\xc3\xad</meaning>\n' + b' <default>abc,\xc3\x80,\xc3\xbd</default>\n' + b" <current>['abc', u'\\xc0', u'\\xfd']" + b'</current>'.decode('utf-8'), actual_output) else: - self.assertIn(b'<name>unicode2</name>\n' - b' <meaning>help:\xc3\xad</meaning>\n' - b' <default>abc,\xc3\x80,\xc3\xbd</default>\n' - b" <current>['abc', '\xc3\x80', '\xc3\xbd']" - b'</current>'.decode('utf-8'), - actual_output) - self.assertIn(b'<name>non_unicode</name>\n' - b' <meaning>help:\xc3\xad</meaning>\n' - b' <default>abc,def,ghi</default>\n' - b" <current>['abc', 'def', 'ghi']" - b'</current>'.decode('utf-8'), - actual_output) + self.assertIn( + b'<name>unicode2</name>\n' + b' <meaning>help:\xc3\xad</meaning>\n' + b' <default>abc,\xc3\x80,\xc3\xbd</default>\n' + b" <current>['abc', '\xc3\x80', '\xc3\xbd']" + b'</current>'.decode('utf-8'), actual_output) + self.assertIn( + b'<name>non_unicode</name>\n' + b' <meaning>help:\xc3\xad</meaning>\n' + b' <default>abc,def,ghi</default>\n' + b" <current>['abc', 'def', 'ghi']" + b'</current>'.decode('utf-8'), actual_output) class LoadFromFlagFileTest(absltest.TestCase): @@ -1709,16 +1767,29 @@ class LoadFromFlagFileTest(absltest.TestCase): def setUp(self): self.flag_values = flags.FlagValues() - flags.DEFINE_string('unittest_message1', 'Foo!', 'You Add Here.', - flag_values=self.flag_values) - flags.DEFINE_string('unittest_message2', 'Bar!', 'Hello, Sailor!', - flag_values=self.flag_values) - flags.DEFINE_boolean('unittest_boolflag', 0, 'Some Boolean thing', - flag_values=self.flag_values) - flags.DEFINE_integer('unittest_number', 12345, 'Some integer', - lower_bound=0, flag_values=self.flag_values) - flags.DEFINE_list('UnitTestList', '1,2,3', 'Some list', - flag_values=self.flag_values) + flags.DEFINE_string( + 'unittest_message1', + 'Foo!', + 'You Add Here.', + flag_values=self.flag_values) + flags.DEFINE_string( + 'unittest_message2', + 'Bar!', + 'Hello, Sailor!', + flag_values=self.flag_values) + flags.DEFINE_boolean( + 'unittest_boolflag', + 0, + 'Some Boolean thing', + flag_values=self.flag_values) + flags.DEFINE_integer( + 'unittest_number', + 12345, + 'Some integer', + lower_bound=0, + flag_values=self.flag_values) + flags.DEFINE_list( + 'UnitTestList', '1,2,3', 'Some list', flag_values=self.flag_values) self.tmp_path = None self.flag_values.mark_as_parsed() @@ -1788,8 +1859,8 @@ class LoadFromFlagFileTest(absltest.TestCase): fake_argv = fake_cmd_line.split(' ') self.flag_values(fake_argv) self.assertEqual(self.flag_values.unittest_boolflag, 1) - self.assertListEqual( - fake_argv, self._read_flags_from_files(fake_argv, False)) + self.assertListEqual(fake_argv, + self._read_flags_from_files(fake_argv, False)) def test_method_flagfiles_2(self): """Tests parsing one file + arguments off simulated argv.""" @@ -1801,32 +1872,31 @@ class LoadFromFlagFileTest(absltest.TestCase): # We should see the original cmd line with the file's contents spliced in. # Flags from the file will appear in the order order they are sepcified # in the file, in the same position as the flagfile argument. - expected_results = ['fooScript', - '--q', - '--unittest_message1=tempFile1!', - '--unittest_number=54321', - '--nounittest_boolflag'] + expected_results = [ + 'fooScript', '--q', '--unittest_message1=tempFile1!', + '--unittest_number=54321', '--nounittest_boolflag' + ] test_results = self._read_flags_from_files(fake_argv, False) self.assertListEqual(expected_results, test_results) + # end testTwo def def test_method_flagfiles_3(self): """Tests parsing nested files + arguments of simulated argv.""" tmp_files = self._setup_test_files() # specify our temp file on the fake cmd line - fake_cmd_line = ('fooScript --unittest_number=77 --flagfile=%s' - % tmp_files[1]) + fake_cmd_line = ('fooScript --unittest_number=77 --flagfile=%s' % + tmp_files[1]) fake_argv = fake_cmd_line.split(' ') - expected_results = ['fooScript', - '--unittest_number=77', - '--unittest_message1=tempFile1!', - '--unittest_number=54321', - '--nounittest_boolflag', - '--unittest_message2=setFromTempFile2', - '--unittest_number=6789a'] + expected_results = [ + 'fooScript', '--unittest_number=77', '--unittest_message1=tempFile1!', + '--unittest_number=54321', '--nounittest_boolflag', + '--unittest_message2=setFromTempFile2', '--unittest_number=6789a' + ] test_results = self._read_flags_from_files(fake_argv, False) self.assertListEqual(expected_results, test_results) + # end testThree def def test_method_flagfiles_3_spaces(self): @@ -1837,18 +1907,16 @@ class LoadFromFlagFileTest(absltest.TestCase): """ tmp_files = self._setup_test_files() # specify our temp file on the fake cmd line - fake_cmd_line = ('fooScript --unittest_number 77 --flagfile=%s' - % tmp_files[1]) + fake_cmd_line = ('fooScript --unittest_number 77 --flagfile=%s' % + tmp_files[1]) fake_argv = fake_cmd_line.split(' ') - expected_results = ['fooScript', - '--unittest_number', - '77', - '--unittest_message1=tempFile1!', - '--unittest_number=54321', - '--nounittest_boolflag', - '--unittest_message2=setFromTempFile2', - '--unittest_number=6789a'] + expected_results = [ + 'fooScript', '--unittest_number', '77', + '--unittest_message1=tempFile1!', '--unittest_number=54321', + '--nounittest_boolflag', '--unittest_message2=setFromTempFile2', + '--unittest_number=6789a' + ] test_results = self._read_flags_from_files(fake_argv, False) self.assertListEqual(expected_results, test_results) @@ -1860,14 +1928,14 @@ class LoadFromFlagFileTest(absltest.TestCase): """ tmp_files = self._setup_test_files() # specify our temp file on the fake cmd line - fake_cmd_line = ('fooScript --unittest_boolflag 77 --flagfile=%s' - % tmp_files[1]) + fake_cmd_line = ('fooScript --unittest_boolflag 77 --flagfile=%s' % + tmp_files[1]) fake_argv = fake_cmd_line.split(' ') - expected_results = ['fooScript', - '--unittest_boolflag', - '77', - '--flagfile=%s' % tmp_files[1]] + expected_results = [ + 'fooScript', '--unittest_boolflag', '77', + '--flagfile=%s' % tmp_files[1] + ] with _use_gnu_getopt(self.flag_values, False): test_results = self._read_flags_from_files(fake_argv, False) self.assertListEqual(expected_results, test_results) @@ -1879,13 +1947,13 @@ class LoadFromFlagFileTest(absltest.TestCase): """ tmp_files = self._setup_test_files() # specify our temp file on the fake cmd line - fake_cmd_line = ('fooScript --flagfile=%s --nounittest_boolflag' - % tmp_files[2]) + fake_cmd_line = ('fooScript --flagfile=%s --nounittest_boolflag' % + tmp_files[2]) fake_argv = fake_cmd_line.split(' ') - expected_results = ['fooScript', - '--unittest_message1=setFromTempFile3', - '--unittest_boolflag', - '--nounittest_boolflag'] + expected_results = [ + 'fooScript', '--unittest_message1=setFromTempFile3', + '--unittest_boolflag', '--nounittest_boolflag' + ] test_results = self._read_flags_from_files(fake_argv, False) self.assertListEqual(expected_results, test_results) @@ -1896,10 +1964,10 @@ class LoadFromFlagFileTest(absltest.TestCase): # specify our temp file on the fake cmd line fake_cmd_line = 'fooScript --some_flag -- --flagfile=%s' % tmp_files[0] fake_argv = fake_cmd_line.split(' ') - expected_results = ['fooScript', - '--some_flag', - '--', - '--flagfile=%s' % tmp_files[0]] + expected_results = [ + 'fooScript', '--some_flag', '--', + '--flagfile=%s' % tmp_files[0] + ] test_results = self._read_flags_from_files(fake_argv, False) self.assertListEqual(expected_results, test_results) @@ -1908,13 +1976,13 @@ class LoadFromFlagFileTest(absltest.TestCase): """Test that --flagfile parsing stops at non-options (non-GNU behavior).""" tmp_files = self._setup_test_files() # specify our temp file on the fake cmd line - fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' - % tmp_files[0]) + fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' % + tmp_files[0]) fake_argv = fake_cmd_line.split(' ') - expected_results = ['fooScript', - '--some_flag', - 'some_arg', - '--flagfile=%s' % tmp_files[0]] + expected_results = [ + 'fooScript', '--some_flag', 'some_arg', + '--flagfile=%s' % tmp_files[0] + ] with _use_gnu_getopt(self.flag_values, False): test_results = self._read_flags_from_files(fake_argv, False) @@ -1925,15 +1993,14 @@ class LoadFromFlagFileTest(absltest.TestCase): self.flag_values.set_gnu_getopt() tmp_files = self._setup_test_files() # specify our temp file on the fake cmd line - fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' - % tmp_files[0]) + fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' % + tmp_files[0]) fake_argv = fake_cmd_line.split(' ') - expected_results = ['fooScript', - '--some_flag', - 'some_arg', - '--unittest_message1=tempFile1!', - '--unittest_number=54321', - '--nounittest_boolflag'] + expected_results = [ + 'fooScript', '--some_flag', 'some_arg', + '--unittest_message1=tempFile1!', '--unittest_number=54321', + '--nounittest_boolflag' + ] test_results = self._read_flags_from_files(fake_argv, False) self.assertListEqual(expected_results, test_results) @@ -1942,15 +2009,14 @@ class LoadFromFlagFileTest(absltest.TestCase): """Test that --flagfile parsing respects force_gnu=True.""" tmp_files = self._setup_test_files() # specify our temp file on the fake cmd line - fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' - % tmp_files[0]) + fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' % + tmp_files[0]) fake_argv = fake_cmd_line.split(' ') - expected_results = ['fooScript', - '--some_flag', - 'some_arg', - '--unittest_message1=tempFile1!', - '--unittest_number=54321', - '--nounittest_boolflag'] + expected_results = [ + 'fooScript', '--some_flag', 'some_arg', + '--unittest_message1=tempFile1!', '--unittest_number=54321', + '--nounittest_boolflag' + ] test_results = self._read_flags_from_files(fake_argv, True) self.assertListEqual(expected_results, test_results) @@ -1959,18 +2025,16 @@ class LoadFromFlagFileTest(absltest.TestCase): """Tests that parsing repeated non-circular flagfiles works.""" tmp_files = self._setup_test_files() # specify our temp files on the fake cmd line - fake_cmd_line = ('fooScript --flagfile=%s --flagfile=%s' - % (tmp_files[1], tmp_files[0])) + fake_cmd_line = ('fooScript --flagfile=%s --flagfile=%s' % + (tmp_files[1], tmp_files[0])) fake_argv = fake_cmd_line.split(' ') - expected_results = ['fooScript', - '--unittest_message1=tempFile1!', - '--unittest_number=54321', - '--nounittest_boolflag', - '--unittest_message2=setFromTempFile2', - '--unittest_number=6789a', - '--unittest_message1=tempFile1!', - '--unittest_number=54321', - '--nounittest_boolflag'] + expected_results = [ + 'fooScript', '--unittest_message1=tempFile1!', + '--unittest_number=54321', '--nounittest_boolflag', + '--unittest_message2=setFromTempFile2', '--unittest_number=6789a', + '--unittest_message1=tempFile1!', '--unittest_number=54321', + '--nounittest_boolflag' + ] test_results = self._read_flags_from_files(fake_argv, False) self.assertListEqual(expected_results, test_results) @@ -1982,21 +2046,21 @@ class LoadFromFlagFileTest(absltest.TestCase): """Test that --flagfile raises except on file that is unreadable.""" tmp_files = self._setup_test_files() # specify our temp file on the fake cmd line - fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' - % tmp_files[3]) + fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' % + tmp_files[3]) fake_argv = fake_cmd_line.split(' ') - self.assertRaises(flags.CantOpenFlagFileError, - self._read_flags_from_files, fake_argv, True) + self.assertRaises(flags.CantOpenFlagFileError, self._read_flags_from_files, + fake_argv, True) def test_method_flagfiles_not_found(self): """Test that --flagfile raises except on file that does not exist.""" tmp_files = self._setup_test_files() # specify our temp file on the fake cmd line - fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%sNOTEXIST' - % tmp_files[3]) + fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%sNOTEXIST' % + tmp_files[3]) fake_argv = fake_cmd_line.split(' ') - self.assertRaises(flags.CantOpenFlagFileError, - self._read_flags_from_files, fake_argv, True) + self.assertRaises(flags.CantOpenFlagFileError, self._read_flags_from_files, + fake_argv, True) def test_flagfiles_user_path_expansion(self): """Test that user directory referenced paths are correctly expanded. @@ -2023,8 +2087,10 @@ class LoadFromFlagFileTest(absltest.TestCase): The argumants are not supposed to be flags """ - fake_argv = ['fooScript', '--unittest_boolflag', - 'command', '--command_arg1', '--UnitTestBoom', '--UnitTestB'] + fake_argv = [ + 'fooScript', '--unittest_boolflag', 'command', '--command_arg1', + '--UnitTestBoom', '--UnitTestB' + ] with _use_gnu_getopt(self.flag_values, False): argv = self.flag_values(fake_argv) self.assertListEqual(argv, fake_argv[:1] + fake_argv[2:]) @@ -2032,8 +2098,9 @@ class LoadFromFlagFileTest(absltest.TestCase): def test_parse_flags_after_args_if_using_gnugetopt(self): """Test that flags given after arguments are parsed if using gnu_getopt.""" self.flag_values.set_gnu_getopt() - fake_argv = ['fooScript', '--unittest_boolflag', - 'command', '--unittest_number=54321'] + fake_argv = [ + 'fooScript', '--unittest_boolflag', 'command', '--unittest_number=54321' + ] argv = self.flag_values(fake_argv) self.assertListEqual(argv, ['fooScript', 'command']) @@ -2059,8 +2126,7 @@ class LoadFromFlagFileTest(absltest.TestCase): self.flag_values.set_default('unittest_number', 0) self.assertEqual(self.flag_values['unittest_number'].default, 0) self.assertEqual(self.flag_values.unittest_number, 56) - self.assertEqual( - self.flag_values['unittest_number'].default_as_str, "'0'") + self.assertEqual(self.flag_values['unittest_number'].default_as_str, "'0'") self.flag_values(['dummyscript', '--unittest_number=56']) self.assertEqual(self.flag_values.unittest_number, 56) @@ -2083,8 +2149,7 @@ class LoadFromFlagFileTest(absltest.TestCase): # Test that setting a list default works correctly. self.flag_values.set_default('UnitTestList', '4,5,6') self.assertListEqual(self.flag_values.UnitTestList, ['4', '5', '6']) - self.assertEqual(self.flag_values['UnitTestList'].default_as_str, - "'4,5,6'") + self.assertEqual(self.flag_values['UnitTestList'].default_as_str, "'4,5,6'") self.flag_values(['dummyscript', '--UnitTestList=7,8,9']) self.assertListEqual(self.flag_values.UnitTestList, ['7', '8', '9']) @@ -2102,25 +2167,21 @@ class FlagsParsingTest(absltest.TestCase): self.flag_values = flags.FlagValues() def test_two_dash_arg_first(self): - flags.DEFINE_string('twodash_name', 'Bob', 'namehelp', - flag_values=self.flag_values) - flags.DEFINE_string('twodash_blame', 'Rob', 'blamehelp', - flag_values=self.flag_values) - argv = ('./program', - '--', - '--twodash_name=Harry') + flags.DEFINE_string( + 'twodash_name', 'Bob', 'namehelp', flag_values=self.flag_values) + flags.DEFINE_string( + 'twodash_blame', 'Rob', 'blamehelp', flag_values=self.flag_values) + argv = ('./program', '--', '--twodash_name=Harry') argv = self.flag_values(argv) self.assertEqual('Bob', self.flag_values.twodash_name) self.assertEqual(argv[1], '--twodash_name=Harry') def test_two_dash_arg_middle(self): - flags.DEFINE_string('twodash2_name', 'Bob', 'namehelp', - flag_values=self.flag_values) - flags.DEFINE_string('twodash2_blame', 'Rob', 'blamehelp', - flag_values=self.flag_values) - argv = ('./program', - '--twodash2_blame=Larry', - '--', + flags.DEFINE_string( + 'twodash2_name', 'Bob', 'namehelp', flag_values=self.flag_values) + flags.DEFINE_string( + 'twodash2_blame', 'Rob', 'blamehelp', flag_values=self.flag_values) + argv = ('./program', '--twodash2_blame=Larry', '--', '--twodash2_name=Harry') argv = self.flag_values(argv) self.assertEqual('Bob', self.flag_values.twodash2_name) @@ -2128,19 +2189,42 @@ class FlagsParsingTest(absltest.TestCase): self.assertEqual(argv[1], '--twodash2_name=Harry') def test_one_dash_arg_first(self): - flags.DEFINE_string('onedash_name', 'Bob', 'namehelp', - flag_values=self.flag_values) - flags.DEFINE_string('onedash_blame', 'Rob', 'blamehelp', - flag_values=self.flag_values) - argv = ('./program', - '-', - '--onedash_name=Harry') + flags.DEFINE_string( + 'onedash_name', 'Bob', 'namehelp', flag_values=self.flag_values) + flags.DEFINE_string( + 'onedash_blame', 'Rob', 'blamehelp', flag_values=self.flag_values) + argv = ('./program', '-', '--onedash_name=Harry') with _use_gnu_getopt(self.flag_values, False): argv = self.flag_values(argv) self.assertEqual(len(argv), 3) self.assertEqual(argv[1], '-') self.assertEqual(argv[2], '--onedash_name=Harry') + def test_required_flag_not_specified(self): + flags.DEFINE_string( + 'str_flag', + default=None, + help='help', + required=True, + flag_values=self.flag_values) + argv = ('./program',) + with _use_gnu_getopt(self.flag_values, False): + with self.assertRaises(flags.IllegalFlagValueError): + self.flag_values(argv) + + def test_required_arg_works_with_other_validators(self): + flags.DEFINE_integer( + 'int_flag', + default=None, + help='help', + required=True, + lower_bound=4, + flag_values=self.flag_values) + argv = ('./program', '--int_flag=2') + with _use_gnu_getopt(self.flag_values, False): + with self.assertRaises(flags.IllegalFlagValueError): + self.flag_values(argv) + def test_unrecognized_flags(self): flags.DEFINE_string('name', 'Bob', 'namehelp', flag_values=self.flag_values) # Unknown flag --nosuchflag @@ -2171,16 +2255,16 @@ class FlagsParsingTest(absltest.TestCase): self.assertEqual(e.flagvalue, '--nosuchflagwithparam=foo') # Allow unknown flag --nosuchflag if specified with undefok - argv = ('./program', '--nosuchflag', '--name=Bob', - '--undefok=nosuchflag', 'extra') + argv = ('./program', '--nosuchflag', '--name=Bob', '--undefok=nosuchflag', + 'extra') argv = self.flag_values(argv) self.assertEqual(len(argv), 2, 'wrong number of arguments pulled') self.assertEqual(argv[0], './program', 'program name not preserved') self.assertEqual(argv[1], 'extra', 'extra argument not preserved') # Allow unknown flag --noboolflag if undefok=boolflag is specified - argv = ('./program', '--noboolflag', '--name=Bob', - '--undefok=boolflag', 'extra') + argv = ('./program', '--noboolflag', '--name=Bob', '--undefok=boolflag', + 'extra') argv = self.flag_values(argv) self.assertEqual(len(argv), 2, 'wrong number of arguments pulled') self.assertEqual(argv[0], './program', 'program name not preserved') @@ -2188,8 +2272,8 @@ class FlagsParsingTest(absltest.TestCase): # But not if the flagname is misspelled: try: - argv = ('./program', '--nosuchflag', '--name=Bob', - '--undefok=nosuchfla', 'extra') + argv = ('./program', '--nosuchflag', '--name=Bob', '--undefok=nosuchfla', + 'extra') self.flag_values(argv) raise AssertionError('Unknown flag exception not raised') except flags.UnrecognizedFlagError as e: @@ -2221,9 +2305,7 @@ class FlagsParsingTest(absltest.TestCase): # Even if undefok specifies multiple flags argv = ('./program', '--nosuchflag', '-w', '--nosuchflagwithparam=foo', - '--name=Bob', - '--undefok=nosuchflag,w,nosuchflagwithparam', - 'extra') + '--name=Bob', '--undefok=nosuchflag,w,nosuchflagwithparam', 'extra') argv = self.flag_values(argv) self.assertEqual(len(argv), 2, 'wrong number of arguments pulled') self.assertEqual(argv[0], './program', 'program name not preserved') @@ -2251,9 +2333,7 @@ class FlagsParsingTest(absltest.TestCase): # Test --undefok <list> argv = ('./program', '--nosuchflag', '-w', '--nosuchflagwithparam=foo', - '--name=Bob', - '--undefok', - 'nosuchflag,w,nosuchflagwithparam', + '--name=Bob', '--undefok', 'nosuchflag,w,nosuchflagwithparam', 'extra') argv = self.flag_values(argv) self.assertEqual(len(argv), 2, 'wrong number of arguments pulled') @@ -2272,9 +2352,7 @@ class NonGlobalFlagsTest(absltest.TestCase): """Test use of non-global FlagValues.""" nonglobal_flags = flags.FlagValues() flags.DEFINE_string('nonglobal_flag', 'Bob', 'flaghelp', nonglobal_flags) - argv = ('./program', - '--nonglobal_flag=Mary', - 'extra') + argv = ('./program', '--nonglobal_flag=Mary', 'extra') argv = nonglobal_flags(argv) self.assertEqual(len(argv), 2, 'wrong number of arguments pulled') self.assertEqual(argv[0], './program', 'program name not preserved') @@ -2284,17 +2362,14 @@ class NonGlobalFlagsTest(absltest.TestCase): def test_unrecognized_nonglobal_flags(self): """Test unrecognized non-global flags.""" nonglobal_flags = flags.FlagValues() - argv = ('./program', - '--nosuchflag') + argv = ('./program', '--nosuchflag') try: argv = nonglobal_flags(argv) raise AssertionError('Unknown flag exception not raised') except flags.UnrecognizedFlagError as e: self.assertEqual(e.flagname, 'nosuchflag') - argv = ('./program', - '--nosuchflag', - '--undefok=nosuchflag') + argv = ('./program', '--nosuchflag', '--undefok=nosuchflag') argv = nonglobal_flags(argv) self.assertEqual(len(argv), 1, 'wrong number of arguments pulled') @@ -2317,8 +2392,8 @@ class NonGlobalFlagsTest(absltest.TestCase): default_value = 'default value for test_flag_values_del_attr' # 1. Declare and delete a flag with no short name. flag_values = flags.FlagValues() - flags.DEFINE_string('delattr_foo', default_value, 'A simple flag.', - flag_values=flag_values) + flags.DEFINE_string( + 'delattr_foo', default_value, 'A simple flag.', flag_values=flag_values) flag_values.mark_as_parsed() self.assertEqual(flag_values.delattr_foo, default_value) @@ -2330,15 +2405,19 @@ class NonGlobalFlagsTest(absltest.TestCase): self.assertFalse(flag_values._flag_is_registered(flag_obj)) # If the previous del FLAGS.delattr_foo did not work properly, the # next definition will trigger a redefinition error. - flags.DEFINE_integer('delattr_foo', 3, 'A simple flag.', - flag_values=flag_values) + flags.DEFINE_integer( + 'delattr_foo', 3, 'A simple flag.', flag_values=flag_values) del flag_values.delattr_foo self.assertFalse('delattr_foo' in flag_values) # 2. Declare and delete a flag with a short name. - flags.DEFINE_string('delattr_bar', default_value, 'flag with short name', - short_name='x5', flag_values=flag_values) + flags.DEFINE_string( + 'delattr_bar', + default_value, + 'flag with short name', + short_name='x5', + flag_values=flag_values) flag_obj = flag_values['delattr_bar'] self.assertTrue(flag_values._flag_is_registered(flag_obj)) del flag_values.x5 @@ -2347,8 +2426,12 @@ class NonGlobalFlagsTest(absltest.TestCase): self.assertFalse(flag_values._flag_is_registered(flag_obj)) # 3. Just like 2, but del flag_values.name last - flags.DEFINE_string('delattr_bar', default_value, 'flag with short name', - short_name='x5', flag_values=flag_values) + flags.DEFINE_string( + 'delattr_bar', + default_value, + 'flag with short name', + short_name='x5', + flag_values=flag_values) flag_obj = flag_values['delattr_bar'] self.assertTrue(flag_values._flag_is_registered(flag_obj)) del flag_values.delattr_bar @@ -2361,24 +2444,25 @@ class NonGlobalFlagsTest(absltest.TestCase): def test_list_flag_format(self): """Tests for correctly-formatted list flags.""" - flags.DEFINE_list('listflag', '', 'A list of arguments') + fv = flags.FlagValues() + flags.DEFINE_list('listflag', '', 'A list of arguments', flag_values=fv) def _check_parsing(listval): """Parse a particular value for our test flag, --listflag.""" - argv = FLAGS(['./program', '--listflag=' + listval, 'plain-arg']) + argv = fv(['./program', '--listflag=' + listval, 'plain-arg']) self.assertEqual(['./program', 'plain-arg'], argv) - return FLAGS.listflag + return fv.listflag # Basic success case self.assertEqual(_check_parsing('foo,bar'), ['foo', 'bar']) # Success case: newline in argument is quoted. self.assertEqual(_check_parsing('"foo","bar\nbar"'), ['foo', 'bar\nbar']) # Failure case: newline in argument is unquoted. - self.assertRaises( - flags.IllegalFlagValueError, _check_parsing, '"foo",bar\nbar') + self.assertRaises(flags.IllegalFlagValueError, _check_parsing, + '"foo",bar\nbar') # Failure case: unmatched ". - self.assertRaises( - flags.IllegalFlagValueError, _check_parsing, '"foo,barbar') + self.assertRaises(flags.IllegalFlagValueError, _check_parsing, + '"foo,barbar') def test_flag_definition_via_setitem(self): with self.assertRaises(flags.IllegalFlagValueError): @@ -2432,16 +2516,14 @@ class KeyFlagsTest(absltest.TestCase): flag_values = flags.FlagValues() # Before starting any testing, make sure no flags are already # defined for module_foo and module_bar. - self.assertListEqual(self._get_names_of_key_flags(module_foo, flag_values), - []) - self.assertListEqual(self._get_names_of_key_flags(module_bar, flag_values), - []) - self.assertListEqual(self._get_names_of_defined_flags(module_foo, - flag_values), - []) - self.assertListEqual(self._get_names_of_defined_flags(module_bar, - flag_values), - []) + self.assertListEqual( + self._get_names_of_key_flags(module_foo, flag_values), []) + self.assertListEqual( + self._get_names_of_key_flags(module_bar, flag_values), []) + self.assertListEqual( + self._get_names_of_defined_flags(module_foo, flag_values), []) + self.assertListEqual( + self._get_names_of_defined_flags(module_bar, flag_values), []) # Defines a few flags in module_foo and module_bar. module_foo.define_flags(flag_values=flag_values) @@ -2501,12 +2583,8 @@ class KeyFlagsTest(absltest.TestCase): # Before starting any testing, make sure no flags are already # defined for module_foo and module_bar. - self.assertListEqual( - self._get_names_of_key_flags(module_bar, fv), - []) - self.assertListEqual( - self._get_names_of_defined_flags(module_bar, fv), - []) + self.assertListEqual(self._get_names_of_key_flags(module_bar, fv), []) + self.assertListEqual(self._get_names_of_defined_flags(module_bar, fv), []) module_bar.define_flags(flag_values=fv) @@ -2531,8 +2609,7 @@ class KeyFlagsTest(absltest.TestCase): flags.declare_key_flag(flag_name_0, flag_values=fv) self._assert_lists_have_same_elements( - self._get_names_of_key_flags(main_module, fv), - [flag_name_0]) + self._get_names_of_key_flags(main_module, fv), [flag_name_0]) flags.declare_key_flag(flag_name_2, flag_values=fv) self._assert_lists_have_same_elements( @@ -2569,9 +2646,11 @@ class KeyFlagsTest(absltest.TestCase): # Define one flag in this main module and some flags in modules # a and b. Also declare one flag from module a and one flag # from module b as key flags for the main module. - flags.DEFINE_integer('main_module_int_fg', 1, - 'Integer flag in the main module.', - flag_values=self.flag_values) + flags.DEFINE_integer( + 'main_module_int_fg', + 1, + 'Integer flag in the main module.', + flag_values=self.flag_values) try: main_module_int_fg_help = ( @@ -2611,9 +2690,7 @@ class KeyFlagsTest(absltest.TestCase): # main_module_help, so we can't keep incrementally extending # the expected_help string ... expected_help = ('\n%s:\n%s\n%s\n%s' % - (sys.argv[0], - main_module_int_fg_help, - tmod_bar_z_help, + (sys.argv[0], main_module_int_fg_help, tmod_bar_z_help, tmod_foo_bool_help)) self.assertMultiLineEqual(self.flag_values.main_module_help(), expected_help) @@ -2626,9 +2703,7 @@ class KeyFlagsTest(absltest.TestCase): def test_adoptmodule_key_flags(self): # Check that adopt_module_key_flags raises an exception when # called with a module name (as opposed to a module object). - self.assertRaises(flags.Error, - flags.adopt_module_key_flags, - 'pyglib.app') + self.assertRaises(flags.Error, flags.adopt_module_key_flags, 'pyglib.app') def test_disclaimkey_flags(self): original_disclaim_module_ids = _helpers.disclaim_module_ids @@ -2646,36 +2721,42 @@ class FindModuleTest(absltest.TestCase): """Testing methods that find a module that defines a given flag.""" def test_find_module_defining_flag(self): - self.assertEqual('default', FLAGS.find_module_defining_flag( - '__NON_EXISTENT_FLAG__', 'default')) self.assertEqual( - module_baz.__name__, FLAGS.find_module_defining_flag('tmod_baz_x')) + 'default', + FLAGS.find_module_defining_flag('__NON_EXISTENT_FLAG__', 'default')) + self.assertEqual(module_baz.__name__, + FLAGS.find_module_defining_flag('tmod_baz_x')) def test_find_module_id_defining_flag(self): - self.assertEqual('default', FLAGS.find_module_id_defining_flag( - '__NON_EXISTENT_FLAG__', 'default')) + self.assertEqual( + 'default', + FLAGS.find_module_id_defining_flag('__NON_EXISTENT_FLAG__', 'default')) self.assertEqual( id(module_baz), FLAGS.find_module_id_defining_flag('tmod_baz_x')) def test_find_module_defining_flag_passing_module_name(self): my_flags = flags.FlagValues() module_name = sys.__name__ # Must use an existing module. - flags.DEFINE_boolean('flag_name', True, - 'Flag with a different module name.', - flag_values=my_flags, - module_name=module_name) + flags.DEFINE_boolean( + 'flag_name', + True, + 'Flag with a different module name.', + flag_values=my_flags, + module_name=module_name) self.assertEqual(module_name, my_flags.find_module_defining_flag('flag_name')) def test_find_module_id_defining_flag_passing_module_name(self): my_flags = flags.FlagValues() module_name = sys.__name__ # Must use an existing module. - flags.DEFINE_boolean('flag_name', True, - 'Flag with a different module name.', - flag_values=my_flags, - module_name=module_name) - self.assertEqual(id(sys), - my_flags.find_module_id_defining_flag('flag_name')) + flags.DEFINE_boolean( + 'flag_name', + True, + 'Flag with a different module name.', + flag_values=my_flags, + module_name=module_name) + self.assertEqual( + id(sys), my_flags.find_module_id_defining_flag('flag_name')) class FlagsErrorMessagesTest(absltest.TestCase): @@ -2686,22 +2767,56 @@ class FlagsErrorMessagesTest(absltest.TestCase): def test_integer_error_text(self): # Make sure we get proper error text - flags.DEFINE_integer('positive', 4, 'non-negative flag', lower_bound=1, - flag_values=self.flag_values) - flags.DEFINE_integer('non_negative', 4, 'positive flag', lower_bound=0, - flag_values=self.flag_values) - flags.DEFINE_integer('negative', -4, 'negative flag', upper_bound=-1, - flag_values=self.flag_values) - flags.DEFINE_integer('non_positive', -4, 'non-positive flag', upper_bound=0, - flag_values=self.flag_values) - flags.DEFINE_integer('greater', 19, 'greater-than flag', lower_bound=4, - flag_values=self.flag_values) - flags.DEFINE_integer('smaller', -19, 'smaller-than flag', upper_bound=4, - flag_values=self.flag_values) - flags.DEFINE_integer('usual', 4, 'usual flag', lower_bound=0, - upper_bound=10000, flag_values=self.flag_values) - flags.DEFINE_integer('another_usual', 0, 'usual flag', lower_bound=-1, - upper_bound=1, flag_values=self.flag_values) + flags.DEFINE_integer( + 'positive', + 4, + 'non-negative flag', + lower_bound=1, + flag_values=self.flag_values) + flags.DEFINE_integer( + 'non_negative', + 4, + 'positive flag', + lower_bound=0, + flag_values=self.flag_values) + flags.DEFINE_integer( + 'negative', + -4, + 'negative flag', + upper_bound=-1, + flag_values=self.flag_values) + flags.DEFINE_integer( + 'non_positive', + -4, + 'non-positive flag', + upper_bound=0, + flag_values=self.flag_values) + flags.DEFINE_integer( + 'greater', + 19, + 'greater-than flag', + lower_bound=4, + flag_values=self.flag_values) + flags.DEFINE_integer( + 'smaller', + -19, + 'smaller-than flag', + upper_bound=4, + flag_values=self.flag_values) + flags.DEFINE_integer( + 'usual', + 4, + 'usual flag', + lower_bound=0, + upper_bound=10000, + flag_values=self.flag_values) + flags.DEFINE_integer( + 'another_usual', + 0, + 'usual flag', + lower_bound=-1, + upper_bound=1, + flag_values=self.flag_values) self._check_error_message('positive', -4, 'a positive integer') self._check_error_message('non_negative', -4, 'a non-negative integer') @@ -2714,22 +2829,56 @@ class FlagsErrorMessagesTest(absltest.TestCase): self._check_error_message('smaller', 5, 'integer <= 4') def test_float_error_text(self): - flags.DEFINE_float('positive', 4, 'non-negative flag', lower_bound=1, - flag_values=self.flag_values) - flags.DEFINE_float('non_negative', 4, 'positive flag', lower_bound=0, - flag_values=self.flag_values) - flags.DEFINE_float('negative', -4, 'negative flag', upper_bound=-1, - flag_values=self.flag_values) - flags.DEFINE_float('non_positive', -4, 'non-positive flag', upper_bound=0, - flag_values=self.flag_values) - flags.DEFINE_float('greater', 19, 'greater-than flag', lower_bound=4, - flag_values=self.flag_values) - flags.DEFINE_float('smaller', -19, 'smaller-than flag', upper_bound=4, - flag_values=self.flag_values) - flags.DEFINE_float('usual', 4, 'usual flag', lower_bound=0, - upper_bound=10000, flag_values=self.flag_values) - flags.DEFINE_float('another_usual', 0, 'usual flag', lower_bound=-1, - upper_bound=1, flag_values=self.flag_values) + flags.DEFINE_float( + 'positive', + 4, + 'non-negative flag', + lower_bound=1, + flag_values=self.flag_values) + flags.DEFINE_float( + 'non_negative', + 4, + 'positive flag', + lower_bound=0, + flag_values=self.flag_values) + flags.DEFINE_float( + 'negative', + -4, + 'negative flag', + upper_bound=-1, + flag_values=self.flag_values) + flags.DEFINE_float( + 'non_positive', + -4, + 'non-positive flag', + upper_bound=0, + flag_values=self.flag_values) + flags.DEFINE_float( + 'greater', + 19, + 'greater-than flag', + lower_bound=4, + flag_values=self.flag_values) + flags.DEFINE_float( + 'smaller', + -19, + 'smaller-than flag', + upper_bound=4, + flag_values=self.flag_values) + flags.DEFINE_float( + 'usual', + 4, + 'usual flag', + lower_bound=0, + upper_bound=10000, + flag_values=self.flag_values) + flags.DEFINE_float( + 'another_usual', + 0, + 'usual flag', + lower_bound=-1, + upper_bound=1, + flag_values=self.flag_values) self._check_error_message('positive', 0.5, 'number >= 1') self._check_error_message('non_negative', -4.0, 'a non-negative number') @@ -2748,9 +2897,11 @@ class FlagsErrorMessagesTest(absltest.TestCase): self.flag_values.__setattr__(flag_name, flag_value) raise AssertionError('Bounds exception not raised!') except flags.IllegalFlagValueError as e: - expected = ('flag --%(name)s=%(value)s: %(value)s is not %(suffix)s' % - {'name': flag_name, 'value': flag_value, - 'suffix': expected_message_suffix}) + expected = ('flag --%(name)s=%(value)s: %(value)s is not %(suffix)s' % { + 'name': flag_name, + 'value': flag_value, + 'suffix': expected_message_suffix + }) self.assertEqual(str(e), expected) |