diff options
-rw-r--r-- | absl/flags/BUILD | 4 | ||||
-rw-r--r-- | absl/flags/__init__.py | 1 | ||||
-rw-r--r-- | absl/flags/_defines.py | 441 | ||||
-rw-r--r-- | absl/flags/_exceptions.py | 4 | ||||
-rw-r--r-- | absl/flags/_flagvalues.py | 85 | ||||
-rw-r--r-- | absl/flags/tests/_flagvalues_test.py | 47 |
6 files changed, 431 insertions, 151 deletions
diff --git a/absl/flags/BUILD b/absl/flags/BUILD index 1dabe14..47a4393 100644 --- a/absl/flags/BUILD +++ b/absl/flags/BUILD @@ -1,9 +1,9 @@ +load("//absl:_build_defs.bzl", "py2and3_test") + licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//absl:_build_defs.bzl", "py2and3_test") - py_library( name = "flags", srcs = ["__init__.py"], diff --git a/absl/flags/__init__.py b/absl/flags/__init__.py index cbfc019..226f4f1 100644 --- a/absl/flags/__init__.py +++ b/absl/flags/__init__.py @@ -112,6 +112,7 @@ EnumFlag = _flag.EnumFlag EnumClassFlag = _flag.EnumClassFlag MultiFlag = _flag.MultiFlag MultiEnumClassFlag = _flag.MultiEnumClassFlag +FlagHolder = _flagvalues.FlagHolder FlagValues = _flagvalues.FlagValues ArgumentParser = _argument_parser.ArgumentParser BooleanParser = _argument_parser.BooleanParser diff --git a/absl/flags/_defines.py b/absl/flags/_defines.py index b95c419..6c209a8 100644 --- a/absl/flags/_defines.py +++ b/absl/flags/_defines.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """This modules contains flags DEFINE functions. Do NOT import this module directly. Import the flags package and use the @@ -32,6 +31,17 @@ from absl.flags import _flagvalues from absl.flags import _helpers from absl.flags import _validators +# pylint: disable=unused-import +try: + from typing import Text, List, Any +except ImportError: + pass + +try: + import enum +except ImportError: + pass +# pylint: enable=unused-import _helpers.disclaim_module_ids.add(id(sys.modules[__name__])) @@ -41,7 +51,7 @@ def _register_bounds_validator_if_needed(parser, name, flag_values): Args: parser: NumericParser (either FloatParser or IntegerParser), provides lower - and upper bounds, and help text to display. + and upper bounds, and help text to display. name: str, name of the flag flag_values: FlagValues. """ @@ -56,8 +66,15 @@ def _register_bounds_validator_if_needed(parser, name, flag_values): _validators.register_validator(name, checker, flag_values=flag_values) -def DEFINE(parser, name, default, help, flag_values=_flagvalues.FLAGS, # pylint: disable=redefined-builtin,invalid-name - serializer=None, module_name=None, **args): +def DEFINE( # pylint: disable=invalid-name + parser, + name, + default, + help, # pylint: disable=redefined-builtin + flag_values=_flagvalues.FLAGS, + serializer=None, + module_name=None, + **args): """Registers a generic Flag object. NOTE: in the docstrings of all DEFINE* functions, "registers" is short @@ -71,15 +88,19 @@ def DEFINE(parser, name, default, help, flag_values=_flagvalues.FLAGS, # pylint name: str, the flag name. default: The default value of the flag. help: str, the help message. - flag_values: FlagValues, the FlagValues instance with which the flag will - be registered. This should almost never need to be overridden. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. serializer: ArgumentSerializer, the flag serializer instance. - module_name: str, the name of the Python module declaring this flag. - If not provided, it will be computed using the stack trace of this call. + module_name: str, the name of the Python module declaring this flag. If not + provided, it will be computed using the stack trace of this call. **args: dict, the extra keyword args that are passed to Flag __init__. + + Returns: + a handle to defined flag. """ - DEFINE_flag(_flag.Flag(parser, serializer, name, default, help, **args), - flag_values, module_name) + return DEFINE_flag( + _flag.Flag(parser, serializer, name, default, help, **args), flag_values, + module_name) def DEFINE_flag(flag, flag_values=_flagvalues.FLAGS, module_name=None): # pylint: disable=invalid-name @@ -94,10 +115,13 @@ def DEFINE_flag(flag, flag_values=_flagvalues.FLAGS, module_name=None): # pylin Args: flag: Flag, a flag that is key to the module. - flag_values: FlagValues, the FlagValues instance with which the flag will - be registered. This should almost never need to be overridden. - module_name: str, the name of the Python module declaring this flag. - If not provided, it will be computed using the stack trace of this call. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + module_name: str, the name of the Python module declaring this flag. If not + provided, it will be computed using the stack trace of this call. + + Returns: + a handle to defined flag. """ # Copying the reference to flag_values prevents pychecker warnings. fv = flag_values @@ -109,10 +133,12 @@ def DEFINE_flag(flag, flag_values=_flagvalues.FLAGS, module_name=None): # pylin module, module_name = _helpers.get_calling_module_object_and_name() flag_values.register_flag_by_module(module_name, flag) flag_values.register_flag_by_module_id(id(module), flag) + return _flagvalues.FlagHolder(fv, flag.name) -def _internal_declare_key_flags( - flag_names, flag_values=_flagvalues.FLAGS, key_flag_values=None): +def _internal_declare_key_flags(flag_names, + flag_values=_flagvalues.FLAGS, + key_flag_values=None): """Declares a flag as key for the calling module. Internal function. User code should call declare_key_flag or @@ -120,15 +146,15 @@ def _internal_declare_key_flags( Args: flag_names: [str], a list of strings that are names of already-registered - Flag objects. + Flag objects. flag_values: FlagValues, the FlagValues instance with which the flags listed - in flag_names have registered (the value of the flag_values - argument from the DEFINE_* calls that defined those flags). - This should almost never need to be overridden. + in flag_names have registered (the value of the flag_values argument from + the DEFINE_* calls that defined those flags). This should almost never + need to be overridden. key_flag_values: FlagValues, the FlagValues instance that (among possibly - many other things) keeps track of the key flags for each module. - Default None means "same as flag_values". This should almost - never need to be overridden. + many other things) keeps track of the key flags for each module. Default + None means "same as flag_values". This should almost never need to be + overridden. Raises: UnrecognizedFlagError: Raised when the flag is not defined. @@ -156,12 +182,11 @@ def declare_key_flag(flag_name, flag_values=_flagvalues.FLAGS): flags.declare_key_flag('flag_1') Args: - flag_name: str, the name of an already declared flag. - (Redeclaring flags as key, including flags implicitly key - because they were declared in this module, is a no-op.) - flag_values: FlagValues, the FlagValues instance in which the flag will - be declared as a key flag. This should almost never need to be - overridden. + flag_name: str, the name of an already declared flag. (Redeclaring flags as + key, including flags implicitly key because they were declared in this + module, is a no-op.) + flag_values: FlagValues, the FlagValues instance in which the flag will be + declared as a key flag. This should almost never need to be overridden. Raises: ValueError: Raised if flag_name not defined as a Python flag. @@ -187,10 +212,9 @@ def adopt_module_key_flags(module, flag_values=_flagvalues.FLAGS): Args: module: module, the module object from which all key flags will be declared - as key flags to the current module. - flag_values: FlagValues, the FlagValues instance in which the flags will - be declared as key flags. This should almost never need to be - overridden. + as key flags to the current module. + flag_values: FlagValues, the FlagValues instance in which the flags will be + declared as key flags. This should almost never need to be overridden. Raises: Error: Raised when given an argument that is a module name (a string), @@ -234,16 +258,24 @@ def disclaim_key_flags(): def DEFINE_string( # pylint: disable=invalid-name,redefined-builtin - name, default, help, flag_values=_flagvalues.FLAGS, **args): + name, + default, + help, + flag_values=_flagvalues.FLAGS, + **args): # type: (...) -> _flagvalues.FlagHolder[Text] """Registers a flag whose value can be any string.""" parser = _argument_parser.ArgumentParser() serializer = _argument_parser.ArgumentSerializer() - DEFINE(parser, name, default, help, flag_values, serializer, **args) + return DEFINE(parser, name, default, help, flag_values, serializer, **args) def DEFINE_boolean( # pylint: disable=invalid-name,redefined-builtin - name, default, help, flag_values=_flagvalues.FLAGS, module_name=None, - **args): + name, + default, + help, + flag_values=_flagvalues.FLAGS, + module_name=None, + **args): # type: (...) -> _flagvalues.FlagHolder[bool] """Registers a boolean flag. Such a boolean flag does not take an argument. If a user wants to @@ -258,19 +290,27 @@ def DEFINE_boolean( # pylint: disable=invalid-name,redefined-builtin name: str, the flag name. default: bool|str|None, the default value of the flag. help: str, the help message. - flag_values: FlagValues, the FlagValues instance with which the flag will - be registered. This should almost never need to be overridden. - module_name: str, the name of the Python module declaring this flag. - If not provided, it will be computed using the stack trace of this call. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + module_name: str, the name of the Python module declaring this flag. If not + provided, it will be computed using the stack trace of this call. **args: dict, the extra keyword args that are passed to Flag __init__. + + Returns: + a handle to defined flag. """ - DEFINE_flag(_flag.BooleanFlag(name, default, help, **args), - flag_values, module_name) + return DEFINE_flag( + _flag.BooleanFlag(name, default, help, **args), flag_values, module_name) def DEFINE_float( # pylint: disable=invalid-name,redefined-builtin - name, default, help, lower_bound=None, upper_bound=None, - flag_values=_flagvalues.FLAGS, **args): # pylint: disable=invalid-name + name, + default, + help, + lower_bound=None, + upper_bound=None, + flag_values=_flagvalues.FLAGS, + **args): # type: (...) -> _flagvalues.FlagHolder[Text] """Registers a flag whose value must be a float. If lower_bound or upper_bound are set, then this flag must be @@ -282,19 +322,28 @@ def DEFINE_float( # pylint: disable=invalid-name,redefined-builtin help: str, the help message. lower_bound: float, min value of the flag. upper_bound: float, max value of the flag. - flag_values: FlagValues, the FlagValues instance with which the flag will - be registered. This should almost never need to be overridden. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. **args: dict, the extra keyword args that are passed to DEFINE. + + Returns: + a handle to defined flag. """ parser = _argument_parser.FloatParser(lower_bound, upper_bound) serializer = _argument_parser.ArgumentSerializer() - DEFINE(parser, name, default, help, flag_values, serializer, **args) + result = DEFINE(parser, name, default, help, flag_values, serializer, **args) _register_bounds_validator_if_needed(parser, name, flag_values=flag_values) + return result def DEFINE_integer( # pylint: disable=invalid-name,redefined-builtin - name, default, help, lower_bound=None, upper_bound=None, - flag_values=_flagvalues.FLAGS, **args): + name, + default, + help, + lower_bound=None, + upper_bound=None, + flag_values=_flagvalues.FLAGS, + **args): # type: (...) -> _flagvalues.FlagHolder[int] """Registers a flag whose value must be an integer. If lower_bound, or upper_bound are set, then this flag must be @@ -306,19 +355,28 @@ def DEFINE_integer( # pylint: disable=invalid-name,redefined-builtin help: str, the help message. lower_bound: int, min value of the flag. upper_bound: int, max value of the flag. - flag_values: FlagValues, the FlagValues instance with which the flag will - be registered. This should almost never need to be overridden. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. **args: dict, the extra keyword args that are passed to DEFINE. + + Returns: + a handle to defined flag. """ parser = _argument_parser.IntegerParser(lower_bound, upper_bound) serializer = _argument_parser.ArgumentSerializer() - DEFINE(parser, name, default, help, flag_values, serializer, **args) + result = DEFINE(parser, name, default, help, flag_values, serializer, **args) _register_bounds_validator_if_needed(parser, name, flag_values=flag_values) + return result def DEFINE_enum( # pylint: disable=invalid-name,redefined-builtin - name, default, enum_values, help, flag_values=_flagvalues.FLAGS, - module_name=None, **args): + name, + default, + enum_values, + help, + flag_values=_flagvalues.FLAGS, + module_name=None, + **args): # type: (...) -> _flagvalues.FlagHolder[Text] """Registers a flag whose value can be any string from enum_values. Instead of a string enum, prefer `DEFINE_enum_class`, which allows @@ -328,21 +386,30 @@ def DEFINE_enum( # pylint: disable=invalid-name,redefined-builtin name: str, the flag name. default: str|None, the default value of the flag. enum_values: [str], a non-empty list of strings with the possible values for - the flag. + the flag. help: str, the help message. - flag_values: FlagValues, the FlagValues instance with which the flag will - be registered. This should almost never need to be overridden. - module_name: str, the name of the Python module declaring this flag. - If not provided, it will be computed using the stack trace of this call. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + module_name: str, the name of the Python module declaring this flag. If not + provided, it will be computed using the stack trace of this call. **args: dict, the extra keyword args that are passed to Flag __init__. + + Returns: + a handle to defined flag. """ - DEFINE_flag(_flag.EnumFlag(name, default, help, enum_values, **args), - flag_values, module_name) + return DEFINE_flag( + _flag.EnumFlag(name, default, help, enum_values, **args), flag_values, + module_name) def DEFINE_enum_class( # pylint: disable=invalid-name,redefined-builtin - name, default, enum_class, help, flag_values=_flagvalues.FLAGS, - module_name=None, **args): + name, + default, + enum_class, + help, + flag_values=_flagvalues.FLAGS, + module_name=None, + **args): # type: (...) -> _flagvalues.FlagHolder[enum.Enum] """Registers a flag whose value can be the name of enum members. Args: @@ -350,18 +417,26 @@ def DEFINE_enum_class( # pylint: disable=invalid-name,redefined-builtin default: Enum|str|None, the default value of the flag. enum_class: class, the Enum class with all the possible values for the flag. help: str, the help message. - flag_values: FlagValues, the FlagValues instance with which the flag will - be registered. This should almost never need to be overridden. - module_name: str, the name of the Python module declaring this flag. - If not provided, it will be computed using the stack trace of this call. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + module_name: str, the name of the Python module declaring this flag. If not + provided, it will be computed using the stack trace of this call. **args: dict, the extra keyword args that are passed to Flag __init__. + + Returns: + a handle to defined flag. """ - DEFINE_flag(_flag.EnumClassFlag(name, default, help, enum_class, **args), - flag_values, module_name) + return DEFINE_flag( + _flag.EnumClassFlag(name, default, help, enum_class, **args), flag_values, + module_name) def DEFINE_list( # pylint: disable=invalid-name,redefined-builtin - name, default, help, flag_values=_flagvalues.FLAGS, **args): + name, + default, + help, + flag_values=_flagvalues.FLAGS, + **args): # type: (...) -> _flagvalues.FlagHolder[List[Text]] """Registers a flag whose value is a comma-separated list of strings. The flag value is parsed with a CSV parser. @@ -370,19 +445,26 @@ def DEFINE_list( # pylint: disable=invalid-name,redefined-builtin name: str, the flag name. default: list|str|None, the default value of the flag. help: str, the help message. - flag_values: FlagValues, the FlagValues instance with which the flag will - be registered. This should almost never need to be overridden. - **args: Dictionary with extra keyword args that are passed to the - Flag __init__. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + **args: Dictionary with extra keyword args that are passed to the Flag + __init__. + + Returns: + a handle to defined flag. """ parser = _argument_parser.ListParser() serializer = _argument_parser.CsvListSerializer(',') - DEFINE(parser, name, default, help, flag_values, serializer, **args) + return DEFINE(parser, name, default, help, flag_values, serializer, **args) def DEFINE_spaceseplist( # pylint: disable=invalid-name,redefined-builtin - name, default, help, comma_compat=False, flag_values=_flagvalues.FLAGS, - **args): + name, + default, + help, + comma_compat=False, + flag_values=_flagvalues.FLAGS, + **args): # type: (...) -> _flagvalues.FlagHolder[List[Text]] """Registers a flag whose value is a whitespace-separated list of strings. Any whitespace can be used as a separator. @@ -391,23 +473,32 @@ def DEFINE_spaceseplist( # pylint: disable=invalid-name,redefined-builtin name: str, the flag name. default: list|str|None, the default value of the flag. help: str, the help message. - comma_compat: bool - Whether to support comma as an additional separator. - If false then only whitespace is supported. This is intended only for - backwards compatibility with flags that used to be comma-separated. - flag_values: FlagValues, the FlagValues instance with which the flag will - be registered. This should almost never need to be overridden. - **args: Dictionary with extra keyword args that are passed to the - Flag __init__. + comma_compat: bool - Whether to support comma as an additional separator. If + false then only whitespace is supported. This is intended only for + backwards compatibility with flags that used to be comma-separated. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + **args: Dictionary with extra keyword args that are passed to the Flag + __init__. + + Returns: + a handle to defined flag. """ parser = _argument_parser.WhitespaceSeparatedListParser( comma_compat=comma_compat) serializer = _argument_parser.ListSerializer(' ') - DEFINE(parser, name, default, help, flag_values, serializer, **args) + return DEFINE(parser, name, default, help, flag_values, serializer, **args) def DEFINE_multi( # pylint: disable=invalid-name,redefined-builtin - parser, serializer, name, default, help, flag_values=_flagvalues.FLAGS, - module_name=None, **args): + parser, + serializer, + name, + default, + help, + flag_values=_flagvalues.FLAGS, + module_name=None, + **args): # type: (...) -> _flagvalues.FlagHolder[List] """Registers a generic MultiFlag that parses its args with a given parser. Auxiliary function. Normal users should NOT use it directly. @@ -420,25 +511,33 @@ def DEFINE_multi( # pylint: disable=invalid-name,redefined-builtin parser: ArgumentParser, used to parse the flag arguments. serializer: ArgumentSerializer, the flag serializer instance. name: str, the flag name. - default: Union[Iterable[T], Text, None], the default value of the flag. - If the value is text, it will be parsed as if it was provided from - the command line. If the value is a non-string iterable, it will be - iterated over to create a shallow copy of the values. If it is None, - it is left as-is. + default: Union[Iterable[T], Text, None], the default value of the flag. If + the value is text, it will be parsed as if it was provided from the + command line. If the value is a non-string iterable, it will be iterated + over to create a shallow copy of the values. If it is None, it is left + as-is. help: str, the help message. - flag_values: FlagValues, the FlagValues instance with which the flag will - be registered. This should almost never need to be overridden. - module_name: A string, the name of the Python module declaring this flag. - If not provided, it will be computed using the stack trace of this call. - **args: Dictionary with extra keyword args that are passed to the - Flag __init__. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + module_name: A string, the name of the Python module declaring this flag. If + not provided, it will be computed using the stack trace of this call. + **args: Dictionary with extra keyword args that are passed to the Flag + __init__. + + Returns: + a handle to defined flag. """ - DEFINE_flag(_flag.MultiFlag(parser, serializer, name, default, help, **args), - flag_values, module_name) + return DEFINE_flag( + _flag.MultiFlag(parser, serializer, name, default, help, **args), + flag_values, module_name) def DEFINE_multi_string( # pylint: disable=invalid-name,redefined-builtin - name, default, help, flag_values=_flagvalues.FLAGS, **args): + name, + default, + help, + flag_values=_flagvalues.FLAGS, + **args): # type: (...) -> _flagvalues.FlagHolder[List[Text]] """Registers a flag whose value can be a list of any strings. Use the flag on the command line multiple times to place multiple @@ -450,21 +549,30 @@ def DEFINE_multi_string( # pylint: disable=invalid-name,redefined-builtin Args: name: str, the flag name. default: Union[Iterable[Text], Text, None], the default value of the flag; - see `DEFINE_multi`. + see `DEFINE_multi`. help: str, the help message. - flag_values: FlagValues, the FlagValues instance with which the flag will - be registered. This should almost never need to be overridden. - **args: Dictionary with extra keyword args that are passed to the - Flag __init__. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + **args: Dictionary with extra keyword args that are passed to the Flag + __init__. + + Returns: + a handle to defined flag. """ parser = _argument_parser.ArgumentParser() serializer = _argument_parser.ArgumentSerializer() - DEFINE_multi(parser, serializer, name, default, help, flag_values, **args) + return DEFINE_multi(parser, serializer, name, default, help, flag_values, + **args) def DEFINE_multi_integer( # pylint: disable=invalid-name,redefined-builtin - name, default, help, lower_bound=None, upper_bound=None, - flag_values=_flagvalues.FLAGS, **args): + name, + default, + help, + lower_bound=None, + upper_bound=None, + flag_values=_flagvalues.FLAGS, + **args): # type: (...) -> _flagvalues.FlagHolder[List[int]] """Registers a flag whose value can be a list of arbitrary integers. Use the flag on the command line multiple times to place multiple @@ -475,23 +583,32 @@ def DEFINE_multi_integer( # pylint: disable=invalid-name,redefined-builtin Args: name: str, the flag name. default: Union[Iterable[int], Text, None], the default value of the flag; - see `DEFINE_multi`. + see `DEFINE_multi`. help: str, the help message. lower_bound: int, min values of the flag. upper_bound: int, max values of the flag. - flag_values: FlagValues, the FlagValues instance with which the flag will - be registered. This should almost never need to be overridden. - **args: Dictionary with extra keyword args that are passed to the - Flag __init__. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + **args: Dictionary with extra keyword args that are passed to the Flag + __init__. + + Returns: + a handle to defined flag. """ parser = _argument_parser.IntegerParser(lower_bound, upper_bound) serializer = _argument_parser.ArgumentSerializer() - DEFINE_multi(parser, serializer, name, default, help, flag_values, **args) + return DEFINE_multi(parser, serializer, name, default, help, flag_values, + **args) def DEFINE_multi_float( # pylint: disable=invalid-name,redefined-builtin - name, default, help, lower_bound=None, upper_bound=None, - flag_values=_flagvalues.FLAGS, **args): + name, + default, + help, + lower_bound=None, + upper_bound=None, + flag_values=_flagvalues.FLAGS, + **args): # type: (...) -> _flagvalues.FlagHolder[List[float]] """Registers a flag whose value can be a list of arbitrary floats. Use the flag on the command line multiple times to place multiple @@ -502,23 +619,32 @@ def DEFINE_multi_float( # pylint: disable=invalid-name,redefined-builtin Args: name: str, the flag name. default: Union[Iterable[float], Text, None], the default value of the flag; - see `DEFINE_multi`. + see `DEFINE_multi`. help: str, the help message. lower_bound: float, min values of the flag. upper_bound: float, max values of the flag. - flag_values: FlagValues, the FlagValues instance with which the flag will - be registered. This should almost never need to be overridden. - **args: Dictionary with extra keyword args that are passed to the - Flag __init__. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + **args: Dictionary with extra keyword args that are passed to the Flag + __init__. + + Returns: + a handle to defined flag. """ parser = _argument_parser.FloatParser(lower_bound, upper_bound) serializer = _argument_parser.ArgumentSerializer() - DEFINE_multi(parser, serializer, name, default, help, flag_values, **args) + return DEFINE_multi(parser, serializer, name, default, help, flag_values, + **args) def DEFINE_multi_enum( # pylint: disable=invalid-name,redefined-builtin - name, default, enum_values, help, flag_values=_flagvalues.FLAGS, - case_sensitive=True, **args): + name, + default, + enum_values, + help, + flag_values=_flagvalues.FLAGS, + case_sensitive=True, + **args): # type: (...) -> _flagvalues.FlagHolder[List[Text]] """Registers a flag whose value can be a list strings from enum_values. Use the flag on the command line multiple times to place multiple @@ -529,19 +655,23 @@ def DEFINE_multi_enum( # pylint: disable=invalid-name,redefined-builtin Args: name: str, the flag name. default: Union[Iterable[Text], Text, None], the default value of the flag; - see `DEFINE_multi`. + see `DEFINE_multi`. enum_values: [str], a non-empty list of strings with the possible values for - the flag. + the flag. help: str, the help message. - flag_values: FlagValues, the FlagValues instance with which the flag will - be registered. This should almost never need to be overridden. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. case_sensitive: Whether or not the enum is to be case-sensitive. - **args: Dictionary with extra keyword args that are passed to the - Flag __init__. + **args: Dictionary with extra keyword args that are passed to the Flag + __init__. + + Returns: + a handle to defined flag. """ parser = _argument_parser.EnumParser(enum_values, case_sensitive) serializer = _argument_parser.ArgumentSerializer() - DEFINE_multi(parser, serializer, name, default, help, flag_values, **args) + return DEFINE_multi(parser, serializer, name, default, help, flag_values, + **args) def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin @@ -551,7 +681,7 @@ def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin help, flag_values=_flagvalues.FLAGS, module_name=None, - **args): + **args): # type: (...) -> _flagvalues.FlagHolder[List[enum.Enum]] """Registers a flag whose value can be a list of enum members. Use the flag on the command line multiple times to place multiple @@ -560,11 +690,10 @@ def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin Args: name: str, the flag name. default: Union[Iterable[Enum], Iterable[Text], Enum, Text, None], the - default value of the flag; see - `DEFINE_multi`; only differences are documented here. If the value is - a single Enum, it is treated as a single-item list of that Enum value. - If it is an iterable, text values within the iterable will be converted - to the equivalent Enum objects. + default value of the flag; see `DEFINE_multi`; only differences are + documented here. If the value is a single Enum, it is treated as a + single-item list of that Enum value. If it is an iterable, text values + within the iterable will be converted to the equivalent Enum objects. enum_class: class, the Enum class with all the possible values for the flag. help: str, the help message. flag_values: FlagValues, the FlagValues instance with which the flag will be @@ -573,23 +702,32 @@ def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin not provided, it will be computed using the stack trace of this call. **args: Dictionary with extra keyword args that are passed to the Flag __init__. + + Returns: + a handle to defined flag. """ - DEFINE_flag( - _flag.MultiEnumClassFlag(name, default, help, enum_class), - flag_values, module_name, **args) + return DEFINE_flag( + _flag.MultiEnumClassFlag(name, default, help, enum_class), flag_values, + module_name, **args) -def DEFINE_alias(name, original_name, flag_values=_flagvalues.FLAGS, # pylint: disable=invalid-name - module_name=None): +def DEFINE_alias( # pylint: disable=invalid-name + name, + original_name, + flag_values=_flagvalues.FLAGS, + module_name=None): # type: (...) -> _flagvalues.FlagHolder[Any] """Defines an alias flag for an existing one. Args: name: str, the flag name. original_name: str, the original flag name. - flag_values: FlagValues, the FlagValues instance with which the flag will - be registered. This should almost never need to be overridden. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. module_name: A string, the name of the module that defines this flag. + Returns: + a handle to defined flag. + Raises: flags.FlagError: UnrecognizedFlagError: if the referenced flag doesn't exist. @@ -619,6 +757,11 @@ def DEFINE_alias(name, original_name, flag_values=_flagvalues.FLAGS, # pylint: help_msg = 'Alias for --%s.' % flag.name # If alias_name has been used, flags.DuplicatedFlag will be raised. - DEFINE_flag(_FlagAlias(_Parser(), flag.serializer, name, flag.default, - help_msg, boolean=flag.boolean), - flag_values, module_name) + return DEFINE_flag( + _FlagAlias( + _Parser(), + flag.serializer, + name, + flag.default, + help_msg, + boolean=flag.boolean), flag_values, module_name) diff --git a/absl/flags/_exceptions.py b/absl/flags/_exceptions.py index 254eb9b..e95a893 100644 --- a/absl/flags/_exceptions.py +++ b/absl/flags/_exceptions.py @@ -110,3 +110,7 @@ class ValidationError(Error): class FlagNameConflictsWithMethodError(Error): """Raised when a flag name conflicts with FlagValues methods.""" + + +class NoneValueError(Error): + """Raised when a flag unexpectedly has None value.""" diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py index 3473cd0..1f93351 100644 --- a/absl/flags/_flagvalues.py +++ b/absl/flags/_flagvalues.py @@ -34,6 +34,14 @@ from absl.flags import _flag from absl.flags import _helpers import six +# pylint: disable=unused-import +try: + import typing + from typing import Text, Optional +except ImportError: + typing = None +# pylint: enable=unused-import + # Add flagvalues module to disclaimed module ids. _helpers.disclaim_module_ids.add(id(sys.modules[__name__])) @@ -1277,3 +1285,80 @@ class FlagValues(object): FLAGS = FlagValues() + +if typing: + _T = typing.TypeVar('_T') + _Base = typing.Generic[_T] +else: + _Base = object + + +class FlagHolder(_Base): + """Holds a defined flag. + + This facilitates a cleaner api around global state. Instead of + + ``` + flags.DEFINE_integer('foo', ...) + flags.DEFINE_integer('bar', ...) + ... + def method(): + # prints parsed value of 'bar' flag + print(flags.FLAGS.foo) + # runtime error due to typo or possibly bad coding style. + print(flags.FLAGS.baz) + ``` + + it encourages code like + + ``` + FOO_FLAG = flags.DEFINE_integer('foo', ...) + BAR_FLAG = flags.DEFINE_integer('bar', ...) + ... + def method(): + print(FOO_FLAG.value) + print(BAR_FLAG.value) + ``` + + since the name of the flag appears only once in the source code. + """ + + def __init__(self, flag_values, name): + # type: (FlagValues, Text) -> None + self._flagvalues = flag_values + self._name = name + + @property + def name(self): + # type: () -> Text + return self._name + + @property + def value(self): + # type: () -> Optional[T] + """Returns the value of the flag. + + Raises: + UnparsedFlagAccessError: if flag parsing has not finished. + """ + return getattr(self._flagvalues, self._name) + + @property + def non_none_value(self): + # type: () -> T + """Returns the value of the flag after checking it is not None. + + Raises: + UnparsedFlagAccessError: if flag parsing has not finished. + NoneValueError: if flag had a None value. + """ + value = self.value + if value is None: + raise _exceptions.NoneValueError('Flag %s is set to None' % self._name) + return value + + @property + def default(self): + # type: () -> Optional[T] + """Returns the default value of the flag.""" + return self._flagvalues[self._name].default diff --git a/absl/flags/tests/_flagvalues_test.py b/absl/flags/tests/_flagvalues_test.py index c3d9549..978b8a5 100644 --- a/absl/flags/tests/_flagvalues_test.py +++ b/absl/flags/tests/_flagvalues_test.py @@ -835,5 +835,52 @@ class UnparsedFlagAccessTest(absltest.TestCase): _ = fv.a_str +class FlagHolderTest(absltest.TestCase): + + def setUp(self): + super(FlagHolderTest, self).setUp() + self.fv = _flagvalues.FlagValues() + self.name_flag = _defines.DEFINE_string( + 'name', 'default', 'help', flag_values=self.fv) + + def parse_flags(self, *argv): + self.fv.unparse_flags() + self.fv(['binary_name'] + list(argv)) + + def test_name(self): + self.assertEqual('name', self.name_flag.name) + + def test_value_before_flag_parsing(self): + with self.assertRaises(_exceptions.UnparsedFlagAccessError): + _ = self.name_flag.value + + def test_value_returns_default_value_if_not_explicitly_set(self): + self.parse_flags() + self.assertEqual('default', self.name_flag.value) + + def test_value_returns_explicitly_set_value(self): + self.parse_flags('--name=new_value') + self.assertEqual('new_value', self.name_flag.value) + + def test_non_none_value_fails_if_value_is_none(self): + self.parse_flags() + self.fv.name = None + with self.assertRaises(_exceptions.NoneValueError): + _ = self.name_flag.non_none_value + + def test_non_none_value(self): + self.parse_flags('--name=default') + self.assertEqual('default', self.name_flag.non_none_value) + + def test_allow_override(self): + first = _defines.DEFINE_integer( + 'int_flag', 1, 'help', flag_values=self.fv, allow_override=1) + second = _defines.DEFINE_integer( + 'int_flag', 2, 'help', flag_values=self.fv, allow_override=1) + self.parse_flags('--int_flag=3') + self.assertEqual(3, first.value) + self.assertEqual(3, second.value) + + if __name__ == '__main__': absltest.main() |