diff options
Diffstat (limited to 'absl')
-rw-r--r-- | absl/CHANGELOG.md | 21 | ||||
-rw-r--r-- | absl/app.py | 11 | ||||
-rw-r--r-- | absl/app.pyi | 99 | ||||
-rw-r--r-- | absl/flags/BUILD | 13 | ||||
-rw-r--r-- | absl/flags/_argument_parser.pyi | 3 | ||||
-rw-r--r-- | absl/flags/_flag.pyi | 2 | ||||
-rw-r--r-- | absl/flags/_flagvalues.py | 17 | ||||
-rw-r--r-- | absl/flags/_validators.py | 156 | ||||
-rw-r--r-- | absl/flags/_validators_classes.py | 176 | ||||
-rw-r--r-- | absl/flags/argparse_flags.py | 1 | ||||
-rw-r--r-- | absl/flags/tests/_validators_test.py | 142 | ||||
-rw-r--r-- | absl/logging/__init__.py | 6 | ||||
-rw-r--r-- | absl/testing/absltest.py | 147 | ||||
-rw-r--r-- | absl/testing/tests/absltest_test.py | 239 |
14 files changed, 816 insertions, 217 deletions
diff --git a/absl/CHANGELOG.md b/absl/CHANGELOG.md index b35c612..aff590d 100644 --- a/absl/CHANGELOG.md +++ b/absl/CHANGELOG.md @@ -8,6 +8,23 @@ The format is based on [Keep a Changelog](https://keepachangelog.com). Nothing notable unreleased. +## 0.13.0 (2021-06-14) + +### Added + +* (app) Type annotations for public `app` interfaces. +* (testing) Added new decorator `@absltest.skipThisClass` to indicate a class + contains shared functionality to be used as a base class for other + TestCases, and therefore should be skipped. + +### Changed + +* (app) Annotated the `flag_parser` paramteter of `run` as keyword-only. This + keyword-only constraint will be enforced at runtime in a future release. +* (app, flags) Flag validations now include all errors from disjoint flag + sets, instead of fail fast upon first error from all validators. Multiple + validators on the same flag still fails fast. + ## 0.12.0 (2021-03-08) ### Added @@ -62,8 +79,8 @@ Nothing notable unreleased. * (testing) Failed tests output a copy/pastable test id to make it easier to copy the failing test to the command line. -* (testing) `@parameterized.parameters` now treats a single `abc.Mapping` as - a single test case, consistent with `named_parameters`. Previously the +* (testing) `@parameterized.parameters` now treats a single `abc.Mapping` as a + single test case, consistent with `named_parameters`. Previously the `abc.Mapping` is treated as if only its keys are passed as a list of test cases. If you were relying on the old inconsistent behavior, explicitly convert the `abc.Mapping` to a `list`. diff --git a/absl/app.py b/absl/app.py index 6b5e91d..fbdfccd 100644 --- a/absl/app.py +++ b/absl/app.py @@ -34,6 +34,7 @@ import errno import os import pdb import sys +import textwrap import traceback from absl import command_name @@ -158,7 +159,13 @@ def parse_flags_with_usage(args): try: return FLAGS(args) except flags.Error as error: - sys.stderr.write('FATAL Flags parsing error: %s\n' % error) + message = str(error) + if '\n' in message: + final_message = 'FATAL Flags parsing error:\n%s\n' % textwrap.indent( + message, ' ') + else: + final_message = 'FATAL Flags parsing error: %s\n' % message + sys.stderr.write(final_message) sys.stderr.write('Pass --helpshort or --helpfull to see help on flags.\n') sys.exit(1) @@ -286,6 +293,8 @@ def run( flags_parser: Callable[[List[Text]], Any], the function used to parse flags. The return value of this function is passed to `main` untouched. It must guarantee FLAGS is parsed after this function is called. + Should be passed as a keyword-only arg which will become mandatory in a + future release. - Parses command line flags with the flag module. - If there are any errors, prints usage(). - Calls main() with the remaining arguments. diff --git a/absl/app.pyi b/absl/app.pyi new file mode 100644 index 0000000..fe5e448 --- /dev/null +++ b/absl/app.pyi @@ -0,0 +1,99 @@ + +from typing import Any, Callable, Collection, Iterable, List, NoReturn, Optional, Text, TypeVar, Union, overload + +from absl.flags import _flag + + +_MainArgs = TypeVar('_MainArgs') +_Exc = TypeVar('_Exc', bound=Exception) + + +class ExceptionHandler(): + + def wants(self, exc: _Exc) -> bool: + ... + + def handle(self, exc: _Exc): + ... + + +EXCEPTION_HANDLERS: List[ExceptionHandler] = ... + + +class HelpFlag(_flag.BooleanFlag): + def __init__(self): + ... + + +class HelpshortFlag(HelpFlag): + ... + + +class HelpfullFlag(_flag.BooleanFlag): + def __init__(self): + ... + + +class HelpXMLFlag(_flag.BooleanFlag): + def __init__(self): + ... + + +def define_help_flags() -> None: + ... + + +@overload +def usage(shorthelp: Union[bool, int] = ..., + writeto_stdout: Union[bool, int] = ..., + detailed_error: Optional[Any] = ..., + exitcode: None = ...) -> None: + ... + + +@overload +def usage(shorthelp: Union[bool, int] = ..., + writeto_stdout: Union[bool, int] = ..., + detailed_error: Optional[Any] = ..., + exitcode: int = ...) -> NoReturn: + ... + + +def install_exception_handler(handler: ExceptionHandler) -> None: + ... + + +class Error(Exception): + ... + + +class UsageError(Error): + exitcode: int + + +def parse_flags_with_usage(args: List[Text]) -> List[Text]: + ... + + +def call_after_init(callback: Callable[[], Any]) -> None: + ... + + +# Without the flag_parser argument, `main` should require a List[Text]. +@overload +def run( + main: Callable[[List[Text]], Any], + argv: Optional[List[Text]] = ..., + *, +) -> NoReturn: + ... + + +@overload +def run( + main: Callable[[_MainArgs], Any], + argv: Optional[List[Text]] = ..., + *, + flags_parser: Callable[[List[Text]], _MainArgs], +) -> NoReturn: + ... diff --git a/absl/flags/BUILD b/absl/flags/BUILD index 50bdb00..340ae93 100644 --- a/absl/flags/BUILD +++ b/absl/flags/BUILD @@ -83,6 +83,7 @@ py_library( ":_exceptions", ":_flag", ":_helpers", + ":_validators_classes", "@six_archive//:six", ], ) @@ -103,6 +104,18 @@ py_library( deps = [ ":_exceptions", ":_flagvalues", + ":_validators_classes", + ], +) + +py_library( + name = "_validators_classes", + srcs = [ + "_validators_classes.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":_exceptions", ], ) diff --git a/absl/flags/_argument_parser.pyi b/absl/flags/_argument_parser.pyi index db5f41b..7e78d7d 100644 --- a/absl/flags/_argument_parser.pyi +++ b/absl/flags/_argument_parser.pyi @@ -30,6 +30,9 @@ class ArgumentSerializer(Generic[_T]): # The metaclass of ArgumentParser is not reflected here, because it does not # affect the provided API. class ArgumentParser(Generic[_T]): + + syntactic_help: Text + def parse(self, argument: Text) -> Optional[_T]: ... def flag_type(self) -> Text: ... diff --git a/absl/flags/_flag.pyi b/absl/flags/_flag.pyi index 2e7400b..f28bbf3 100644 --- a/absl/flags/_flag.pyi +++ b/absl/flags/_flag.pyi @@ -32,6 +32,8 @@ class Flag(Generic[_T]): name = ... # type: Text default = ... # type: Any + default_unparsed = ... # type: Any + default_as_str = ... # type: Optional[Text] help = ... # type: Text short_name = ... # type: Text boolean = ... # type: bool diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py index 2a13927..8b9d662 100644 --- a/absl/flags/_flagvalues.py +++ b/absl/flags/_flagvalues.py @@ -31,6 +31,7 @@ from xml.dom import minidom from absl.flags import _exceptions from absl.flags import _flag from absl.flags import _helpers +from absl.flags import _validators_classes import six # pylint: disable=unused-import @@ -544,13 +545,27 @@ class FlagValues(object): IllegalFlagValueError: Raised if validation fails for at least one validator. """ + messages = [] + bad_flags = set() for validator in sorted( validators, key=lambda validator: validator.insertion_index): try: + if isinstance(validator, _validators_classes.SingleFlagValidator): + if validator.flag_name in bad_flags: + continue + elif isinstance(validator, _validators_classes.MultiFlagsValidator): + if bad_flags & set(validator.flag_names): + continue validator.verify(self) except _exceptions.ValidationError as e: + if isinstance(validator, _validators_classes.SingleFlagValidator): + bad_flags.add(validator.flag_name) + elif isinstance(validator, _validators_classes.MultiFlagsValidator): + bad_flags.update(set(validator.flag_names)) message = validator.print_flags_with_values(self) - raise _exceptions.IllegalFlagValueError('%s: %s' % (message, str(e))) + messages.append('%s: %s' % (message, str(e))) + if messages: + raise _exceptions.IllegalFlagValueError('\n'.join(messages)) def __delattr__(self, flag_name): """Deletes a previously-defined flag from a flag object. diff --git a/absl/flags/_validators.py b/absl/flags/_validators.py index fdb6ea2..47cc430 100644 --- a/absl/flags/_validators.py +++ b/absl/flags/_validators.py @@ -40,157 +40,7 @@ import warnings from absl.flags import _exceptions from absl.flags import _flagvalues - - -class Validator(object): - """Base class for flags validators. - - Users should NOT overload these classes, and use flags.Register... - methods instead. - """ - - # Used to assign each validator an unique insertion_index - validators_count = 0 - - def __init__(self, checker, message): - """Constructor to create all validators. - - Args: - checker: function to verify the constraint. - Input of this method varies, see SingleFlagValidator and - multi_flags_validator for a detailed description. - message: str, error message to be shown to the user. - """ - self.checker = checker - self.message = message - Validator.validators_count += 1 - # Used to assert validators in the order they were registered. - self.insertion_index = Validator.validators_count - - def verify(self, flag_values): - """Verifies that constraint is satisfied. - - flags library calls this method to verify Validator's constraint. - - Args: - flag_values: flags.FlagValues, the FlagValues instance to get flags from. - Raises: - Error: Raised if constraint is not satisfied. - """ - param = self._get_input_to_checker_function(flag_values) - if not self.checker(param): - raise _exceptions.ValidationError(self.message) - - def get_flags_names(self): - """Returns the names of the flags checked by this validator. - - Returns: - [string], names of the flags. - """ - raise NotImplementedError('This method should be overloaded') - - def print_flags_with_values(self, flag_values): - raise NotImplementedError('This method should be overloaded') - - def _get_input_to_checker_function(self, flag_values): - """Given flag values, returns the input to be given to checker. - - Args: - flag_values: flags.FlagValues, containing all flags. - Returns: - The input to be given to checker. The return type depends on the specific - validator. - """ - raise NotImplementedError('This method should be overloaded') - - -class SingleFlagValidator(Validator): - """Validator behind register_validator() method. - - Validates that a single flag passes its checker function. The checker function - takes the flag value and returns True (if value looks fine) or, if flag value - is not valid, either returns False or raises an Exception. - """ - - def __init__(self, flag_name, checker, message): - """Constructor. - - Args: - flag_name: string, name of the flag. - checker: function to verify the validator. - input - value of the corresponding flag (string, boolean, etc). - output - bool, True if validator constraint is satisfied. - If constraint is not satisfied, it should either return False or - raise flags.ValidationError(desired_error_message). - message: str, error message to be shown to the user if validator's - condition is not satisfied. - """ - super(SingleFlagValidator, self).__init__(checker, message) - self.flag_name = flag_name - - def get_flags_names(self): - return [self.flag_name] - - def print_flags_with_values(self, flag_values): - return 'flag --%s=%s' % (self.flag_name, flag_values[self.flag_name].value) - - def _get_input_to_checker_function(self, flag_values): - """Given flag values, returns the input to be given to checker. - - Args: - flag_values: flags.FlagValues, the FlagValues instance to get flags from. - Returns: - object, the input to be given to checker. - """ - return flag_values[self.flag_name].value - - -class MultiFlagsValidator(Validator): - """Validator behind register_multi_flags_validator method. - - Validates that flag values pass their common checker function. The checker - function takes flag values and returns True (if values look fine) or, - if values are not valid, either returns False or raises an Exception. - """ - - def __init__(self, flag_names, checker, message): - """Constructor. - - Args: - flag_names: [str], containing names of the flags used by checker. - checker: function to verify the validator. - input - dict, with keys() being flag_names, and value for each - key being the value of the corresponding flag (string, boolean, - etc). - output - bool, True if validator constraint is satisfied. - If constraint is not satisfied, it should either return False or - raise flags.ValidationError(desired_error_message). - message: str, error message to be shown to the user if validator's - condition is not satisfied - """ - super(MultiFlagsValidator, self).__init__(checker, message) - self.flag_names = flag_names - - def _get_input_to_checker_function(self, flag_values): - """Given flag values, returns the input to be given to checker. - - Args: - flag_values: flags.FlagValues, the FlagValues instance to get flags from. - Returns: - dict, with keys() being self.lag_names, and value for each key - being the value of the corresponding flag (string, boolean, etc). - """ - return dict([key, flag_values[key].value] for key in self.flag_names) - - def print_flags_with_values(self, flag_values): - prefix = 'flags ' - flags_with_values = [] - for key in self.flag_names: - flags_with_values.append('%s=%s' % (key, flag_values[key].value)) - return prefix + ', '.join(flags_with_values) - - def get_flags_names(self): - return self.flag_names +from absl.flags import _validators_classes def register_validator(flag_name, @@ -219,7 +69,7 @@ def register_validator(flag_name, AttributeError: Raised when flag_name is not registered as a valid flag name. """ - v = SingleFlagValidator(flag_name, checker, message) + v = _validators_classes.SingleFlagValidator(flag_name, checker, message) _add_validator(flag_values, v) @@ -283,7 +133,7 @@ def register_multi_flags_validator(flag_names, Raises: AttributeError: Raised when a flag is not registered as a valid flag name. """ - v = MultiFlagsValidator( + v = _validators_classes.MultiFlagsValidator( flag_names, multi_flags_checker, message) _add_validator(flag_values, v) diff --git a/absl/flags/_validators_classes.py b/absl/flags/_validators_classes.py new file mode 100644 index 0000000..d8996e0 --- /dev/null +++ b/absl/flags/_validators_classes.py @@ -0,0 +1,176 @@ +# Copyright 2021 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines *private* classes used for flag validators. + +Do NOT import this module. DO NOT use anything from this module. They are +private APIs. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.flags import _exceptions + + +class Validator(object): + """Base class for flags validators. + + Users should NOT overload these classes, and use flags.Register... + methods instead. + """ + + # Used to assign each validator an unique insertion_index + validators_count = 0 + + def __init__(self, checker, message): + """Constructor to create all validators. + + Args: + checker: function to verify the constraint. + Input of this method varies, see SingleFlagValidator and + multi_flags_validator for a detailed description. + message: str, error message to be shown to the user. + """ + self.checker = checker + self.message = message + Validator.validators_count += 1 + # Used to assert validators in the order they were registered. + self.insertion_index = Validator.validators_count + + def verify(self, flag_values): + """Verifies that constraint is satisfied. + + flags library calls this method to verify Validator's constraint. + + Args: + flag_values: flags.FlagValues, the FlagValues instance to get flags from. + Raises: + Error: Raised if constraint is not satisfied. + """ + param = self._get_input_to_checker_function(flag_values) + if not self.checker(param): + raise _exceptions.ValidationError(self.message) + + def get_flags_names(self): + """Returns the names of the flags checked by this validator. + + Returns: + [string], names of the flags. + """ + raise NotImplementedError('This method should be overloaded') + + def print_flags_with_values(self, flag_values): + raise NotImplementedError('This method should be overloaded') + + def _get_input_to_checker_function(self, flag_values): + """Given flag values, returns the input to be given to checker. + + Args: + flag_values: flags.FlagValues, containing all flags. + Returns: + The input to be given to checker. The return type depends on the specific + validator. + """ + raise NotImplementedError('This method should be overloaded') + + +class SingleFlagValidator(Validator): + """Validator behind register_validator() method. + + Validates that a single flag passes its checker function. The checker function + takes the flag value and returns True (if value looks fine) or, if flag value + is not valid, either returns False or raises an Exception. + """ + + def __init__(self, flag_name, checker, message): + """Constructor. + + Args: + flag_name: string, name of the flag. + checker: function to verify the validator. + input - value of the corresponding flag (string, boolean, etc). + output - bool, True if validator constraint is satisfied. + If constraint is not satisfied, it should either return False or + raise flags.ValidationError(desired_error_message). + message: str, error message to be shown to the user if validator's + condition is not satisfied. + """ + super(SingleFlagValidator, self).__init__(checker, message) + self.flag_name = flag_name + + def get_flags_names(self): + return [self.flag_name] + + def print_flags_with_values(self, flag_values): + return 'flag --%s=%s' % (self.flag_name, flag_values[self.flag_name].value) + + def _get_input_to_checker_function(self, flag_values): + """Given flag values, returns the input to be given to checker. + + Args: + flag_values: flags.FlagValues, the FlagValues instance to get flags from. + Returns: + object, the input to be given to checker. + """ + return flag_values[self.flag_name].value + + +class MultiFlagsValidator(Validator): + """Validator behind register_multi_flags_validator method. + + Validates that flag values pass their common checker function. The checker + function takes flag values and returns True (if values look fine) or, + if values are not valid, either returns False or raises an Exception. + """ + + def __init__(self, flag_names, checker, message): + """Constructor. + + Args: + flag_names: [str], containing names of the flags used by checker. + checker: function to verify the validator. + input - dict, with keys() being flag_names, and value for each + key being the value of the corresponding flag (string, boolean, + etc). + output - bool, True if validator constraint is satisfied. + If constraint is not satisfied, it should either return False or + raise flags.ValidationError(desired_error_message). + message: str, error message to be shown to the user if validator's + condition is not satisfied + """ + super(MultiFlagsValidator, self).__init__(checker, message) + self.flag_names = flag_names + + def _get_input_to_checker_function(self, flag_values): + """Given flag values, returns the input to be given to checker. + + Args: + flag_values: flags.FlagValues, the FlagValues instance to get flags from. + Returns: + dict, with keys() being self.lag_names, and value for each key + being the value of the corresponding flag (string, boolean, etc). + """ + return dict([key, flag_values[key].value] for key in self.flag_names) + + def print_flags_with_values(self, flag_values): + prefix = 'flags ' + flags_with_values = [] + for key in self.flag_names: + flags_with_values.append('%s=%s' % (key, flag_values[key].value)) + return prefix + ', '.join(flags_with_values) + + def get_flags_names(self): + return self.flag_names diff --git a/absl/flags/argparse_flags.py b/absl/flags/argparse_flags.py index bcd40f5..ea4cdfa 100644 --- a/absl/flags/argparse_flags.py +++ b/absl/flags/argparse_flags.py @@ -18,7 +18,6 @@ argparse_flags.ArgumentParser is a drop-in replacement for argparse.ArgumentParser. It takes care of collecting and defining absl flags in argparse. - Here is a simple example: # Assume the following absl.flags is defined in another module: diff --git a/absl/flags/tests/_validators_test.py b/absl/flags/tests/_validators_test.py index a5dec45..f724813 100644 --- a/absl/flags/tests/_validators_test.py +++ b/absl/flags/tests/_validators_test.py @@ -54,7 +54,7 @@ class SingleFlagValidatorTest(absltest.TestCase): argv = ('./program',) self.flag_values(argv) - self.assertEqual(None, self.flag_values.test_flag) + self.assertIsNone(self.flag_values.test_flag) self.flag_values.test_flag = 2 self.assertEqual(2, self.flag_values.test_flag) self.assertEqual([None, 2], self.call_args) @@ -110,11 +110,9 @@ class SingleFlagValidatorTest(absltest.TestCase): argv = ('./program', '--test_flag=1') self.flag_values(argv) - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.test_flag = 2 - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flag --test_flag=2: Errors happen', str(e)) + self.assertEqual('flag --test_flag=2: Errors happen', str(cm.exception)) self.assertEqual([1, 2], self.call_args) def test_exception_raised_if_checker_raises_exception(self): @@ -134,11 +132,9 @@ class SingleFlagValidatorTest(absltest.TestCase): argv = ('./program', '--test_flag=1') self.flag_values(argv) - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.test_flag = 2 - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flag --test_flag=2: Specific message', str(e)) + self.assertEqual('flag --test_flag=2: Specific message', str(cm.exception)) self.assertEqual([1, 2], self.call_args) def test_error_message_when_checker_returns_false_on_start(self): @@ -154,11 +150,9 @@ class SingleFlagValidatorTest(absltest.TestCase): flag_values=self.flag_values) argv = ('./program', '--test_flag=1') - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values(argv) - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flag --test_flag=1: Errors happen', str(e)) + self.assertEqual('flag --test_flag=1: Errors happen', str(cm.exception)) self.assertEqual([1], self.call_args) def test_error_message_when_checker_raises_exception_on_start(self): @@ -175,11 +169,9 @@ class SingleFlagValidatorTest(absltest.TestCase): flag_values=self.flag_values) argv = ('./program', '--test_flag=1') - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values(argv) - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flag --test_flag=1: Specific message', str(e)) + self.assertEqual('flag --test_flag=1: Specific message', str(cm.exception)) self.assertEqual([1], self.call_args) def test_validators_checked_in_order(self): @@ -222,7 +214,7 @@ class SingleFlagValidatorTest(absltest.TestCase): argv = ('./program',) self.flag_values(argv) - self.assertEqual(None, self.flag_values.test_flag) + self.assertIsNone(self.flag_values.test_flag) self.flag_values.test_flag = 2 self.assertEqual(2, self.flag_values.test_flag) self.assertEqual([None, 2], self.call_args) @@ -288,11 +280,9 @@ class MultiFlagsValidatorTest(absltest.TestCase): argv = ('./program',) self.flag_values(argv) - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.bar = 1 - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flags foo=1, bar=1: Errors happen', str(e)) + self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception)) self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}], self.call_args) @@ -313,11 +303,9 @@ class MultiFlagsValidatorTest(absltest.TestCase): argv = ('./program',) self.flag_values(argv) - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.bar = 1 - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flags foo=1, bar=1: Specific message', str(e)) + self.assertEqual('flags foo=1, bar=1: Specific message', str(cm.exception)) self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}], self.call_args) @@ -332,11 +320,9 @@ class MultiFlagsValidatorTest(absltest.TestCase): argv = ('./program',) self.flag_values(argv) - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.bar = 1 - raise AssertionError('IllegalFlagValueError expected') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual('flags foo=1, bar=1: Errors happen', str(e)) + self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception)) self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}], self.call_args) @@ -373,8 +359,8 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase): argv = ('./program',) self.flag_values(argv) - self.assertEqual(None, self.flag_values.flag_one) - self.assertEqual(None, self.flag_values.flag_two) + self.assertIsNone(self.flag_values.flag_one) + self.assertIsNone(self.flag_values.flag_two) def test_no_flags_present_required(self): self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True) @@ -452,8 +438,8 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase): ['multi_flag_one', 'multi_flag_two'], False) argv = ('./program',) self.flag_values(argv) - self.assertEqual(None, self.flag_values.multi_flag_one) - self.assertEqual(None, self.flag_values.multi_flag_two) + self.assertIsNone(self.flag_values.multi_flag_one) + self.assertIsNone(self.flag_values.multi_flag_two) def test_no_multistring_flags_present_required(self): self._mark_flags_as_mutually_exclusive( @@ -558,8 +544,8 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase): self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], False) argv = ('./program', '--false_1', '--false_2') expected = ( - 'flags false_1=True, false_2=True: At most one of (false_1, false_2) ' - 'must be True.') + 'flags false_1=True, false_2=True: At most one of (false_1, ' + 'false_2) must be True.') self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError, expected, self.flag_values, argv) @@ -610,11 +596,9 @@ class MarkFlagAsRequiredTest(absltest.TestCase): self.assertEqual('value', self.flag_values.string_flag) expected = ('flag --string_flag=None: Flag --string_flag must have a value ' 'other than None.') - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.string_flag = None - raise AssertionError('Failed to detect non-set required flag.') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual(expected, str(e)) + self.assertEqual(expected, str(cm.exception)) def test_flag_default_not_none_warning(self): _defines.DEFINE_string( @@ -680,11 +664,81 @@ class MarkFlagsAsRequiredTest(absltest.TestCase): expected = ( 'flag --string_flag_1=None: Flag --string_flag_1 must have a value ' 'other than None.') - try: + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: self.flag_values.string_flag_1 = None - raise AssertionError('Failed to detect non-set required flag.') - except _exceptions.IllegalFlagValueError as e: - self.assertEqual(expected, str(e)) + self.assertEqual(expected, str(cm.exception)) + + def test_catch_multiple_flags_as_none_at_program_start(self): + _defines.DEFINE_float( + 'float_flag_1', + None, + 'string flag 1', + flag_values=self.flag_values) + _defines.DEFINE_float( + 'float_flag_2', + None, + 'string flag 2', + flag_values=self.flag_values) + _validators.mark_flags_as_required( + ['float_flag_1', 'float_flag_2'], flag_values=self.flag_values) + argv = ('./program', '') + expected = ( + 'flag --float_flag_1=None: Flag --float_flag_1 must have a value ' + 'other than None.\n' + 'flag --float_flag_2=None: Flag --float_flag_2 must have a value ' + 'other than None.') + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values(argv) + self.assertEqual(expected, str(cm.exception)) + + def test_fail_fast_single_flag_and_skip_remaining_validators(self): + def raise_unexpected_error(x): + del x + raise _exceptions.ValidationError('Should not be raised.') + _defines.DEFINE_float( + 'flag_1', None, 'flag 1', flag_values=self.flag_values) + _defines.DEFINE_float( + 'flag_2', 4.2, 'flag 2', flag_values=self.flag_values) + _validators.mark_flag_as_required('flag_1', flag_values=self.flag_values) + _validators.register_validator( + 'flag_1', raise_unexpected_error, flag_values=self.flag_values) + _validators.register_multi_flags_validator(['flag_2', 'flag_1'], + raise_unexpected_error, + flag_values=self.flag_values) + argv = ('./program', '') + expected = ( + 'flag --flag_1=None: Flag --flag_1 must have a value other than None.') + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values(argv) + self.assertEqual(expected, str(cm.exception)) + + def test_fail_fast_multi_flag_and_skip_remaining_validators(self): + def raise_expected_error(x): + del x + raise _exceptions.ValidationError('Expected error.') + def raise_unexpected_error(x): + del x + raise _exceptions.ValidationError('Got unexpected error.') + _defines.DEFINE_float( + 'flag_1', 5.1, 'flag 1', flag_values=self.flag_values) + _defines.DEFINE_float( + 'flag_2', 10.0, 'flag 2', flag_values=self.flag_values) + _validators.register_multi_flags_validator(['flag_1', 'flag_2'], + raise_expected_error, + flag_values=self.flag_values) + _validators.register_multi_flags_validator(['flag_2', 'flag_1'], + raise_unexpected_error, + flag_values=self.flag_values) + _validators.register_validator( + 'flag_1', raise_unexpected_error, flag_values=self.flag_values) + _validators.register_validator( + 'flag_2', raise_unexpected_error, flag_values=self.flag_values) + argv = ('./program', '') + expected = ('flags flag_1=5.1, flag_2=10.0: Expected error.') + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values(argv) + self.assertEqual(expected, str(cm.exception)) + if __name__ == '__main__': absltest.main() diff --git a/absl/logging/__init__.py b/absl/logging/__init__.py index 0b1814d..0165923 100644 --- a/absl/logging/__init__.py +++ b/absl/logging/__init__.py @@ -100,6 +100,11 @@ if six.PY2: else: import threading as _thread_lib # For .get_ident(). +try: + from typing import NoReturn +except ImportError: + pass + FLAGS = flags.FLAGS @@ -379,6 +384,7 @@ def set_stderrthreshold(s): def fatal(msg, *args, **kwargs): + # type: (Any, Any, Any) -> NoReturn """Logs a fatal message.""" log(FATAL, msg, *args, **kwargs) diff --git a/absl/testing/absltest.py b/absl/testing/absltest.py index fbce512..0708875 100644 --- a/absl/testing/absltest.py +++ b/absl/testing/absltest.py @@ -560,6 +560,39 @@ class _TempFile(object): yield fp +class _method(object): + """A decorator that supports both instance and classmethod invocations. + + Using similar semantics to the @property builtin, this decorator can augment + an instance method to support conditional logic when invoked on a class + object. This breaks support for invoking an instance method via the class + (e.g. Cls.method(self, ...)) but is still situationally useful. + """ + + def __init__(self, finstancemethod): + # type: (Callable[..., Any]) -> None + self._finstancemethod = finstancemethod + self._fclassmethod = None + + def classmethod(self, fclassmethod): + # type: (Callable[..., Any]) -> _method + self._fclassmethod = classmethod(fclassmethod) + return self + + def __doc__(self): + # type: () -> str + if getattr(self._finstancemethod, '__doc__'): + return self._finstancemethod.__doc__ + elif getattr(self._fclassmethod, '__doc__'): + return self._fclassmethod.__doc__ + return '' + + def __get__(self, obj, type_): + # type: (Optional[Any], Optional[Type[Any]]) -> Callable[..., Any] + func = self._fclassmethod if obj is None else self._finstancemethod + return func.__get__(obj, type_) # pytype: disable=attribute-error + + class TestCase(unittest3_backport.TestCase): """Extension of unittest.TestCase providing more power.""" @@ -576,20 +609,31 @@ class TestCase(unittest3_backport.TestCase): maxDiff = 80 * 20 longMessage = True + # Exit stacks for per-test and per-class scopes. + _exit_stack = None + _cls_exit_stack = None + def __init__(self, *args, **kwargs): super(TestCase, self).__init__(*args, **kwargs) # This is to work around missing type stubs in unittest.pyi self._outcome = getattr(self, '_outcome') # type: Optional[_OutcomeType] - # This is re-initialized by setUp(). - self._exit_stack = None def setUp(self): super(TestCase, self).setUp() - # NOTE: Only Py3 contextlib has ExitStack + # NOTE: Only Python 3 contextlib has ExitStack if hasattr(contextlib, 'ExitStack'): self._exit_stack = contextlib.ExitStack() self.addCleanup(self._exit_stack.close) + @classmethod + def setUpClass(cls): + super(TestCase, cls).setUpClass() + # NOTE: Only Python 3 contextlib has ExitStack and only Python 3.8+ has + # addClassCleanup. + if hasattr(contextlib, 'ExitStack') and hasattr(cls, 'addClassCleanup'): + cls._cls_exit_stack = contextlib.ExitStack() + cls.addClassCleanup(cls._cls_exit_stack.close) + def create_tempdir(self, name=None, cleanup=None): # type: (Optional[Text], Optional[TempFileCleanup]) -> _TempDir """Create a temporary directory specific to the test. @@ -700,14 +744,19 @@ class TestCase(unittest3_backport.TestCase): self._maybe_add_temp_path_cleanup(cleanup_path, cleanup) return tf + @_method def enter_context(self, manager): # type: (ContextManager[_T]) -> _T """Returns the CM's value after registering it with the exit stack. - Entering a context pushes it onto a stack of contexts. The context is exited - when the test completes. Contexts are are exited in the reverse order of - entering. They will always be exited, regardless of test failure/success. - The context stack is specific to the test being run. + Entering a context pushes it onto a stack of contexts. When `enter_context` + is called on the test instance (e.g. `self.enter_context`), the context is + exited after the test case's tearDown call. When called on the test class + (e.g. `TestCase.enter_context`), the context is exited after the test + class's tearDownClass call. + + Contexts are are exited in the reverse order of entering. They will always + be exited, regardless of test failure/success. This is useful to eliminate per-test boilerplate when context managers are used. For example, instead of decorating every test with `@mock.patch`, @@ -726,6 +775,15 @@ class TestCase(unittest3_backport.TestCase): 'sure that AbslTest.setUp() is called.') return self._exit_stack.enter_context(manager) + @enter_context.classmethod + def enter_context(cls, manager): # pylint: disable=no-self-argument + # type: (ContextManager[_T]) -> _T + if not cls._cls_exit_stack: + raise AssertionError( + 'cls._cls_exit_stack is not set: cls.enter_context requires ' + 'Python 3.8+; also make sure that AbslTest.setUpClass() is called.') + return cls._cls_exit_stack.enter_context(manager) + @classmethod def _get_tempdir_path_cls(cls): # type: () -> Text @@ -2142,6 +2200,81 @@ def _is_suspicious_attribute(testCaseClass, name): return False +def skipThisClass(reason): + # type: (Text) -> Callable[[_T], _T] + """Skip tests in the decorated TestCase, but not any of its subclasses. + + This decorator indicates that this class should skip all its tests, but not + any of its subclasses. Useful for if you want to share testMethod or setUp + implementations between a number of concrete testcase classes. + + Example usage, showing how you can share some common test methods between + subclasses. In this example, only 'BaseTest' will be marked as skipped, and + not RealTest or SecondRealTest: + + @absltest.skipThisClass("Shared functionality") + class BaseTest(absltest.TestCase): + def test_simple_functionality(self): + self.assertEqual(self.system_under_test.method(), 1) + + class RealTest(BaseTest): + def setUp(self): + super().setUp() + self.system_under_test = MakeSystem(argument) + + def test_specific_behavior(self): + ... + + class SecondRealTest(BaseTest): + def setUp(self): + super().setUp() + self.system_under_test = MakeSystem(other_arguments) + + def test_other_behavior(self): + ... + + Args: + reason: The reason we have a skip in place. For instance: 'shared test + methods' or 'shared assertion methods'. + + Returns: + Decorator function that will cause a class to be skipped. + """ + if isinstance(reason, type): + raise TypeError('Got {!r}, expected reason as string'.format(reason)) + + def _skip_class(test_case_class): + if not issubclass(test_case_class, unittest.TestCase): + raise TypeError( + 'Decorating {!r}, expected TestCase subclass'.format(test_case_class)) + + # Only shadow the setUpClass method if it is directly defined. If it is + # in the parent class we invoke it via a super() call instead of holding + # a reference to it. + shadowed_setupclass = test_case_class.__dict__.get('setUpClass', None) + + @classmethod + def replacement_setupclass(cls, *args, **kwargs): + # Skip this class if it is the one that was decorated with @skipThisClass + if cls is test_case_class: + raise SkipTest(reason) + if shadowed_setupclass: + # Pass along `cls` so the MRO chain doesn't break. + # The original method is a `classmethod` descriptor, which can't + # be directly called, but `__func__` has the underlying function. + return shadowed_setupclass.__func__(cls, *args, **kwargs) + else: + # Because there's no setUpClass() defined directly on test_case_class, + # we call super() ourselves to continue execution of the inheritance + # chain. + return super(test_case_class, cls).setUpClass(*args, **kwargs) + + test_case_class.setUpClass = replacement_setupclass + return test_case_class + + return _skip_class + + class TestLoader(unittest.TestLoader): """A test loader which supports common test features. diff --git a/absl/testing/tests/absltest_test.py b/absl/testing/tests/absltest_test.py index f7abd39..a54ff55 100644 --- a/absl/testing/tests/absltest_test.py +++ b/absl/testing/tests/absltest_test.py @@ -28,6 +28,7 @@ import string import subprocess import sys import tempfile +import unittest from absl.testing import _bazelize_command from absl.testing import absltest @@ -1486,6 +1487,15 @@ class GetCommandStderrTestCase(absltest.TestCase): self.assertRegex(stderr, 'No such file or directory') +@contextlib.contextmanager +def cm_for_test(obj): + try: + obj.cm_state = 'yielded' + yield 'value' + finally: + obj.cm_state = 'exited' + + @absltest.skipIf(six.PY2, 'Python 2 does not have ExitStack') class EnterContextTest(absltest.TestCase): @@ -1502,15 +1512,33 @@ class EnterContextTest(absltest.TestCase): self.addCleanup(assert_cm_exited) super(EnterContextTest, self).setUp() - self.cm_value = self.enter_context(self.cm_for_test()) + self.cm_value = self.enter_context(cm_for_test(self)) - @contextlib.contextmanager - def cm_for_test(self): - try: - self.cm_state = 'yielded' - yield 'value' - finally: - self.cm_state = 'exited' + def test_enter_context(self): + self.assertEqual(self.cm_value, 'value') + self.assertEqual(self.cm_state, 'yielded') + + +@absltest.skipIf(not hasattr(absltest.TestCase, 'addClassCleanup'), + 'Python 3.8 required for class-level enter_context') +class EnterContextClassmethodTest(absltest.TestCase): + + cm_state = 'unset' + cm_value = 'unset' + + @classmethod + def setUpClass(cls): + + def assert_cm_exited(): + assert cls.cm_state == 'exited' + + # Because cleanup functions are run in reverse order, we have to add + # our assert-cleanup before the exit stack registers its own cleanup. + # This ensures we see state after the stack cleanup runs. + cls.addClassCleanup(assert_cm_exited) + + super(EnterContextClassmethodTest, cls).setUpClass() + cls.cm_value = cls.enter_context(cm_for_test(cls)) def test_enter_context(self): self.assertEqual(self.cm_value, 'value') @@ -2191,6 +2219,201 @@ class TempFileTest(absltest.TestCase, HelperMixin): self.run_tempfile_helper('OFF', expected) +class SkipClassTest(absltest.TestCase): + + def test_incorrect_decorator_call(self): + with self.assertRaises(TypeError): + + @absltest.skipThisClass # pylint: disable=unused-variable + class Test(absltest.TestCase): + pass + + def test_incorrect_decorator_subclass(self): + with self.assertRaises(TypeError): + + @absltest.skipThisClass('reason') + def test_method(): # pylint: disable=unused-variable + pass + + def test_correct_decorator_class(self): + + @absltest.skipThisClass('reason') + class Test(absltest.TestCase): + pass + + with self.assertRaises(absltest.SkipTest): + Test.setUpClass() + + def test_correct_decorator_subclass(self): + + @absltest.skipThisClass('reason') + class Test(absltest.TestCase): + pass + + class Subclass(Test): + pass + + with self.subTest('Base class should be skipped'): + with self.assertRaises(absltest.SkipTest): + Test.setUpClass() + + with self.subTest('Subclass should not be skipped'): + Subclass.setUpClass() # should not raise. + + def test_setup(self): + + @absltest.skipThisClass('reason') + class Test(absltest.TestCase): + + @classmethod + def setUpClass(cls): + super(Test, cls).setUpClass() + cls.foo = 1 + + class Subclass(Test): + pass + + Subclass.setUpClass() + self.assertEqual(Subclass.foo, 1) + + def test_setup_chain(self): + + @absltest.skipThisClass('reason') + class BaseTest(absltest.TestCase): + + @classmethod + def setUpClass(cls): + super(BaseTest, cls).setUpClass() + cls.foo = 1 + + @absltest.skipThisClass('reason') + class SecondBaseTest(BaseTest): + + @classmethod + def setUpClass(cls): + super(SecondBaseTest, cls).setUpClass() + cls.bar = 2 + + class Subclass(SecondBaseTest): + pass + + Subclass.setUpClass() + self.assertEqual(Subclass.foo, 1) + self.assertEqual(Subclass.bar, 2) + + def test_setup_args(self): + + @absltest.skipThisClass('reason') + class Test(absltest.TestCase): + + @classmethod + def setUpClass(cls, foo, bar=None): + super(Test, cls).setUpClass() + cls.foo = foo + cls.bar = bar + + class Subclass(Test): + + @classmethod + def setUpClass(cls): + super(Subclass, cls).setUpClass('foo', bar='baz') + + Subclass.setUpClass() + self.assertEqual(Subclass.foo, 'foo') + self.assertEqual(Subclass.bar, 'baz') + + def test_setup_multiple_inheritance(self): + + # Test that skipping this class doesn't break the MRO chain and stop + # RequiredBase.setUpClass from running. + @absltest.skipThisClass('reason') + class Left(absltest.TestCase): + pass + + class RequiredBase(absltest.TestCase): + + @classmethod + def setUpClass(cls): + super(RequiredBase, cls).setUpClass() + cls.foo = 'foo' + + class Right(RequiredBase): + + @classmethod + def setUpClass(cls): + super(Right, cls).setUpClass() + + # Test will fail unless Left.setUpClass() follows mro properly + # Right.setUpClass() + class Subclass(Left, Right): + + @classmethod + def setUpClass(cls): + super(Subclass, cls).setUpClass() + + class Test(Subclass): + pass + + Test.setUpClass() + self.assertEqual(Test.foo, 'foo') + + def test_skip_class(self): + + @absltest.skipThisClass('reason') + class BaseTest(absltest.TestCase): + + def test_foo(self): + _ = 1 / 0 + + class Test(BaseTest): + + def test_foo(self): + self.assertEqual(1, 1) + + with self.subTest('base class'): + ts = unittest.makeSuite(BaseTest) + self.assertEqual(1, ts.countTestCases()) + + res = unittest.TestResult() + ts.run(res) + self.assertTrue(res.wasSuccessful()) + self.assertLen(res.skipped, 1) + self.assertEqual(0, res.testsRun) + self.assertEmpty(res.failures) + self.assertEmpty(res.errors) + + with self.subTest('real test'): + ts = unittest.makeSuite(Test) + self.assertEqual(1, ts.countTestCases()) + + res = unittest.TestResult() + ts.run(res) + self.assertTrue(res.wasSuccessful()) + self.assertEqual(1, res.testsRun) + self.assertEmpty(res.skipped) + self.assertEmpty(res.failures) + self.assertEmpty(res.errors) + + def test_skip_class_unittest(self): + + @absltest.skipThisClass('reason') + class Test(unittest.TestCase): # note: unittest not absltest + + def test_foo(self): + _ = 1 / 0 + + ts = unittest.makeSuite(Test) + self.assertEqual(1, ts.countTestCases()) + + res = unittest.TestResult() + ts.run(res) + self.assertTrue(res.wasSuccessful()) + self.assertLen(res.skipped, 1) + self.assertEqual(0, res.testsRun) + self.assertEmpty(res.failures) + self.assertEmpty(res.errors) + + def _listdir_recursive(path): for dirname, _, filenames in os.walk(path): yield dirname |