aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYifan Hong <elsk@google.com>2022-06-16 17:11:11 -0700
committerYifan Hong <elsk@google.com>2022-06-16 17:13:28 -0700
commit41ad18dfefd9f807e798ffee08323c84eaa8a9e7 (patch)
treefaf738843747b43d4bc5f9b46ec8ea33436a62ad
parentb6bca275ddb60f54428b9f462c8eec6c714f5543 (diff)
parent58ead8c22230a2493006fa0ab9f76776b6e7280f (diff)
downloadabsl-py-41ad18dfefd9f807e798ffee08323c84eaa8a9e7.tar.gz
Merge tag 'v1.1.0' into master
Test: none Bug: 227705464 Bug: 236118730 Change-Id: I28557fa16390caefaac5d21bbc5c0881ad511797 Signed-off-by: Yifan Hong <elsk@google.com>
-rw-r--r--AUTHORS7
-rw-r--r--BUILD.bazel21
-rw-r--r--CHANGELOG.md302
-rw-r--r--CONTRIBUTING.md69
-rw-r--r--LICENSE202
-rw-r--r--MANIFEST.in1
-rw-r--r--README.md60
-rw-r--r--WORKSPACE14
-rw-r--r--absl/BUILD84
-rw-r--r--absl/__init__.py13
-rw-r--r--absl/app.py484
-rw-r--r--absl/app.pyi99
-rw-r--r--absl/command_name.py67
-rw-r--r--absl/flags/BUILD314
-rw-r--r--absl/flags/__init__.py145
-rw-r--r--absl/flags/__init__.pyi103
-rw-r--r--absl/flags/_argument_parser.py629
-rw-r--r--absl/flags/_argument_parser.pyi127
-rw-r--r--absl/flags/_defines.py912
-rw-r--r--absl/flags/_defines.pyi637
-rw-r--r--absl/flags/_exceptions.py112
-rw-r--r--absl/flags/_flag.py485
-rw-r--r--absl/flags/_flag.pyi133
-rw-r--r--absl/flags/_flagvalues.py1387
-rw-r--r--absl/flags/_flagvalues.pyi148
-rw-r--r--absl/flags/_helpers.py433
-rw-r--r--absl/flags/_validators.py313
-rw-r--r--absl/flags/_validators_classes.py176
-rw-r--r--absl/flags/argparse_flags.py390
-rw-r--r--absl/flags/tests/__init__.py13
-rw-r--r--absl/flags/tests/_argument_parser_test.py214
-rw-r--r--absl/flags/tests/_flag_test.py240
-rw-r--r--absl/flags/tests/_flagvalues_test.py929
-rw-r--r--absl/flags/tests/_helpers_test.py173
-rw-r--r--absl/flags/tests/_validators_test.py744
-rw-r--r--absl/flags/tests/argparse_flags_test.py447
-rw-r--r--absl/flags/tests/argparse_flags_test_helper.py89
-rw-r--r--absl/flags/tests/flags_formatting_test.py217
-rw-r--r--absl/flags/tests/flags_helpxml_test.py659
-rw-r--r--absl/flags/tests/flags_numeric_bounds_test.py105
-rw-r--r--absl/flags/tests/flags_test.py2922
-rw-r--r--absl/flags/tests/flags_unicode_literals_test.py42
-rw-r--r--absl/flags/tests/module_bar.py121
-rw-r--r--absl/flags/tests/module_baz.py29
-rw-r--r--absl/flags/tests/module_foo.py128
-rw-r--r--absl/logging/BUILD100
-rw-r--r--absl/logging/__init__.py1234
-rw-r--r--absl/logging/converter.py211
-rw-r--r--absl/logging/tests/__init__.py13
-rw-r--r--absl/logging/tests/converter_test.py135
-rw-r--r--absl/logging/tests/log_before_import_test.py127
-rw-r--r--absl/logging/tests/logging_functional_test.py732
-rw-r--r--absl/logging/tests/logging_functional_test_helper.py312
-rw-r--r--absl/logging/tests/logging_test.py1002
-rw-r--r--absl/logging/tests/verbosity_flag_test.py56
-rw-r--r--absl/testing/BUILD254
-rw-r--r--absl/testing/__init__.py13
-rw-r--r--absl/testing/_bazelize_command.py72
-rw-r--r--absl/testing/_pretty_print_reporter.py95
-rw-r--r--absl/testing/absltest.py2554
-rw-r--r--absl/testing/flagsaver.py198
-rw-r--r--absl/testing/parameterized.py700
-rw-r--r--absl/testing/tests/__init__.py13
-rw-r--r--absl/testing/tests/absltest_env.py30
-rw-r--r--absl/testing/tests/absltest_fail_fast_test.py109
-rw-r--r--absl/testing/tests/absltest_fail_fast_test_helper.py56
-rw-r--r--absl/testing/tests/absltest_filtering_test.py192
-rw-r--r--absl/testing/tests/absltest_filtering_test_helper.py85
-rw-r--r--absl/testing/tests/absltest_py3_test.py44
-rw-r--r--absl/testing/tests/absltest_randomization_test.py154
-rw-r--r--absl/testing/tests/absltest_randomization_testcase.py47
-rw-r--r--absl/testing/tests/absltest_sharding_test.py165
-rw-r--r--absl/testing/tests/absltest_sharding_test_helper.py60
-rw-r--r--absl/testing/tests/absltest_test.py2374
-rw-r--r--absl/testing/tests/absltest_test_helper.py106
-rw-r--r--absl/testing/tests/flagsaver_test.py467
-rw-r--r--absl/testing/tests/parameterized_test.py1077
-rw-r--r--absl/testing/tests/xml_reporter_helper_test.py97
-rw-r--r--absl/testing/tests/xml_reporter_test.py1108
-rw-r--r--absl/testing/xml_reporter.py562
-rw-r--r--absl/tests/__init__.py13
-rw-r--r--absl/tests/app_test.py359
-rw-r--r--absl/tests/app_test_helper.py151
-rw-r--r--absl/tests/command_name_test.py108
-rw-r--r--absl/tests/python_version_test.py40
-rw-r--r--setup.py77
-rw-r--r--smoke_tests/sample_app.py41
-rw-r--r--smoke_tests/sample_test.py33
-rwxr-xr-xsmoke_tests/smoke_test.sh70
89 files changed, 30385 insertions, 0 deletions
diff --git a/AUTHORS b/AUTHORS
new file mode 100644
index 0000000..23b11ad
--- /dev/null
+++ b/AUTHORS
@@ -0,0 +1,7 @@
+# This is the list of Abseil authors for copyright purposes.
+#
+# This does not necessarily list everyone who has contributed code, since in
+# some cases, their employer may be the copyright holder. To see the full list
+# of contributors, see the revision history in source control.
+
+Google Inc.
diff --git a/BUILD.bazel b/BUILD.bazel
new file mode 100644
index 0000000..d72b220
--- /dev/null
+++ b/BUILD.bazel
@@ -0,0 +1,21 @@
+# Copyright 2021 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"])
+
+exports_files([
+ "LICENSE",
+])
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000..3c36751
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,302 @@
+# Python Absl Changelog
+
+All notable changes to Python Absl are recorded here.
+
+The format is based on [Keep a Changelog](https://keepachangelog.com).
+
+## Unreleased
+
+Nothing notable unreleased.
+
+## 1.1.0 (2022-06-01)
+
+* `Flag` instances now raise an error if used in a bool context. This prevents
+ the occasional mistake of testing an instance for truthiness rather than
+ testing `flag.value`.
+* `absl-py` no longer depends on `six`.
+
+## 1.0.0 (2021-11-09)
+
+### Changed
+
+* `absl-py` no longer supports Python 2.7, 3.4, 3.5. All versions have reached
+ end-of-life for more than a year now.
+* New releases will be tagged as `vX.Y.Z` instead of `pypi-vX.Y.Z` in the git
+ repo going forward.
+
+## 0.15.0 (2021-10-19)
+
+### Changed
+
+* (testing) #128: When running bazel with its `--test_filter=` flag, it now
+ treats the filters as `unittest`'s `-k` flag in Python 3.7+.
+
+## 0.14.1 (2021-09-30)
+
+### Fixed
+
+* Top-level `LICENSE` file is now exported in bazel.
+
+## 0.14.0 (2021-09-21)
+
+### Fixed
+
+* #171: Creating `argparse_flags.ArgumentParser` with `argument_default=` no
+ longer raises an exception when other `absl.flags` flags are defined.
+* #173: `absltest` now correctly sets up test filtering and fail fast flags
+ when an explicit `argv=` parameter is passed to `absltest.main`.
+
+## 0.13.0 (2021-06-14)
+
+### Added
+
+* (app) Type annotations for public `app` interfaces.
+* (testing) Added new decorator `@absltest.skipThisClass` to indicate a class
+ contains shared functionality to be used as a base class for other
+ TestCases, and therefore should be skipped.
+
+### Changed
+
+* (app) Annotated the `flag_parser` paramteter of `run` as keyword-only. This
+ keyword-only constraint will be enforced at runtime in a future release.
+* (app, flags) Flag validations now include all errors from disjoint flag
+ sets, instead of fail fast upon first error from all validators. Multiple
+ validators on the same flag still fails fast.
+
+## 0.12.0 (2021-03-08)
+
+### Added
+
+* (flags) Made `EnumClassSerializer` and `EnumClassListSerializer` public.
+* (flags) Added a `required: Optional[bool] = False` parameter to `DEFINE_*`
+ functions.
+* (testing) flagsaver overrides can now be specified in terms of FlagHolder.
+* (testing) `parameterized.product`: Allows testing a method over cartesian
+ product of parameters values, specified as a sequences of values for each
+ parameter or as kwargs-like dicts of parameter values.
+* (testing) Added public flag holders for `--test_srcdir` and `--test_tmpdir`.
+ Users should use `absltest.TEST_SRCDIR.value` and
+ `absltest.TEST_TMPDIR.value` instead of `FLAGS.test_srcdir` and
+ `FLAGS.test_tmpdir`.
+
+### Fixed
+
+* (flags) Made `CsvListSerializer` respect its delimiter argument.
+
+## 0.11.0 (2020-10-27)
+
+### Changed
+
+* (testing) Surplus entries in AssertionError stack traces from absltest are
+ now suppressed and no longer reported in the xml_reporter.
+* (logging) An exception is now raised instead of `logging.fatal` when logging
+ directories cannot be found.
+* (testing) Multiple flags are now set together before their validators run.
+ This resolves an issue where multi-flag validators rely on specific flag
+ combinations.
+* (flags) As a deterrent for misuse, FlagHolder objects will now raise a
+ TypeError exception when used in a conditional statement or equality
+ expression.
+
+## 0.10.0 (2020-08-19)
+
+### Added
+
+* (testing) `_TempDir` and `_TempFile` now implement `__fspath__` to satisfy
+ `os.PathLike`
+* (logging) `--logger_levels`: allows specifying the log levels of loggers.
+* (flags) `FLAGS.validate_all_flags`: a new method that validates all flags
+ and raises an exception if one fails.
+* (flags) `FLAGS.get_flags_for_module`: Allows fetching the flags a module
+ defines.
+* (testing) `parameterized.TestCase`: Supports async test definitions.
+* (testing,app) Added `--pdb` flag: When true, uncaught exceptions will be
+ handled by `pdb.post_mortem`. This is an alias for `--pdb_post_mortem`.
+
+### Changed
+
+* (testing) Failed tests output a copy/pastable test id to make it easier to
+ copy the failing test to the command line.
+* (testing) `@parameterized.parameters` now treats a single `abc.Mapping` as a
+ single test case, consistent with `named_parameters`. Previously the
+ `abc.Mapping` is treated as if only its keys are passed as a list of test
+ cases. If you were relying on the old inconsistent behavior, explicitly
+ convert the `abc.Mapping` to a `list`.
+* (flags) `DEFINE_enum_class` and `DEFINE_mutlti_enum_class` accept a
+ `case_sensitive` argument. When `False` (the default), strings are mapped to
+ enum member names without case sensitivity, and member names are serialized
+ in lowercase form. Flag definitions for enums whose members include
+ duplicates when case is ignored must now explicitly pass
+ `case_sensitive=True`.
+
+### Fixed
+
+* (flags) Defining an alias no longer marks the aliased flag as always present
+ on the command line.
+* (flags) Aliasing a multi flag no longer causes the default value to be
+ appended to.
+* (flags) Alias default values now matched the aliased default value.
+* (flags) Alias `present` counter now correctly reflects command line usage.
+
+## 0.9.0 (2019-12-17)
+
+### Added
+
+* (testing) `TestCase.enter_context`: Allows using context managers in setUp
+ and having them automatically exited when a test finishes.
+
+### Fixed
+
+* #126: calling `logging.debug(msg, stack_info=...)` no longer throws an
+ exception in Python 3.8.
+
+## 0.8.1 (2019-10-08)
+
+### Fixed
+
+* (testing) `absl.testing`'s pretty print reporter no longer buffers
+ RUN/OK/FAILED messages.
+* (testing) `create_tempfile` will overwrite pre-existing read-only files.
+
+## 0.8.0 (2019-08-26)
+
+### Added
+
+* (testing) `absltest.expectedFailureIf`: a variant of
+ `unittest.expectedFailure` that allows a condition to be given.
+
+### Changed
+
+* (bazel) Tests now pass when bazel
+ `--incompatible_allow_python_version_transitions=true` is set.
+* (bazel) Both Python 2 and Python 3 versions of tests are now created. To
+ only run one major Python version, use `bazel test
+ --test_tag_filters=-python[23]` to ignore the other version.
+* (testing) `assertTotallyOrdered` no longer requires objects to implement
+ `__hash__`.
+* (testing) `absltest` now integrates better with `--pdb_post_mortem`.
+* (testing) `xml_reporter` now includes timestamps to testcases, test_suite,
+ test_suites elements.
+
+### Fixed
+
+* #99: `absl.logging` no longer registers itself to `logging.root` at import
+ time.
+* #108: Tests now pass with Bazel 0.28.0 on macOS.
+
+## 0.7.1 (2019-03-12)
+
+### Added
+
+* (flags) `flags.mark_bool_flags_as_mutual_exclusive`: convenience function to
+ check that only one, or at most one, flag among a set of boolean flags are
+ True.
+
+### Changed
+
+* (bazel) Bazel 0.23+ or 0.22+ is now required for building/testing.
+ Specifically, a Bazel version that supports
+ `@bazel_tools//tools/python:python_version` for selecting the Python
+ version.
+
+### Fixed
+
+* #94: LICENSE files are now included in sdist.
+* #93: Change log added.
+
+## 0.7.0 (2019-01-11)
+
+### Added
+
+* (bazel) testonly=1 has been removed from the testing libraries, which allows
+ their use outside of testing contexts.
+* (flags) Multi-flags now accept any Iterable type for the default value
+ instead of only lists. Strings are still special cased as before. This
+ allows sets, generators, views, etc to be used naturally.
+* (flags) DEFINE_multi_enum_class: a multi flag variant of enum_class.
+* (testing) Most of absltest is now type-annotated.
+* (testing) Made AbslTest.assertRegex available under Python 2. This allows
+ Python 2 code to write more natural Python 3 compatible code. (Note: this
+ was actually released in 0.6.1, but unannounced)
+* (logging) logging.vlog_is_on: helper to tell if a vlog() call will actually
+ log anything. This allows avoiding computing expansive inputs to a logging
+ call when logging isn't enabled for that level.
+
+### Fixed
+
+* (flags) Pickling flags now raises an clear error instead of a cryptic one.
+ Pickling flags isn't supported; instead use flags_into_string to serialize
+ flags.
+* (flags) Flags serialization works better: the resulting serialized value,
+ when deserialized, won't cause --help to be invoked, thus ending the
+ process.
+* (flags) Several flag fixes to make them behave more like the Absl C++ flags:
+ empty --flagfile is allowed; --nohelp and --help=false don't display help
+* (flags) An empty --flagfile value (e.g. "--flagfile=" or "--flagfile=''"
+ doesn't raise an error; its not just ignored. This matches Abseil C++
+ behavior.
+* (bazel) Building with Bazel 0.2.0 works without extra incompatibility disable
+ build flags.
+
+### Changed
+
+* (flags) Flag serialization is now deterministic: this improves Bazel build
+ caching for tools that are affected by flag serialization.
+
+## 0.6.0 (2018-10-22)
+
+### Added
+
+* Tempfile management APIs for tests: read/write/manage tempfiles for test
+ purposes easily and correctly. See TestCase.create_temp{file/dir} and the
+ corresponding commit for more info.
+
+## 0.5.0 (2018-09-17)
+
+### Added
+
+* Flags enum support: flags.DEFINE_enum_class allows using an `Enum` derived
+ class to define the allowed values for a flag.
+
+## 0.4.1 (2018-08-28)
+
+### Fixed
+
+* Flags no long allow spaces in their names
+
+### Changed
+
+* XML test output is written at the end of all test execution.
+* If the current user's username can't be gotten, fallback to uid, else fall
+ back to a generic 'unknown' string.
+
+## 0.4.0 (2018-08-14)
+
+### Added
+
+* argparse integration: absl-registered flags can now be accessed via argparse
+ using absl.flags.argparse_flags: see that module for more information.
+* TestCase.assertSameStructure now allows mixed set types.
+
+### Changed
+
+* Test output now includes start/end markers for each test ran. This is to
+ help distinguish output from tests clearly.
+
+## 0.3.0 (2018-07-25)
+
+### Added
+
+* `app.call_after_init`: Register functions to be called after app.run() is
+ called. Useful for program-wide initialization that library code may need.
+* `logging.log_every_n_seconds`: like log_every_n, but based on elapsed time
+ between logging calls.
+* `absltest.mock`: alias to unittest.mock (PY3) for better unittest drop-in
+ replacement. For PY2, it will be available if mock is importable.
+
+### Fixed
+
+* `ABSLLogger.findCaller()`: allow stack_info arg and return value for PY2
+* Make stopTest locking reentrant: this prevents deadlocks for test frameworks
+ that customize unittest.TextTestResult.stopTest.
+* Make --helpfull work with unicode flag help strings.
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..5134aff
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,69 @@
+# How to Contribute
+
+We'd love to accept your patches and contributions to this project. There are
+just a few small guidelines you need to follow.
+
+NOTE: If you are new to GitHub, please start by reading the [Pull Request
+howto](https://help.github.com/articles/about-pull-requests/).
+
+## Contributor License Agreement
+
+Contributions to this project must be accompanied by a Contributor License
+Agreement. You (or your employer) retain the copyright to your contribution,
+this simply gives us permission to use and redistribute your contributions as
+part of the project. Head over to <https://cla.developers.google.com/> to see
+your current agreements on file or to sign a new one.
+
+You generally only need to submit a CLA once, so if you've already submitted one
+(even if it was for a different project), you probably don't need to do it
+again.
+
+## Coding Style
+
+To keep the source consistent, readable, diffable and easy to merge, we use a
+fairly rigid coding style, as defined by the
+[google-styleguide](https://github.com/google/styleguide) project. All patches
+will be expected to conform to the Python style outlined
+[here](https://google.github.io/styleguide/pyguide.html).
+
+## Guidelines for Pull Requests
+
+* Create **small PRs** that are narrowly focused on **addressing a single
+ concern**. We often receive PRs that are trying to fix several things at a
+ time, but if only one fix is considered acceptable, nothing gets merged and
+ both author's & review's time is wasted. Create more PRs to address
+ different concerns and everyone will be happy.
+
+* For speculative changes, consider opening an
+ [issue](https://github.com/abseil/abseil-py/issues) and discussing it first.
+
+* Provide a good **PR description** as a record of **what** change is being
+ made and **why** it was made. Link to a GitHub issue if it exists.
+
+* Don't fix code style and formatting unless you are already changing that
+ line to address an issue. PRs with irrelevant changes won't be merged. If
+ you do want to fix formatting or style, do that in a separate PR.
+
+* Unless your PR is trivial, you should expect there will be reviewer comments
+ that you'll need to address before merging. We expect you to be reasonably
+ responsive to those comments, otherwise the PR will be closed after 2-3
+ weeks of inactivity.
+
+* Maintain **clean commit history** and use **meaningful commit messages**.
+ PRs with messy commit history are difficult to review and won't be merged.
+ Use `rebase -i upstream/main` to curate your commit history and/or to
+ bring in latest changes from main (but avoid rebasing in the middle of a
+ code review).
+
+* Keep your PR up to date with upstream/main (if there are merge conflicts,
+ we can't really merge your change).
+
+* **All tests need to be passing** before your change can be merged. We
+ recommend you **run tests locally** (see
+ [Running Tests](README.md#running-tests)).
+
+* Exceptions to the rules can be made if there's a compelling reason for doing
+ so. That is - the rules are here to serve us, not the other way around, and
+ the rules need to be serving their intended purpose to be valuable.
+
+* All submissions, including submissions by project members, require review.
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..d645695
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000..1aba38f
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1 @@
+include LICENSE
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..5ab2365
--- /dev/null
+++ b/README.md
@@ -0,0 +1,60 @@
+# Abseil Python Common Libraries
+
+This repository is a collection of Python library code for building Python
+applications. The code is collected from Google's own Python code base, and has
+been extensively tested and used in production.
+
+## Features
+
+* Simple application startup
+* Distributed commandline flags system
+* Custom logging module with additional features
+* Testing utilities
+
+## Getting Started
+
+### Installation
+
+To install the package, simply run:
+
+```bash
+pip install absl-py
+```
+
+Or install from source:
+
+```bash
+python setup.py install
+```
+
+### Running Tests
+
+To run Abseil tests, you can clone the git repo and run
+[bazel](https://bazel.build/):
+
+```bash
+git clone https://github.com/abseil/abseil-py.git
+cd abseil-py
+bazel test absl/...
+```
+
+### Example Code
+
+Please refer to
+[smoke_tests/sample_app.py](https://github.com/abseil/abseil-py/blob/main/smoke_tests/sample_app.py)
+as an example to get started.
+
+## Documentation
+
+See the [Abseil Python Developer Guide](https://abseil.io/docs/python/).
+
+## Future Releases
+
+The current repository includes an initial set of libraries for early adoption.
+More components and interoperability with Abseil C++ Common Libraries
+will come in future releases.
+
+## License
+
+The Abseil Python library is licensed under the terms of the Apache
+license. See [LICENSE](LICENSE) for more information.
diff --git a/WORKSPACE b/WORKSPACE
new file mode 100644
index 0000000..a964e21
--- /dev/null
+++ b/WORKSPACE
@@ -0,0 +1,14 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+workspace(name = "io_abseil_py")
diff --git a/absl/BUILD b/absl/BUILD
new file mode 100644
index 0000000..4e747ea
--- /dev/null
+++ b/absl/BUILD
@@ -0,0 +1,84 @@
+licenses(["notice"])
+
+py_library(
+ name = "app",
+ srcs = [
+ "app.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":command_name",
+ "//absl/flags",
+ "//absl/logging",
+ ],
+)
+
+py_library(
+ name = "command_name",
+ srcs = ["command_name.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+)
+
+py_library(
+ name = "tests/app_test_helper",
+ testonly = 1,
+ srcs = ["tests/app_test_helper.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":app",
+ "//absl/flags",
+ ],
+)
+
+py_binary(
+ name = "tests/app_test_helper_pure_python",
+ testonly = 1,
+ srcs = ["tests/app_test_helper.py"],
+ main = "tests/app_test_helper.py",
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":app",
+ "//absl/flags",
+ ],
+)
+
+py_test(
+ name = "tests/app_test",
+ srcs = ["tests/app_test.py"],
+ data = [":tests/app_test_helper_pure_python"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":app",
+ ":tests/app_test_helper",
+ "//absl/flags",
+ "//absl/testing:_bazelize_command",
+ "//absl/testing:absltest",
+ "//absl/testing:flagsaver",
+ ],
+)
+
+py_test(
+ name = "tests/command_name_test",
+ srcs = ["tests/command_name_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":command_name",
+ "//absl/testing:absltest",
+ ],
+)
+
+py_test(
+ name = "tests/python_version_test",
+ srcs = ["tests/python_version_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ "//absl/flags",
+ "//absl/testing:absltest",
+ ],
+)
diff --git a/absl/__init__.py b/absl/__init__.py
new file mode 100644
index 0000000..a3bd1cd
--- /dev/null
+++ b/absl/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/absl/app.py b/absl/app.py
new file mode 100644
index 0000000..037f75c
--- /dev/null
+++ b/absl/app.py
@@ -0,0 +1,484 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Generic entry point for Abseil Python applications.
+
+To use this module, define a 'main' function with a single 'argv' argument and
+call app.run(main). For example:
+
+ def main(argv):
+ if len(argv) > 1:
+ raise app.UsageError('Too many command-line arguments.')
+
+ if __name__ == '__main__':
+ app.run(main)
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import errno
+import os
+import pdb
+import sys
+import textwrap
+import traceback
+
+from absl import command_name
+from absl import flags
+from absl import logging
+
+try:
+ import faulthandler
+except ImportError:
+ faulthandler = None
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_boolean('run_with_pdb', False, 'Set to true for PDB debug mode')
+flags.DEFINE_boolean('pdb_post_mortem', False,
+ 'Set to true to handle uncaught exceptions with PDB '
+ 'post mortem.')
+flags.DEFINE_alias('pdb', 'pdb_post_mortem')
+flags.DEFINE_boolean('run_with_profiling', False,
+ 'Set to true for profiling the script. '
+ 'Execution will be slower, and the output format might '
+ 'change over time.')
+flags.DEFINE_string('profile_file', None,
+ 'Dump profile information to a file (for python -m '
+ 'pstats). Implies --run_with_profiling.')
+flags.DEFINE_boolean('use_cprofile_for_profiling', True,
+ 'Use cProfile instead of the profile module for '
+ 'profiling. This has no effect unless '
+ '--run_with_profiling is set.')
+flags.DEFINE_boolean('only_check_args', False,
+ 'Set to true to validate args and exit.',
+ allow_hide_cpp=True)
+
+
+# If main() exits via an abnormal exception, call into these
+# handlers before exiting.
+EXCEPTION_HANDLERS = []
+
+
+class Error(Exception):
+ pass
+
+
+class UsageError(Error):
+ """Exception raised when the arguments supplied by the user are invalid.
+
+ Raise this when the arguments supplied are invalid from the point of
+ view of the application. For example when two mutually exclusive
+ flags have been supplied or when there are not enough non-flag
+ arguments. It is distinct from flags.Error which covers the lower
+ level of parsing and validating individual flags.
+ """
+
+ def __init__(self, message, exitcode=1):
+ super(UsageError, self).__init__(message)
+ self.exitcode = exitcode
+
+
+class HelpFlag(flags.BooleanFlag):
+ """Special boolean flag that displays usage and raises SystemExit."""
+ NAME = 'help'
+ SHORT_NAME = '?'
+
+ def __init__(self):
+ super(HelpFlag, self).__init__(
+ self.NAME, False, 'show this help',
+ short_name=self.SHORT_NAME, allow_hide_cpp=True)
+
+ def parse(self, arg):
+ if self._parse(arg):
+ usage(shorthelp=True, writeto_stdout=True)
+ # Advertise --helpfull on stdout, since usage() was on stdout.
+ print()
+ print('Try --helpfull to get a list of all flags.')
+ sys.exit(1)
+
+
+class HelpshortFlag(HelpFlag):
+ """--helpshort is an alias for --help."""
+ NAME = 'helpshort'
+ SHORT_NAME = None
+
+
+class HelpfullFlag(flags.BooleanFlag):
+ """Display help for flags in the main module and all dependent modules."""
+
+ def __init__(self):
+ super(HelpfullFlag, self).__init__(
+ 'helpfull', False, 'show full help', allow_hide_cpp=True)
+
+ def parse(self, arg):
+ if self._parse(arg):
+ usage(writeto_stdout=True)
+ sys.exit(1)
+
+
+class HelpXMLFlag(flags.BooleanFlag):
+ """Similar to HelpfullFlag, but generates output in XML format."""
+
+ def __init__(self):
+ super(HelpXMLFlag, self).__init__(
+ 'helpxml', False, 'like --helpfull, but generates XML output',
+ allow_hide_cpp=True)
+
+ def parse(self, arg):
+ if self._parse(arg):
+ flags.FLAGS.write_help_in_xml_format(sys.stdout)
+ sys.exit(1)
+
+
+def parse_flags_with_usage(args):
+ """Tries to parse the flags, print usage, and exit if unparsable.
+
+ Args:
+ args: [str], a non-empty list of the command line arguments including
+ program name.
+
+ Returns:
+ [str], a non-empty list of remaining command line arguments after parsing
+ flags, including program name.
+ """
+ try:
+ return FLAGS(args)
+ except flags.Error as error:
+ message = str(error)
+ if '\n' in message:
+ final_message = 'FATAL Flags parsing error:\n%s\n' % textwrap.indent(
+ message, ' ')
+ else:
+ final_message = 'FATAL Flags parsing error: %s\n' % message
+ sys.stderr.write(final_message)
+ sys.stderr.write('Pass --helpshort or --helpfull to see help on flags.\n')
+ sys.exit(1)
+
+
+_define_help_flags_called = False
+
+
+def define_help_flags():
+ """Registers help flags. Idempotent."""
+ # Use a global to ensure idempotence.
+ global _define_help_flags_called
+
+ if not _define_help_flags_called:
+ flags.DEFINE_flag(HelpFlag())
+ flags.DEFINE_flag(HelpshortFlag()) # alias for --help
+ flags.DEFINE_flag(HelpfullFlag())
+ flags.DEFINE_flag(HelpXMLFlag())
+ _define_help_flags_called = True
+
+
+def _register_and_parse_flags_with_usage(
+ argv=None,
+ flags_parser=parse_flags_with_usage,
+):
+ """Registers help flags, parses arguments and shows usage if appropriate.
+
+ This also calls sys.exit(0) if flag --only_check_args is True.
+
+ Args:
+ argv: [str], a non-empty list of the command line arguments including
+ program name, sys.argv is used if None.
+ flags_parser: Callable[[List[Text]], Any], the function used to parse flags.
+ The return value of this function is passed to `main` untouched.
+ It must guarantee FLAGS is parsed after this function is called.
+
+ Returns:
+ The return value of `flags_parser`. When using the default `flags_parser`,
+ it returns the following:
+ [str], a non-empty list of remaining command line arguments after parsing
+ flags, including program name.
+
+ Raises:
+ Error: Raised when flags_parser is called, but FLAGS is not parsed.
+ SystemError: Raised when it's called more than once.
+ """
+ if _register_and_parse_flags_with_usage.done:
+ raise SystemError('Flag registration can be done only once.')
+
+ define_help_flags()
+
+ original_argv = sys.argv if argv is None else argv
+ args_to_main = flags_parser(original_argv)
+ if not FLAGS.is_parsed():
+ raise Error('FLAGS must be parsed after flags_parser is called.')
+
+ # Exit when told so.
+ if FLAGS.only_check_args:
+ sys.exit(0)
+ # Immediately after flags are parsed, bump verbosity to INFO if the flag has
+ # not been set.
+ if FLAGS['verbosity'].using_default_value:
+ FLAGS.verbosity = 0
+ _register_and_parse_flags_with_usage.done = True
+
+ return args_to_main
+
+_register_and_parse_flags_with_usage.done = False
+
+
+def _run_main(main, argv):
+ """Calls main, optionally with pdb or profiler."""
+ if FLAGS.run_with_pdb:
+ sys.exit(pdb.runcall(main, argv))
+ elif FLAGS.run_with_profiling or FLAGS.profile_file:
+ # Avoid import overhead since most apps (including performance-sensitive
+ # ones) won't be run with profiling.
+ import atexit
+ if FLAGS.use_cprofile_for_profiling:
+ import cProfile as profile
+ else:
+ import profile
+ profiler = profile.Profile()
+ if FLAGS.profile_file:
+ atexit.register(profiler.dump_stats, FLAGS.profile_file)
+ else:
+ atexit.register(profiler.print_stats)
+ retval = profiler.runcall(main, argv)
+ sys.exit(retval)
+ else:
+ sys.exit(main(argv))
+
+
+def _call_exception_handlers(exception):
+ """Calls any installed exception handlers."""
+ for handler in EXCEPTION_HANDLERS:
+ try:
+ if handler.wants(exception):
+ handler.handle(exception)
+ except: # pylint: disable=bare-except
+ try:
+ # We don't want to stop for exceptions in the exception handlers but
+ # we shouldn't hide them either.
+ logging.error(traceback.format_exc())
+ except: # pylint: disable=bare-except
+ # In case even the logging statement fails, ignore.
+ pass
+
+
+def run(
+ main,
+ argv=None,
+ flags_parser=parse_flags_with_usage,
+):
+ """Begins executing the program.
+
+ Args:
+ main: The main function to execute. It takes an single argument "argv",
+ which is a list of command line arguments with parsed flags removed.
+ The return value is passed to `sys.exit`, and so for example
+ a return value of 0 or None results in a successful termination, whereas
+ a return value of 1 results in abnormal termination.
+ For more details, see https://docs.python.org/3/library/sys#sys.exit
+ argv: A non-empty list of the command line arguments including program name,
+ sys.argv is used if None.
+ flags_parser: Callable[[List[Text]], Any], the function used to parse flags.
+ The return value of this function is passed to `main` untouched.
+ It must guarantee FLAGS is parsed after this function is called.
+ Should be passed as a keyword-only arg which will become mandatory in a
+ future release.
+ - Parses command line flags with the flag module.
+ - If there are any errors, prints usage().
+ - Calls main() with the remaining arguments.
+ - If main() raises a UsageError, prints usage and the error message.
+ """
+ try:
+ args = _run_init(
+ sys.argv if argv is None else argv,
+ flags_parser,
+ )
+ while _init_callbacks:
+ callback = _init_callbacks.popleft()
+ callback()
+ try:
+ _run_main(main, args)
+ except UsageError as error:
+ usage(shorthelp=True, detailed_error=error, exitcode=error.exitcode)
+ except:
+ exc = sys.exc_info()[1]
+ # Don't try to post-mortem debug successful SystemExits, since those
+ # mean there wasn't actually an error. In particular, the test framework
+ # raises SystemExit(False) even if all tests passed.
+ if isinstance(exc, SystemExit) and not exc.code:
+ raise
+
+ # Check the tty so that we don't hang waiting for input in an
+ # non-interactive scenario.
+ if FLAGS.pdb_post_mortem and sys.stdout.isatty():
+ traceback.print_exc()
+ print()
+ print(' *** Entering post-mortem debugging ***')
+ print()
+ pdb.post_mortem()
+ raise
+ except Exception as e:
+ _call_exception_handlers(e)
+ raise
+
+# Callbacks which have been deferred until after _run_init has been called.
+_init_callbacks = collections.deque()
+
+
+def call_after_init(callback):
+ """Calls the given callback only once ABSL has finished initialization.
+
+ If ABSL has already finished initialization when `call_after_init` is
+ called then the callback is executed immediately, otherwise `callback` is
+ stored to be executed after `app.run` has finished initializing (aka. just
+ before the main function is called).
+
+ If called after `app.run`, this is equivalent to calling `callback()` in the
+ caller thread. If called before `app.run`, callbacks are run sequentially (in
+ an undefined order) in the same thread as `app.run`.
+
+ Args:
+ callback: a callable to be called once ABSL has finished initialization.
+ This may be immediate if initialization has already finished. It
+ takes no arguments and returns nothing.
+ """
+ if _run_init.done:
+ callback()
+ else:
+ _init_callbacks.append(callback)
+
+
+def _run_init(
+ argv,
+ flags_parser,
+):
+ """Does one-time initialization and re-parses flags on rerun."""
+ if _run_init.done:
+ return flags_parser(argv)
+ command_name.make_process_name_useful()
+ # Set up absl logging handler.
+ logging.use_absl_handler()
+ args = _register_and_parse_flags_with_usage(
+ argv=argv,
+ flags_parser=flags_parser,
+ )
+ if faulthandler:
+ try:
+ faulthandler.enable()
+ except Exception: # pylint: disable=broad-except
+ # Some tests verify stderr output very closely, so don't print anything.
+ # Disabled faulthandler is a low-impact error.
+ pass
+ _run_init.done = True
+ return args
+
+
+_run_init.done = False
+
+
+def usage(shorthelp=False, writeto_stdout=False, detailed_error=None,
+ exitcode=None):
+ """Writes __main__'s docstring to stderr with some help text.
+
+ Args:
+ shorthelp: bool, if True, prints only flags from the main module,
+ rather than all flags.
+ writeto_stdout: bool, if True, writes help message to stdout,
+ rather than to stderr.
+ detailed_error: str, additional detail about why usage info was presented.
+ exitcode: optional integer, if set, exits with this status code after
+ writing help.
+ """
+ if writeto_stdout:
+ stdfile = sys.stdout
+ else:
+ stdfile = sys.stderr
+
+ doc = sys.modules['__main__'].__doc__
+ if not doc:
+ doc = '\nUSAGE: %s [flags]\n' % sys.argv[0]
+ doc = flags.text_wrap(doc, indent=' ', firstline_indent='')
+ else:
+ # Replace all '%s' with sys.argv[0], and all '%%' with '%'.
+ num_specifiers = doc.count('%') - 2 * doc.count('%%')
+ try:
+ doc %= (sys.argv[0],) * num_specifiers
+ except (OverflowError, TypeError, ValueError):
+ # Just display the docstring as-is.
+ pass
+ if shorthelp:
+ flag_str = FLAGS.main_module_help()
+ else:
+ flag_str = FLAGS.get_help()
+ try:
+ stdfile.write(doc)
+ if flag_str:
+ stdfile.write('\nflags:\n')
+ stdfile.write(flag_str)
+ stdfile.write('\n')
+ if detailed_error is not None:
+ stdfile.write('\n%s\n' % detailed_error)
+ except IOError as e:
+ # We avoid printing a huge backtrace if we get EPIPE, because
+ # "foo.par --help | less" is a frequent use case.
+ if e.errno != errno.EPIPE:
+ raise
+ if exitcode is not None:
+ sys.exit(exitcode)
+
+
+class ExceptionHandler(object):
+ """Base exception handler from which other may inherit."""
+
+ def wants(self, exc):
+ """Returns whether this handler wants to handle the exception or not.
+
+ This base class returns True for all exceptions by default. Override in
+ subclass if it wants to be more selective.
+
+ Args:
+ exc: Exception, the current exception.
+ """
+ del exc # Unused.
+ return True
+
+ def handle(self, exc):
+ """Do something with the current exception.
+
+ Args:
+ exc: Exception, the current exception
+
+ This method must be overridden.
+ """
+ raise NotImplementedError()
+
+
+def install_exception_handler(handler):
+ """Installs an exception handler.
+
+ Args:
+ handler: ExceptionHandler, the exception handler to install.
+
+ Raises:
+ TypeError: Raised when the handler was not of the correct type.
+
+ All installed exception handlers will be called if main() exits via
+ an abnormal exception, i.e. not one of SystemExit, KeyboardInterrupt,
+ FlagsError or UsageError.
+ """
+ if not isinstance(handler, ExceptionHandler):
+ raise TypeError('handler of type %s does not inherit from ExceptionHandler'
+ % type(handler))
+ EXCEPTION_HANDLERS.append(handler)
diff --git a/absl/app.pyi b/absl/app.pyi
new file mode 100644
index 0000000..fe5e448
--- /dev/null
+++ b/absl/app.pyi
@@ -0,0 +1,99 @@
+
+from typing import Any, Callable, Collection, Iterable, List, NoReturn, Optional, Text, TypeVar, Union, overload
+
+from absl.flags import _flag
+
+
+_MainArgs = TypeVar('_MainArgs')
+_Exc = TypeVar('_Exc', bound=Exception)
+
+
+class ExceptionHandler():
+
+ def wants(self, exc: _Exc) -> bool:
+ ...
+
+ def handle(self, exc: _Exc):
+ ...
+
+
+EXCEPTION_HANDLERS: List[ExceptionHandler] = ...
+
+
+class HelpFlag(_flag.BooleanFlag):
+ def __init__(self):
+ ...
+
+
+class HelpshortFlag(HelpFlag):
+ ...
+
+
+class HelpfullFlag(_flag.BooleanFlag):
+ def __init__(self):
+ ...
+
+
+class HelpXMLFlag(_flag.BooleanFlag):
+ def __init__(self):
+ ...
+
+
+def define_help_flags() -> None:
+ ...
+
+
+@overload
+def usage(shorthelp: Union[bool, int] = ...,
+ writeto_stdout: Union[bool, int] = ...,
+ detailed_error: Optional[Any] = ...,
+ exitcode: None = ...) -> None:
+ ...
+
+
+@overload
+def usage(shorthelp: Union[bool, int] = ...,
+ writeto_stdout: Union[bool, int] = ...,
+ detailed_error: Optional[Any] = ...,
+ exitcode: int = ...) -> NoReturn:
+ ...
+
+
+def install_exception_handler(handler: ExceptionHandler) -> None:
+ ...
+
+
+class Error(Exception):
+ ...
+
+
+class UsageError(Error):
+ exitcode: int
+
+
+def parse_flags_with_usage(args: List[Text]) -> List[Text]:
+ ...
+
+
+def call_after_init(callback: Callable[[], Any]) -> None:
+ ...
+
+
+# Without the flag_parser argument, `main` should require a List[Text].
+@overload
+def run(
+ main: Callable[[List[Text]], Any],
+ argv: Optional[List[Text]] = ...,
+ *,
+) -> NoReturn:
+ ...
+
+
+@overload
+def run(
+ main: Callable[[_MainArgs], Any],
+ argv: Optional[List[Text]] = ...,
+ *,
+ flags_parser: Callable[[List[Text]], _MainArgs],
+) -> NoReturn:
+ ...
diff --git a/absl/command_name.py b/absl/command_name.py
new file mode 100644
index 0000000..3bf9fad
--- /dev/null
+++ b/absl/command_name.py
@@ -0,0 +1,67 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A tiny stand alone library to change the kernel process name on Linux."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+
+# This library must be kept small and stand alone. It is used by small things
+# that require no extension modules.
+
+
+def make_process_name_useful():
+ """Sets the process name to something better than 'python' if possible."""
+ set_kernel_process_name(os.path.basename(sys.argv[0]))
+
+
+def set_kernel_process_name(name):
+ """Changes the Kernel's /proc/self/status process name on Linux.
+
+ The kernel name is NOT what will be shown by the ps or top command.
+ It is a 15 character string stored in the kernel's process table that
+ is included in the kernel log when a process is OOM killed.
+ The first 15 bytes of name are used. Non-ASCII unicode is replaced with '?'.
+
+ Does nothing if /proc/self/comm cannot be written or prctl() fails.
+
+ Args:
+ name: bytes|unicode, the Linux kernel's command name to set.
+ """
+ if not isinstance(name, bytes):
+ name = name.encode('ascii', 'replace')
+ try:
+ # This is preferred to using ctypes to try and call prctl() when possible.
+ with open('/proc/self/comm', 'wb') as proc_comm:
+ proc_comm.write(name[:15])
+ except EnvironmentError:
+ try:
+ import ctypes
+ except ImportError:
+ return # No ctypes.
+ try:
+ libc = ctypes.CDLL('libc.so.6')
+ except EnvironmentError:
+ return # No libc.so.6.
+ pr_set_name = ctypes.c_ulong(15) # linux/prctl.h PR_SET_NAME value.
+ zero = ctypes.c_ulong(0)
+ try:
+ libc.prctl(pr_set_name, name, zero, zero, zero)
+ # Ignore the prctl return value. Nothing we can do if it errored.
+ except AttributeError:
+ return # No prctl.
diff --git a/absl/flags/BUILD b/absl/flags/BUILD
new file mode 100644
index 0000000..33a7b07
--- /dev/null
+++ b/absl/flags/BUILD
@@ -0,0 +1,314 @@
+licenses(["notice"])
+
+py_library(
+ name = "flags",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":_argument_parser",
+ ":_defines",
+ ":_exceptions",
+ ":_flag",
+ ":_flagvalues",
+ ":_helpers",
+ ":_validators",
+ ],
+)
+
+py_library(
+ name = "argparse_flags",
+ srcs = ["argparse_flags.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [":flags"],
+)
+
+py_library(
+ name = "_argument_parser",
+ srcs = ["_argument_parser.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":_helpers",
+ ],
+)
+
+py_library(
+ name = "_defines",
+ srcs = ["_defines.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":_argument_parser",
+ ":_exceptions",
+ ":_flag",
+ ":_flagvalues",
+ ":_helpers",
+ ":_validators",
+ ],
+)
+
+py_library(
+ name = "_exceptions",
+ srcs = ["_exceptions.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":_helpers",
+ ],
+)
+
+py_library(
+ name = "_flag",
+ srcs = ["_flag.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":_argument_parser",
+ ":_exceptions",
+ ":_helpers",
+ ],
+)
+
+py_library(
+ name = "_flagvalues",
+ srcs = ["_flagvalues.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":_exceptions",
+ ":_flag",
+ ":_helpers",
+ ":_validators_classes",
+ ],
+)
+
+py_library(
+ name = "_helpers",
+ srcs = ["_helpers.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_library(
+ name = "_validators",
+ srcs = [
+ "_validators.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":_exceptions",
+ ":_flagvalues",
+ ":_validators_classes",
+ ],
+)
+
+py_library(
+ name = "_validators_classes",
+ srcs = [
+ "_validators_classes.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":_exceptions",
+ ],
+)
+
+py_test(
+ name = "tests/_argument_parser_test",
+ srcs = ["tests/_argument_parser_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":_argument_parser",
+ "//absl/testing:absltest",
+ "//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "tests/_flag_test",
+ srcs = ["tests/_flag_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":_argument_parser",
+ ":_exceptions",
+ ":_flag",
+ "//absl/testing:absltest",
+ "//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "tests/_flagvalues_test",
+ size = "small",
+ srcs = ["tests/_flagvalues_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":_defines",
+ ":_exceptions",
+ ":_flagvalues",
+ ":_helpers",
+ ":_validators",
+ ":tests/module_foo",
+ "//absl/logging",
+ "//absl/testing:absltest",
+ "//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "tests/_helpers_test",
+ size = "small",
+ srcs = ["tests/_helpers_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":_helpers",
+ ":tests/module_bar",
+ ":tests/module_foo",
+ "//absl/testing:absltest",
+ ],
+)
+
+py_test(
+ name = "tests/_validators_test",
+ size = "small",
+ srcs = ["tests/_validators_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":_defines",
+ ":_exceptions",
+ ":_flagvalues",
+ ":_validators",
+ "//absl/testing:absltest",
+ ],
+)
+
+py_test(
+ name = "tests/argparse_flags_test",
+ size = "small",
+ srcs = ["tests/argparse_flags_test.py"],
+ data = [":tests/argparse_flags_test_helper"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":argparse_flags",
+ ":flags",
+ "//absl/logging",
+ "//absl/testing:_bazelize_command",
+ "//absl/testing:absltest",
+ "//absl/testing:parameterized",
+ ],
+)
+
+py_binary(
+ name = "tests/argparse_flags_test_helper",
+ testonly = 1,
+ srcs = ["tests/argparse_flags_test_helper.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":argparse_flags",
+ ":flags",
+ "//absl:app",
+ ],
+)
+
+py_test(
+ name = "tests/flags_formatting_test",
+ size = "small",
+ srcs = ["tests/flags_formatting_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":_helpers",
+ ":flags",
+ "//absl/testing:absltest",
+ ],
+)
+
+py_test(
+ name = "tests/flags_helpxml_test",
+ size = "small",
+ srcs = ["tests/flags_helpxml_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":_helpers",
+ ":flags",
+ ":tests/module_bar",
+ "//absl/testing:absltest",
+ ],
+)
+
+py_test(
+ name = "tests/flags_numeric_bounds_test",
+ size = "small",
+ srcs = ["tests/flags_numeric_bounds_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":_validators",
+ ":flags",
+ "//absl/testing:absltest",
+ ],
+)
+
+py_test(
+ name = "tests/flags_test",
+ size = "small",
+ srcs = ["tests/flags_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":_exceptions",
+ ":_helpers",
+ ":flags",
+ ":tests/module_bar",
+ ":tests/module_baz",
+ ":tests/module_foo",
+ "//absl/testing:absltest",
+ ],
+)
+
+py_test(
+ name = "tests/flags_unicode_literals_test",
+ size = "small",
+ srcs = ["tests/flags_unicode_literals_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":flags",
+ "//absl/testing:absltest",
+ ],
+)
+
+py_library(
+ name = "tests/module_bar",
+ testonly = 1,
+ srcs = ["tests/module_bar.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":_helpers",
+ ":flags",
+ ],
+)
+
+py_library(
+ name = "tests/module_baz",
+ testonly = 1,
+ srcs = ["tests/module_baz.py"],
+ srcs_version = "PY2AND3",
+ deps = [":flags"],
+)
+
+py_library(
+ name = "tests/module_foo",
+ testonly = 1,
+ srcs = ["tests/module_foo.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":_helpers",
+ ":flags",
+ ":tests/module_bar",
+ ],
+)
diff --git a/absl/flags/__init__.py b/absl/flags/__init__.py
new file mode 100644
index 0000000..e6014a6
--- /dev/null
+++ b/absl/flags/__init__.py
@@ -0,0 +1,145 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This package is used to define and parse command line flags.
+
+This package defines a *distributed* flag-definition policy: rather than
+an application having to define all flags in or near main(), each Python
+module defines flags that are useful to it. When one Python module
+imports another, it gains access to the other's flags. (This is
+implemented by having all modules share a common, global registry object
+containing all the flag information.)
+
+Flags are defined through the use of one of the DEFINE_xxx functions.
+The specific function used determines how the flag is parsed, checked,
+and optionally type-converted, when it's seen on the command line.
+"""
+
+import getopt
+import os
+import re
+import sys
+import types
+import warnings
+
+from absl.flags import _argument_parser
+from absl.flags import _defines
+from absl.flags import _exceptions
+from absl.flags import _flag
+from absl.flags import _flagvalues
+from absl.flags import _helpers
+from absl.flags import _validators
+
+# Initialize the FLAGS_MODULE as early as possible.
+# It's only used by adopt_module_key_flags to take SPECIAL_FLAGS into account.
+_helpers.FLAGS_MODULE = sys.modules[__name__]
+
+# Add current module to disclaimed module ids.
+_helpers.disclaim_module_ids.add(id(sys.modules[__name__]))
+
+# DEFINE functions. They are explained in more details in the module doc string.
+# pylint: disable=invalid-name
+DEFINE = _defines.DEFINE
+DEFINE_flag = _defines.DEFINE_flag
+DEFINE_string = _defines.DEFINE_string
+DEFINE_boolean = _defines.DEFINE_boolean
+DEFINE_bool = DEFINE_boolean # Match C++ API.
+DEFINE_float = _defines.DEFINE_float
+DEFINE_integer = _defines.DEFINE_integer
+DEFINE_enum = _defines.DEFINE_enum
+DEFINE_enum_class = _defines.DEFINE_enum_class
+DEFINE_list = _defines.DEFINE_list
+DEFINE_spaceseplist = _defines.DEFINE_spaceseplist
+DEFINE_multi = _defines.DEFINE_multi
+DEFINE_multi_string = _defines.DEFINE_multi_string
+DEFINE_multi_integer = _defines.DEFINE_multi_integer
+DEFINE_multi_float = _defines.DEFINE_multi_float
+DEFINE_multi_enum = _defines.DEFINE_multi_enum
+DEFINE_multi_enum_class = _defines.DEFINE_multi_enum_class
+DEFINE_alias = _defines.DEFINE_alias
+# pylint: enable=invalid-name
+
+# Flag validators.
+register_validator = _validators.register_validator
+validator = _validators.validator
+register_multi_flags_validator = _validators.register_multi_flags_validator
+multi_flags_validator = _validators.multi_flags_validator
+mark_flag_as_required = _validators.mark_flag_as_required
+mark_flags_as_required = _validators.mark_flags_as_required
+mark_flags_as_mutual_exclusive = _validators.mark_flags_as_mutual_exclusive
+mark_bool_flags_as_mutual_exclusive = _validators.mark_bool_flags_as_mutual_exclusive
+
+# Key flag related functions.
+declare_key_flag = _defines.declare_key_flag
+adopt_module_key_flags = _defines.adopt_module_key_flags
+disclaim_key_flags = _defines.disclaim_key_flags
+
+# Module exceptions.
+# pylint: disable=invalid-name
+Error = _exceptions.Error
+CantOpenFlagFileError = _exceptions.CantOpenFlagFileError
+DuplicateFlagError = _exceptions.DuplicateFlagError
+IllegalFlagValueError = _exceptions.IllegalFlagValueError
+UnrecognizedFlagError = _exceptions.UnrecognizedFlagError
+UnparsedFlagAccessError = _exceptions.UnparsedFlagAccessError
+ValidationError = _exceptions.ValidationError
+FlagNameConflictsWithMethodError = _exceptions.FlagNameConflictsWithMethodError
+
+# Public classes.
+Flag = _flag.Flag
+BooleanFlag = _flag.BooleanFlag
+EnumFlag = _flag.EnumFlag
+EnumClassFlag = _flag.EnumClassFlag
+MultiFlag = _flag.MultiFlag
+MultiEnumClassFlag = _flag.MultiEnumClassFlag
+FlagHolder = _flagvalues.FlagHolder
+FlagValues = _flagvalues.FlagValues
+ArgumentParser = _argument_parser.ArgumentParser
+BooleanParser = _argument_parser.BooleanParser
+EnumParser = _argument_parser.EnumParser
+EnumClassParser = _argument_parser.EnumClassParser
+ArgumentSerializer = _argument_parser.ArgumentSerializer
+FloatParser = _argument_parser.FloatParser
+IntegerParser = _argument_parser.IntegerParser
+BaseListParser = _argument_parser.BaseListParser
+ListParser = _argument_parser.ListParser
+ListSerializer = _argument_parser.ListSerializer
+EnumClassListSerializer = _argument_parser.EnumClassListSerializer
+CsvListSerializer = _argument_parser.CsvListSerializer
+WhitespaceSeparatedListParser = _argument_parser.WhitespaceSeparatedListParser
+EnumClassSerializer = _argument_parser.EnumClassSerializer
+# pylint: enable=invalid-name
+
+# Helper functions.
+get_help_width = _helpers.get_help_width
+text_wrap = _helpers.text_wrap
+flag_dict_to_args = _helpers.flag_dict_to_args
+doc_to_help = _helpers.doc_to_help
+
+# Special flags.
+_helpers.SPECIAL_FLAGS = FlagValues()
+
+DEFINE_string(
+ 'flagfile', '',
+ 'Insert flag definitions from the given file into the command line.',
+ _helpers.SPECIAL_FLAGS) # pytype: disable=wrong-arg-types
+
+DEFINE_string('undefok', '',
+ 'comma-separated list of flag names that it is okay to specify '
+ 'on the command line even if the program does not define a flag '
+ 'with that name. IMPORTANT: flags in this list that have '
+ 'arguments MUST use the --flag=value format.',
+ _helpers.SPECIAL_FLAGS) # pytype: disable=wrong-arg-types
+
+# The global FlagValues instance.
+FLAGS = _flagvalues.FLAGS
diff --git a/absl/flags/__init__.pyi b/absl/flags/__init__.pyi
new file mode 100644
index 0000000..4eee59e
--- /dev/null
+++ b/absl/flags/__init__.pyi
@@ -0,0 +1,103 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from absl.flags import _argument_parser
+from absl.flags import _defines
+from absl.flags import _exceptions
+from absl.flags import _flag
+from absl.flags import _flagvalues
+from absl.flags import _helpers
+from absl.flags import _validators
+
+# DEFINE functions. They are explained in more details in the module doc string.
+# pylint: disable=invalid-name
+DEFINE = _defines.DEFINE
+DEFINE_flag = _defines.DEFINE_flag
+DEFINE_string = _defines.DEFINE_string
+DEFINE_boolean = _defines.DEFINE_boolean
+DEFINE_bool = DEFINE_boolean # Match C++ API.
+DEFINE_float = _defines.DEFINE_float
+DEFINE_integer = _defines.DEFINE_integer
+DEFINE_enum = _defines.DEFINE_enum
+DEFINE_enum_class = _defines.DEFINE_enum_class
+DEFINE_list = _defines.DEFINE_list
+DEFINE_spaceseplist = _defines.DEFINE_spaceseplist
+DEFINE_multi = _defines.DEFINE_multi
+DEFINE_multi_string = _defines.DEFINE_multi_string
+DEFINE_multi_integer = _defines.DEFINE_multi_integer
+DEFINE_multi_float = _defines.DEFINE_multi_float
+DEFINE_multi_enum = _defines.DEFINE_multi_enum
+DEFINE_multi_enum_class = _defines.DEFINE_multi_enum_class
+DEFINE_alias = _defines.DEFINE_alias
+# pylint: enable=invalid-name
+
+# Flag validators.
+register_validator = _validators.register_validator
+validator = _validators.validator
+register_multi_flags_validator = _validators.register_multi_flags_validator
+multi_flags_validator = _validators.multi_flags_validator
+mark_flag_as_required = _validators.mark_flag_as_required
+mark_flags_as_required = _validators.mark_flags_as_required
+mark_flags_as_mutual_exclusive = _validators.mark_flags_as_mutual_exclusive
+mark_bool_flags_as_mutual_exclusive = _validators.mark_bool_flags_as_mutual_exclusive
+
+# Key flag related functions.
+declare_key_flag = _defines.declare_key_flag
+adopt_module_key_flags = _defines.adopt_module_key_flags
+disclaim_key_flags = _defines.disclaim_key_flags
+
+# Module exceptions.
+# pylint: disable=invalid-name
+Error = _exceptions.Error
+CantOpenFlagFileError = _exceptions.CantOpenFlagFileError
+DuplicateFlagError = _exceptions.DuplicateFlagError
+IllegalFlagValueError = _exceptions.IllegalFlagValueError
+UnrecognizedFlagError = _exceptions.UnrecognizedFlagError
+UnparsedFlagAccessError = _exceptions.UnparsedFlagAccessError
+ValidationError = _exceptions.ValidationError
+FlagNameConflictsWithMethodError = _exceptions.FlagNameConflictsWithMethodError
+
+# Public classes.
+Flag = _flag.Flag
+BooleanFlag = _flag.BooleanFlag
+EnumFlag = _flag.EnumFlag
+EnumClassFlag = _flag.EnumClassFlag
+MultiFlag = _flag.MultiFlag
+MultiEnumClassFlag = _flag.MultiEnumClassFlag
+FlagHolder = _flagvalues.FlagHolder
+FlagValues = _flagvalues.FlagValues
+ArgumentParser = _argument_parser.ArgumentParser
+BooleanParser = _argument_parser.BooleanParser
+EnumParser = _argument_parser.EnumParser
+EnumClassParser = _argument_parser.EnumClassParser
+ArgumentSerializer = _argument_parser.ArgumentSerializer
+FloatParser = _argument_parser.FloatParser
+IntegerParser = _argument_parser.IntegerParser
+BaseListParser = _argument_parser.BaseListParser
+ListParser = _argument_parser.ListParser
+ListSerializer = _argument_parser.ListSerializer
+CsvListSerializer = _argument_parser.CsvListSerializer
+WhitespaceSeparatedListParser = _argument_parser.WhitespaceSeparatedListParser
+EnumClassSerializer = _argument_parser.EnumClassSerializer
+# pylint: enable=invalid-name
+
+# Helper functions.
+get_help_width = _helpers.get_help_width
+text_wrap = _helpers.text_wrap
+flag_dict_to_args = _helpers.flag_dict_to_args
+doc_to_help = _helpers.doc_to_help
+
+# The global FlagValues instance.
+FLAGS = _flagvalues.FLAGS
+
diff --git a/absl/flags/_argument_parser.py b/absl/flags/_argument_parser.py
new file mode 100644
index 0000000..9c6c8c6
--- /dev/null
+++ b/absl/flags/_argument_parser.py
@@ -0,0 +1,629 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Contains base classes used to parse and convert arguments.
+
+Do NOT import this module directly. Import the flags package and use the
+aliases defined at the package level instead.
+"""
+
+import collections
+import csv
+import io
+import string
+
+from absl.flags import _helpers
+
+
+def _is_integer_type(instance):
+ """Returns True if instance is an integer, and not a bool."""
+ return (isinstance(instance, int) and
+ not isinstance(instance, bool))
+
+
+class _ArgumentParserCache(type):
+ """Metaclass used to cache and share argument parsers among flags."""
+
+ _instances = {}
+
+ def __call__(cls, *args, **kwargs):
+ """Returns an instance of the argument parser cls.
+
+ This method overrides behavior of the __new__ methods in
+ all subclasses of ArgumentParser (inclusive). If an instance
+ for cls with the same set of arguments exists, this instance is
+ returned, otherwise a new instance is created.
+
+ If any keyword arguments are defined, or the values in args
+ are not hashable, this method always returns a new instance of
+ cls.
+
+ Args:
+ *args: Positional initializer arguments.
+ **kwargs: Initializer keyword arguments.
+
+ Returns:
+ An instance of cls, shared or new.
+ """
+ if kwargs:
+ return type.__call__(cls, *args, **kwargs)
+ else:
+ instances = cls._instances
+ key = (cls,) + tuple(args)
+ try:
+ return instances[key]
+ except KeyError:
+ # No cache entry for key exists, create a new one.
+ return instances.setdefault(key, type.__call__(cls, *args))
+ except TypeError:
+ # An object in args cannot be hashed, always return
+ # a new instance.
+ return type.__call__(cls, *args)
+
+
+# NOTE about Genericity and Metaclass of ArgumentParser.
+# (1) In the .py source (this file)
+# - is not declared as Generic
+# - has _ArgumentParserCache as a metaclass
+# (2) In the .pyi source (type stub)
+# - is declared as Generic
+# - doesn't have a metaclass
+# The reason we need this is due to Generic having a different metaclass
+# (for python versions <= 3.7) and a class can have only one metaclass.
+#
+# * Lack of metaclass in .pyi is not a deal breaker, since the metaclass
+# doesn't affect any type information. Also type checkers can check the type
+# parameters.
+# * However, not declaring ArgumentParser as Generic in the source affects
+# runtime annotation processing. In particular this means, subclasses should
+# inherit from `ArgumentParser` and not `ArgumentParser[SomeType]`.
+# The corresponding DEFINE_someType method (the public API) can be annotated
+# to return FlagHolder[SomeType].
+class ArgumentParser(metaclass=_ArgumentParserCache):
+ """Base class used to parse and convert arguments.
+
+ The parse() method checks to make sure that the string argument is a
+ legal value and convert it to a native type. If the value cannot be
+ converted, it should throw a 'ValueError' exception with a human
+ readable explanation of why the value is illegal.
+
+ Subclasses should also define a syntactic_help string which may be
+ presented to the user to describe the form of the legal values.
+
+ Argument parser classes must be stateless, since instances are cached
+ and shared between flags. Initializer arguments are allowed, but all
+ member variables must be derived from initializer arguments only.
+ """
+
+ syntactic_help = ''
+
+ def parse(self, argument):
+ """Parses the string argument and returns the native value.
+
+ By default it returns its argument unmodified.
+
+ Args:
+ argument: string argument passed in the commandline.
+
+ Raises:
+ ValueError: Raised when it fails to parse the argument.
+ TypeError: Raised when the argument has the wrong type.
+
+ Returns:
+ The parsed value in native type.
+ """
+ if not isinstance(argument, str):
+ raise TypeError('flag value must be a string, found "{}"'.format(
+ type(argument)))
+ return argument
+
+ def flag_type(self):
+ """Returns a string representing the type of the flag."""
+ return 'string'
+
+ def _custom_xml_dom_elements(self, doc):
+ """Returns a list of minidom.Element to add additional flag information.
+
+ Args:
+ doc: minidom.Document, the DOM document it should create nodes from.
+ """
+ del doc # Unused.
+ return []
+
+
+class ArgumentSerializer(object):
+ """Base class for generating string representations of a flag value."""
+
+ def serialize(self, value):
+ """Returns a serialized string of the value."""
+ return _helpers.str_or_unicode(value)
+
+
+class NumericParser(ArgumentParser):
+ """Parser of numeric values.
+
+ Parsed value may be bounded to a given upper and lower bound.
+ """
+
+ def is_outside_bounds(self, val):
+ """Returns whether the value is outside the bounds or not."""
+ return ((self.lower_bound is not None and val < self.lower_bound) or
+ (self.upper_bound is not None and val > self.upper_bound))
+
+ def parse(self, argument):
+ """See base class."""
+ val = self.convert(argument)
+ if self.is_outside_bounds(val):
+ raise ValueError('%s is not %s' % (val, self.syntactic_help))
+ return val
+
+ def _custom_xml_dom_elements(self, doc):
+ elements = []
+ if self.lower_bound is not None:
+ elements.append(_helpers.create_xml_dom_element(
+ doc, 'lower_bound', self.lower_bound))
+ if self.upper_bound is not None:
+ elements.append(_helpers.create_xml_dom_element(
+ doc, 'upper_bound', self.upper_bound))
+ return elements
+
+ def convert(self, argument):
+ """Returns the correct numeric value of argument.
+
+ Subclass must implement this method, and raise TypeError if argument is not
+ string or has the right numeric type.
+
+ Args:
+ argument: string argument passed in the commandline, or the numeric type.
+
+ Raises:
+ TypeError: Raised when argument is not a string or the right numeric type.
+ ValueError: Raised when failed to convert argument to the numeric value.
+ """
+ raise NotImplementedError
+
+
+class FloatParser(NumericParser):
+ """Parser of floating point values.
+
+ Parsed value may be bounded to a given upper and lower bound.
+ """
+ number_article = 'a'
+ number_name = 'number'
+ syntactic_help = ' '.join((number_article, number_name))
+
+ def __init__(self, lower_bound=None, upper_bound=None):
+ super(FloatParser, self).__init__()
+ self.lower_bound = lower_bound
+ self.upper_bound = upper_bound
+ sh = self.syntactic_help
+ if lower_bound is not None and upper_bound is not None:
+ sh = ('%s in the range [%s, %s]' % (sh, lower_bound, upper_bound))
+ elif lower_bound == 0:
+ sh = 'a non-negative %s' % self.number_name
+ elif upper_bound == 0:
+ sh = 'a non-positive %s' % self.number_name
+ elif upper_bound is not None:
+ sh = '%s <= %s' % (self.number_name, upper_bound)
+ elif lower_bound is not None:
+ sh = '%s >= %s' % (self.number_name, lower_bound)
+ self.syntactic_help = sh
+
+ def convert(self, argument):
+ """Returns the float value of argument."""
+ if (_is_integer_type(argument) or isinstance(argument, float) or
+ isinstance(argument, str)):
+ return float(argument)
+ else:
+ raise TypeError(
+ 'Expect argument to be a string, int, or float, found {}'.format(
+ type(argument)))
+
+ def flag_type(self):
+ """See base class."""
+ return 'float'
+
+
+class IntegerParser(NumericParser):
+ """Parser of an integer value.
+
+ Parsed value may be bounded to a given upper and lower bound.
+ """
+ number_article = 'an'
+ number_name = 'integer'
+ syntactic_help = ' '.join((number_article, number_name))
+
+ def __init__(self, lower_bound=None, upper_bound=None):
+ super(IntegerParser, self).__init__()
+ self.lower_bound = lower_bound
+ self.upper_bound = upper_bound
+ sh = self.syntactic_help
+ if lower_bound is not None and upper_bound is not None:
+ sh = ('%s in the range [%s, %s]' % (sh, lower_bound, upper_bound))
+ elif lower_bound == 1:
+ sh = 'a positive %s' % self.number_name
+ elif upper_bound == -1:
+ sh = 'a negative %s' % self.number_name
+ elif lower_bound == 0:
+ sh = 'a non-negative %s' % self.number_name
+ elif upper_bound == 0:
+ sh = 'a non-positive %s' % self.number_name
+ elif upper_bound is not None:
+ sh = '%s <= %s' % (self.number_name, upper_bound)
+ elif lower_bound is not None:
+ sh = '%s >= %s' % (self.number_name, lower_bound)
+ self.syntactic_help = sh
+
+ def convert(self, argument):
+ """Returns the int value of argument."""
+ if _is_integer_type(argument):
+ return argument
+ elif isinstance(argument, str):
+ base = 10
+ if len(argument) > 2 and argument[0] == '0':
+ if argument[1] == 'o':
+ base = 8
+ elif argument[1] == 'x':
+ base = 16
+ return int(argument, base)
+ else:
+ raise TypeError('Expect argument to be a string or int, found {}'.format(
+ type(argument)))
+
+ def flag_type(self):
+ """See base class."""
+ return 'int'
+
+
+class BooleanParser(ArgumentParser):
+ """Parser of boolean values."""
+
+ def parse(self, argument):
+ """See base class."""
+ if isinstance(argument, str):
+ if argument.lower() in ('true', 't', '1'):
+ return True
+ elif argument.lower() in ('false', 'f', '0'):
+ return False
+ else:
+ raise ValueError('Non-boolean argument to boolean flag', argument)
+ elif isinstance(argument, int):
+ # Only allow bool or integer 0, 1.
+ # Note that float 1.0 == True, 0.0 == False.
+ bool_value = bool(argument)
+ if argument == bool_value:
+ return bool_value
+ else:
+ raise ValueError('Non-boolean argument to boolean flag', argument)
+
+ raise TypeError('Non-boolean argument to boolean flag', argument)
+
+ def flag_type(self):
+ """See base class."""
+ return 'bool'
+
+
+class EnumParser(ArgumentParser):
+ """Parser of a string enum value (a string value from a given set)."""
+
+ def __init__(self, enum_values, case_sensitive=True):
+ """Initializes EnumParser.
+
+ Args:
+ enum_values: [str], a non-empty list of string values in the enum.
+ case_sensitive: bool, whether or not the enum is to be case-sensitive.
+
+ Raises:
+ ValueError: When enum_values is empty.
+ """
+ if not enum_values:
+ raise ValueError(
+ 'enum_values cannot be empty, found "{}"'.format(enum_values))
+ super(EnumParser, self).__init__()
+ self.enum_values = enum_values
+ self.case_sensitive = case_sensitive
+
+ def parse(self, argument):
+ """Determines validity of argument and returns the correct element of enum.
+
+ Args:
+ argument: str, the supplied flag value.
+
+ Returns:
+ The first matching element from enum_values.
+
+ Raises:
+ ValueError: Raised when argument didn't match anything in enum.
+ """
+ if self.case_sensitive:
+ if argument not in self.enum_values:
+ raise ValueError('value should be one of <%s>' %
+ '|'.join(self.enum_values))
+ else:
+ return argument
+ else:
+ if argument.upper() not in [value.upper() for value in self.enum_values]:
+ raise ValueError('value should be one of <%s>' %
+ '|'.join(self.enum_values))
+ else:
+ return [value for value in self.enum_values
+ if value.upper() == argument.upper()][0]
+
+ def flag_type(self):
+ """See base class."""
+ return 'string enum'
+
+
+class EnumClassParser(ArgumentParser):
+ """Parser of an Enum class member."""
+
+ def __init__(self, enum_class, case_sensitive=True):
+ """Initializes EnumParser.
+
+ Args:
+ enum_class: class, the Enum class with all possible flag values.
+ case_sensitive: bool, whether or not the enum is to be case-sensitive. If
+ False, all member names must be unique when case is ignored.
+
+ Raises:
+ TypeError: When enum_class is not a subclass of Enum.
+ ValueError: When enum_class is empty.
+ """
+ # Users must have an Enum class defined before using EnumClass flag.
+ # Therefore this dependency is guaranteed.
+ import enum
+
+ if not issubclass(enum_class, enum.Enum):
+ raise TypeError('{} is not a subclass of Enum.'.format(enum_class))
+ if not enum_class.__members__:
+ raise ValueError('enum_class cannot be empty, but "{}" is empty.'
+ .format(enum_class))
+ if not case_sensitive:
+ members = collections.Counter(
+ name.lower() for name in enum_class.__members__)
+ duplicate_keys = {
+ member for member, count in members.items() if count > 1
+ }
+ if duplicate_keys:
+ raise ValueError(
+ 'Duplicate enum values for {} using case_sensitive=False'.format(
+ duplicate_keys))
+
+ super(EnumClassParser, self).__init__()
+ self.enum_class = enum_class
+ self._case_sensitive = case_sensitive
+ if case_sensitive:
+ self._member_names = tuple(enum_class.__members__)
+ else:
+ self._member_names = tuple(
+ name.lower() for name in enum_class.__members__)
+
+ @property
+ def member_names(self):
+ """The accepted enum names, in lowercase if not case sensitive."""
+ return self._member_names
+
+ def parse(self, argument):
+ """Determines validity of argument and returns the correct element of enum.
+
+ Args:
+ argument: str or Enum class member, the supplied flag value.
+
+ Returns:
+ The first matching Enum class member in Enum class.
+
+ Raises:
+ ValueError: Raised when argument didn't match anything in enum.
+ """
+ if isinstance(argument, self.enum_class):
+ return argument
+ elif not isinstance(argument, str):
+ raise ValueError(
+ '{} is not an enum member or a name of a member in {}'.format(
+ argument, self.enum_class))
+ key = EnumParser(
+ self._member_names, case_sensitive=self._case_sensitive).parse(argument)
+ if self._case_sensitive:
+ return self.enum_class[key]
+ else:
+ # If EnumParser.parse() return a value, we're guaranteed to find it
+ # as a member of the class
+ return next(value for name, value in self.enum_class.__members__.items()
+ if name.lower() == key.lower())
+
+ def flag_type(self):
+ """See base class."""
+ return 'enum class'
+
+
+class ListSerializer(ArgumentSerializer):
+
+ def __init__(self, list_sep):
+ self.list_sep = list_sep
+
+ def serialize(self, value):
+ """See base class."""
+ return self.list_sep.join([_helpers.str_or_unicode(x) for x in value])
+
+
+class EnumClassListSerializer(ListSerializer):
+ """A serializer for MultiEnumClass flags.
+
+ This serializer simply joins the output of `EnumClassSerializer` using a
+ provided separator.
+ """
+
+ def __init__(self, list_sep, **kwargs):
+ """Initializes EnumClassListSerializer.
+
+ Args:
+ list_sep: String to be used as a separator when serializing
+ **kwargs: Keyword arguments to the `EnumClassSerializer` used to serialize
+ individual values.
+ """
+ super(EnumClassListSerializer, self).__init__(list_sep)
+ self._element_serializer = EnumClassSerializer(**kwargs)
+
+ def serialize(self, value):
+ """See base class."""
+ if isinstance(value, list):
+ return self.list_sep.join(
+ self._element_serializer.serialize(x) for x in value)
+ else:
+ return self._element_serializer.serialize(value)
+
+
+class CsvListSerializer(ArgumentSerializer):
+
+ def __init__(self, list_sep):
+ self.list_sep = list_sep
+
+ def serialize(self, value):
+ """Serializes a list as a CSV string or unicode."""
+ output = io.StringIO()
+ writer = csv.writer(output, delimiter=self.list_sep)
+ writer.writerow([str(x) for x in value])
+ serialized_value = output.getvalue().strip()
+
+ # We need the returned value to be pure ascii or Unicodes so that
+ # when the xml help is generated they are usefully encodable.
+ return _helpers.str_or_unicode(serialized_value)
+
+
+class EnumClassSerializer(ArgumentSerializer):
+ """Class for generating string representations of an enum class flag value."""
+
+ def __init__(self, lowercase):
+ """Initializes EnumClassSerializer.
+
+ Args:
+ lowercase: If True, enum member names are lowercased during serialization.
+ """
+ self._lowercase = lowercase
+
+ def serialize(self, value):
+ """Returns a serialized string of the Enum class value."""
+ as_string = _helpers.str_or_unicode(value.name)
+ return as_string.lower() if self._lowercase else as_string
+
+
+class BaseListParser(ArgumentParser):
+ """Base class for a parser of lists of strings.
+
+ To extend, inherit from this class; from the subclass __init__, call
+
+ BaseListParser.__init__(self, token, name)
+
+ where token is a character used to tokenize, and name is a description
+ of the separator.
+ """
+
+ def __init__(self, token=None, name=None):
+ assert name
+ super(BaseListParser, self).__init__()
+ self._token = token
+ self._name = name
+ self.syntactic_help = 'a %s separated list' % self._name
+
+ def parse(self, argument):
+ """See base class."""
+ if isinstance(argument, list):
+ return argument
+ elif not argument:
+ return []
+ else:
+ return [s.strip() for s in argument.split(self._token)]
+
+ def flag_type(self):
+ """See base class."""
+ return '%s separated list of strings' % self._name
+
+
+class ListParser(BaseListParser):
+ """Parser for a comma-separated list of strings."""
+
+ def __init__(self):
+ super(ListParser, self).__init__(',', 'comma')
+
+ def parse(self, argument):
+ """Parses argument as comma-separated list of strings."""
+ if isinstance(argument, list):
+ return argument
+ elif not argument:
+ return []
+ else:
+ try:
+ return [s.strip() for s in list(csv.reader([argument], strict=True))[0]]
+ except csv.Error as e:
+ # Provide a helpful report for case like
+ # --listflag="$(printf 'hello,\nworld')"
+ # IOW, list flag values containing naked newlines. This error
+ # was previously "reported" by allowing csv.Error to
+ # propagate.
+ raise ValueError('Unable to parse the value %r as a %s: %s'
+ % (argument, self.flag_type(), e))
+
+ def _custom_xml_dom_elements(self, doc):
+ elements = super(ListParser, self)._custom_xml_dom_elements(doc)
+ elements.append(_helpers.create_xml_dom_element(
+ doc, 'list_separator', repr(',')))
+ return elements
+
+
+class WhitespaceSeparatedListParser(BaseListParser):
+ """Parser for a whitespace-separated list of strings."""
+
+ def __init__(self, comma_compat=False):
+ """Initializer.
+
+ Args:
+ comma_compat: bool, whether to support comma as an additional separator.
+ If False then only whitespace is supported. This is intended only for
+ backwards compatibility with flags that used to be comma-separated.
+ """
+ self._comma_compat = comma_compat
+ name = 'whitespace or comma' if self._comma_compat else 'whitespace'
+ super(WhitespaceSeparatedListParser, self).__init__(None, name)
+
+ def parse(self, argument):
+ """Parses argument as whitespace-separated list of strings.
+
+ It also parses argument as comma-separated list of strings if requested.
+
+ Args:
+ argument: string argument passed in the commandline.
+
+ Returns:
+ [str], the parsed flag value.
+ """
+ if isinstance(argument, list):
+ return argument
+ elif not argument:
+ return []
+ else:
+ if self._comma_compat:
+ argument = argument.replace(',', ' ')
+ return argument.split()
+
+ def _custom_xml_dom_elements(self, doc):
+ elements = super(WhitespaceSeparatedListParser, self
+ )._custom_xml_dom_elements(doc)
+ separators = list(string.whitespace)
+ if self._comma_compat:
+ separators.append(',')
+ separators.sort()
+ for sep_char in separators:
+ elements.append(_helpers.create_xml_dom_element(
+ doc, 'list_separator', repr(sep_char)))
+ return elements
diff --git a/absl/flags/_argument_parser.pyi b/absl/flags/_argument_parser.pyi
new file mode 100644
index 0000000..7e78d7d
--- /dev/null
+++ b/absl/flags/_argument_parser.pyi
@@ -0,0 +1,127 @@
+# Copyright 2020 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Contains type annotations for _argument_parser.py."""
+
+
+from typing import Text, TypeVar, Generic, Iterable, Type, List, Optional, Sequence, Any
+
+import enum
+
+_T = TypeVar('_T')
+_ET = TypeVar('_ET', bound=enum.Enum)
+
+
+class ArgumentSerializer(Generic[_T]):
+ def serialize(self, value: _T) -> Text: ...
+
+
+# The metaclass of ArgumentParser is not reflected here, because it does not
+# affect the provided API.
+class ArgumentParser(Generic[_T]):
+
+ syntactic_help: Text
+
+ def parse(self, argument: Text) -> Optional[_T]: ...
+
+ def flag_type(self) -> Text: ...
+
+
+# Using bound=numbers.Number results in an error: b/153268436
+_N = TypeVar('_N', int, float)
+
+
+class NumericParser(ArgumentParser[_N]):
+
+ def is_outside_bounds(self, val: _N) -> bool: ...
+
+ def parse(self, argument: Text) -> _N: ...
+
+ def convert(self, argument: Text) -> _N: ...
+
+
+class FloatParser(NumericParser[float]):
+
+ def __init__(self, lower_bound:Optional[float]=None,
+ upper_bound:Optional[float]=None) -> None:
+ ...
+
+
+class IntegerParser(NumericParser[int]):
+
+ def __init__(self, lower_bound:Optional[int]=None,
+ upper_bound:Optional[int]=None) -> None:
+ ...
+
+
+class BooleanParser(ArgumentParser[bool]):
+ ...
+
+
+class EnumParser(ArgumentParser[Text]):
+ def __init__(self, enum_values: Sequence[Text], case_sensitive: bool=...) -> None:
+ ...
+
+
+
+class EnumClassParser(ArgumentParser[_ET]):
+
+ def __init__(self, enum_class: Type[_ET], case_sensitive: bool=...) -> None:
+ ...
+
+ @property
+ def member_names(self) -> Sequence[Text]: ...
+
+
+class BaseListParser(ArgumentParser[List[Text]]):
+ def __init__(self, token: Text, name:Text) -> None: ...
+
+ # Unlike baseclass BaseListParser never returns None.
+ def parse(self, argument: Text) -> List[Text]: ...
+
+
+
+class ListParser(BaseListParser):
+ def __init__(self) -> None:
+ ...
+
+
+
+class WhitespaceSeparatedListParser(BaseListParser):
+ def __init__(self, comma_compat: bool=False) -> None:
+ ...
+
+
+
+class ListSerializer(ArgumentSerializer[List[Text]]):
+ list_sep = ... # type: Text
+
+ def __init__(self, list_sep: Text) -> None:
+ ...
+
+
+class EnumClassListSerializer(ArgumentSerializer[List[Text]]):
+ def __init__(self, list_sep: Text, **kwargs: Any) -> None:
+ ...
+
+
+class CsvListSerializer(ArgumentSerializer[List[Any]]):
+
+ def __init__(self, list_sep: Text) -> None:
+ ...
+
+
+class EnumClassSerializer(ArgumentSerializer[_ET]):
+ def __init__(self, lowercase: bool) -> None:
+ ...
diff --git a/absl/flags/_defines.py b/absl/flags/_defines.py
new file mode 100644
index 0000000..4494c3b
--- /dev/null
+++ b/absl/flags/_defines.py
@@ -0,0 +1,912 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This modules contains flags DEFINE functions.
+
+Do NOT import this module directly. Import the flags package and use the
+aliases defined at the package level instead.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+import types
+
+from absl.flags import _argument_parser
+from absl.flags import _exceptions
+from absl.flags import _flag
+from absl.flags import _flagvalues
+from absl.flags import _helpers
+from absl.flags import _validators
+
+# pylint: disable=unused-import
+try:
+ from typing import Text, List, Any
+except ImportError:
+ pass
+
+try:
+ import enum
+except ImportError:
+ pass
+# pylint: enable=unused-import
+
+_helpers.disclaim_module_ids.add(id(sys.modules[__name__]))
+
+
+def _register_bounds_validator_if_needed(parser, name, flag_values):
+ """Enforces lower and upper bounds for numeric flags.
+
+ Args:
+ parser: NumericParser (either FloatParser or IntegerParser), provides lower
+ and upper bounds, and help text to display.
+ name: str, name of the flag
+ flag_values: FlagValues.
+ """
+ if parser.lower_bound is not None or parser.upper_bound is not None:
+
+ def checker(value):
+ if value is not None and parser.is_outside_bounds(value):
+ message = '%s is not %s' % (value, parser.syntactic_help)
+ raise _exceptions.ValidationError(message)
+ return True
+
+ _validators.register_validator(name, checker, flag_values=flag_values)
+
+
+def DEFINE( # pylint: disable=invalid-name
+ parser,
+ name,
+ default,
+ help, # pylint: disable=redefined-builtin
+ flag_values=_flagvalues.FLAGS,
+ serializer=None,
+ module_name=None,
+ required=False,
+ **args):
+ """Registers a generic Flag object.
+
+ NOTE: in the docstrings of all DEFINE* functions, "registers" is short
+ for "creates a new flag and registers it".
+
+ Auxiliary function: clients should use the specialized DEFINE_<type>
+ function instead.
+
+ Args:
+ parser: ArgumentParser, used to parse the flag arguments.
+ name: str, the flag name.
+ default: The default value of the flag.
+ help: str, the help message.
+ flag_values: FlagValues, the FlagValues instance with which the flag will be
+ registered. This should almost never need to be overridden.
+ serializer: ArgumentSerializer, the flag serializer instance.
+ module_name: str, the name of the Python module declaring this flag. If not
+ provided, it will be computed using the stack trace of this call.
+ required: bool, is this a required flag. This must be used as a keyword
+ argument.
+ **args: dict, the extra keyword args that are passed to Flag __init__.
+
+ Returns:
+ a handle to defined flag.
+ """
+ return DEFINE_flag(
+ _flag.Flag(parser, serializer, name, default, help, **args), flag_values,
+ module_name, required)
+
+
+def DEFINE_flag( # pylint: disable=invalid-name
+ flag,
+ flag_values=_flagvalues.FLAGS,
+ module_name=None,
+ required=False):
+ """Registers a 'Flag' object with a 'FlagValues' object.
+
+ By default, the global FLAGS 'FlagValue' object is used.
+
+ Typical users will use one of the more specialized DEFINE_xxx
+ functions, such as DEFINE_string or DEFINE_integer. But developers
+ who need to create Flag objects themselves should use this function
+ to register their flags.
+
+ Args:
+ flag: Flag, a flag that is key to the module.
+ flag_values: FlagValues, the FlagValues instance with which the flag will be
+ registered. This should almost never need to be overridden.
+ module_name: str, the name of the Python module declaring this flag. If not
+ provided, it will be computed using the stack trace of this call.
+ required: bool, is this a required flag. This must be used as a keyword
+ argument.
+
+ Returns:
+ a handle to defined flag.
+ """
+ if required and flag.default is not None:
+ raise ValueError('Required flag --%s cannot have a non-None default' %
+ flag.name)
+ # Copying the reference to flag_values prevents pychecker warnings.
+ fv = flag_values
+ fv[flag.name] = flag
+ # Tell flag_values who's defining the flag.
+ if module_name:
+ module = sys.modules.get(module_name)
+ else:
+ module, module_name = _helpers.get_calling_module_object_and_name()
+ flag_values.register_flag_by_module(module_name, flag)
+ flag_values.register_flag_by_module_id(id(module), flag)
+ if required:
+ _validators.mark_flag_as_required(flag.name, fv)
+ ensure_non_none_value = (flag.default is not None) or required
+ return _flagvalues.FlagHolder(
+ fv, flag, ensure_non_none_value=ensure_non_none_value)
+
+
+def _internal_declare_key_flags(flag_names,
+ flag_values=_flagvalues.FLAGS,
+ key_flag_values=None):
+ """Declares a flag as key for the calling module.
+
+ Internal function. User code should call declare_key_flag or
+ adopt_module_key_flags instead.
+
+ Args:
+ flag_names: [str], a list of strings that are names of already-registered
+ Flag objects.
+ flag_values: FlagValues, the FlagValues instance with which the flags listed
+ in flag_names have registered (the value of the flag_values argument from
+ the DEFINE_* calls that defined those flags). This should almost never
+ need to be overridden.
+ key_flag_values: FlagValues, the FlagValues instance that (among possibly
+ many other things) keeps track of the key flags for each module. Default
+ None means "same as flag_values". This should almost never need to be
+ overridden.
+
+ Raises:
+ UnrecognizedFlagError: Raised when the flag is not defined.
+ """
+ key_flag_values = key_flag_values or flag_values
+
+ module = _helpers.get_calling_module()
+
+ for flag_name in flag_names:
+ flag = flag_values[flag_name]
+ key_flag_values.register_key_flag_for_module(module, flag)
+
+
+def declare_key_flag(flag_name, flag_values=_flagvalues.FLAGS):
+ """Declares one flag as key to the current module.
+
+ Key flags are flags that are deemed really important for a module.
+ They are important when listing help messages; e.g., if the
+ --helpshort command-line flag is used, then only the key flags of the
+ main module are listed (instead of all flags, as in the case of
+ --helpfull).
+
+ Sample usage:
+
+ flags.declare_key_flag('flag_1')
+
+ Args:
+ flag_name: str, the name of an already declared flag. (Redeclaring flags as
+ key, including flags implicitly key because they were declared in this
+ module, is a no-op.)
+ flag_values: FlagValues, the FlagValues instance in which the flag will be
+ declared as a key flag. This should almost never need to be overridden.
+
+ Raises:
+ ValueError: Raised if flag_name not defined as a Python flag.
+ """
+ if flag_name in _helpers.SPECIAL_FLAGS:
+ # Take care of the special flags, e.g., --flagfile, --undefok.
+ # These flags are defined in SPECIAL_FLAGS, and are treated
+ # specially during flag parsing, taking precedence over the
+ # user-defined flags.
+ _internal_declare_key_flags([flag_name],
+ flag_values=_helpers.SPECIAL_FLAGS,
+ key_flag_values=flag_values)
+ return
+ try:
+ _internal_declare_key_flags([flag_name], flag_values=flag_values)
+ except KeyError:
+ raise ValueError('Flag --%s is undefined. To set a flag as a key flag '
+ 'first define it in Python.' % flag_name)
+
+
+def adopt_module_key_flags(module, flag_values=_flagvalues.FLAGS):
+ """Declares that all flags key to a module are key to the current module.
+
+ Args:
+ module: module, the module object from which all key flags will be declared
+ as key flags to the current module.
+ flag_values: FlagValues, the FlagValues instance in which the flags will be
+ declared as key flags. This should almost never need to be overridden.
+
+ Raises:
+ Error: Raised when given an argument that is a module name (a string),
+ instead of a module object.
+ """
+ if not isinstance(module, types.ModuleType):
+ raise _exceptions.Error('Expected a module object, not %r.' % (module,))
+ _internal_declare_key_flags(
+ [f.name for f in flag_values.get_key_flags_for_module(module.__name__)],
+ flag_values=flag_values)
+ # If module is this flag module, take _helpers.SPECIAL_FLAGS into account.
+ if module == _helpers.FLAGS_MODULE:
+ _internal_declare_key_flags(
+ # As we associate flags with get_calling_module_object_and_name(), the
+ # special flags defined in this module are incorrectly registered with
+ # a different module. So, we can't use get_key_flags_for_module.
+ # Instead, we take all flags from _helpers.SPECIAL_FLAGS (a private
+ # FlagValues, where no other module should register flags).
+ [_helpers.SPECIAL_FLAGS[name].name for name in _helpers.SPECIAL_FLAGS],
+ flag_values=_helpers.SPECIAL_FLAGS,
+ key_flag_values=flag_values)
+
+
+def disclaim_key_flags():
+ """Declares that the current module will not define any more key flags.
+
+ Normally, the module that calls the DEFINE_xxx functions claims the
+ flag to be its key flag. This is undesirable for modules that
+ define additional DEFINE_yyy functions with its own flag parsers and
+ serializers, since that module will accidentally claim flags defined
+ by DEFINE_yyy as its key flags. After calling this function, the
+ module disclaims flag definitions thereafter, so the key flags will
+ be correctly attributed to the caller of DEFINE_yyy.
+
+ After calling this function, the module will not be able to define
+ any more flags. This function will affect all FlagValues objects.
+ """
+ globals_for_caller = sys._getframe(1).f_globals # pylint: disable=protected-access
+ module, _ = _helpers.get_module_object_and_name(globals_for_caller)
+ _helpers.disclaim_module_ids.add(id(module))
+
+
+def DEFINE_string( # pylint: disable=invalid-name,redefined-builtin
+ name,
+ default,
+ help,
+ flag_values=_flagvalues.FLAGS,
+ required=False,
+ **args):
+ """Registers a flag whose value can be any string."""
+ parser = _argument_parser.ArgumentParser()
+ serializer = _argument_parser.ArgumentSerializer()
+ return DEFINE(
+ parser,
+ name,
+ default,
+ help,
+ flag_values,
+ serializer,
+ required=required,
+ **args)
+
+
+def DEFINE_boolean( # pylint: disable=invalid-name,redefined-builtin
+ name,
+ default,
+ help,
+ flag_values=_flagvalues.FLAGS,
+ module_name=None,
+ required=False,
+ **args):
+ """Registers a boolean flag.
+
+ Such a boolean flag does not take an argument. If a user wants to
+ specify a false value explicitly, the long option beginning with 'no'
+ must be used: i.e. --noflag
+
+ This flag will have a value of None, True or False. None is possible
+ if default=None and the user does not specify the flag on the command
+ line.
+
+ Args:
+ name: str, the flag name.
+ default: bool|str|None, the default value of the flag.
+ help: str, the help message.
+ flag_values: FlagValues, the FlagValues instance with which the flag will be
+ registered. This should almost never need to be overridden.
+ module_name: str, the name of the Python module declaring this flag. If not
+ provided, it will be computed using the stack trace of this call.
+ required: bool, is this a required flag. This must be used as a keyword
+ argument.
+ **args: dict, the extra keyword args that are passed to Flag __init__.
+
+ Returns:
+ a handle to defined flag.
+ """
+ return DEFINE_flag(
+ _flag.BooleanFlag(name, default, help, **args), flag_values, module_name,
+ required)
+
+
+def DEFINE_float( # pylint: disable=invalid-name,redefined-builtin
+ name,
+ default,
+ help,
+ lower_bound=None,
+ upper_bound=None,
+ flag_values=_flagvalues.FLAGS,
+ required=False,
+ **args):
+ """Registers a flag whose value must be a float.
+
+ If lower_bound or upper_bound are set, then this flag must be
+ within the given range.
+
+ Args:
+ name: str, the flag name.
+ default: float|str|None, the default value of the flag.
+ help: str, the help message.
+ lower_bound: float, min value of the flag.
+ upper_bound: float, max value of the flag.
+ flag_values: FlagValues, the FlagValues instance with which the flag will be
+ registered. This should almost never need to be overridden.
+ required: bool, is this a required flag. This must be used as a keyword
+ argument.
+ **args: dict, the extra keyword args that are passed to DEFINE.
+
+ Returns:
+ a handle to defined flag.
+ """
+ parser = _argument_parser.FloatParser(lower_bound, upper_bound)
+ serializer = _argument_parser.ArgumentSerializer()
+ result = DEFINE(
+ parser,
+ name,
+ default,
+ help,
+ flag_values,
+ serializer,
+ required=required,
+ **args)
+ _register_bounds_validator_if_needed(parser, name, flag_values=flag_values)
+ return result
+
+
+def DEFINE_integer( # pylint: disable=invalid-name,redefined-builtin
+ name,
+ default,
+ help,
+ lower_bound=None,
+ upper_bound=None,
+ flag_values=_flagvalues.FLAGS,
+ required=False,
+ **args):
+ """Registers a flag whose value must be an integer.
+
+ If lower_bound, or upper_bound are set, then this flag must be
+ within the given range.
+
+ Args:
+ name: str, the flag name.
+ default: int|str|None, the default value of the flag.
+ help: str, the help message.
+ lower_bound: int, min value of the flag.
+ upper_bound: int, max value of the flag.
+ flag_values: FlagValues, the FlagValues instance with which the flag will be
+ registered. This should almost never need to be overridden.
+ required: bool, is this a required flag. This must be used as a keyword
+ argument.
+ **args: dict, the extra keyword args that are passed to DEFINE.
+
+ Returns:
+ a handle to defined flag.
+ """
+ parser = _argument_parser.IntegerParser(lower_bound, upper_bound)
+ serializer = _argument_parser.ArgumentSerializer()
+ result = DEFINE(
+ parser,
+ name,
+ default,
+ help,
+ flag_values,
+ serializer,
+ required=required,
+ **args)
+ _register_bounds_validator_if_needed(parser, name, flag_values=flag_values)
+ return result
+
+
+def DEFINE_enum( # pylint: disable=invalid-name,redefined-builtin
+ name,
+ default,
+ enum_values,
+ help,
+ flag_values=_flagvalues.FLAGS,
+ module_name=None,
+ required=False,
+ **args):
+ """Registers a flag whose value can be any string from enum_values.
+
+ Instead of a string enum, prefer `DEFINE_enum_class`, which allows
+ defining enums from an `enum.Enum` class.
+
+ Args:
+ name: str, the flag name.
+ default: str|None, the default value of the flag.
+ enum_values: [str], a non-empty list of strings with the possible values for
+ the flag.
+ help: str, the help message.
+ flag_values: FlagValues, the FlagValues instance with which the flag will be
+ registered. This should almost never need to be overridden.
+ module_name: str, the name of the Python module declaring this flag. If not
+ provided, it will be computed using the stack trace of this call.
+ required: bool, is this a required flag. This must be used as a keyword
+ argument.
+ **args: dict, the extra keyword args that are passed to Flag __init__.
+
+ Returns:
+ a handle to defined flag.
+ """
+ return DEFINE_flag(
+ _flag.EnumFlag(name, default, help, enum_values, **args), flag_values,
+ module_name, required)
+
+
+def DEFINE_enum_class( # pylint: disable=invalid-name,redefined-builtin
+ name,
+ default,
+ enum_class,
+ help,
+ flag_values=_flagvalues.FLAGS,
+ module_name=None,
+ case_sensitive=False,
+ required=False,
+ **args):
+ """Registers a flag whose value can be the name of enum members.
+
+ Args:
+ name: str, the flag name.
+ default: Enum|str|None, the default value of the flag.
+ enum_class: class, the Enum class with all the possible values for the flag.
+ help: str, the help message.
+ flag_values: FlagValues, the FlagValues instance with which the flag will be
+ registered. This should almost never need to be overridden.
+ module_name: str, the name of the Python module declaring this flag. If not
+ provided, it will be computed using the stack trace of this call.
+ case_sensitive: bool, whether to map strings to members of the enum_class
+ without considering case.
+ required: bool, is this a required flag. This must be used as a keyword
+ argument.
+ **args: dict, the extra keyword args that are passed to Flag __init__.
+
+ Returns:
+ a handle to defined flag.
+ """
+ return DEFINE_flag(
+ _flag.EnumClassFlag(
+ name,
+ default,
+ help,
+ enum_class,
+ case_sensitive=case_sensitive,
+ **args), flag_values, module_name, required)
+
+
+def DEFINE_list( # pylint: disable=invalid-name,redefined-builtin
+ name,
+ default,
+ help,
+ flag_values=_flagvalues.FLAGS,
+ required=False,
+ **args):
+ """Registers a flag whose value is a comma-separated list of strings.
+
+ The flag value is parsed with a CSV parser.
+
+ Args:
+ name: str, the flag name.
+ default: list|str|None, the default value of the flag.
+ help: str, the help message.
+ flag_values: FlagValues, the FlagValues instance with which the flag will be
+ registered. This should almost never need to be overridden.
+ required: bool, is this a required flag. This must be used as a keyword
+ argument.
+ **args: Dictionary with extra keyword args that are passed to the Flag
+ __init__.
+
+ Returns:
+ a handle to defined flag.
+ """
+ parser = _argument_parser.ListParser()
+ serializer = _argument_parser.CsvListSerializer(',')
+ return DEFINE(
+ parser,
+ name,
+ default,
+ help,
+ flag_values,
+ serializer,
+ required=required,
+ **args)
+
+
+def DEFINE_spaceseplist( # pylint: disable=invalid-name,redefined-builtin
+ name,
+ default,
+ help,
+ comma_compat=False,
+ flag_values=_flagvalues.FLAGS,
+ required=False,
+ **args):
+ """Registers a flag whose value is a whitespace-separated list of strings.
+
+ Any whitespace can be used as a separator.
+
+ Args:
+ name: str, the flag name.
+ default: list|str|None, the default value of the flag.
+ help: str, the help message.
+ comma_compat: bool - Whether to support comma as an additional separator. If
+ false then only whitespace is supported. This is intended only for
+ backwards compatibility with flags that used to be comma-separated.
+ flag_values: FlagValues, the FlagValues instance with which the flag will be
+ registered. This should almost never need to be overridden.
+ required: bool, is this a required flag. This must be used as a keyword
+ argument.
+ **args: Dictionary with extra keyword args that are passed to the Flag
+ __init__.
+
+ Returns:
+ a handle to defined flag.
+ """
+ parser = _argument_parser.WhitespaceSeparatedListParser(
+ comma_compat=comma_compat)
+ serializer = _argument_parser.ListSerializer(' ')
+ return DEFINE(
+ parser,
+ name,
+ default,
+ help,
+ flag_values,
+ serializer,
+ required=required,
+ **args)
+
+
+def DEFINE_multi( # pylint: disable=invalid-name,redefined-builtin
+ parser,
+ serializer,
+ name,
+ default,
+ help,
+ flag_values=_flagvalues.FLAGS,
+ module_name=None,
+ required=False,
+ **args):
+ """Registers a generic MultiFlag that parses its args with a given parser.
+
+ Auxiliary function. Normal users should NOT use it directly.
+
+ Developers who need to create their own 'Parser' classes for options
+ which can appear multiple times can call this module function to
+ register their flags.
+
+ Args:
+ parser: ArgumentParser, used to parse the flag arguments.
+ serializer: ArgumentSerializer, the flag serializer instance.
+ name: str, the flag name.
+ default: Union[Iterable[T], Text, None], the default value of the flag. If
+ the value is text, it will be parsed as if it was provided from the
+ command line. If the value is a non-string iterable, it will be iterated
+ over to create a shallow copy of the values. If it is None, it is left
+ as-is.
+ help: str, the help message.
+ flag_values: FlagValues, the FlagValues instance with which the flag will be
+ registered. This should almost never need to be overridden.
+ module_name: A string, the name of the Python module declaring this flag. If
+ not provided, it will be computed using the stack trace of this call.
+ required: bool, is this a required flag. This must be used as a keyword
+ argument.
+ **args: Dictionary with extra keyword args that are passed to the Flag
+ __init__.
+
+ Returns:
+ a handle to defined flag.
+ """
+ return DEFINE_flag(
+ _flag.MultiFlag(parser, serializer, name, default, help, **args),
+ flag_values, module_name, required)
+
+
+def DEFINE_multi_string( # pylint: disable=invalid-name,redefined-builtin
+ name,
+ default,
+ help,
+ flag_values=_flagvalues.FLAGS,
+ required=False,
+ **args):
+ """Registers a flag whose value can be a list of any strings.
+
+ Use the flag on the command line multiple times to place multiple
+ string values into the list. The 'default' may be a single string
+ (which will be converted into a single-element list) or a list of
+ strings.
+
+
+ Args:
+ name: str, the flag name.
+ default: Union[Iterable[Text], Text, None], the default value of the flag;
+ see `DEFINE_multi`.
+ help: str, the help message.
+ flag_values: FlagValues, the FlagValues instance with which the flag will be
+ registered. This should almost never need to be overridden.
+ required: bool, is this a required flag. This must be used as a keyword
+ argument.
+ **args: Dictionary with extra keyword args that are passed to the Flag
+ __init__.
+
+ Returns:
+ a handle to defined flag.
+ """
+ parser = _argument_parser.ArgumentParser()
+ serializer = _argument_parser.ArgumentSerializer()
+ return DEFINE_multi(
+ parser,
+ serializer,
+ name,
+ default,
+ help,
+ flag_values,
+ required=required,
+ **args)
+
+
+def DEFINE_multi_integer( # pylint: disable=invalid-name,redefined-builtin
+ name,
+ default,
+ help,
+ lower_bound=None,
+ upper_bound=None,
+ flag_values=_flagvalues.FLAGS,
+ required=False,
+ **args):
+ """Registers a flag whose value can be a list of arbitrary integers.
+
+ Use the flag on the command line multiple times to place multiple
+ integer values into the list. The 'default' may be a single integer
+ (which will be converted into a single-element list) or a list of
+ integers.
+
+ Args:
+ name: str, the flag name.
+ default: Union[Iterable[int], Text, None], the default value of the flag;
+ see `DEFINE_multi`.
+ help: str, the help message.
+ lower_bound: int, min values of the flag.
+ upper_bound: int, max values of the flag.
+ flag_values: FlagValues, the FlagValues instance with which the flag will be
+ registered. This should almost never need to be overridden.
+ required: bool, is this a required flag. This must be used as a keyword
+ argument.
+ **args: Dictionary with extra keyword args that are passed to the Flag
+ __init__.
+
+ Returns:
+ a handle to defined flag.
+ """
+ parser = _argument_parser.IntegerParser(lower_bound, upper_bound)
+ serializer = _argument_parser.ArgumentSerializer()
+ return DEFINE_multi(
+ parser,
+ serializer,
+ name,
+ default,
+ help,
+ flag_values,
+ required=required,
+ **args)
+
+
+def DEFINE_multi_float( # pylint: disable=invalid-name,redefined-builtin
+ name,
+ default,
+ help,
+ lower_bound=None,
+ upper_bound=None,
+ flag_values=_flagvalues.FLAGS,
+ required=False,
+ **args):
+ """Registers a flag whose value can be a list of arbitrary floats.
+
+ Use the flag on the command line multiple times to place multiple
+ float values into the list. The 'default' may be a single float
+ (which will be converted into a single-element list) or a list of
+ floats.
+
+ Args:
+ name: str, the flag name.
+ default: Union[Iterable[float], Text, None], the default value of the flag;
+ see `DEFINE_multi`.
+ help: str, the help message.
+ lower_bound: float, min values of the flag.
+ upper_bound: float, max values of the flag.
+ flag_values: FlagValues, the FlagValues instance with which the flag will be
+ registered. This should almost never need to be overridden.
+ required: bool, is this a required flag. This must be used as a keyword
+ argument.
+ **args: Dictionary with extra keyword args that are passed to the Flag
+ __init__.
+
+ Returns:
+ a handle to defined flag.
+ """
+ parser = _argument_parser.FloatParser(lower_bound, upper_bound)
+ serializer = _argument_parser.ArgumentSerializer()
+ return DEFINE_multi(
+ parser,
+ serializer,
+ name,
+ default,
+ help,
+ flag_values,
+ required=required,
+ **args)
+
+
+def DEFINE_multi_enum( # pylint: disable=invalid-name,redefined-builtin
+ name,
+ default,
+ enum_values,
+ help,
+ flag_values=_flagvalues.FLAGS,
+ case_sensitive=True,
+ required=False,
+ **args):
+ """Registers a flag whose value can be a list strings from enum_values.
+
+ Use the flag on the command line multiple times to place multiple
+ enum values into the list. The 'default' may be a single string
+ (which will be converted into a single-element list) or a list of
+ strings.
+
+ Args:
+ name: str, the flag name.
+ default: Union[Iterable[Text], Text, None], the default value of the flag;
+ see `DEFINE_multi`.
+ enum_values: [str], a non-empty list of strings with the possible values for
+ the flag.
+ help: str, the help message.
+ flag_values: FlagValues, the FlagValues instance with which the flag will be
+ registered. This should almost never need to be overridden.
+ case_sensitive: Whether or not the enum is to be case-sensitive.
+ required: bool, is this a required flag. This must be used as a keyword
+ argument.
+ **args: Dictionary with extra keyword args that are passed to the Flag
+ __init__.
+
+ Returns:
+ a handle to defined flag.
+ """
+ parser = _argument_parser.EnumParser(enum_values, case_sensitive)
+ serializer = _argument_parser.ArgumentSerializer()
+ return DEFINE_multi(
+ parser,
+ serializer,
+ name,
+ default,
+ '<%s>: %s' % ('|'.join(enum_values), help),
+ flag_values,
+ required=required,
+ **args)
+
+
+def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin
+ name,
+ default,
+ enum_class,
+ help,
+ flag_values=_flagvalues.FLAGS,
+ module_name=None,
+ case_sensitive=False,
+ required=False,
+ **args):
+ """Registers a flag whose value can be a list of enum members.
+
+ Use the flag on the command line multiple times to place multiple
+ enum values into the list.
+
+ Args:
+ name: str, the flag name.
+ default: Union[Iterable[Enum], Iterable[Text], Enum, Text, None], the
+ default value of the flag; see `DEFINE_multi`; only differences are
+ documented here. If the value is a single Enum, it is treated as a
+ single-item list of that Enum value. If it is an iterable, text values
+ within the iterable will be converted to the equivalent Enum objects.
+ enum_class: class, the Enum class with all the possible values for the flag.
+ help: str, the help message.
+ flag_values: FlagValues, the FlagValues instance with which the flag will be
+ registered. This should almost never need to be overridden.
+ module_name: A string, the name of the Python module declaring this flag. If
+ not provided, it will be computed using the stack trace of this call.
+ case_sensitive: bool, whether to map strings to members of the enum_class
+ without considering case.
+ required: bool, is this a required flag. This must be used as a keyword
+ argument.
+ **args: Dictionary with extra keyword args that are passed to the Flag
+ __init__.
+
+ Returns:
+ a handle to defined flag.
+ """
+ return DEFINE_flag(
+ _flag.MultiEnumClassFlag(
+ name, default, help, enum_class, case_sensitive=case_sensitive),
+ flag_values,
+ module_name,
+ required=required,
+ **args)
+
+
+def DEFINE_alias( # pylint: disable=invalid-name
+ name,
+ original_name,
+ flag_values=_flagvalues.FLAGS,
+ module_name=None):
+ """Defines an alias flag for an existing one.
+
+ Args:
+ name: str, the flag name.
+ original_name: str, the original flag name.
+ flag_values: FlagValues, the FlagValues instance with which the flag will be
+ registered. This should almost never need to be overridden.
+ module_name: A string, the name of the module that defines this flag.
+
+ Returns:
+ a handle to defined flag.
+
+ Raises:
+ flags.FlagError:
+ UnrecognizedFlagError: if the referenced flag doesn't exist.
+ DuplicateFlagError: if the alias name has been used by some existing flag.
+ """
+ if original_name not in flag_values:
+ raise _exceptions.UnrecognizedFlagError(original_name)
+ flag = flag_values[original_name]
+
+ class _FlagAlias(_flag.Flag):
+ """Overrides Flag class so alias value is copy of original flag value."""
+
+ def parse(self, argument):
+ flag.parse(argument)
+ self.present += 1
+
+ def _parse_from_default(self, value):
+ # The value was already parsed by the aliased flag, so there is no
+ # need to call the parser on it a second time.
+ # Additionally, because of how MultiFlag parses and merges values,
+ # it isn't possible to delegate to the aliased flag and still get
+ # the correct values.
+ return value
+
+ @property
+ def value(self):
+ return flag.value
+
+ @value.setter
+ def value(self, value):
+ flag.value = value
+
+ help_msg = 'Alias for --%s.' % flag.name
+ # If alias_name has been used, flags.DuplicatedFlag will be raised.
+ return DEFINE_flag(
+ _FlagAlias(
+ flag.parser,
+ flag.serializer,
+ name,
+ flag.default,
+ help_msg,
+ boolean=flag.boolean), flag_values, module_name)
diff --git a/absl/flags/_defines.pyi b/absl/flags/_defines.pyi
new file mode 100644
index 0000000..1f482a2
--- /dev/null
+++ b/absl/flags/_defines.pyi
@@ -0,0 +1,637 @@
+# Copyright 2020 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This modules contains type annotated stubs for DEFINE functions."""
+
+
+from absl.flags import _argument_parser
+from absl.flags import _flag
+from absl.flags import _flagvalues
+
+import enum
+
+from typing import Text, List, Any, TypeVar, Optional, Union, Type, Iterable, overload, Literal
+
+_T = TypeVar('_T')
+_ET = TypeVar('_ET', bound=enum.Enum)
+
+
+@overload
+def DEFINE(
+ parser: _argument_parser.ArgumentParser[_T],
+ name: Text,
+ default: Any,
+ help: Optional[Text],
+ flag_values : _flagvalues.FlagValues = ...,
+ serializer: Optional[_argument_parser.ArgumentSerializer[_T]] = ...,
+ module_name: Optional[Text] = ...,
+ required: Literal[True] = ...,
+ **args: Any) -> _flagvalues.FlagHolder[_T]:
+ ...
+
+
+@overload
+def DEFINE(
+ parser: _argument_parser.ArgumentParser[_T],
+ name: Text,
+ default: Any,
+ help: Optional[Text],
+ flag_values : _flagvalues.FlagValues = ...,
+ serializer: Optional[_argument_parser.ArgumentSerializer[_T]] = ...,
+ module_name: Optional[Text] = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[Optional[_T]]:
+ ...
+
+
+@overload
+def DEFINE_flag(
+ flag: _flag.Flag[_T],
+ flag_values: _flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = ...,
+ required: Literal[True] = ...
+) -> _flagvalues.FlagHolder[_T]:
+ ...
+
+@overload
+def DEFINE_flag(
+ flag: _flag.Flag[_T],
+ flag_values: _flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = ...,
+ required: bool = ...) -> _flagvalues.FlagHolder[Optional[_T]]:
+ ...
+
+# typing overloads for DEFINE_* methods...
+#
+# - DEFINE_* method return FlagHolder[Optional[T]] or FlagHolder[T] depending
+# on the arguments.
+# - If the flag value is guaranteed to be not None, the return type is
+# FlagHolder[T].
+# - If the flag is required OR has a non-None default, the flag value i
+# guaranteed to be not None after flag parsing has finished.
+# The information above is captured with three overloads as follows.
+#
+# (if required=True and passed in as a keyword argument,
+# return type is FlagHolder[Y])
+# @overload
+# def DEFINE_xxx(
+# ... arguments...
+# default: Union[None, X] = ...,
+# *,
+# required: Literal[True]) -> _flagvalues.FlagHolder[Y]:
+# ...
+#
+# (if default=None, return type is FlagHolder[Optional[Y]])
+# @overload
+# def DEFINE_xxx(
+# ... arguments...
+# default: None,
+# required: bool = ...) -> _flagvalues.FlagHolder[Optional[Y]]:
+# ...
+#
+# (if default!=None, return type is FlagHolder[Y]):
+# @overload
+# def DEFINE_xxx(
+# ... arguments...
+# default: X,
+# required: bool = ...) -> _flagvalues.FlagHolder[Y]:
+# ...
+#
+# where X = type of non-None default values for the flag
+# and Y = non-None type for flag value
+
+@overload
+def DEFINE_string(
+ name: Text,
+ default: Optional[Text],
+ help: Optional[Text],
+ flag_values: _flagvalues.FlagValues = ...,
+ *,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[Text]:
+ ...
+
+@overload
+def DEFINE_string(
+ name: Text,
+ default: None,
+ help: Optional[Text],
+ flag_values: _flagvalues.FlagValues = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[Optional[Text]]:
+ ...
+
+@overload
+def DEFINE_string(
+ name: Text,
+ default: Text,
+ help: Optional[Text],
+ flag_values: _flagvalues.FlagValues = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[Text]:
+ ...
+
+@overload
+def DEFINE_boolean(
+ name : Text,
+ default: Union[None, Text, bool, int],
+ help: Optional[Text],
+ flag_values: _flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = ...,
+ *,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[bool]:
+ ...
+
+@overload
+def DEFINE_boolean(
+ name : Text,
+ default: None,
+ help: Optional[Text],
+ flag_values: _flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[Optional[bool]]:
+ ...
+
+@overload
+def DEFINE_boolean(
+ name : Text,
+ default: Union[Text, bool, int],
+ help: Optional[Text],
+ flag_values: _flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[bool]:
+ ...
+
+@overload
+def DEFINE_float(
+ name: Text,
+ default: Union[None, float, Text],
+ help: Optional[Text],
+ lower_bound: Optional[float] = ...,
+ upper_bound: Optional[float] = ...,
+ flag_values: _flagvalues.FlagValues = ...,
+ *,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[float]:
+ ...
+
+@overload
+def DEFINE_float(
+ name: Text,
+ default: None,
+ help: Optional[Text],
+ lower_bound: Optional[float] = ...,
+ upper_bound: Optional[float] = ...,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[Optional[float]]:
+ ...
+
+@overload
+def DEFINE_float(
+ name: Text,
+ default: Union[float, Text],
+ help: Optional[Text],
+ lower_bound: Optional[float] = ...,
+ upper_bound: Optional[float] = ...,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[float]:
+ ...
+
+
+@overload
+def DEFINE_integer(
+ name: Text,
+ default: Union[None, int, Text],
+ help: Optional[Text],
+ lower_bound: Optional[int] = ...,
+ upper_bound: Optional[int] = ...,
+ flag_values: _flagvalues.FlagValues = ...,
+ *,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[int]:
+ ...
+
+@overload
+def DEFINE_integer(
+ name: Text,
+ default: None,
+ help: Optional[Text],
+ lower_bound: Optional[int] = ...,
+ upper_bound: Optional[int] = ...,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[Optional[int]]:
+ ...
+
+@overload
+def DEFINE_integer(
+ name: Text,
+ default: Union[int, Text],
+ help: Optional[Text],
+ lower_bound: Optional[int] = ...,
+ upper_bound: Optional[int] = ...,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[int]:
+ ...
+
+@overload
+def DEFINE_enum(
+ name : Text,
+ default: Optional[Text],
+ enum_values: Iterable[Text],
+ help: Optional[Text],
+ flag_values: _flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = ...,
+ *,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[Text]:
+ ...
+
+@overload
+def DEFINE_enum(
+ name : Text,
+ default: None,
+ enum_values: Iterable[Text],
+ help: Optional[Text],
+ flag_values: _flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[Optional[Text]]:
+ ...
+
+@overload
+def DEFINE_enum(
+ name : Text,
+ default: Text,
+ enum_values: Iterable[Text],
+ help: Optional[Text],
+ flag_values: _flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[Text]:
+ ...
+
+@overload
+def DEFINE_enum_class(
+ name: Text,
+ default: Union[None, _ET, Text],
+ enum_class: Type[_ET],
+ help: Optional[Text],
+ flag_values: _flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = ...,
+ case_sensitive: bool = ...,
+ *,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[_ET]:
+ ...
+
+@overload
+def DEFINE_enum_class(
+ name: Text,
+ default: None,
+ enum_class: Type[_ET],
+ help: Optional[Text],
+ flag_values: _flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = ...,
+ case_sensitive: bool = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[Optional[_ET]]:
+ ...
+
+@overload
+def DEFINE_enum_class(
+ name: Text,
+ default: Union[_ET, Text],
+ enum_class: Type[_ET],
+ help: Optional[Text],
+ flag_values: _flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = ...,
+ case_sensitive: bool = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[_ET]:
+ ...
+
+
+@overload
+def DEFINE_list(
+ name: Text,
+ default: Union[None, Iterable[Text], Text],
+ help: Text,
+ flag_values: _flagvalues.FlagValues = ...,
+ *,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[List[Text]]:
+ ...
+
+@overload
+def DEFINE_list(
+ name: Text,
+ default: None,
+ help: Text,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[Optional[List[Text]]]:
+ ...
+
+@overload
+def DEFINE_list(
+ name: Text,
+ default: Union[Iterable[Text], Text],
+ help: Text,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[List[Text]]:
+ ...
+
+@overload
+def DEFINE_spaceseplist(
+ name: Text,
+ default: Union[None, Iterable[Text], Text],
+ help: Text,
+ comma_compat: bool = ...,
+ flag_values: _flagvalues.FlagValues = ...,
+ *,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[List[Text]]:
+ ...
+
+@overload
+def DEFINE_spaceseplist(
+ name: Text,
+ default: None,
+ help: Text,
+ comma_compat: bool = ...,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[Optional[List[Text]]]:
+ ...
+
+@overload
+def DEFINE_spaceseplist(
+ name: Text,
+ default: Union[Iterable[Text], Text],
+ help: Text,
+ comma_compat: bool = ...,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[List[Text]]:
+ ...
+
+@overload
+def DEFINE_multi(
+ parser : _argument_parser.ArgumentParser[_T],
+ serializer: _argument_parser.ArgumentSerializer[_T],
+ name: Text,
+ default: Union[None, Iterable[_T], _T, Text],
+ help: Text,
+ flag_values:_flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = ...,
+ *,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[List[_T]]:
+ ...
+
+@overload
+def DEFINE_multi(
+ parser : _argument_parser.ArgumentParser[_T],
+ serializer: _argument_parser.ArgumentSerializer[_T],
+ name: Text,
+ default: None,
+ help: Text,
+ flag_values:_flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[Optional[List[_T]]]:
+ ...
+
+@overload
+def DEFINE_multi(
+ parser : _argument_parser.ArgumentParser[_T],
+ serializer: _argument_parser.ArgumentSerializer[_T],
+ name: Text,
+ default: Union[Iterable[_T], _T, Text],
+ help: Text,
+ flag_values:_flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[List[_T]]:
+ ...
+
+@overload
+def DEFINE_multi_string(
+ name: Text,
+ default: Union[None, Iterable[Text], Text],
+ help: Text,
+ flag_values: _flagvalues.FlagValues = ...,
+ *,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[List[Text]]:
+ ...
+
+@overload
+def DEFINE_multi_string(
+ name: Text,
+ default: None,
+ help: Text,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[Optional[List[Text]]]:
+ ...
+
+@overload
+def DEFINE_multi_string(
+ name: Text,
+ default: Union[Iterable[Text], Text],
+ help: Text,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[List[Text]]:
+ ...
+
+@overload
+def DEFINE_multi_integer(
+ name: Text,
+ default: Union[None, Iterable[int], int, Text],
+ help: Text,
+ lower_bound: Optional[int] = ...,
+ upper_bound: Optional[int] = ...,
+ flag_values: _flagvalues.FlagValues = ...,
+ *,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[List[int]]:
+ ...
+
+@overload
+def DEFINE_multi_integer(
+ name: Text,
+ default: None,
+ help: Text,
+ lower_bound: Optional[int] = ...,
+ upper_bound: Optional[int] = ...,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[Optional[List[int]]]:
+ ...
+
+@overload
+def DEFINE_multi_integer(
+ name: Text,
+ default: Union[Iterable[int], int, Text],
+ help: Text,
+ lower_bound: Optional[int] = ...,
+ upper_bound: Optional[int] = ...,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[List[int]]:
+ ...
+
+@overload
+def DEFINE_multi_float(
+ name: Text,
+ default: Union[None, Iterable[float], float, Text],
+ help: Text,
+ lower_bound: Optional[float] = ...,
+ upper_bound: Optional[float] = ...,
+ flag_values: _flagvalues.FlagValues = ...,
+ *,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[List[float]]:
+ ...
+
+@overload
+def DEFINE_multi_float(
+ name: Text,
+ default: None,
+ help: Text,
+ lower_bound: Optional[float] = ...,
+ upper_bound: Optional[float] = ...,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[Optional[List[float]]]:
+ ...
+
+@overload
+def DEFINE_multi_float(
+ name: Text,
+ default: Union[Iterable[float], float, Text],
+ help: Text,
+ lower_bound: Optional[float] = ...,
+ upper_bound: Optional[float] = ...,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[List[float]]:
+ ...
+
+
+@overload
+def DEFINE_multi_enum(
+ name: Text,
+ default: Union[None, Iterable[Text], Text],
+ enum_values: Iterable[Text],
+ help: Text,
+ flag_values: _flagvalues.FlagValues = ...,
+ *,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[List[Text]]:
+ ...
+
+@overload
+def DEFINE_multi_enum(
+ name: Text,
+ default: None,
+ enum_values: Iterable[Text],
+ help: Text,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[Optional[List[Text]]]:
+ ...
+
+@overload
+def DEFINE_multi_enum(
+ name: Text,
+ default: Union[Iterable[Text], Text],
+ enum_values: Iterable[Text],
+ help: Text,
+ flag_values: _flagvalues.FlagValues = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[List[Text]]:
+ ...
+
+@overload
+def DEFINE_multi_enum_class(
+ name: Text,
+ default: Union[None, Iterable[_ET], _ET, Text],
+ enum_class: Type[_ET],
+ help: Text,
+ flag_values: _flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = ...,
+ *,
+ required: Literal[True],
+ **args: Any) -> _flagvalues.FlagHolder[List[_ET]]:
+ ...
+
+@overload
+def DEFINE_multi_enum_class(
+ name: Text,
+ default: None,
+ enum_class: Type[_ET],
+ help: Text,
+ flag_values: _flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[Optional[List[_ET]]]:
+ ...
+
+@overload
+def DEFINE_multi_enum_class(
+ name: Text,
+ default: Union[Iterable[_ET], _ET, Text],
+ enum_class: Type[_ET],
+ help: Text,
+ flag_values: _flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = ...,
+ required: bool = ...,
+ **args: Any) -> _flagvalues.FlagHolder[List[_ET]]:
+ ...
+
+
+
+def DEFINE_alias(
+ name: Text,
+ original_name: Text,
+ flag_values: _flagvalues.FlagValues = ...,
+ module_name: Optional[Text] = ...) -> _flagvalues.FlagHolder[Any]:
+ ...
+
+
+
+def declare_key_flag(flag_name: Text,
+ flag_values: _flagvalues.FlagValues = ...) -> None:
+ ...
+
+
+
+def adopt_module_key_flags(module: Any,
+ flag_values: _flagvalues.FlagValues = ...) -> None:
+ ...
+
+
+
+def disclaim_key_flags() -> None:
+ ...
diff --git a/absl/flags/_exceptions.py b/absl/flags/_exceptions.py
new file mode 100644
index 0000000..254eb9b
--- /dev/null
+++ b/absl/flags/_exceptions.py
@@ -0,0 +1,112 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Exception classes in ABSL flags library.
+
+Do NOT import this module directly. Import the flags package and use the
+aliases defined at the package level instead.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+from absl.flags import _helpers
+
+
+_helpers.disclaim_module_ids.add(id(sys.modules[__name__]))
+
+
+class Error(Exception):
+ """The base class for all flags errors."""
+
+
+class CantOpenFlagFileError(Error):
+ """Raised when flagfile fails to open.
+
+ E.g. the file doesn't exist, or has wrong permissions.
+ """
+
+
+class DuplicateFlagError(Error):
+ """Raised if there is a flag naming conflict."""
+
+ @classmethod
+ def from_flag(cls, flagname, flag_values, other_flag_values=None):
+ """Creates a DuplicateFlagError by providing flag name and values.
+
+ Args:
+ flagname: str, the name of the flag being redefined.
+ flag_values: FlagValues, the FlagValues instance containing the first
+ definition of flagname.
+ other_flag_values: FlagValues, if it is not None, it should be the
+ FlagValues object where the second definition of flagname occurs.
+ If it is None, we assume that we're being called when attempting
+ to create the flag a second time, and we use the module calling
+ this one as the source of the second definition.
+
+ Returns:
+ An instance of DuplicateFlagError.
+ """
+ first_module = flag_values.find_module_defining_flag(
+ flagname, default='<unknown>')
+ if other_flag_values is None:
+ second_module = _helpers.get_calling_module()
+ else:
+ second_module = other_flag_values.find_module_defining_flag(
+ flagname, default='<unknown>')
+ flag_summary = flag_values[flagname].help
+ msg = ("The flag '%s' is defined twice. First from %s, Second from %s. "
+ "Description from first occurrence: %s") % (
+ flagname, first_module, second_module, flag_summary)
+ return cls(msg)
+
+
+class IllegalFlagValueError(Error):
+ """Raised when the flag command line argument is illegal."""
+
+
+class UnrecognizedFlagError(Error):
+ """Raised when a flag is unrecognized.
+
+ Attributes:
+ flagname: str, the name of the unrecognized flag.
+ flagvalue: The value of the flag, empty if the flag is not defined.
+ """
+
+ def __init__(self, flagname, flagvalue='', suggestions=None):
+ self.flagname = flagname
+ self.flagvalue = flagvalue
+ if suggestions:
+ # Space before the question mark is intentional to not include it in the
+ # selection when copy-pasting the suggestion from (some) terminals.
+ tip = '. Did you mean: %s ?' % ', '.join(suggestions)
+ else:
+ tip = ''
+ super(UnrecognizedFlagError, self).__init__(
+ 'Unknown command line flag \'%s\'%s' % (flagname, tip))
+
+
+class UnparsedFlagAccessError(Error):
+ """Raised when accessing the flag value from unparsed FlagValues."""
+
+
+class ValidationError(Error):
+ """Raised when flag validator constraint is not satisfied."""
+
+
+class FlagNameConflictsWithMethodError(Error):
+ """Raised when a flag name conflicts with FlagValues methods."""
diff --git a/absl/flags/_flag.py b/absl/flags/_flag.py
new file mode 100644
index 0000000..d7ad944
--- /dev/null
+++ b/absl/flags/_flag.py
@@ -0,0 +1,485 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Contains Flag class - information about single command-line flag.
+
+Do NOT import this module directly. Import the flags package and use the
+aliases defined at the package level instead.
+"""
+
+from collections import abc
+import copy
+import functools
+
+from absl.flags import _argument_parser
+from absl.flags import _exceptions
+from absl.flags import _helpers
+
+
+@functools.total_ordering
+class Flag(object):
+ """Information about a command-line flag.
+
+ 'Flag' objects define the following fields:
+ .name - the name for this flag;
+ .default - the default value for this flag;
+ .default_unparsed - the unparsed default value for this flag.
+ .default_as_str - default value as repr'd string, e.g., "'true'" (or None);
+ .value - the most recent parsed value of this flag; set by parse();
+ .help - a help string or None if no help is available;
+ .short_name - the single letter alias for this flag (or None);
+ .boolean - if 'true', this flag does not accept arguments;
+ .present - true if this flag was parsed from command line flags;
+ .parser - an ArgumentParser object;
+ .serializer - an ArgumentSerializer object;
+ .allow_override - the flag may be redefined without raising an error, and
+ newly defined flag overrides the old one.
+ .allow_override_cpp - use the flag from C++ if available; the flag
+ definition is replaced by the C++ flag after init;
+ .allow_hide_cpp - use the Python flag despite having a C++ flag with
+ the same name (ignore the C++ flag);
+ .using_default_value - the flag value has not been set by user;
+ .allow_overwrite - the flag may be parsed more than once without raising
+ an error, the last set value will be used;
+ .allow_using_method_names - whether this flag can be defined even if it has
+ a name that conflicts with a FlagValues method.
+
+ The only public method of a 'Flag' object is parse(), but it is
+ typically only called by a 'FlagValues' object. The parse() method is
+ a thin wrapper around the 'ArgumentParser' parse() method. The parsed
+ value is saved in .value, and the .present attribute is updated. If
+ this flag was already present, an Error is raised.
+
+ parse() is also called during __init__ to parse the default value and
+ initialize the .value attribute. This enables other python modules to
+ safely use flags even if the __main__ module neglects to parse the
+ command line arguments. The .present attribute is cleared after
+ __init__ parsing. If the default value is set to None, then the
+ __init__ parsing step is skipped and the .value attribute is
+ initialized to None.
+
+ Note: The default value is also presented to the user in the help
+ string, so it is important that it be a legal value for this flag.
+ """
+
+ def __init__(self, parser, serializer, name, default, help_string,
+ short_name=None, boolean=False, allow_override=False,
+ allow_override_cpp=False, allow_hide_cpp=False,
+ allow_overwrite=True, allow_using_method_names=False):
+ self.name = name
+
+ if not help_string:
+ help_string = '(no help available)'
+
+ self.help = help_string
+ self.short_name = short_name
+ self.boolean = boolean
+ self.present = 0
+ self.parser = parser
+ self.serializer = serializer
+ self.allow_override = allow_override
+ self.allow_override_cpp = allow_override_cpp
+ self.allow_hide_cpp = allow_hide_cpp
+ self.allow_overwrite = allow_overwrite
+ self.allow_using_method_names = allow_using_method_names
+
+ self.using_default_value = True
+ self._value = None
+ self.validators = []
+ if self.allow_hide_cpp and self.allow_override_cpp:
+ raise _exceptions.Error(
+ "Can't have both allow_hide_cpp (means use Python flag) and "
+ 'allow_override_cpp (means use C++ flag after InitGoogle)')
+
+ self._set_default(default)
+
+ @property
+ def value(self):
+ return self._value
+
+ @value.setter
+ def value(self, value):
+ self._value = value
+
+ def __hash__(self):
+ return hash(id(self))
+
+ def __eq__(self, other):
+ return self is other
+
+ def __lt__(self, other):
+ if isinstance(other, Flag):
+ return id(self) < id(other)
+ return NotImplemented
+
+ def __bool__(self):
+ raise TypeError('A Flag instance would always be True. '
+ 'Did you mean to test the `.value` attribute?')
+
+ def __getstate__(self):
+ raise TypeError("can't pickle Flag objects")
+
+ def __copy__(self):
+ raise TypeError('%s does not support shallow copies. '
+ 'Use copy.deepcopy instead.' % type(self).__name__)
+
+ def __deepcopy__(self, memo):
+ result = object.__new__(type(self))
+ result.__dict__ = copy.deepcopy(self.__dict__, memo)
+ return result
+
+ def _get_parsed_value_as_string(self, value):
+ """Returns parsed flag value as string."""
+ if value is None:
+ return None
+ if self.serializer:
+ return repr(self.serializer.serialize(value))
+ if self.boolean:
+ if value:
+ return repr('true')
+ else:
+ return repr('false')
+ return repr(_helpers.str_or_unicode(value))
+
+ def parse(self, argument):
+ """Parses string and sets flag value.
+
+ Args:
+ argument: str or the correct flag value type, argument to be parsed.
+ """
+ if self.present and not self.allow_overwrite:
+ raise _exceptions.IllegalFlagValueError(
+ 'flag --%s=%s: already defined as %s' % (
+ self.name, argument, self.value))
+ self.value = self._parse(argument)
+ self.present += 1
+
+ def _parse(self, argument):
+ """Internal parse function.
+
+ It returns the parsed value, and does not modify class states.
+
+ Args:
+ argument: str or the correct flag value type, argument to be parsed.
+
+ Returns:
+ The parsed value.
+ """
+ try:
+ return self.parser.parse(argument)
+ except (TypeError, ValueError) as e: # Recast as IllegalFlagValueError.
+ raise _exceptions.IllegalFlagValueError(
+ 'flag --%s=%s: %s' % (self.name, argument, e))
+
+ def unparse(self):
+ self.value = self.default
+ self.using_default_value = True
+ self.present = 0
+
+ def serialize(self):
+ """Serializes the flag."""
+ return self._serialize(self.value)
+
+ def _serialize(self, value):
+ """Internal serialize function."""
+ if value is None:
+ return ''
+ if self.boolean:
+ if value:
+ return '--%s' % self.name
+ else:
+ return '--no%s' % self.name
+ else:
+ if not self.serializer:
+ raise _exceptions.Error(
+ 'Serializer not present for flag %s' % self.name)
+ return '--%s=%s' % (self.name, self.serializer.serialize(value))
+
+ def _set_default(self, value):
+ """Changes the default value (and current value too) for this Flag."""
+ self.default_unparsed = value
+ if value is None:
+ self.default = None
+ else:
+ self.default = self._parse_from_default(value)
+ self.default_as_str = self._get_parsed_value_as_string(self.default)
+ if self.using_default_value:
+ self.value = self.default
+
+ # This is split out so that aliases can skip regular parsing of the default
+ # value.
+ def _parse_from_default(self, value):
+ return self._parse(value)
+
+ def flag_type(self):
+ """Returns a str that describes the type of the flag.
+
+ NOTE: we use strings, and not the types.*Type constants because
+ our flags can have more exotic types, e.g., 'comma separated list
+ of strings', 'whitespace separated list of strings', etc.
+ """
+ return self.parser.flag_type()
+
+ def _create_xml_dom_element(self, doc, module_name, is_key=False):
+ """Returns an XML element that contains this flag's information.
+
+ This is information that is relevant to all flags (e.g., name,
+ meaning, etc.). If you defined a flag that has some other pieces of
+ info, then please override _ExtraXMLInfo.
+
+ Please do NOT override this method.
+
+ Args:
+ doc: minidom.Document, the DOM document it should create nodes from.
+ module_name: str,, the name of the module that defines this flag.
+ is_key: boolean, True iff this flag is key for main module.
+
+ Returns:
+ A minidom.Element instance.
+ """
+ element = doc.createElement('flag')
+ if is_key:
+ element.appendChild(_helpers.create_xml_dom_element(doc, 'key', 'yes'))
+ element.appendChild(_helpers.create_xml_dom_element(
+ doc, 'file', module_name))
+ # Adds flag features that are relevant for all flags.
+ element.appendChild(_helpers.create_xml_dom_element(doc, 'name', self.name))
+ if self.short_name:
+ element.appendChild(_helpers.create_xml_dom_element(
+ doc, 'short_name', self.short_name))
+ if self.help:
+ element.appendChild(_helpers.create_xml_dom_element(
+ doc, 'meaning', self.help))
+ # The default flag value can either be represented as a string like on the
+ # command line, or as a Python object. We serialize this value in the
+ # latter case in order to remain consistent.
+ if self.serializer and not isinstance(self.default, str):
+ if self.default is not None:
+ default_serialized = self.serializer.serialize(self.default)
+ else:
+ default_serialized = ''
+ else:
+ default_serialized = self.default
+ element.appendChild(_helpers.create_xml_dom_element(
+ doc, 'default', default_serialized))
+ value_serialized = self._serialize_value_for_xml(self.value)
+ element.appendChild(_helpers.create_xml_dom_element(
+ doc, 'current', value_serialized))
+ element.appendChild(_helpers.create_xml_dom_element(
+ doc, 'type', self.flag_type()))
+ # Adds extra flag features this flag may have.
+ for e in self._extra_xml_dom_elements(doc):
+ element.appendChild(e)
+ return element
+
+ def _serialize_value_for_xml(self, value):
+ """Returns the serialized value, for use in an XML help text."""
+ return value
+
+ def _extra_xml_dom_elements(self, doc):
+ """Returns extra info about this flag in XML.
+
+ "Extra" means "not already included by _create_xml_dom_element above."
+
+ Args:
+ doc: minidom.Document, the DOM document it should create nodes from.
+
+ Returns:
+ A list of minidom.Element.
+ """
+ # Usually, the parser knows the extra details about the flag, so
+ # we just forward the call to it.
+ return self.parser._custom_xml_dom_elements(doc) # pylint: disable=protected-access
+
+
+class BooleanFlag(Flag):
+ """Basic boolean flag.
+
+ Boolean flags do not take any arguments, and their value is either
+ True (1) or False (0). The false value is specified on the command
+ line by prepending the word 'no' to either the long or the short flag
+ name.
+
+ For example, if a Boolean flag was created whose long name was
+ 'update' and whose short name was 'x', then this flag could be
+ explicitly unset through either --noupdate or --nox.
+ """
+
+ def __init__(self, name, default, help, short_name=None, **args): # pylint: disable=redefined-builtin
+ p = _argument_parser.BooleanParser()
+ super(BooleanFlag, self).__init__(
+ p, None, name, default, help, short_name, 1, **args)
+
+
+class EnumFlag(Flag):
+ """Basic enum flag; its value can be any string from list of enum_values."""
+
+ def __init__(self, name, default, help, enum_values, # pylint: disable=redefined-builtin
+ short_name=None, case_sensitive=True, **args):
+ p = _argument_parser.EnumParser(enum_values, case_sensitive)
+ g = _argument_parser.ArgumentSerializer()
+ super(EnumFlag, self).__init__(
+ p, g, name, default, help, short_name, **args)
+ self.help = '<%s>: %s' % ('|'.join(enum_values), self.help)
+
+ def _extra_xml_dom_elements(self, doc):
+ elements = []
+ for enum_value in self.parser.enum_values:
+ elements.append(_helpers.create_xml_dom_element(
+ doc, 'enum_value', enum_value))
+ return elements
+
+
+class EnumClassFlag(Flag):
+ """Basic enum flag; its value is an enum class's member."""
+
+ def __init__(
+ self,
+ name,
+ default,
+ help, # pylint: disable=redefined-builtin
+ enum_class,
+ short_name=None,
+ case_sensitive=False,
+ **args):
+ p = _argument_parser.EnumClassParser(
+ enum_class, case_sensitive=case_sensitive)
+ g = _argument_parser.EnumClassSerializer(lowercase=not case_sensitive)
+ super(EnumClassFlag, self).__init__(
+ p, g, name, default, help, short_name, **args)
+ self.help = '<%s>: %s' % ('|'.join(p.member_names), self.help)
+
+ def _extra_xml_dom_elements(self, doc):
+ elements = []
+ for enum_value in self.parser.enum_class.__members__.keys():
+ elements.append(_helpers.create_xml_dom_element(
+ doc, 'enum_value', enum_value))
+ return elements
+
+
+class MultiFlag(Flag):
+ """A flag that can appear multiple time on the command-line.
+
+ The value of such a flag is a list that contains the individual values
+ from all the appearances of that flag on the command-line.
+
+ See the __doc__ for Flag for most behavior of this class. Only
+ differences in behavior are described here:
+
+ * The default value may be either a single value or an iterable of values.
+ A single value is transformed into a single-item list of that value.
+
+ * The value of the flag is always a list, even if the option was
+ only supplied once, and even if the default value is a single
+ value
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(MultiFlag, self).__init__(*args, **kwargs)
+ self.help += ';\n repeat this option to specify a list of values'
+
+ def parse(self, arguments):
+ """Parses one or more arguments with the installed parser.
+
+ Args:
+ arguments: a single argument or a list of arguments (typically a
+ list of default values); a single argument is converted
+ internally into a list containing one item.
+ """
+ new_values = self._parse(arguments)
+ if self.present:
+ self.value.extend(new_values)
+ else:
+ self.value = new_values
+ self.present += len(new_values)
+
+ def _parse(self, arguments):
+ if (isinstance(arguments, abc.Iterable) and
+ not isinstance(arguments, str)):
+ arguments = list(arguments)
+
+ if not isinstance(arguments, list):
+ # Default value may be a list of values. Most other arguments
+ # will not be, so convert them into a single-item list to make
+ # processing simpler below.
+ arguments = [arguments]
+
+ return [super(MultiFlag, self)._parse(item) for item in arguments]
+
+ def _serialize(self, value):
+ """See base class."""
+ if not self.serializer:
+ raise _exceptions.Error(
+ 'Serializer not present for flag %s' % self.name)
+ if value is None:
+ return ''
+
+ serialized_items = [
+ super(MultiFlag, self)._serialize(value_item) for value_item in value
+ ]
+
+ return '\n'.join(serialized_items)
+
+ def flag_type(self):
+ """See base class."""
+ return 'multi ' + self.parser.flag_type()
+
+ def _extra_xml_dom_elements(self, doc):
+ elements = []
+ if hasattr(self.parser, 'enum_values'):
+ for enum_value in self.parser.enum_values:
+ elements.append(_helpers.create_xml_dom_element(
+ doc, 'enum_value', enum_value))
+ return elements
+
+
+class MultiEnumClassFlag(MultiFlag):
+ """A multi_enum_class flag.
+
+ See the __doc__ for MultiFlag for most behaviors of this class. In addition,
+ this class knows how to handle enum.Enum instances as values for this flag
+ type.
+ """
+
+ def __init__(self,
+ name,
+ default,
+ help_string,
+ enum_class,
+ case_sensitive=False,
+ **args):
+ p = _argument_parser.EnumClassParser(
+ enum_class, case_sensitive=case_sensitive)
+ g = _argument_parser.EnumClassListSerializer(
+ list_sep=',', lowercase=not case_sensitive)
+ super(MultiEnumClassFlag, self).__init__(
+ p, g, name, default, help_string, **args)
+ self.help = (
+ '<%s>: %s;\n repeat this option to specify a list of values' %
+ ('|'.join(p.member_names), help_string or '(no help available)'))
+
+ def _extra_xml_dom_elements(self, doc):
+ elements = []
+ for enum_value in self.parser.enum_class.__members__.keys():
+ elements.append(_helpers.create_xml_dom_element(
+ doc, 'enum_value', enum_value))
+ return elements
+
+ def _serialize_value_for_xml(self, value):
+ """See base class."""
+ if value is not None:
+ value_serialized = self.serializer.serialize(value)
+ else:
+ value_serialized = ''
+ return value_serialized
diff --git a/absl/flags/_flag.pyi b/absl/flags/_flag.pyi
new file mode 100644
index 0000000..9b4a3d3
--- /dev/null
+++ b/absl/flags/_flag.pyi
@@ -0,0 +1,133 @@
+# Copyright 2020 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Contains type annotations for Flag class."""
+
+import copy
+import functools
+
+from absl.flags import _argument_parser
+import enum
+
+from typing import Text, TypeVar, Generic, Iterable, Type, List, Optional, Any, Union, Sequence
+
+_T = TypeVar('_T')
+_ET = TypeVar('_ET', bound=enum.Enum)
+
+
+class Flag(Generic[_T]):
+
+ name = ... # type: Text
+ default = ... # type: Any
+ default_unparsed = ... # type: Any
+ default_as_str = ... # type: Optional[Text]
+ help = ... # type: Text
+ short_name = ... # type: Text
+ boolean = ... # type: bool
+ present = ... # type: bool
+ parser = ... # type: _argument_parser.ArgumentParser[_T]
+ serializer = ... # type: _argument_parser.ArgumentSerializer[_T]
+ allow_override = ... # type: bool
+ allow_override_cpp = ... # type: bool
+ allow_hide_cpp = ... # type: bool
+ using_default_value = ... # type: bool
+ allow_overwrite = ... # type: bool
+ allow_using_method_names = ... # type: bool
+
+ def __init__(self,
+ parser: _argument_parser.ArgumentParser[_T],
+ serializer: Optional[_argument_parser.ArgumentSerializer[_T]],
+ name: Text,
+ default: Any,
+ help_string: Optional[Text],
+ short_name: Optional[Text] = ...,
+ boolean: bool = ...,
+ allow_override: bool = ...,
+ allow_override_cpp: bool = ...,
+ allow_hide_cpp: bool = ...,
+ allow_overwrite: bool = ...,
+ allow_using_method_names: bool = ...) -> None:
+ ...
+
+
+ @property
+ def value(self) -> Optional[_T]: ...
+
+ def parse(self, argument: Union[_T, Text, None]) -> None: ...
+
+ def unparse(self) -> None: ...
+
+ def _parse(self, argument: Any) -> Any: ...
+
+ def __deepcopy__(self, memo: dict) -> Flag: ...
+
+ def _get_parsed_value_as_string(self, value: Optional[_T]) -> Optional[Text]:
+ ...
+
+ def serialize(self) -> Text: ...
+
+ def flag_type(self) -> Text: ...
+
+
+class BooleanFlag(Flag[bool]):
+ def __init__(self,
+ name: Text,
+ default: Any,
+ help: Optional[Text],
+ short_name: Optional[Text]=None,
+ **args: Any) -> None:
+ ...
+
+
+
+class EnumFlag(Flag[Text]):
+ def __init__(self,
+ name: Text,
+ default: Any,
+ help: Optional[Text],
+ enum_values: Sequence[Text],
+ short_name: Optional[Text] = ...,
+ case_sensitive: bool = ...,
+ **args: Any):
+ ...
+
+
+
+class EnumClassFlag(Flag[_ET]):
+
+ def __init__(self,
+ name: Text,
+ default: Any,
+ help: Optional[Text],
+ enum_class: Type[_ET],
+ short_name: Optional[Text]=None,
+ **args: Any):
+ ...
+
+
+
+class MultiFlag(Flag[List[_T]]):
+ ...
+
+
+class MultiEnumClassFlag(MultiFlag[_ET]):
+ def __init__(self,
+ name: Text,
+ default: Any,
+ help_string: Optional[Text],
+ enum_class: Type[_ET],
+ **args: Any):
+ ...
+
+
diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py
new file mode 100644
index 0000000..1b54fb3
--- /dev/null
+++ b/absl/flags/_flagvalues.py
@@ -0,0 +1,1387 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Defines the FlagValues class - registry of 'Flag' objects.
+
+Do NOT import this module directly. Import the flags package and use the
+aliases defined at the package level instead.
+"""
+
+import copy
+import itertools
+import logging
+import os
+import sys
+from typing import Generic, TypeVar
+from xml.dom import minidom
+
+from absl.flags import _exceptions
+from absl.flags import _flag
+from absl.flags import _helpers
+from absl.flags import _validators_classes
+
+# Add flagvalues module to disclaimed module ids.
+_helpers.disclaim_module_ids.add(id(sys.modules[__name__]))
+
+_T = TypeVar('_T')
+
+
+class FlagValues:
+ """Registry of 'Flag' objects.
+
+ A 'FlagValues' can then scan command line arguments, passing flag
+ arguments through to the 'Flag' objects that it owns. It also
+ provides easy access to the flag values. Typically only one
+ 'FlagValues' object is needed by an application: flags.FLAGS
+
+ This class is heavily overloaded:
+
+ 'Flag' objects are registered via __setitem__:
+ FLAGS['longname'] = x # register a new flag
+
+ The .value attribute of the registered 'Flag' objects can be accessed
+ as attributes of this 'FlagValues' object, through __getattr__. Both
+ the long and short name of the original 'Flag' objects can be used to
+ access its value:
+ FLAGS.longname # parsed flag value
+ FLAGS.x # parsed flag value (short name)
+
+ Command line arguments are scanned and passed to the registered 'Flag'
+ objects through the __call__ method. Unparsed arguments, including
+ argv[0] (e.g. the program name) are returned.
+ argv = FLAGS(sys.argv) # scan command line arguments
+
+ The original registered Flag objects can be retrieved through the use
+ of the dictionary-like operator, __getitem__:
+ x = FLAGS['longname'] # access the registered Flag object
+
+ The str() operator of a 'FlagValues' object provides help for all of
+ the registered 'Flag' objects.
+ """
+
+ # A note on collections.abc.Mapping:
+ # FlagValues defines __getitem__, __iter__, and __len__. It makes perfect
+ # sense to let it be a collections.abc.Mapping class. However, we are not
+ # able to do so. The mixin methods, e.g. keys, values, are not uncommon flag
+ # names. Those flag values would not be accessible via the FLAGS.xxx form.
+
+ def __init__(self):
+ # Since everything in this class is so heavily overloaded, the only
+ # way of defining and using fields is to access __dict__ directly.
+
+ # Dictionary: flag name (string) -> Flag object.
+ self.__dict__['__flags'] = {}
+
+ # Set: name of hidden flag (string).
+ # Holds flags that should not be directly accessible from Python.
+ self.__dict__['__hiddenflags'] = set()
+
+ # Dictionary: module name (string) -> list of Flag objects that are defined
+ # by that module.
+ self.__dict__['__flags_by_module'] = {}
+ # Dictionary: module id (int) -> list of Flag objects that are defined by
+ # that module.
+ self.__dict__['__flags_by_module_id'] = {}
+ # Dictionary: module name (string) -> list of Flag objects that are
+ # key for that module.
+ self.__dict__['__key_flags_by_module'] = {}
+
+ # Bool: True if flags were parsed.
+ self.__dict__['__flags_parsed'] = False
+
+ # Bool: True if unparse_flags() was called.
+ self.__dict__['__unparse_flags_called'] = False
+
+ # None or Method(name, value) to call from __setattr__ for an unknown flag.
+ self.__dict__['__set_unknown'] = None
+
+ # A set of banned flag names. This is to prevent users from accidentally
+ # defining a flag that has the same name as a method on this class.
+ # Users can still allow defining the flag by passing
+ # allow_using_method_names=True in DEFINE_xxx functions.
+ self.__dict__['__banned_flag_names'] = frozenset(dir(FlagValues))
+
+ # Bool: Whether to use GNU style scanning.
+ self.__dict__['__use_gnu_getopt'] = True
+
+ # Bool: Whether use_gnu_getopt has been explicitly set by the user.
+ self.__dict__['__use_gnu_getopt_explicitly_set'] = False
+
+ # Function: Takes a flag name as parameter, returns a tuple
+ # (is_retired, type_is_bool).
+ self.__dict__['__is_retired_flag_func'] = None
+
+ def set_gnu_getopt(self, gnu_getopt=True):
+ """Sets whether or not to use GNU style scanning.
+
+ GNU style allows mixing of flag and non-flag arguments. See
+ http://docs.python.org/library/getopt.html#getopt.gnu_getopt
+
+ Args:
+ gnu_getopt: bool, whether or not to use GNU style scanning.
+ """
+ self.__dict__['__use_gnu_getopt'] = gnu_getopt
+ self.__dict__['__use_gnu_getopt_explicitly_set'] = True
+
+ def is_gnu_getopt(self):
+ return self.__dict__['__use_gnu_getopt']
+
+ def _flags(self):
+ return self.__dict__['__flags']
+
+ def flags_by_module_dict(self):
+ """Returns the dictionary of module_name -> list of defined flags.
+
+ Returns:
+ A dictionary. Its keys are module names (strings). Its values
+ are lists of Flag objects.
+ """
+ return self.__dict__['__flags_by_module']
+
+ def flags_by_module_id_dict(self):
+ """Returns the dictionary of module_id -> list of defined flags.
+
+ Returns:
+ A dictionary. Its keys are module IDs (ints). Its values
+ are lists of Flag objects.
+ """
+ return self.__dict__['__flags_by_module_id']
+
+ def key_flags_by_module_dict(self):
+ """Returns the dictionary of module_name -> list of key flags.
+
+ Returns:
+ A dictionary. Its keys are module names (strings). Its values
+ are lists of Flag objects.
+ """
+ return self.__dict__['__key_flags_by_module']
+
+ def register_flag_by_module(self, module_name, flag):
+ """Records the module that defines a specific flag.
+
+ We keep track of which flag is defined by which module so that we
+ can later sort the flags by module.
+
+ Args:
+ module_name: str, the name of a Python module.
+ flag: Flag, the Flag instance that is key to the module.
+ """
+ flags_by_module = self.flags_by_module_dict()
+ flags_by_module.setdefault(module_name, []).append(flag)
+
+ def register_flag_by_module_id(self, module_id, flag):
+ """Records the module that defines a specific flag.
+
+ Args:
+ module_id: int, the ID of the Python module.
+ flag: Flag, the Flag instance that is key to the module.
+ """
+ flags_by_module_id = self.flags_by_module_id_dict()
+ flags_by_module_id.setdefault(module_id, []).append(flag)
+
+ def register_key_flag_for_module(self, module_name, flag):
+ """Specifies that a flag is a key flag for a module.
+
+ Args:
+ module_name: str, the name of a Python module.
+ flag: Flag, the Flag instance that is key to the module.
+ """
+ key_flags_by_module = self.key_flags_by_module_dict()
+ # The list of key flags for the module named module_name.
+ key_flags = key_flags_by_module.setdefault(module_name, [])
+ # Add flag, but avoid duplicates.
+ if flag not in key_flags:
+ key_flags.append(flag)
+
+ def _flag_is_registered(self, flag_obj):
+ """Checks whether a Flag object is registered under long name or short name.
+
+ Args:
+ flag_obj: Flag, the Flag instance to check for.
+
+ Returns:
+ bool, True iff flag_obj is registered under long name or short name.
+ """
+ flag_dict = self._flags()
+ # Check whether flag_obj is registered under its long name.
+ name = flag_obj.name
+ if flag_dict.get(name, None) == flag_obj:
+ return True
+ # Check whether flag_obj is registered under its short name.
+ short_name = flag_obj.short_name
+ if (short_name is not None and flag_dict.get(short_name, None) == flag_obj):
+ return True
+ return False
+
+ def _cleanup_unregistered_flag_from_module_dicts(self, flag_obj):
+ """Cleans up unregistered flags from all module -> [flags] dictionaries.
+
+ If flag_obj is registered under either its long name or short name, it
+ won't be removed from the dictionaries.
+
+ Args:
+ flag_obj: Flag, the Flag instance to clean up for.
+ """
+ if self._flag_is_registered(flag_obj):
+ return
+ for flags_by_module_dict in (self.flags_by_module_dict(),
+ self.flags_by_module_id_dict(),
+ self.key_flags_by_module_dict()):
+ for flags_in_module in flags_by_module_dict.values():
+ # While (as opposed to if) takes care of multiple occurrences of a
+ # flag in the list for the same module.
+ while flag_obj in flags_in_module:
+ flags_in_module.remove(flag_obj)
+
+ def get_flags_for_module(self, module):
+ """Returns the list of flags defined by a module.
+
+ Args:
+ module: module|str, the module to get flags from.
+
+ Returns:
+ [Flag], a new list of Flag instances. Caller may update this list as
+ desired: none of those changes will affect the internals of this
+ FlagValue instance.
+ """
+ if not isinstance(module, str):
+ module = module.__name__
+ if module == '__main__':
+ module = sys.argv[0]
+
+ return list(self.flags_by_module_dict().get(module, []))
+
+ def get_key_flags_for_module(self, module):
+ """Returns the list of key flags for a module.
+
+ Args:
+ module: module|str, the module to get key flags from.
+
+ Returns:
+ [Flag], a new list of Flag instances. Caller may update this list as
+ desired: none of those changes will affect the internals of this
+ FlagValue instance.
+ """
+ if not isinstance(module, str):
+ module = module.__name__
+ if module == '__main__':
+ module = sys.argv[0]
+
+ # Any flag is a key flag for the module that defined it. NOTE:
+ # key_flags is a fresh list: we can update it without affecting the
+ # internals of this FlagValues object.
+ key_flags = self.get_flags_for_module(module)
+
+ # Take into account flags explicitly declared as key for a module.
+ for flag in self.key_flags_by_module_dict().get(module, []):
+ if flag not in key_flags:
+ key_flags.append(flag)
+ return key_flags
+
+ def find_module_defining_flag(self, flagname, default=None):
+ """Return the name of the module defining this flag, or default.
+
+ Args:
+ flagname: str, name of the flag to lookup.
+ default: Value to return if flagname is not defined. Defaults to None.
+
+ Returns:
+ The name of the module which registered the flag with this name.
+ If no such module exists (i.e. no flag with this name exists),
+ we return default.
+ """
+ registered_flag = self._flags().get(flagname)
+ if registered_flag is None:
+ return default
+ for module, flags in self.flags_by_module_dict().items():
+ for flag in flags:
+ # It must compare the flag with the one in _flags. This is because a
+ # flag might be overridden only for its long name (or short name),
+ # and only its short name (or long name) is considered registered.
+ if (flag.name == registered_flag.name and
+ flag.short_name == registered_flag.short_name):
+ return module
+ return default
+
+ def find_module_id_defining_flag(self, flagname, default=None):
+ """Return the ID of the module defining this flag, or default.
+
+ Args:
+ flagname: str, name of the flag to lookup.
+ default: Value to return if flagname is not defined. Defaults to None.
+
+ Returns:
+ The ID of the module which registered the flag with this name.
+ If no such module exists (i.e. no flag with this name exists),
+ we return default.
+ """
+ registered_flag = self._flags().get(flagname)
+ if registered_flag is None:
+ return default
+ for module_id, flags in self.flags_by_module_id_dict().items():
+ for flag in flags:
+ # It must compare the flag with the one in _flags. This is because a
+ # flag might be overridden only for its long name (or short name),
+ # and only its short name (or long name) is considered registered.
+ if (flag.name == registered_flag.name and
+ flag.short_name == registered_flag.short_name):
+ return module_id
+ return default
+
+ def _register_unknown_flag_setter(self, setter):
+ """Allow set default values for undefined flags.
+
+ Args:
+ setter: Method(name, value) to call to __setattr__ an unknown flag. Must
+ raise NameError or ValueError for invalid name/value.
+ """
+ self.__dict__['__set_unknown'] = setter
+
+ def _set_unknown_flag(self, name, value):
+ """Returns value if setting flag |name| to |value| returned True.
+
+ Args:
+ name: str, name of the flag to set.
+ value: Value to set.
+
+ Returns:
+ Flag value on successful call.
+
+ Raises:
+ UnrecognizedFlagError
+ IllegalFlagValueError
+ """
+ setter = self.__dict__['__set_unknown']
+ if setter:
+ try:
+ setter(name, value)
+ return value
+ except (TypeError, ValueError): # Flag value is not valid.
+ raise _exceptions.IllegalFlagValueError(
+ '"{1}" is not valid for --{0}'.format(name, value))
+ except NameError: # Flag name is not valid.
+ pass
+ raise _exceptions.UnrecognizedFlagError(name, value)
+
+ def append_flag_values(self, flag_values):
+ """Appends flags registered in another FlagValues instance.
+
+ Args:
+ flag_values: FlagValues, the FlagValues instance from which to copy flags.
+ """
+ for flag_name, flag in flag_values._flags().items(): # pylint: disable=protected-access
+ # Each flags with short_name appears here twice (once under its
+ # normal name, and again with its short name). To prevent
+ # problems (DuplicateFlagError) with double flag registration, we
+ # perform a check to make sure that the entry we're looking at is
+ # for its normal name.
+ if flag_name == flag.name:
+ try:
+ self[flag_name] = flag
+ except _exceptions.DuplicateFlagError:
+ raise _exceptions.DuplicateFlagError.from_flag(
+ flag_name, self, other_flag_values=flag_values)
+
+ def remove_flag_values(self, flag_values):
+ """Remove flags that were previously appended from another FlagValues.
+
+ Args:
+ flag_values: FlagValues, the FlagValues instance containing flags to
+ remove.
+ """
+ for flag_name in flag_values:
+ self.__delattr__(flag_name)
+
+ def __setitem__(self, name, flag):
+ """Registers a new flag variable."""
+ fl = self._flags()
+ if not isinstance(flag, _flag.Flag):
+ raise _exceptions.IllegalFlagValueError(flag)
+ if str is bytes and isinstance(name, unicode):
+ # When using Python 2 with unicode_literals, allow it but encode it
+ # into the bytes type we require.
+ name = name.encode('utf-8')
+ if not isinstance(name, type('')):
+ raise _exceptions.Error('Flag name must be a string')
+ if not name:
+ raise _exceptions.Error('Flag name cannot be empty')
+ if ' ' in name:
+ raise _exceptions.Error('Flag name cannot contain a space')
+ self._check_method_name_conflicts(name, flag)
+ if name in fl and not flag.allow_override and not fl[name].allow_override:
+ module, module_name = _helpers.get_calling_module_object_and_name()
+ if (self.find_module_defining_flag(name) == module_name and
+ id(module) != self.find_module_id_defining_flag(name)):
+ # If the flag has already been defined by a module with the same name,
+ # but a different ID, we can stop here because it indicates that the
+ # module is simply being imported a subsequent time.
+ return
+ raise _exceptions.DuplicateFlagError.from_flag(name, self)
+ short_name = flag.short_name
+ # If a new flag overrides an old one, we need to cleanup the old flag's
+ # modules if it's not registered.
+ flags_to_cleanup = set()
+ if short_name is not None:
+ if (short_name in fl and not flag.allow_override and
+ not fl[short_name].allow_override):
+ raise _exceptions.DuplicateFlagError.from_flag(short_name, self)
+ if short_name in fl and fl[short_name] != flag:
+ flags_to_cleanup.add(fl[short_name])
+ fl[short_name] = flag
+ if (name not in fl # new flag
+ or fl[name].using_default_value or not flag.using_default_value):
+ if name in fl and fl[name] != flag:
+ flags_to_cleanup.add(fl[name])
+ fl[name] = flag
+ for f in flags_to_cleanup:
+ self._cleanup_unregistered_flag_from_module_dicts(f)
+
+ def __dir__(self):
+ """Returns list of names of all defined flags.
+
+ Useful for TAB-completion in ipython.
+
+ Returns:
+ [str], a list of names of all defined flags.
+ """
+ return sorted(self.__dict__['__flags'])
+
+ def __getitem__(self, name):
+ """Returns the Flag object for the flag --name."""
+ return self._flags()[name]
+
+ def _hide_flag(self, name):
+ """Marks the flag --name as hidden."""
+ self.__dict__['__hiddenflags'].add(name)
+
+ def __getattr__(self, name):
+ """Retrieves the 'value' attribute of the flag --name."""
+ fl = self._flags()
+ if name not in fl:
+ raise AttributeError(name)
+ if name in self.__dict__['__hiddenflags']:
+ raise AttributeError(name)
+
+ if self.__dict__['__flags_parsed'] or fl[name].present:
+ return fl[name].value
+ else:
+ raise _exceptions.UnparsedFlagAccessError(
+ 'Trying to access flag --%s before flags were parsed.' % name)
+
+ def __setattr__(self, name, value):
+ """Sets the 'value' attribute of the flag --name."""
+ self._set_attributes(**{name: value})
+ return value
+
+ def _set_attributes(self, **attributes):
+ """Sets multiple flag values together, triggers validators afterwards."""
+ fl = self._flags()
+ known_flags = set()
+ for name, value in attributes.items():
+ if name in self.__dict__['__hiddenflags']:
+ raise AttributeError(name)
+ if name in fl:
+ fl[name].value = value
+ known_flags.add(name)
+ else:
+ self._set_unknown_flag(name, value)
+ for name in known_flags:
+ self._assert_validators(fl[name].validators)
+ fl[name].using_default_value = False
+
+ def validate_all_flags(self):
+ """Verifies whether all flags pass validation.
+
+ Raises:
+ AttributeError: Raised if validators work with a non-existing flag.
+ IllegalFlagValueError: Raised if validation fails for at least one
+ validator.
+ """
+ all_validators = set()
+ for flag in self._flags().values():
+ all_validators.update(flag.validators)
+ self._assert_validators(all_validators)
+
+ def _assert_validators(self, validators):
+ """Asserts if all validators in the list are satisfied.
+
+ It asserts validators in the order they were created.
+
+ Args:
+ validators: Iterable(validators.Validator), validators to be verified.
+
+ Raises:
+ AttributeError: Raised if validators work with a non-existing flag.
+ IllegalFlagValueError: Raised if validation fails for at least one
+ validator.
+ """
+ messages = []
+ bad_flags = set()
+ for validator in sorted(
+ validators, key=lambda validator: validator.insertion_index):
+ try:
+ if isinstance(validator, _validators_classes.SingleFlagValidator):
+ if validator.flag_name in bad_flags:
+ continue
+ elif isinstance(validator, _validators_classes.MultiFlagsValidator):
+ if bad_flags & set(validator.flag_names):
+ continue
+ validator.verify(self)
+ except _exceptions.ValidationError as e:
+ if isinstance(validator, _validators_classes.SingleFlagValidator):
+ bad_flags.add(validator.flag_name)
+ elif isinstance(validator, _validators_classes.MultiFlagsValidator):
+ bad_flags.update(set(validator.flag_names))
+ message = validator.print_flags_with_values(self)
+ messages.append('%s: %s' % (message, str(e)))
+ if messages:
+ raise _exceptions.IllegalFlagValueError('\n'.join(messages))
+
+ def __delattr__(self, flag_name):
+ """Deletes a previously-defined flag from a flag object.
+
+ This method makes sure we can delete a flag by using
+
+ del FLAGS.<flag_name>
+
+ E.g.,
+
+ flags.DEFINE_integer('foo', 1, 'Integer flag.')
+ del flags.FLAGS.foo
+
+ If a flag is also registered by its the other name (long name or short
+ name), the other name won't be deleted.
+
+ Args:
+ flag_name: str, the name of the flag to be deleted.
+
+ Raises:
+ AttributeError: Raised when there is no registered flag named flag_name.
+ """
+ fl = self._flags()
+ if flag_name not in fl:
+ raise AttributeError(flag_name)
+
+ flag_obj = fl[flag_name]
+ del fl[flag_name]
+
+ self._cleanup_unregistered_flag_from_module_dicts(flag_obj)
+
+ def set_default(self, name, value):
+ """Changes the default value of the named flag object.
+
+ The flag's current value is also updated if the flag is currently using
+ the default value, i.e. not specified in the command line, and not set
+ by FLAGS.name = value.
+
+ Args:
+ name: str, the name of the flag to modify.
+ value: The new default value.
+
+ Raises:
+ UnrecognizedFlagError: Raised when there is no registered flag named name.
+ IllegalFlagValueError: Raised when value is not valid.
+ """
+ fl = self._flags()
+ if name not in fl:
+ self._set_unknown_flag(name, value)
+ return
+ fl[name]._set_default(value) # pylint: disable=protected-access
+ self._assert_validators(fl[name].validators)
+
+ def __contains__(self, name):
+ """Returns True if name is a value (flag) in the dict."""
+ return name in self._flags()
+
+ def __len__(self):
+ return len(self.__dict__['__flags'])
+
+ def __iter__(self):
+ return iter(self._flags())
+
+ def __call__(self, argv, known_only=False):
+ """Parses flags from argv; stores parsed flags into this FlagValues object.
+
+ All unparsed arguments are returned.
+
+ Args:
+ argv: a tuple/list of strings.
+ known_only: bool, if True, parse and remove known flags; return the rest
+ untouched. Unknown flags specified by --undefok are not returned.
+
+ Returns:
+ The list of arguments not parsed as options, including argv[0].
+
+ Raises:
+ Error: Raised on any parsing error.
+ TypeError: Raised on passing wrong type of arguments.
+ ValueError: Raised on flag value parsing error.
+ """
+ if _helpers.is_bytes_or_string(argv):
+ raise TypeError(
+ 'argv should be a tuple/list of strings, not bytes or string.')
+ if not argv:
+ raise ValueError(
+ 'argv cannot be an empty list, and must contain the program name as '
+ 'the first element.')
+
+ # This pre parses the argv list for --flagfile=<> options.
+ program_name = argv[0]
+ args = self.read_flags_from_files(argv[1:], force_gnu=False)
+
+ # Parse the arguments.
+ unknown_flags, unparsed_args = self._parse_args(args, known_only)
+
+ # Handle unknown flags by raising UnrecognizedFlagError.
+ # Note some users depend on us raising this particular error.
+ for name, value in unknown_flags:
+ suggestions = _helpers.get_flag_suggestions(name, list(self))
+ raise _exceptions.UnrecognizedFlagError(
+ name, value, suggestions=suggestions)
+
+ self.mark_as_parsed()
+ self.validate_all_flags()
+ return [program_name] + unparsed_args
+
+ def __getstate__(self):
+ raise TypeError("can't pickle FlagValues")
+
+ def __copy__(self):
+ raise TypeError('FlagValues does not support shallow copies. '
+ 'Use absl.testing.flagsaver or copy.deepcopy instead.')
+
+ def __deepcopy__(self, memo):
+ result = object.__new__(type(self))
+ result.__dict__.update(copy.deepcopy(self.__dict__, memo))
+ return result
+
+ def _set_is_retired_flag_func(self, is_retired_flag_func):
+ """Sets a function for checking retired flags.
+
+ Do not use it. This is a private absl API used to check retired flags
+ registered by the absl C++ flags library.
+
+ Args:
+ is_retired_flag_func: Callable(str) -> (bool, bool), a function takes flag
+ name as parameter, returns a tuple (is_retired, type_is_bool).
+ """
+ self.__dict__['__is_retired_flag_func'] = is_retired_flag_func
+
+ def _parse_args(self, args, known_only):
+ """Helper function to do the main argument parsing.
+
+ This function goes through args and does the bulk of the flag parsing.
+ It will find the corresponding flag in our flag dictionary, and call its
+ .parse() method on the flag value.
+
+ Args:
+ args: [str], a list of strings with the arguments to parse.
+ known_only: bool, if True, parse and remove known flags; return the rest
+ untouched. Unknown flags specified by --undefok are not returned.
+
+ Returns:
+ A tuple with the following:
+ unknown_flags: List of (flag name, arg) for flags we don't know about.
+ unparsed_args: List of arguments we did not parse.
+
+ Raises:
+ Error: Raised on any parsing error.
+ ValueError: Raised on flag value parsing error.
+ """
+ unparsed_names_and_args = [] # A list of (flag name or None, arg).
+ undefok = set()
+ retired_flag_func = self.__dict__['__is_retired_flag_func']
+
+ flag_dict = self._flags()
+ args = iter(args)
+ for arg in args:
+ value = None
+
+ def get_value():
+ # pylint: disable=cell-var-from-loop
+ try:
+ return next(args) if value is None else value
+ except StopIteration:
+ raise _exceptions.Error('Missing value for flag ' + arg) # pylint: disable=undefined-loop-variable
+
+ if not arg.startswith('-'):
+ # A non-argument: default is break, GNU is skip.
+ unparsed_names_and_args.append((None, arg))
+ if self.is_gnu_getopt():
+ continue
+ else:
+ break
+
+ if arg == '--':
+ if known_only:
+ unparsed_names_and_args.append((None, arg))
+ break
+
+ # At this point, arg must start with '-'.
+ if arg.startswith('--'):
+ arg_without_dashes = arg[2:]
+ else:
+ arg_without_dashes = arg[1:]
+
+ if '=' in arg_without_dashes:
+ name, value = arg_without_dashes.split('=', 1)
+ else:
+ name, value = arg_without_dashes, None
+
+ if not name:
+ # The argument is all dashes (including one dash).
+ unparsed_names_and_args.append((None, arg))
+ if self.is_gnu_getopt():
+ continue
+ else:
+ break
+
+ # --undefok is a special case.
+ if name == 'undefok':
+ value = get_value()
+ undefok.update(v.strip() for v in value.split(','))
+ undefok.update('no' + v.strip() for v in value.split(','))
+ continue
+
+ flag = flag_dict.get(name)
+ if flag is not None:
+ if flag.boolean and value is None:
+ value = 'true'
+ else:
+ value = get_value()
+ elif name.startswith('no') and len(name) > 2:
+ # Boolean flags can take the form of --noflag, with no value.
+ noflag = flag_dict.get(name[2:])
+ if noflag is not None and noflag.boolean:
+ if value is not None:
+ raise ValueError(arg + ' does not take an argument')
+ flag = noflag
+ value = 'false'
+
+ if retired_flag_func and flag is None:
+ is_retired, is_bool = retired_flag_func(name)
+
+ # If we didn't recognize that flag, but it starts with
+ # "no" then maybe it was a boolean flag specified in the
+ # --nofoo form.
+ if not is_retired and name.startswith('no'):
+ is_retired, is_bool = retired_flag_func(name[2:])
+ is_retired = is_retired and is_bool
+
+ if is_retired:
+ if not is_bool and value is None:
+ # This happens when a non-bool retired flag is specified
+ # in format of "--flag value".
+ get_value()
+ logging.error(
+ 'Flag "%s" is retired and should no longer '
+ 'be specified. See go/totw/90.', name)
+ continue
+
+ if flag is not None:
+ flag.parse(value)
+ flag.using_default_value = False
+ else:
+ unparsed_names_and_args.append((name, arg))
+
+ unknown_flags = []
+ unparsed_args = []
+ for name, arg in unparsed_names_and_args:
+ if name is None:
+ # Positional arguments.
+ unparsed_args.append(arg)
+ elif name in undefok:
+ # Remove undefok flags.
+ continue
+ else:
+ # This is an unknown flag.
+ if known_only:
+ unparsed_args.append(arg)
+ else:
+ unknown_flags.append((name, arg))
+
+ unparsed_args.extend(list(args))
+ return unknown_flags, unparsed_args
+
+ def is_parsed(self):
+ """Returns whether flags were parsed."""
+ return self.__dict__['__flags_parsed']
+
+ def mark_as_parsed(self):
+ """Explicitly marks flags as parsed.
+
+ Use this when the caller knows that this FlagValues has been parsed as if
+ a __call__() invocation has happened. This is only a public method for
+ use by things like appcommands which do additional command like parsing.
+ """
+ self.__dict__['__flags_parsed'] = True
+
+ def unparse_flags(self):
+ """Unparses all flags to the point before any FLAGS(argv) was called."""
+ for f in self._flags().values():
+ f.unparse()
+ # We log this message before marking flags as unparsed to avoid a
+ # problem when the logging library causes flags access.
+ logging.info('unparse_flags() called; flags access will now raise errors.')
+ self.__dict__['__flags_parsed'] = False
+ self.__dict__['__unparse_flags_called'] = True
+
+ def flag_values_dict(self):
+ """Returns a dictionary that maps flag names to flag values."""
+ return {name: flag.value for name, flag in self._flags().items()}
+
+ def __str__(self):
+ """Returns a help string for all known flags."""
+ return self.get_help()
+
+ def get_help(self, prefix='', include_special_flags=True):
+ """Returns a help string for all known flags.
+
+ Args:
+ prefix: str, per-line output prefix.
+ include_special_flags: bool, whether to include description of
+ SPECIAL_FLAGS, i.e. --flagfile and --undefok.
+
+ Returns:
+ str, formatted help message.
+ """
+ flags_by_module = self.flags_by_module_dict()
+ if flags_by_module:
+ modules = sorted(flags_by_module)
+ # Print the help for the main module first, if possible.
+ main_module = sys.argv[0]
+ if main_module in modules:
+ modules.remove(main_module)
+ modules = [main_module] + modules
+ return self._get_help_for_modules(modules, prefix, include_special_flags)
+ else:
+ output_lines = []
+ # Just print one long list of flags.
+ values = self._flags().values()
+ if include_special_flags:
+ values = itertools.chain(
+ values, _helpers.SPECIAL_FLAGS._flags().values()) # pylint: disable=protected-access
+ self._render_flag_list(values, output_lines, prefix)
+ return '\n'.join(output_lines)
+
+ def _get_help_for_modules(self, modules, prefix, include_special_flags):
+ """Returns the help string for a list of modules.
+
+ Private to absl.flags package.
+
+ Args:
+ modules: List[str], a list of modules to get the help string for.
+ prefix: str, a string that is prepended to each generated help line.
+ include_special_flags: bool, whether to include description of
+ SPECIAL_FLAGS, i.e. --flagfile and --undefok.
+ """
+ output_lines = []
+ for module in modules:
+ self._render_our_module_flags(module, output_lines, prefix)
+ if include_special_flags:
+ self._render_module_flags(
+ 'absl.flags',
+ _helpers.SPECIAL_FLAGS._flags().values(), # pylint: disable=protected-access
+ output_lines,
+ prefix)
+ return '\n'.join(output_lines)
+
+ def _render_module_flags(self, module, flags, output_lines, prefix=''):
+ """Returns a help string for a given module."""
+ if not isinstance(module, str):
+ module = module.__name__
+ output_lines.append('\n%s%s:' % (prefix, module))
+ self._render_flag_list(flags, output_lines, prefix + ' ')
+
+ def _render_our_module_flags(self, module, output_lines, prefix=''):
+ """Returns a help string for a given module."""
+ flags = self.get_flags_for_module(module)
+ if flags:
+ self._render_module_flags(module, flags, output_lines, prefix)
+
+ def _render_our_module_key_flags(self, module, output_lines, prefix=''):
+ """Returns a help string for the key flags of a given module.
+
+ Args:
+ module: module|str, the module to render key flags for.
+ output_lines: [str], a list of strings. The generated help message lines
+ will be appended to this list.
+ prefix: str, a string that is prepended to each generated help line.
+ """
+ key_flags = self.get_key_flags_for_module(module)
+ if key_flags:
+ self._render_module_flags(module, key_flags, output_lines, prefix)
+
+ def module_help(self, module):
+ """Describes the key flags of a module.
+
+ Args:
+ module: module|str, the module to describe the key flags for.
+
+ Returns:
+ str, describing the key flags of a module.
+ """
+ helplist = []
+ self._render_our_module_key_flags(module, helplist)
+ return '\n'.join(helplist)
+
+ def main_module_help(self):
+ """Describes the key flags of the main module.
+
+ Returns:
+ str, describing the key flags of the main module.
+ """
+ return self.module_help(sys.argv[0])
+
+ def _render_flag_list(self, flaglist, output_lines, prefix=' '):
+ fl = self._flags()
+ special_fl = _helpers.SPECIAL_FLAGS._flags() # pylint: disable=protected-access
+ flaglist = [(flag.name, flag) for flag in flaglist]
+ flaglist.sort()
+ flagset = {}
+ for (name, flag) in flaglist:
+ # It's possible this flag got deleted or overridden since being
+ # registered in the per-module flaglist. Check now against the
+ # canonical source of current flag information, the _flags.
+ if fl.get(name, None) != flag and special_fl.get(name, None) != flag:
+ # a different flag is using this name now
+ continue
+ # only print help once
+ if flag in flagset:
+ continue
+ flagset[flag] = 1
+ flaghelp = ''
+ if flag.short_name:
+ flaghelp += '-%s,' % flag.short_name
+ if flag.boolean:
+ flaghelp += '--[no]%s:' % flag.name
+ else:
+ flaghelp += '--%s:' % flag.name
+ flaghelp += ' '
+ if flag.help:
+ flaghelp += flag.help
+ flaghelp = _helpers.text_wrap(
+ flaghelp, indent=prefix + ' ', firstline_indent=prefix)
+ if flag.default_as_str:
+ flaghelp += '\n'
+ flaghelp += _helpers.text_wrap(
+ '(default: %s)' % flag.default_as_str, indent=prefix + ' ')
+ if flag.parser.syntactic_help:
+ flaghelp += '\n'
+ flaghelp += _helpers.text_wrap(
+ '(%s)' % flag.parser.syntactic_help, indent=prefix + ' ')
+ output_lines.append(flaghelp)
+
+ def get_flag_value(self, name, default): # pylint: disable=invalid-name
+ """Returns the value of a flag (if not None) or a default value.
+
+ Args:
+ name: str, the name of a flag.
+ default: Default value to use if the flag value is None.
+
+ Returns:
+ Requested flag value or default.
+ """
+
+ value = self.__getattr__(name)
+ if value is not None: # Can't do if not value, b/c value might be '0' or ""
+ return value
+ else:
+ return default
+
+ def _is_flag_file_directive(self, flag_string):
+ """Checks whether flag_string contain a --flagfile=<foo> directive."""
+ if isinstance(flag_string, type('')):
+ if flag_string.startswith('--flagfile='):
+ return 1
+ elif flag_string == '--flagfile':
+ return 1
+ elif flag_string.startswith('-flagfile='):
+ return 1
+ elif flag_string == '-flagfile':
+ return 1
+ else:
+ return 0
+ return 0
+
+ def _extract_filename(self, flagfile_str):
+ """Returns filename from a flagfile_str of form -[-]flagfile=filename.
+
+ The cases of --flagfile foo and -flagfile foo shouldn't be hitting
+ this function, as they are dealt with in the level above this
+ function.
+
+ Args:
+ flagfile_str: str, the flagfile string.
+
+ Returns:
+ str, the filename from a flagfile_str of form -[-]flagfile=filename.
+
+ Raises:
+ Error: Raised when illegal --flagfile is provided.
+ """
+ if flagfile_str.startswith('--flagfile='):
+ return os.path.expanduser((flagfile_str[(len('--flagfile=')):]).strip())
+ elif flagfile_str.startswith('-flagfile='):
+ return os.path.expanduser((flagfile_str[(len('-flagfile=')):]).strip())
+ else:
+ raise _exceptions.Error('Hit illegal --flagfile type: %s' % flagfile_str)
+
+ def _get_flag_file_lines(self, filename, parsed_file_stack=None):
+ """Returns the useful (!=comments, etc) lines from a file with flags.
+
+ Args:
+ filename: str, the name of the flag file.
+ parsed_file_stack: [str], a list of the names of the files that we have
+ recursively encountered at the current depth. MUTATED BY THIS FUNCTION
+ (but the original value is preserved upon successfully returning from
+ function call).
+
+ Returns:
+ List of strings. See the note below.
+
+ NOTE(springer): This function checks for a nested --flagfile=<foo>
+ tag and handles the lower file recursively. It returns a list of
+ all the lines that _could_ contain command flags. This is
+ EVERYTHING except whitespace lines and comments (lines starting
+ with '#' or '//').
+ """
+ # For consistency with the cpp version, ignore empty values.
+ if not filename:
+ return []
+ if parsed_file_stack is None:
+ parsed_file_stack = []
+ # We do a little safety check for reparsing a file we've already encountered
+ # at a previous depth.
+ if filename in parsed_file_stack:
+ sys.stderr.write('Warning: Hit circular flagfile dependency. Ignoring'
+ ' flagfile: %s\n' % (filename,))
+ return []
+ else:
+ parsed_file_stack.append(filename)
+
+ line_list = [] # All line from flagfile.
+ flag_line_list = [] # Subset of lines w/o comments, blanks, flagfile= tags.
+ try:
+ file_obj = open(filename, 'r')
+ except IOError as e_msg:
+ raise _exceptions.CantOpenFlagFileError(
+ 'ERROR:: Unable to open flagfile: %s' % e_msg)
+
+ with file_obj:
+ line_list = file_obj.readlines()
+
+ # This is where we check each line in the file we just read.
+ for line in line_list:
+ if line.isspace():
+ pass
+ # Checks for comment (a line that starts with '#').
+ elif line.startswith('#') or line.startswith('//'):
+ pass
+ # Checks for a nested "--flagfile=<bar>" flag in the current file.
+ # If we find one, recursively parse down into that file.
+ elif self._is_flag_file_directive(line):
+ sub_filename = self._extract_filename(line)
+ included_flags = self._get_flag_file_lines(
+ sub_filename, parsed_file_stack=parsed_file_stack)
+ flag_line_list.extend(included_flags)
+ else:
+ # Any line that's not a comment or a nested flagfile should get
+ # copied into 2nd position. This leaves earlier arguments
+ # further back in the list, thus giving them higher priority.
+ flag_line_list.append(line.strip())
+
+ parsed_file_stack.pop()
+ return flag_line_list
+
+ def read_flags_from_files(self, argv, force_gnu=True):
+ """Processes command line args, but also allow args to be read from file.
+
+ Args:
+ argv: [str], a list of strings, usually sys.argv[1:], which may contain
+ one or more flagfile directives of the form --flagfile="./filename".
+ Note that the name of the program (sys.argv[0]) should be omitted.
+ force_gnu: bool, if False, --flagfile parsing obeys the
+ FLAGS.is_gnu_getopt() value. If True, ignore the value and always follow
+ gnu_getopt semantics.
+
+ Returns:
+ A new list which has the original list combined with what we read
+ from any flagfile(s).
+
+ Raises:
+ IllegalFlagValueError: Raised when --flagfile is provided with no
+ argument.
+
+ This function is called by FLAGS(argv).
+ It scans the input list for a flag that looks like:
+ --flagfile=<somefile>. Then it opens <somefile>, reads all valid key
+ and value pairs and inserts them into the input list in exactly the
+ place where the --flagfile arg is found.
+
+ Note that your application's flags are still defined the usual way
+ using absl.flags DEFINE_flag() type functions.
+
+ Notes (assuming we're getting a commandline of some sort as our input):
+ --> For duplicate flags, the last one we hit should "win".
+ --> Since flags that appear later win, a flagfile's settings can be "weak"
+ if the --flagfile comes at the beginning of the argument sequence,
+ and it can be "strong" if the --flagfile comes at the end.
+ --> A further "--flagfile=<otherfile.cfg>" CAN be nested in a flagfile.
+ It will be expanded in exactly the spot where it is found.
+ --> In a flagfile, a line beginning with # or // is a comment.
+ --> Entirely blank lines _should_ be ignored.
+ """
+ rest_of_args = argv
+ new_argv = []
+ while rest_of_args:
+ current_arg = rest_of_args[0]
+ rest_of_args = rest_of_args[1:]
+ if self._is_flag_file_directive(current_arg):
+ # This handles the case of -(-)flagfile foo. In this case the
+ # next arg really is part of this one.
+ if current_arg == '--flagfile' or current_arg == '-flagfile':
+ if not rest_of_args:
+ raise _exceptions.IllegalFlagValueError(
+ '--flagfile with no argument')
+ flag_filename = os.path.expanduser(rest_of_args[0])
+ rest_of_args = rest_of_args[1:]
+ else:
+ # This handles the case of (-)-flagfile=foo.
+ flag_filename = self._extract_filename(current_arg)
+ new_argv.extend(self._get_flag_file_lines(flag_filename))
+ else:
+ new_argv.append(current_arg)
+ # Stop parsing after '--', like getopt and gnu_getopt.
+ if current_arg == '--':
+ break
+ # Stop parsing after a non-flag, like getopt.
+ if not current_arg.startswith('-'):
+ if not force_gnu and not self.__dict__['__use_gnu_getopt']:
+ break
+ else:
+ if ('=' not in current_arg and rest_of_args and
+ not rest_of_args[0].startswith('-')):
+ # If this is an occurrence of a legitimate --x y, skip the value
+ # so that it won't be mistaken for a standalone arg.
+ fl = self._flags()
+ name = current_arg.lstrip('-')
+ if name in fl and not fl[name].boolean:
+ current_arg = rest_of_args[0]
+ rest_of_args = rest_of_args[1:]
+ new_argv.append(current_arg)
+
+ if rest_of_args:
+ new_argv.extend(rest_of_args)
+
+ return new_argv
+
+ def flags_into_string(self):
+ """Returns a string with the flags assignments from this FlagValues object.
+
+ This function ignores flags whose value is None. Each flag
+ assignment is separated by a newline.
+
+ NOTE: MUST mirror the behavior of the C++ CommandlineFlagsIntoString
+ from https://github.com/gflags/gflags.
+
+ Returns:
+ str, the string with the flags assignments from this FlagValues object.
+ The flags are ordered by (module_name, flag_name).
+ """
+ module_flags = sorted(self.flags_by_module_dict().items())
+ s = ''
+ for unused_module_name, flags in module_flags:
+ flags = sorted(flags, key=lambda f: f.name)
+ for flag in flags:
+ if flag.value is not None:
+ s += flag.serialize() + '\n'
+ return s
+
+ def append_flags_into_file(self, filename):
+ """Appends all flags assignments from this FlagInfo object to a file.
+
+ Output will be in the format of a flagfile.
+
+ NOTE: MUST mirror the behavior of the C++ AppendFlagsIntoFile
+ from https://github.com/gflags/gflags.
+
+ Args:
+ filename: str, name of the file.
+ """
+ with open(filename, 'a') as out_file:
+ out_file.write(self.flags_into_string())
+
+ def write_help_in_xml_format(self, outfile=None):
+ """Outputs flag documentation in XML format.
+
+ NOTE: We use element names that are consistent with those used by
+ the C++ command-line flag library, from
+ https://github.com/gflags/gflags.
+ We also use a few new elements (e.g., <key>), but we do not
+ interfere / overlap with existing XML elements used by the C++
+ library. Please maintain this consistency.
+
+ Args:
+ outfile: File object we write to. Default None means sys.stdout.
+ """
+ doc = minidom.Document()
+ all_flag = doc.createElement('AllFlags')
+ doc.appendChild(all_flag)
+
+ all_flag.appendChild(
+ _helpers.create_xml_dom_element(doc, 'program',
+ os.path.basename(sys.argv[0])))
+
+ usage_doc = sys.modules['__main__'].__doc__
+ if not usage_doc:
+ usage_doc = '\nUSAGE: %s [flags]\n' % sys.argv[0]
+ else:
+ usage_doc = usage_doc.replace('%s', sys.argv[0])
+ all_flag.appendChild(
+ _helpers.create_xml_dom_element(doc, 'usage', usage_doc))
+
+ # Get list of key flags for the main module.
+ key_flags = self.get_key_flags_for_module(sys.argv[0])
+
+ # Sort flags by declaring module name and next by flag name.
+ flags_by_module = self.flags_by_module_dict()
+ all_module_names = list(flags_by_module.keys())
+ all_module_names.sort()
+ for module_name in all_module_names:
+ flag_list = [(f.name, f) for f in flags_by_module[module_name]]
+ flag_list.sort()
+ for unused_flag_name, flag in flag_list:
+ is_key = flag in key_flags
+ all_flag.appendChild(
+ flag._create_xml_dom_element( # pylint: disable=protected-access
+ doc,
+ module_name,
+ is_key=is_key))
+
+ outfile = outfile or sys.stdout
+ outfile.write(
+ doc.toprettyxml(indent=' ', encoding='utf-8').decode('utf-8'))
+ outfile.flush()
+
+ def _check_method_name_conflicts(self, name, flag):
+ if flag.allow_using_method_names:
+ return
+ short_name = flag.short_name
+ flag_names = {name} if short_name is None else {name, short_name}
+ for flag_name in flag_names:
+ if flag_name in self.__dict__['__banned_flag_names']:
+ raise _exceptions.FlagNameConflictsWithMethodError(
+ 'Cannot define a flag named "{name}". It conflicts with a method '
+ 'on class "{class_name}". To allow defining it, use '
+ 'allow_using_method_names and access the flag value with '
+ "FLAGS['{name}'].value. FLAGS.{name} returns the method, "
+ 'not the flag value.'.format(
+ name=flag_name, class_name=type(self).__name__))
+
+
+FLAGS = FlagValues()
+
+
+class FlagHolder(Generic[_T]):
+ """Holds a defined flag.
+
+ This facilitates a cleaner api around global state. Instead of
+
+ ```
+ flags.DEFINE_integer('foo', ...)
+ flags.DEFINE_integer('bar', ...)
+ ...
+ def method():
+ # prints parsed value of 'bar' flag
+ print(flags.FLAGS.foo)
+ # runtime error due to typo or possibly bad coding style.
+ print(flags.FLAGS.baz)
+ ```
+
+ it encourages code like
+
+ ```
+ FOO_FLAG = flags.DEFINE_integer('foo', ...)
+ BAR_FLAG = flags.DEFINE_integer('bar', ...)
+ ...
+ def method():
+ print(FOO_FLAG.value)
+ print(BAR_FLAG.value)
+ ```
+
+ since the name of the flag appears only once in the source code.
+ """
+
+ def __init__(self, flag_values, flag, ensure_non_none_value=False):
+ """Constructs a FlagHolder instance providing typesafe access to flag.
+
+ Args:
+ flag_values: The container the flag is registered to.
+ flag: The flag object for this flag.
+ ensure_non_none_value: Is the value of the flag allowed to be None.
+ """
+ self._flagvalues = flag_values
+ # We take the entire flag object, but only keep the name. Why?
+ # - We want FlagHolder[T] to be generic container
+ # - flag_values contains all flags, so has no reference to T.
+ # - typecheckers don't like to see a generic class where none of the ctor
+ # arguments refer to the generic type.
+ self._name = flag.name
+ # We intentionally do NOT check if the default value is None.
+ # This allows future use of this for "required flags with None default"
+ self._ensure_non_none_value = ensure_non_none_value
+
+ def __eq__(self, other):
+ raise TypeError(
+ "unsupported operand type(s) for ==: '{0}' and '{1}' "
+ "(did you mean to use '{0}.value' instead?)".format(
+ type(self).__name__, type(other).__name__))
+
+ def __bool__(self):
+ raise TypeError(
+ "bool() not supported for instances of type '{0}' "
+ "(did you mean to use '{0}.value' instead?)".format(
+ type(self).__name__))
+
+ __nonzero__ = __bool__
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def value(self):
+ """Returns the value of the flag.
+
+ If _ensure_non_none_value is True, then return value is not None.
+
+ Raises:
+ UnparsedFlagAccessError: if flag parsing has not finished.
+ IllegalFlagValueError: if value is None unexpectedly.
+ """
+ val = getattr(self._flagvalues, self._name)
+ if self._ensure_non_none_value and val is None:
+ raise _exceptions.IllegalFlagValueError(
+ 'Unexpected None value for flag %s' % self._name)
+ return val
+
+ @property
+ def default(self):
+ """Returns the default value of the flag."""
+ return self._flagvalues[self._name].default
+
+ @property
+ def present(self):
+ """Returns True if the flag was parsed from command-line flags."""
+ return bool(self._flagvalues[self._name].present)
diff --git a/absl/flags/_flagvalues.pyi b/absl/flags/_flagvalues.pyi
new file mode 100644
index 0000000..e25c6dd
--- /dev/null
+++ b/absl/flags/_flagvalues.pyi
@@ -0,0 +1,148 @@
+# Copyright 2020 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Defines type annotations for _flagvalues."""
+
+
+from absl.flags import _flag
+
+from typing import Any, Dict, Generic, Iterable, Iterator, List, Optional, Sequence, Text, Type, TypeVar
+
+
+class FlagValues:
+
+ def __getitem__(self, name: Text) -> _flag.Flag: ...
+
+ def __setitem__(self, name: Text, flag: _flag.Flag) -> None: ...
+
+ def __getattr__(self, name: Text) -> Any: ...
+
+ def __setattr__(self, name: Text, value: Any) -> Any: ...
+
+ def __call__(
+ self,
+ argv: Sequence[Text],
+ known_only: bool = ...,
+ ) -> List[Text]: ...
+
+ def __contains__(self, name: Text) -> bool: ...
+
+ def __copy__(self) -> Any: ...
+
+ def __deepcopy__(self, memo) -> Any: ...
+
+ def __delattr__(self, flag_name: Text) -> None: ...
+
+ def __dir__(self) -> List[Text]: ...
+
+ def __getstate__(self) -> Any: ...
+
+ def __iter__(self) -> Iterator[Text]: ...
+
+ def __len__(self) -> int: ...
+
+ def get_help(self,
+ prefix: Text = ...,
+ include_special_flags: bool = ...) -> Text:
+ ...
+
+
+ def set_gnu_getopt(self, gnu_getopt: bool = ...) -> None: ...
+
+ def is_gnu_getopt(self) -> bool: ...
+
+ def flags_by_module_dict(self) -> Dict[Text, List[_flag.Flag]]: ...
+
+ def flags_by_module_id_dict(self) -> Dict[Text, List[_flag.Flag]]: ...
+
+ def key_flags_by_module_dict(self) -> Dict[Text, List[_flag.Flag]]: ...
+
+ def register_flag_by_module(
+ self, module_name: Text, flag: _flag.Flag) -> None: ...
+
+ def register_flag_by_module_id(
+ self, module_id: Text, flag: _flag.Flag) -> None: ...
+
+ def register_key_flag_for_module(
+ self, module_name: Text, flag: _flag.Flag) -> None: ...
+
+ def get_key_flags_for_module(self, module: Any) -> List[_flag.Flag]: ...
+
+ def find_module_defining_flag(
+ self, flagname: Text, default: Any = ...) -> Any:
+ ...
+
+ def find_module_id_defining_flag(
+ self, flagname: Text, default: Any = ...) -> Any:
+ ...
+
+ def append_flag_values(self, flag_values: Any) -> None: ...
+
+ def remove_flag_values(self, flag_values: Any) -> None: ...
+
+ def validate_all_flags(self) -> None: ...
+
+ def set_default(self, name: Text, value: Any) -> None: ...
+
+ def is_parsed(self) -> bool: ...
+
+ def mark_as_parsed(self) -> None: ...
+
+ def unparse_flags(self) -> None: ...
+
+ def flag_values_dict(self) -> Dict[Text, Any]: ...
+
+ def module_help(self, module: Any) -> Text: ...
+
+ def main_module_help(self) -> Text: ...
+
+ def get_flag_value(self, name: Text, default: Any) -> Any: ...
+
+ def read_flags_from_files(
+ self, argv: List[Text], force_gnu: bool = ...) -> List[Text]: ...
+
+ def flags_into_string(self) -> Text: ...
+
+ def append_flags_into_file(self, filename: Text) -> None:...
+
+ # outfile is Optional[fileobject]
+ def write_help_in_xml_format(self, outfile: Any = ...) -> None: ...
+
+
+FLAGS = ... # type: FlagValues
+
+
+_T = TypeVar('_T') # The type of parsed default value of the flag.
+
+# We assume that default and value are guaranteed to have the same type.
+class FlagHolder(Generic[_T]):
+ def __init__(
+ self,
+ flag_values: FlagValues,
+ # NOTE: Use Flag instead of Flag[T] is used to work around some superficial
+ # differences between Flag and FlagHolder typing.
+ flag: _flag.Flag,
+ ensure_non_none_value: bool=False) -> None: ...
+
+ @property
+ def name(self) -> Text: ...
+
+ @property
+ def value(self) -> _T: ...
+
+ @property
+ def default(self) -> _T: ...
+
+ @property
+ def present(self) -> bool: ...
diff --git a/absl/flags/_helpers.py b/absl/flags/_helpers.py
new file mode 100644
index 0000000..37ae360
--- /dev/null
+++ b/absl/flags/_helpers.py
@@ -0,0 +1,433 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Internal helper functions for Abseil Python flags library."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import os
+import re
+import struct
+import sys
+import textwrap
+try:
+ import fcntl
+except ImportError:
+ fcntl = None
+try:
+ # Importing termios will fail on non-unix platforms.
+ import termios
+except ImportError:
+ termios = None
+
+
+_DEFAULT_HELP_WIDTH = 80 # Default width of help output.
+_MIN_HELP_WIDTH = 40 # Minimal "sane" width of help output. We assume that any
+ # value below 40 is unreasonable.
+
+# Define the allowed error rate in an input string to get suggestions.
+#
+# We lean towards a high threshold because we tend to be matching a phrase,
+# and the simple algorithm used here is geared towards correcting word
+# spellings.
+#
+# For manual testing, consider "<command> --list" which produced a large number
+# of spurious suggestions when we used "least_errors > 0.5" instead of
+# "least_erros >= 0.5".
+_SUGGESTION_ERROR_RATE_THRESHOLD = 0.50
+
+# Characters that cannot appear or are highly discouraged in an XML 1.0
+# document. (See http://www.w3.org/TR/REC-xml/#charsets or
+# https://en.wikipedia.org/wiki/Valid_characters_in_XML#XML_1.0)
+_ILLEGAL_XML_CHARS_REGEX = re.compile(
+ u'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x84\x86-\x9f\ud800-\udfff\ufffe\uffff]')
+
+# This is a set of module ids for the modules that disclaim key flags.
+# This module is explicitly added to this set so that we never consider it to
+# define key flag.
+disclaim_module_ids = set([id(sys.modules[__name__])])
+
+
+# Define special flags here so that help may be generated for them.
+# NOTE: Please do NOT use SPECIAL_FLAGS from outside flags module.
+# Initialized inside flagvalues.py.
+SPECIAL_FLAGS = None
+
+
+# This points to the flags module, initialized in flags/__init__.py.
+# This should only be used in adopt_module_key_flags to take SPECIAL_FLAGS into
+# account.
+FLAGS_MODULE = None
+
+
+class _ModuleObjectAndName(
+ collections.namedtuple('_ModuleObjectAndName', 'module module_name')):
+ """Module object and name.
+
+ Fields:
+ - module: object, module object.
+ - module_name: str, module name.
+ """
+
+
+def get_module_object_and_name(globals_dict):
+ """Returns the module that defines a global environment, and its name.
+
+ Args:
+ globals_dict: A dictionary that should correspond to an environment
+ providing the values of the globals.
+
+ Returns:
+ _ModuleObjectAndName - pair of module object & module name.
+ Returns (None, None) if the module could not be identified.
+ """
+ name = globals_dict.get('__name__', None)
+ module = sys.modules.get(name, None)
+ # Pick a more informative name for the main module.
+ return _ModuleObjectAndName(module,
+ (sys.argv[0] if name == '__main__' else name))
+
+
+def get_calling_module_object_and_name():
+ """Returns the module that's calling into this module.
+
+ We generally use this function to get the name of the module calling a
+ DEFINE_foo... function.
+
+ Returns:
+ The module object that called into this one.
+
+ Raises:
+ AssertionError: Raised when no calling module could be identified.
+ """
+ for depth in range(1, sys.getrecursionlimit()):
+ # sys._getframe is the right thing to use here, as it's the best
+ # way to walk up the call stack.
+ globals_for_frame = sys._getframe(depth).f_globals # pylint: disable=protected-access
+ module, module_name = get_module_object_and_name(globals_for_frame)
+ if id(module) not in disclaim_module_ids and module_name is not None:
+ return _ModuleObjectAndName(module, module_name)
+ raise AssertionError('No module was found')
+
+
+def get_calling_module():
+ """Returns the name of the module that's calling into this module."""
+ return get_calling_module_object_and_name().module_name
+
+
+def str_or_unicode(value):
+ """Converts a value to a python string.
+
+ Behavior of this function is intentionally different in Python2/3.
+
+ In Python2, the given value is attempted to convert to a str (byte string).
+ If it contains non-ASCII characters, it is converted to a unicode instead.
+
+ In Python3, the given value is always converted to a str (unicode string).
+
+ This behavior reflects the (bad) practice in Python2 to try to represent
+ a string as str as long as it contains ASCII characters only.
+
+ Args:
+ value: An object to be converted to a string.
+
+ Returns:
+ A string representation of the given value. See the description above
+ for its type.
+ """
+ try:
+ return str(value)
+ except UnicodeEncodeError:
+ return unicode(value) # Python3 should never come here
+
+
+def create_xml_dom_element(doc, name, value):
+ """Returns an XML DOM element with name and text value.
+
+ Args:
+ doc: minidom.Document, the DOM document it should create nodes from.
+ name: str, the tag of XML element.
+ value: object, whose string representation will be used
+ as the value of the XML element. Illegal or highly discouraged xml 1.0
+ characters are stripped.
+
+ Returns:
+ An instance of minidom.Element.
+ """
+ s = str_or_unicode(value)
+ if isinstance(value, bool):
+ # Display boolean values as the C++ flag library does: no caps.
+ s = s.lower()
+ # Remove illegal xml characters.
+ s = _ILLEGAL_XML_CHARS_REGEX.sub(u'', s)
+
+ e = doc.createElement(name)
+ e.appendChild(doc.createTextNode(s))
+ return e
+
+
+def get_help_width():
+ """Returns the integer width of help lines that is used in TextWrap."""
+ if not sys.stdout.isatty() or termios is None or fcntl is None:
+ return _DEFAULT_HELP_WIDTH
+ try:
+ data = fcntl.ioctl(sys.stdout, termios.TIOCGWINSZ, '1234')
+ columns = struct.unpack('hh', data)[1]
+ # Emacs mode returns 0.
+ # Here we assume that any value below 40 is unreasonable.
+ if columns >= _MIN_HELP_WIDTH:
+ return columns
+ # Returning an int as default is fine, int(int) just return the int.
+ return int(os.getenv('COLUMNS', _DEFAULT_HELP_WIDTH))
+
+ except (TypeError, IOError, struct.error):
+ return _DEFAULT_HELP_WIDTH
+
+
+def get_flag_suggestions(attempt, longopt_list):
+ """Returns helpful similar matches for an invalid flag."""
+ # Don't suggest on very short strings, or if no longopts are specified.
+ if len(attempt) <= 2 or not longopt_list:
+ return []
+
+ option_names = [v.split('=')[0] for v in longopt_list]
+
+ # Find close approximations in flag prefixes.
+ # This also handles the case where the flag is spelled right but ambiguous.
+ distances = [(_damerau_levenshtein(attempt, option[0:len(attempt)]), option)
+ for option in option_names]
+ # t[0] is distance, and sorting by t[1] allows us to have stable output.
+ distances.sort()
+
+ least_errors, _ = distances[0]
+ # Don't suggest excessively bad matches.
+ if least_errors >= _SUGGESTION_ERROR_RATE_THRESHOLD * len(attempt):
+ return []
+
+ suggestions = []
+ for errors, name in distances:
+ if errors == least_errors:
+ suggestions.append(name)
+ else:
+ break
+ return suggestions
+
+
+def _damerau_levenshtein(a, b):
+ """Returns Damerau-Levenshtein edit distance from a to b."""
+ memo = {}
+
+ def distance(x, y):
+ """Recursively defined string distance with memoization."""
+ if (x, y) in memo:
+ return memo[x, y]
+ if not x:
+ d = len(y)
+ elif not y:
+ d = len(x)
+ else:
+ d = min(
+ distance(x[1:], y) + 1, # correct an insertion error
+ distance(x, y[1:]) + 1, # correct a deletion error
+ distance(x[1:], y[1:]) + (x[0] != y[0])) # correct a wrong character
+ if len(x) >= 2 and len(y) >= 2 and x[0] == y[1] and x[1] == y[0]:
+ # Correct a transposition.
+ t = distance(x[2:], y[2:]) + 1
+ if d > t:
+ d = t
+
+ memo[x, y] = d
+ return d
+ return distance(a, b)
+
+
+def text_wrap(text, length=None, indent='', firstline_indent=None):
+ """Wraps a given text to a maximum line length and returns it.
+
+ It turns lines that only contain whitespace into empty lines, keeps new lines,
+ and expands tabs using 4 spaces.
+
+ Args:
+ text: str, text to wrap.
+ length: int, maximum length of a line, includes indentation.
+ If this is None then use get_help_width()
+ indent: str, indent for all but first line.
+ firstline_indent: str, indent for first line; if None, fall back to indent.
+
+ Returns:
+ str, the wrapped text.
+
+ Raises:
+ ValueError: Raised if indent or firstline_indent not shorter than length.
+ """
+ # Get defaults where callee used None
+ if length is None:
+ length = get_help_width()
+ if indent is None:
+ indent = ''
+ if firstline_indent is None:
+ firstline_indent = indent
+
+ if len(indent) >= length:
+ raise ValueError('Length of indent exceeds length')
+ if len(firstline_indent) >= length:
+ raise ValueError('Length of first line indent exceeds length')
+
+ text = text.expandtabs(4)
+
+ result = []
+ # Create one wrapper for the first paragraph and one for subsequent
+ # paragraphs that does not have the initial wrapping.
+ wrapper = textwrap.TextWrapper(
+ width=length, initial_indent=firstline_indent, subsequent_indent=indent)
+ subsequent_wrapper = textwrap.TextWrapper(
+ width=length, initial_indent=indent, subsequent_indent=indent)
+
+ # textwrap does not have any special treatment for newlines. From the docs:
+ # "...newlines may appear in the middle of a line and cause strange output.
+ # For this reason, text should be split into paragraphs (using
+ # str.splitlines() or similar) which are wrapped separately."
+ for paragraph in (p.strip() for p in text.splitlines()):
+ if paragraph:
+ result.extend(wrapper.wrap(paragraph))
+ else:
+ result.append('') # Keep empty lines.
+ # Replace initial wrapper with wrapper for subsequent paragraphs.
+ wrapper = subsequent_wrapper
+
+ return '\n'.join(result)
+
+
+def flag_dict_to_args(flag_map, multi_flags=None):
+ """Convert a dict of values into process call parameters.
+
+ This method is used to convert a dictionary into a sequence of parameters
+ for a binary that parses arguments using this module.
+
+ Args:
+ flag_map: dict, a mapping where the keys are flag names (strings).
+ values are treated according to their type:
+ * If value is None, then only the name is emitted.
+ * If value is True, then only the name is emitted.
+ * If value is False, then only the name prepended with 'no' is emitted.
+ * If value is a string then --name=value is emitted.
+ * If value is a collection, this will emit --name=value1,value2,value3,
+ unless the flag name is in multi_flags, in which case this will emit
+ --name=value1 --name=value2 --name=value3.
+ * Everything else is converted to string an passed as such.
+ multi_flags: set, names (strings) of flags that should be treated as
+ multi-flags.
+ Yields:
+ sequence of string suitable for a subprocess execution.
+ """
+ for key, value in flag_map.items():
+ if value is None:
+ yield '--%s' % key
+ elif isinstance(value, bool):
+ if value:
+ yield '--%s' % key
+ else:
+ yield '--no%s' % key
+ elif isinstance(value, (bytes, type(u''))):
+ # We don't want strings to be handled like python collections.
+ yield '--%s=%s' % (key, value)
+ else:
+ # Now we attempt to deal with collections.
+ try:
+ if multi_flags and key in multi_flags:
+ for item in value:
+ yield '--%s=%s' % (key, str(item))
+ else:
+ yield '--%s=%s' % (key, ','.join(str(item) for item in value))
+ except TypeError:
+ # Default case.
+ yield '--%s=%s' % (key, value)
+
+
+def trim_docstring(docstring):
+ """Removes indentation from triple-quoted strings.
+
+ This is the function specified in PEP 257 to handle docstrings:
+ https://www.python.org/dev/peps/pep-0257/.
+
+ Args:
+ docstring: str, a python docstring.
+
+ Returns:
+ str, docstring with indentation removed.
+ """
+ if not docstring:
+ return ''
+
+ # If you've got a line longer than this you have other problems...
+ max_indent = 1 << 29
+
+ # Convert tabs to spaces (following the normal Python rules)
+ # and split into a list of lines:
+ lines = docstring.expandtabs().splitlines()
+
+ # Determine minimum indentation (first line doesn't count):
+ indent = max_indent
+ for line in lines[1:]:
+ stripped = line.lstrip()
+ if stripped:
+ indent = min(indent, len(line) - len(stripped))
+ # Remove indentation (first line is special):
+ trimmed = [lines[0].strip()]
+ if indent < max_indent:
+ for line in lines[1:]:
+ trimmed.append(line[indent:].rstrip())
+ # Strip off trailing and leading blank lines:
+ while trimmed and not trimmed[-1]:
+ trimmed.pop()
+ while trimmed and not trimmed[0]:
+ trimmed.pop(0)
+ # Return a single string:
+ return '\n'.join(trimmed)
+
+
+def doc_to_help(doc):
+ """Takes a __doc__ string and reformats it as help."""
+
+ # Get rid of starting and ending white space. Using lstrip() or even
+ # strip() could drop more than maximum of first line and right space
+ # of last line.
+ doc = doc.strip()
+
+ # Get rid of all empty lines.
+ whitespace_only_line = re.compile('^[ \t]+$', re.M)
+ doc = whitespace_only_line.sub('', doc)
+
+ # Cut out common space at line beginnings.
+ doc = trim_docstring(doc)
+
+ # Just like this module's comment, comments tend to be aligned somehow.
+ # In other words they all start with the same amount of white space.
+ # 1) keep double new lines;
+ # 2) keep ws after new lines if not empty line;
+ # 3) all other new lines shall be changed to a space;
+ # Solution: Match new lines between non white space and replace with space.
+ doc = re.sub(r'(?<=\S)\n(?=\S)', ' ', doc, flags=re.M)
+
+ return doc
+
+
+def is_bytes_or_string(maybe_string):
+ if str is bytes:
+ return isinstance(maybe_string, basestring)
+ else:
+ return isinstance(maybe_string, (str, bytes))
diff --git a/absl/flags/_validators.py b/absl/flags/_validators.py
new file mode 100644
index 0000000..af66050
--- /dev/null
+++ b/absl/flags/_validators.py
@@ -0,0 +1,313 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Module to enforce different constraints on flags.
+
+Flags validators can be registered using following functions / decorators:
+ flags.register_validator
+ @flags.validator
+ flags.register_multi_flags_validator
+ @flags.multi_flags_validator
+
+Three convenience functions are also provided for common flag constraints:
+ flags.mark_flag_as_required
+ flags.mark_flags_as_required
+ flags.mark_flags_as_mutual_exclusive
+ flags.mark_bool_flags_as_mutual_exclusive
+
+See their docstring in this module for a usage manual.
+
+Do NOT import this module directly. Import the flags package and use the
+aliases defined at the package level instead.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import warnings
+
+from absl.flags import _exceptions
+from absl.flags import _flagvalues
+from absl.flags import _validators_classes
+
+
+def register_validator(flag_name,
+ checker,
+ message='Flag validation failed',
+ flag_values=_flagvalues.FLAGS):
+ """Adds a constraint, which will be enforced during program execution.
+
+ The constraint is validated when flags are initially parsed, and after each
+ change of the corresponding flag's value.
+ Args:
+ flag_name: str, name of the flag to be checked.
+ checker: callable, a function to validate the flag.
+ input - A single positional argument: The value of the corresponding
+ flag (string, boolean, etc. This value will be passed to checker
+ by the library).
+ output - bool, True if validator constraint is satisfied.
+ If constraint is not satisfied, it should either return False or
+ raise flags.ValidationError(desired_error_message).
+ message: str, error text to be shown to the user if checker returns False.
+ If checker raises flags.ValidationError, message from the raised
+ error will be shown.
+ flag_values: flags.FlagValues, optional FlagValues instance to validate
+ against.
+ Raises:
+ AttributeError: Raised when flag_name is not registered as a valid flag
+ name.
+ """
+ v = _validators_classes.SingleFlagValidator(flag_name, checker, message)
+ _add_validator(flag_values, v)
+
+
+def validator(flag_name, message='Flag validation failed',
+ flag_values=_flagvalues.FLAGS):
+ """A function decorator for defining a flag validator.
+
+ Registers the decorated function as a validator for flag_name, e.g.
+
+ @flags.validator('foo')
+ def _CheckFoo(foo):
+ ...
+
+ See register_validator() for the specification of checker function.
+
+ Args:
+ flag_name: str, name of the flag to be checked.
+ message: str, error text to be shown to the user if checker returns False.
+ If checker raises flags.ValidationError, message from the raised
+ error will be shown.
+ flag_values: flags.FlagValues, optional FlagValues instance to validate
+ against.
+ Returns:
+ A function decorator that registers its function argument as a validator.
+ Raises:
+ AttributeError: Raised when flag_name is not registered as a valid flag
+ name.
+ """
+
+ def decorate(function):
+ register_validator(flag_name, function,
+ message=message,
+ flag_values=flag_values)
+ return function
+ return decorate
+
+
+def register_multi_flags_validator(flag_names,
+ multi_flags_checker,
+ message='Flags validation failed',
+ flag_values=_flagvalues.FLAGS):
+ """Adds a constraint to multiple flags.
+
+ The constraint is validated when flags are initially parsed, and after each
+ change of the corresponding flag's value.
+
+ Args:
+ flag_names: [str], a list of the flag names to be checked.
+ multi_flags_checker: callable, a function to validate the flag.
+ input - dict, with keys() being flag_names, and value for each key
+ being the value of the corresponding flag (string, boolean, etc).
+ output - bool, True if validator constraint is satisfied.
+ If constraint is not satisfied, it should either return False or
+ raise flags.ValidationError.
+ message: str, error text to be shown to the user if checker returns False.
+ If checker raises flags.ValidationError, message from the raised
+ error will be shown.
+ flag_values: flags.FlagValues, optional FlagValues instance to validate
+ against.
+
+ Raises:
+ AttributeError: Raised when a flag is not registered as a valid flag name.
+ """
+ v = _validators_classes.MultiFlagsValidator(
+ flag_names, multi_flags_checker, message)
+ _add_validator(flag_values, v)
+
+
+def multi_flags_validator(flag_names,
+ message='Flag validation failed',
+ flag_values=_flagvalues.FLAGS):
+ """A function decorator for defining a multi-flag validator.
+
+ Registers the decorated function as a validator for flag_names, e.g.
+
+ @flags.multi_flags_validator(['foo', 'bar'])
+ def _CheckFooBar(flags_dict):
+ ...
+
+ See register_multi_flags_validator() for the specification of checker
+ function.
+
+ Args:
+ flag_names: [str], a list of the flag names to be checked.
+ message: str, error text to be shown to the user if checker returns False.
+ If checker raises flags.ValidationError, message from the raised
+ error will be shown.
+ flag_values: flags.FlagValues, optional FlagValues instance to validate
+ against.
+
+ Returns:
+ A function decorator that registers its function argument as a validator.
+
+ Raises:
+ AttributeError: Raised when a flag is not registered as a valid flag name.
+ """
+
+ def decorate(function):
+ register_multi_flags_validator(flag_names,
+ function,
+ message=message,
+ flag_values=flag_values)
+ return function
+
+ return decorate
+
+
+def mark_flag_as_required(flag_name, flag_values=_flagvalues.FLAGS):
+ """Ensures that flag is not None during program execution.
+
+ Registers a flag validator, which will follow usual validator rules.
+ Important note: validator will pass for any non-None value, such as False,
+ 0 (zero), '' (empty string) and so on.
+
+ If your module might be imported by others, and you only wish to make the flag
+ required when the module is directly executed, call this method like this:
+
+ if __name__ == '__main__':
+ flags.mark_flag_as_required('your_flag_name')
+ app.run()
+
+ Args:
+ flag_name: str, name of the flag
+ flag_values: flags.FlagValues, optional FlagValues instance where the flag
+ is defined.
+ Raises:
+ AttributeError: Raised when flag_name is not registered as a valid flag
+ name.
+ """
+ if flag_values[flag_name].default is not None:
+ warnings.warn(
+ 'Flag --%s has a non-None default value; therefore, '
+ 'mark_flag_as_required will pass even if flag is not specified in the '
+ 'command line!' % flag_name,
+ stacklevel=2)
+ register_validator(
+ flag_name,
+ lambda value: value is not None,
+ message='Flag --{} must have a value other than None.'.format(flag_name),
+ flag_values=flag_values)
+
+
+def mark_flags_as_required(flag_names, flag_values=_flagvalues.FLAGS):
+ """Ensures that flags are not None during program execution.
+
+ If your module might be imported by others, and you only wish to make the flag
+ required when the module is directly executed, call this method like this:
+
+ if __name__ == '__main__':
+ flags.mark_flags_as_required(['flag1', 'flag2', 'flag3'])
+ app.run()
+
+ Args:
+ flag_names: Sequence[str], names of the flags.
+ flag_values: flags.FlagValues, optional FlagValues instance where the flags
+ are defined.
+ Raises:
+ AttributeError: If any of flag name has not already been defined as a flag.
+ """
+ for flag_name in flag_names:
+ mark_flag_as_required(flag_name, flag_values)
+
+
+def mark_flags_as_mutual_exclusive(flag_names, required=False,
+ flag_values=_flagvalues.FLAGS):
+ """Ensures that only one flag among flag_names is not None.
+
+ Important note: This validator checks if flag values are None, and it does not
+ distinguish between default and explicit values. Therefore, this validator
+ does not make sense when applied to flags with default values other than None,
+ including other false values (e.g. False, 0, '', []). That includes multi
+ flags with a default value of [] instead of None.
+
+ Args:
+ flag_names: [str], names of the flags.
+ required: bool. If true, exactly one of the flags must have a value other
+ than None. Otherwise, at most one of the flags can have a value other
+ than None, and it is valid for all of the flags to be None.
+ flag_values: flags.FlagValues, optional FlagValues instance where the flags
+ are defined.
+ """
+ for flag_name in flag_names:
+ if flag_values[flag_name].default is not None:
+ warnings.warn(
+ 'Flag --{} has a non-None default value. That does not make sense '
+ 'with mark_flags_as_mutual_exclusive, which checks whether the '
+ 'listed flags have a value other than None.'.format(flag_name),
+ stacklevel=2)
+
+ def validate_mutual_exclusion(flags_dict):
+ flag_count = sum(1 for val in flags_dict.values() if val is not None)
+ if flag_count == 1 or (not required and flag_count == 0):
+ return True
+ raise _exceptions.ValidationError(
+ '{} one of ({}) must have a value other than None.'.format(
+ 'Exactly' if required else 'At most', ', '.join(flag_names)))
+
+ register_multi_flags_validator(
+ flag_names, validate_mutual_exclusion, flag_values=flag_values)
+
+
+def mark_bool_flags_as_mutual_exclusive(flag_names, required=False,
+ flag_values=_flagvalues.FLAGS):
+ """Ensures that only one flag among flag_names is True.
+
+ Args:
+ flag_names: [str], names of the flags.
+ required: bool. If true, exactly one flag must be True. Otherwise, at most
+ one flag can be True, and it is valid for all flags to be False.
+ flag_values: flags.FlagValues, optional FlagValues instance where the flags
+ are defined.
+ """
+ for flag_name in flag_names:
+ if not flag_values[flag_name].boolean:
+ raise _exceptions.ValidationError(
+ 'Flag --{} is not Boolean, which is required for flags used in '
+ 'mark_bool_flags_as_mutual_exclusive.'.format(flag_name))
+
+ def validate_boolean_mutual_exclusion(flags_dict):
+ flag_count = sum(bool(val) for val in flags_dict.values())
+ if flag_count == 1 or (not required and flag_count == 0):
+ return True
+ raise _exceptions.ValidationError(
+ '{} one of ({}) must be True.'.format(
+ 'Exactly' if required else 'At most', ', '.join(flag_names)))
+
+ register_multi_flags_validator(
+ flag_names, validate_boolean_mutual_exclusion, flag_values=flag_values)
+
+
+def _add_validator(fv, validator_instance):
+ """Register new flags validator to be checked.
+
+ Args:
+ fv: flags.FlagValues, the FlagValues instance to add the validator.
+ validator_instance: validators.Validator, the validator to add.
+ Raises:
+ KeyError: Raised when validators work with a non-existing flag.
+ """
+ for flag_name in validator_instance.get_flags_names():
+ fv[flag_name].validators.append(validator_instance)
diff --git a/absl/flags/_validators_classes.py b/absl/flags/_validators_classes.py
new file mode 100644
index 0000000..d8996e0
--- /dev/null
+++ b/absl/flags/_validators_classes.py
@@ -0,0 +1,176 @@
+# Copyright 2021 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Defines *private* classes used for flag validators.
+
+Do NOT import this module. DO NOT use anything from this module. They are
+private APIs.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.flags import _exceptions
+
+
+class Validator(object):
+ """Base class for flags validators.
+
+ Users should NOT overload these classes, and use flags.Register...
+ methods instead.
+ """
+
+ # Used to assign each validator an unique insertion_index
+ validators_count = 0
+
+ def __init__(self, checker, message):
+ """Constructor to create all validators.
+
+ Args:
+ checker: function to verify the constraint.
+ Input of this method varies, see SingleFlagValidator and
+ multi_flags_validator for a detailed description.
+ message: str, error message to be shown to the user.
+ """
+ self.checker = checker
+ self.message = message
+ Validator.validators_count += 1
+ # Used to assert validators in the order they were registered.
+ self.insertion_index = Validator.validators_count
+
+ def verify(self, flag_values):
+ """Verifies that constraint is satisfied.
+
+ flags library calls this method to verify Validator's constraint.
+
+ Args:
+ flag_values: flags.FlagValues, the FlagValues instance to get flags from.
+ Raises:
+ Error: Raised if constraint is not satisfied.
+ """
+ param = self._get_input_to_checker_function(flag_values)
+ if not self.checker(param):
+ raise _exceptions.ValidationError(self.message)
+
+ def get_flags_names(self):
+ """Returns the names of the flags checked by this validator.
+
+ Returns:
+ [string], names of the flags.
+ """
+ raise NotImplementedError('This method should be overloaded')
+
+ def print_flags_with_values(self, flag_values):
+ raise NotImplementedError('This method should be overloaded')
+
+ def _get_input_to_checker_function(self, flag_values):
+ """Given flag values, returns the input to be given to checker.
+
+ Args:
+ flag_values: flags.FlagValues, containing all flags.
+ Returns:
+ The input to be given to checker. The return type depends on the specific
+ validator.
+ """
+ raise NotImplementedError('This method should be overloaded')
+
+
+class SingleFlagValidator(Validator):
+ """Validator behind register_validator() method.
+
+ Validates that a single flag passes its checker function. The checker function
+ takes the flag value and returns True (if value looks fine) or, if flag value
+ is not valid, either returns False or raises an Exception.
+ """
+
+ def __init__(self, flag_name, checker, message):
+ """Constructor.
+
+ Args:
+ flag_name: string, name of the flag.
+ checker: function to verify the validator.
+ input - value of the corresponding flag (string, boolean, etc).
+ output - bool, True if validator constraint is satisfied.
+ If constraint is not satisfied, it should either return False or
+ raise flags.ValidationError(desired_error_message).
+ message: str, error message to be shown to the user if validator's
+ condition is not satisfied.
+ """
+ super(SingleFlagValidator, self).__init__(checker, message)
+ self.flag_name = flag_name
+
+ def get_flags_names(self):
+ return [self.flag_name]
+
+ def print_flags_with_values(self, flag_values):
+ return 'flag --%s=%s' % (self.flag_name, flag_values[self.flag_name].value)
+
+ def _get_input_to_checker_function(self, flag_values):
+ """Given flag values, returns the input to be given to checker.
+
+ Args:
+ flag_values: flags.FlagValues, the FlagValues instance to get flags from.
+ Returns:
+ object, the input to be given to checker.
+ """
+ return flag_values[self.flag_name].value
+
+
+class MultiFlagsValidator(Validator):
+ """Validator behind register_multi_flags_validator method.
+
+ Validates that flag values pass their common checker function. The checker
+ function takes flag values and returns True (if values look fine) or,
+ if values are not valid, either returns False or raises an Exception.
+ """
+
+ def __init__(self, flag_names, checker, message):
+ """Constructor.
+
+ Args:
+ flag_names: [str], containing names of the flags used by checker.
+ checker: function to verify the validator.
+ input - dict, with keys() being flag_names, and value for each
+ key being the value of the corresponding flag (string, boolean,
+ etc).
+ output - bool, True if validator constraint is satisfied.
+ If constraint is not satisfied, it should either return False or
+ raise flags.ValidationError(desired_error_message).
+ message: str, error message to be shown to the user if validator's
+ condition is not satisfied
+ """
+ super(MultiFlagsValidator, self).__init__(checker, message)
+ self.flag_names = flag_names
+
+ def _get_input_to_checker_function(self, flag_values):
+ """Given flag values, returns the input to be given to checker.
+
+ Args:
+ flag_values: flags.FlagValues, the FlagValues instance to get flags from.
+ Returns:
+ dict, with keys() being self.lag_names, and value for each key
+ being the value of the corresponding flag (string, boolean, etc).
+ """
+ return dict([key, flag_values[key].value] for key in self.flag_names)
+
+ def print_flags_with_values(self, flag_values):
+ prefix = 'flags '
+ flags_with_values = []
+ for key in self.flag_names:
+ flags_with_values.append('%s=%s' % (key, flag_values[key].value))
+ return prefix + ', '.join(flags_with_values)
+
+ def get_flags_names(self):
+ return self.flag_names
diff --git a/absl/flags/argparse_flags.py b/absl/flags/argparse_flags.py
new file mode 100644
index 0000000..4f78f50
--- /dev/null
+++ b/absl/flags/argparse_flags.py
@@ -0,0 +1,390 @@
+# Copyright 2018 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This module provides argparse integration with absl.flags.
+
+argparse_flags.ArgumentParser is a drop-in replacement for
+argparse.ArgumentParser. It takes care of collecting and defining absl flags
+in argparse.
+
+Here is a simple example:
+
+ # Assume the following absl.flags is defined in another module:
+ #
+ # from absl import flags
+ # flags.DEFINE_string('echo', None, 'The echo message.')
+ #
+ parser = argparse_flags.ArgumentParser(
+ description='A demo of absl.flags and argparse integration.')
+ parser.add_argument('--header', help='Header message to print.')
+
+ # The parser will also accept the absl flag `--echo`.
+ # The `header` value is available as `args.header` just like a regular
+ # argparse flag. The absl flag `--echo` continues to be available via
+ # `absl.flags.FLAGS` if you want to access it.
+ args = parser.parse_args()
+
+ # Example usages:
+ # ./program --echo='A message.' --header='A header'
+ # ./program --header 'A header' --echo 'A message.'
+
+
+Here is another example demonstrates subparsers:
+
+ parser = argparse_flags.ArgumentParser(description='A subcommands demo.')
+ parser.add_argument('--header', help='The header message to print.')
+
+ subparsers = parser.add_subparsers(help='The command to execute.')
+
+ roll_dice_parser = subparsers.add_parser(
+ 'roll_dice', help='Roll a dice.',
+ # By default, absl flags can also be specified after the sub-command.
+ # To only allow them before sub-command, pass
+ # `inherited_absl_flags=None`.
+ inherited_absl_flags=None)
+ roll_dice_parser.add_argument('--num_faces', type=int, default=6)
+ roll_dice_parser.set_defaults(command=roll_dice)
+
+ shuffle_parser = subparsers.add_parser('shuffle', help='Shuffle inputs.')
+ shuffle_parser.add_argument(
+ 'inputs', metavar='I', nargs='+', help='Inputs to shuffle.')
+ shuffle_parser.set_defaults(command=shuffle)
+
+ args = parser.parse_args(argv[1:])
+ args.command(args)
+
+ # Example usages:
+ # ./program --echo='A message.' roll_dice --num_faces=6
+ # ./program shuffle --echo='A message.' 1 2 3 4
+
+
+There are several differences between absl.flags and argparse_flags:
+
+1. Flags defined with absl.flags are parsed differently when using the
+ argparse parser. Notably:
+
+ 1) absl.flags allows both single-dash and double-dash for any flag, and
+ doesn't distinguish them; argparse_flags only allows double-dash for
+ flag's regular name, and single-dash for flag's `short_name`.
+ 2) Boolean flags in absl.flags can be specified with `--bool`, `--nobool`,
+ as well as `--bool=true/false` (though not recommended);
+ in argparse_flags, it only allows `--bool`, `--nobool`.
+
+2. Help related flag differences:
+ 1) absl.flags does not define help flags, absl.app does that; argparse_flags
+ defines help flags unless passed with `add_help=False`.
+ 2) absl.app supports `--helpxml`; argparse_flags does not.
+ 3) argparse_flags supports `-h`; absl.app does not.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+from absl import flags
+
+
+_BUILT_IN_FLAGS = frozenset({
+ 'help',
+ 'helpshort',
+ 'helpfull',
+ 'helpxml',
+ 'flagfile',
+ 'undefok',
+})
+
+
+class ArgumentParser(argparse.ArgumentParser):
+ """Custom ArgumentParser class to support special absl flags."""
+
+ def __init__(self, **kwargs):
+ """Initializes ArgumentParser.
+
+ Args:
+ **kwargs: same as argparse.ArgumentParser, except:
+ 1. It also accepts `inherited_absl_flags`: the absl flags to inherit.
+ The default is the global absl.flags.FLAGS instance. Pass None to
+ ignore absl flags.
+ 2. The `prefix_chars` argument must be the default value '-'.
+
+ Raises:
+ ValueError: Raised when prefix_chars is not '-'.
+ """
+ prefix_chars = kwargs.get('prefix_chars', '-')
+ if prefix_chars != '-':
+ raise ValueError(
+ 'argparse_flags.ArgumentParser only supports "-" as the prefix '
+ 'character, found "{}".'.format(prefix_chars))
+
+ # Remove inherited_absl_flags before calling super.
+ self._inherited_absl_flags = kwargs.pop('inherited_absl_flags', flags.FLAGS)
+ # Now call super to initialize argparse.ArgumentParser before calling
+ # add_argument in _define_absl_flags.
+ super(ArgumentParser, self).__init__(**kwargs)
+
+ if self.add_help:
+ # -h and --help are defined in super.
+ # Also add the --helpshort and --helpfull flags.
+ self.add_argument(
+ # Action 'help' defines a similar flag to -h/--help.
+ '--helpshort', action='help',
+ default=argparse.SUPPRESS, help=argparse.SUPPRESS)
+ self.add_argument(
+ '--helpfull', action=_HelpFullAction,
+ default=argparse.SUPPRESS, help='show full help message and exit')
+
+ if self._inherited_absl_flags:
+ self.add_argument(
+ '--undefok', default=argparse.SUPPRESS, help=argparse.SUPPRESS)
+ self._define_absl_flags(self._inherited_absl_flags)
+
+ def parse_known_args(self, args=None, namespace=None):
+ if args is None:
+ args = sys.argv[1:]
+ if self._inherited_absl_flags:
+ # Handle --flagfile.
+ # Explicitly specify force_gnu=True, since argparse behaves like
+ # gnu_getopt: flags can be specified after positional arguments.
+ args = self._inherited_absl_flags.read_flags_from_files(
+ args, force_gnu=True)
+
+ undefok_missing = object()
+ undefok = getattr(namespace, 'undefok', undefok_missing)
+
+ namespace, args = super(ArgumentParser, self).parse_known_args(
+ args, namespace)
+
+ # For Python <= 2.7.8: https://bugs.python.org/issue9351, a bug where
+ # sub-parsers don't preserve existing namespace attributes.
+ # Restore the undefok attribute if a sub-parser dropped it.
+ if undefok is not undefok_missing:
+ namespace.undefok = undefok
+
+ if self._inherited_absl_flags:
+ # Handle --undefok. At this point, `args` only contains unknown flags,
+ # so it won't strip defined flags that are also specified with --undefok.
+ # For Python <= 2.7.8: https://bugs.python.org/issue9351, a bug where
+ # sub-parsers don't preserve existing namespace attributes. The undefok
+ # attribute might not exist because a subparser dropped it.
+ if hasattr(namespace, 'undefok'):
+ args = _strip_undefok_args(namespace.undefok, args)
+ # absl flags are not exposed in the Namespace object. See Namespace:
+ # https://docs.python.org/3/library/argparse.html#argparse.Namespace.
+ del namespace.undefok
+ self._inherited_absl_flags.mark_as_parsed()
+ try:
+ self._inherited_absl_flags.validate_all_flags()
+ except flags.IllegalFlagValueError as e:
+ self.error(str(e))
+
+ return namespace, args
+
+ def _define_absl_flags(self, absl_flags):
+ """Defines flags from absl_flags."""
+ key_flags = set(absl_flags.get_key_flags_for_module(sys.argv[0]))
+ for name in absl_flags:
+ if name in _BUILT_IN_FLAGS:
+ # Do not inherit built-in flags.
+ continue
+ flag_instance = absl_flags[name]
+ # Each flags with short_name appears in FLAGS twice, so only define
+ # when the dictionary key is equal to the regular name.
+ if name == flag_instance.name:
+ # Suppress the flag in the help short message if it's not a main
+ # module's key flag.
+ suppress = flag_instance not in key_flags
+ self._define_absl_flag(flag_instance, suppress)
+
+ def _define_absl_flag(self, flag_instance, suppress):
+ """Defines a flag from the flag_instance."""
+ flag_name = flag_instance.name
+ short_name = flag_instance.short_name
+ argument_names = ['--' + flag_name]
+ if short_name:
+ argument_names.insert(0, '-' + short_name)
+ if suppress:
+ helptext = argparse.SUPPRESS
+ else:
+ # argparse help string uses %-formatting. Escape the literal %'s.
+ helptext = flag_instance.help.replace('%', '%%')
+ if flag_instance.boolean:
+ # Only add the `no` form to the long name.
+ argument_names.append('--no' + flag_name)
+ self.add_argument(
+ *argument_names, action=_BooleanFlagAction, help=helptext,
+ metavar=flag_instance.name.upper(),
+ flag_instance=flag_instance)
+ else:
+ self.add_argument(
+ *argument_names, action=_FlagAction, help=helptext,
+ metavar=flag_instance.name.upper(),
+ flag_instance=flag_instance)
+
+
+class _FlagAction(argparse.Action):
+ """Action class for Abseil non-boolean flags."""
+
+ def __init__(
+ self,
+ option_strings,
+ dest,
+ help, # pylint: disable=redefined-builtin
+ metavar,
+ flag_instance,
+ default=argparse.SUPPRESS):
+ """Initializes _FlagAction.
+
+ Args:
+ option_strings: See argparse.Action.
+ dest: Ignored. The flag is always defined with dest=argparse.SUPPRESS.
+ help: See argparse.Action.
+ metavar: See argparse.Action.
+ flag_instance: absl.flags.Flag, the absl flag instance.
+ default: Ignored. The flag always uses dest=argparse.SUPPRESS so it
+ doesn't affect the parsing result.
+ """
+ del dest
+ self._flag_instance = flag_instance
+ super(_FlagAction, self).__init__(
+ option_strings=option_strings,
+ dest=argparse.SUPPRESS,
+ help=help,
+ metavar=metavar)
+
+ def __call__(self, parser, namespace, values, option_string=None):
+ """See https://docs.python.org/3/library/argparse.html#action-classes."""
+ self._flag_instance.parse(values)
+ self._flag_instance.using_default_value = False
+
+
+class _BooleanFlagAction(argparse.Action):
+ """Action class for Abseil boolean flags."""
+
+ def __init__(
+ self,
+ option_strings,
+ dest,
+ help, # pylint: disable=redefined-builtin
+ metavar,
+ flag_instance,
+ default=argparse.SUPPRESS):
+ """Initializes _BooleanFlagAction.
+
+ Args:
+ option_strings: See argparse.Action.
+ dest: Ignored. The flag is always defined with dest=argparse.SUPPRESS.
+ help: See argparse.Action.
+ metavar: See argparse.Action.
+ flag_instance: absl.flags.Flag, the absl flag instance.
+ default: Ignored. The flag always uses dest=argparse.SUPPRESS so it
+ doesn't affect the parsing result.
+ """
+ del dest, default
+ self._flag_instance = flag_instance
+ flag_names = [self._flag_instance.name]
+ if self._flag_instance.short_name:
+ flag_names.append(self._flag_instance.short_name)
+ self._flag_names = frozenset(flag_names)
+ super(_BooleanFlagAction, self).__init__(
+ option_strings=option_strings,
+ dest=argparse.SUPPRESS,
+ nargs=0, # Does not accept values, only `--bool` or `--nobool`.
+ help=help,
+ metavar=metavar)
+
+ def __call__(self, parser, namespace, values, option_string=None):
+ """See https://docs.python.org/3/library/argparse.html#action-classes."""
+ if not isinstance(values, list) or values:
+ raise ValueError('values must be an empty list.')
+ if option_string.startswith('--'):
+ option = option_string[2:]
+ else:
+ option = option_string[1:]
+ if option in self._flag_names:
+ self._flag_instance.parse('true')
+ else:
+ if not option.startswith('no') or option[2:] not in self._flag_names:
+ raise ValueError('invalid option_string: ' + option_string)
+ self._flag_instance.parse('false')
+ self._flag_instance.using_default_value = False
+
+
+class _HelpFullAction(argparse.Action):
+ """Action class for --helpfull flag."""
+
+ def __init__(self, option_strings, dest, default, help): # pylint: disable=redefined-builtin
+ """Initializes _HelpFullAction.
+
+ Args:
+ option_strings: See argparse.Action.
+ dest: Ignored. The flag is always defined with dest=argparse.SUPPRESS.
+ default: Ignored.
+ help: See argparse.Action.
+ """
+ del dest, default
+ super(_HelpFullAction, self).__init__(
+ option_strings=option_strings,
+ dest=argparse.SUPPRESS,
+ default=argparse.SUPPRESS,
+ nargs=0,
+ help=help)
+
+ def __call__(self, parser, namespace, values, option_string=None):
+ """See https://docs.python.org/3/library/argparse.html#action-classes."""
+ # This only prints flags when help is not argparse.SUPPRESS.
+ # It includes user defined argparse flags, as well as main module's
+ # key absl flags. Other absl flags use argparse.SUPPRESS, so they aren't
+ # printed here.
+ parser.print_help()
+
+ absl_flags = parser._inherited_absl_flags # pylint: disable=protected-access
+ if absl_flags:
+ modules = sorted(absl_flags.flags_by_module_dict())
+ main_module = sys.argv[0]
+ if main_module in modules:
+ # The main module flags are already printed in parser.print_help().
+ modules.remove(main_module)
+ print(absl_flags._get_help_for_modules( # pylint: disable=protected-access
+ modules, prefix='', include_special_flags=True))
+ parser.exit()
+
+
+def _strip_undefok_args(undefok, args):
+ """Returns a new list of args after removing flags in --undefok."""
+ if undefok:
+ undefok_names = set(name.strip() for name in undefok.split(','))
+ undefok_names |= set('no' + name for name in undefok_names)
+ # Remove undefok flags.
+ args = [arg for arg in args if not _is_undefok(arg, undefok_names)]
+ return args
+
+
+def _is_undefok(arg, undefok_names):
+ """Returns whether we can ignore arg based on a set of undefok flag names."""
+ if not arg.startswith('-'):
+ return False
+ if arg.startswith('--'):
+ arg_without_dash = arg[2:]
+ else:
+ arg_without_dash = arg[1:]
+ if '=' in arg_without_dash:
+ name, _ = arg_without_dash.split('=', 1)
+ else:
+ name = arg_without_dash
+ if name in undefok_names:
+ return True
+ return False
diff --git a/absl/flags/tests/__init__.py b/absl/flags/tests/__init__.py
new file mode 100644
index 0000000..a3bd1cd
--- /dev/null
+++ b/absl/flags/tests/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/absl/flags/tests/_argument_parser_test.py b/absl/flags/tests/_argument_parser_test.py
new file mode 100644
index 0000000..4281c3f
--- /dev/null
+++ b/absl/flags/tests/_argument_parser_test.py
@@ -0,0 +1,214 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Additional tests for flag argument parsers.
+
+Most of the argument parsers are covered in the flags_test.py.
+"""
+
+import enum
+
+from absl.flags import _argument_parser
+from absl.testing import absltest
+from absl.testing import parameterized
+
+
+class ArgumentParserTest(absltest.TestCase):
+
+ def test_instance_cache(self):
+ parser1 = _argument_parser.FloatParser()
+ parser2 = _argument_parser.FloatParser()
+ self.assertIs(parser1, parser2)
+
+ def test_parse_wrong_type(self):
+ parser = _argument_parser.ArgumentParser()
+ with self.assertRaises(TypeError):
+ parser.parse(0)
+
+ if bytes is not str:
+ # In PY3, it does not accept bytes.
+ with self.assertRaises(TypeError):
+ parser.parse(b'')
+
+
+class BooleanParserTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.parser = _argument_parser.BooleanParser()
+
+ def test_parse_bytes(self):
+ with self.assertRaises(TypeError):
+ self.parser.parse(b'true')
+
+ def test_parse_str(self):
+ self.assertTrue(self.parser.parse('true'))
+
+ def test_parse_unicode(self):
+ self.assertTrue(self.parser.parse(u'true'))
+
+ def test_parse_wrong_type(self):
+ with self.assertRaises(TypeError):
+ self.parser.parse(1.234)
+
+ def test_parse_str_false(self):
+ self.assertFalse(self.parser.parse('false'))
+
+ def test_parse_integer(self):
+ self.assertTrue(self.parser.parse(1))
+
+ def test_parse_invalid_integer(self):
+ with self.assertRaises(ValueError):
+ self.parser.parse(-1)
+
+ def test_parse_invalid_str(self):
+ with self.assertRaises(ValueError):
+ self.parser.parse('nottrue')
+
+
+class FloatParserTest(absltest.TestCase):
+
+ def setUp(self):
+ self.parser = _argument_parser.FloatParser()
+
+ def test_parse_string(self):
+ self.assertEqual(1.5, self.parser.parse('1.5'))
+
+ def test_parse_wrong_type(self):
+ with self.assertRaises(TypeError):
+ self.parser.parse(False)
+
+
+class IntegerParserTest(absltest.TestCase):
+
+ def setUp(self):
+ self.parser = _argument_parser.IntegerParser()
+
+ def test_parse_string(self):
+ self.assertEqual(1, self.parser.parse('1'))
+
+ def test_parse_wrong_type(self):
+ with self.assertRaises(TypeError):
+ self.parser.parse(1e2)
+ with self.assertRaises(TypeError):
+ self.parser.parse(False)
+
+
+class EnumParserTest(absltest.TestCase):
+
+ def test_empty_values(self):
+ with self.assertRaises(ValueError):
+ _argument_parser.EnumParser([])
+
+ def test_parse(self):
+ parser = _argument_parser.EnumParser(['apple', 'banana'])
+ self.assertEqual('apple', parser.parse('apple'))
+
+ def test_parse_not_found(self):
+ parser = _argument_parser.EnumParser(['apple', 'banana'])
+ with self.assertRaises(ValueError):
+ parser.parse('orange')
+
+
+class Fruit(enum.Enum):
+ APPLE = 1
+ BANANA = 2
+
+
+class EmptyEnum(enum.Enum):
+ pass
+
+
+class MixedCaseEnum(enum.Enum):
+ APPLE = 1
+ BANANA = 2
+ apple = 3
+
+
+class EnumClassParserTest(parameterized.TestCase):
+
+ def test_requires_enum(self):
+ with self.assertRaises(TypeError):
+ _argument_parser.EnumClassParser(['apple', 'banana'])
+
+ def test_requires_non_empty_enum_class(self):
+ with self.assertRaises(ValueError):
+ _argument_parser.EnumClassParser(EmptyEnum)
+
+ def test_case_sensitive_rejects_duplicates(self):
+ unused_normal_parser = _argument_parser.EnumClassParser(MixedCaseEnum)
+ with self.assertRaisesRegex(ValueError, 'Duplicate.+apple'):
+ _argument_parser.EnumClassParser(MixedCaseEnum, case_sensitive=False)
+
+ def test_parse_string(self):
+ parser = _argument_parser.EnumClassParser(Fruit)
+ self.assertEqual(Fruit.APPLE, parser.parse('APPLE'))
+
+ def test_parse_string_case_sensitive(self):
+ parser = _argument_parser.EnumClassParser(Fruit)
+ with self.assertRaises(ValueError):
+ parser.parse('apple')
+
+ @parameterized.parameters('APPLE', 'apple', 'Apple')
+ def test_parse_string_case_insensitive(self, value):
+ parser = _argument_parser.EnumClassParser(Fruit, case_sensitive=False)
+ self.assertIs(Fruit.APPLE, parser.parse(value))
+
+ def test_parse_literal(self):
+ parser = _argument_parser.EnumClassParser(Fruit)
+ self.assertEqual(Fruit.APPLE, parser.parse(Fruit.APPLE))
+
+ def test_parse_not_found(self):
+ parser = _argument_parser.EnumClassParser(Fruit)
+ with self.assertRaises(ValueError):
+ parser.parse('ORANGE')
+
+ @parameterized.parameters((Fruit.BANANA, False, 'BANANA'),
+ (Fruit.BANANA, True, 'banana'))
+ def test_serialize_parse(self, value, lowercase, expected):
+ serializer = _argument_parser.EnumClassSerializer(lowercase=lowercase)
+ parser = _argument_parser.EnumClassParser(
+ Fruit, case_sensitive=not lowercase)
+ serialized = serializer.serialize(value)
+ self.assertEqual(serialized, expected)
+ self.assertEqual(value, parser.parse(expected))
+
+
+class SerializerTest(parameterized.TestCase):
+
+ def test_csv_serializer(self):
+ serializer = _argument_parser.CsvListSerializer('+')
+ self.assertEqual(serializer.serialize(['foo', 'bar']), 'foo+bar')
+
+ @parameterized.parameters([
+ dict(lowercase=False, expected='APPLE+BANANA'),
+ dict(lowercase=True, expected='apple+banana'),
+ ])
+ def test_enum_class_list_serializer(self, lowercase, expected):
+ values = [Fruit.APPLE, Fruit.BANANA]
+ serializer = _argument_parser.EnumClassListSerializer(
+ list_sep='+', lowercase=lowercase)
+ serialized = serializer.serialize(values)
+ self.assertEqual(expected, serialized)
+
+
+class HelperFunctionsTest(absltest.TestCase):
+
+ def test_is_integer_type(self):
+ self.assertTrue(_argument_parser._is_integer_type(1))
+ # Note that isinstance(False, int) == True.
+ self.assertFalse(_argument_parser._is_integer_type(False))
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/flags/tests/_flag_test.py b/absl/flags/tests/_flag_test.py
new file mode 100644
index 0000000..492f117
--- /dev/null
+++ b/absl/flags/tests/_flag_test.py
@@ -0,0 +1,240 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Additional tests for Flag classes.
+
+Most of the Flag classes are covered in the flags_test.py.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import enum
+import pickle
+
+from absl.flags import _argument_parser
+from absl.flags import _exceptions
+from absl.flags import _flag
+from absl.testing import absltest
+from absl.testing import parameterized
+
+
+class FlagTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.flag = _flag.Flag(
+ _argument_parser.ArgumentParser(),
+ _argument_parser.ArgumentSerializer(),
+ 'fruit', 'apple', 'help')
+
+ def test_default_unparsed(self):
+ flag = _flag.Flag(
+ _argument_parser.ArgumentParser(),
+ _argument_parser.ArgumentSerializer(),
+ 'fruit', 'apple', 'help')
+ self.assertEqual('apple', flag.default_unparsed)
+
+ flag = _flag.Flag(
+ _argument_parser.IntegerParser(),
+ _argument_parser.ArgumentSerializer(),
+ 'number', '1', 'help')
+ self.assertEqual('1', flag.default_unparsed)
+
+ flag = _flag.Flag(
+ _argument_parser.IntegerParser(),
+ _argument_parser.ArgumentSerializer(),
+ 'number', 1, 'help')
+ self.assertEqual(1, flag.default_unparsed)
+
+ def test_no_truthiness(self):
+ with self.assertRaises(TypeError):
+ if self.flag:
+ self.fail('Flag instances must raise rather than be truthy.')
+
+ def test_set_default_overrides_current_value(self):
+ self.assertEqual('apple', self.flag.value)
+ self.flag._set_default('orange')
+ self.assertEqual('orange', self.flag.value)
+
+ def test_set_default_overrides_current_value_when_not_using_default(self):
+ self.flag.using_default_value = False
+ self.assertEqual('apple', self.flag.value)
+ self.flag._set_default('orange')
+ self.assertEqual('apple', self.flag.value)
+
+ def test_pickle(self):
+ with self.assertRaisesRegex(TypeError, "can't pickle Flag objects"):
+ pickle.dumps(self.flag)
+
+ def test_copy(self):
+ self.flag.value = 'orange'
+
+ with self.assertRaisesRegex(TypeError,
+ 'Flag does not support shallow copies'):
+ copy.copy(self.flag)
+
+ flag2 = copy.deepcopy(self.flag)
+ self.assertEqual(flag2.value, 'orange')
+
+ flag2.value = 'mango'
+ self.assertEqual(flag2.value, 'mango')
+ self.assertEqual(self.flag.value, 'orange')
+
+
+class BooleanFlagTest(parameterized.TestCase):
+
+ @parameterized.parameters(('', '(no help available)'),
+ ('Is my test brilliant?', 'Is my test brilliant?'))
+ def test_help_text(self, helptext_input, helptext_output):
+ f = _flag.BooleanFlag('a_bool', False, helptext_input)
+ self.assertEqual(helptext_output, f.help)
+
+
+class EnumFlagTest(parameterized.TestCase):
+
+ @parameterized.parameters(
+ ('', '<apple|orange>: (no help available)'),
+ ('Type of fruit.', '<apple|orange>: Type of fruit.'))
+ def test_help_text(self, helptext_input, helptext_output):
+ f = _flag.EnumFlag('fruit', 'apple', helptext_input, ['apple', 'orange'])
+ self.assertEqual(helptext_output, f.help)
+
+ def test_empty_values(self):
+ with self.assertRaises(ValueError):
+ _flag.EnumFlag('fruit', None, 'help', [])
+
+
+class Fruit(enum.Enum):
+ APPLE = 1
+ ORANGE = 2
+
+
+class EmptyEnum(enum.Enum):
+ pass
+
+
+class EnumClassFlagTest(parameterized.TestCase):
+
+ @parameterized.parameters(
+ ('', '<apple|orange>: (no help available)'),
+ ('Type of fruit.', '<apple|orange>: Type of fruit.'))
+ def test_help_text_case_insensitive(self, helptext_input, helptext_output):
+ f = _flag.EnumClassFlag('fruit', None, helptext_input, Fruit)
+ self.assertEqual(helptext_output, f.help)
+
+ @parameterized.parameters(
+ ('', '<APPLE|ORANGE>: (no help available)'),
+ ('Type of fruit.', '<APPLE|ORANGE>: Type of fruit.'))
+ def test_help_text_case_sensitive(self, helptext_input, helptext_output):
+ f = _flag.EnumClassFlag(
+ 'fruit', None, helptext_input, Fruit, case_sensitive=True)
+ self.assertEqual(helptext_output, f.help)
+
+ def test_requires_enum(self):
+ with self.assertRaises(TypeError):
+ _flag.EnumClassFlag('fruit', None, 'help', ['apple', 'orange'])
+
+ def test_requires_non_empty_enum_class(self):
+ with self.assertRaises(ValueError):
+ _flag.EnumClassFlag('empty', None, 'help', EmptyEnum)
+
+ def test_accepts_literal_default(self):
+ f = _flag.EnumClassFlag('fruit', Fruit.APPLE, 'A sample enum flag.', Fruit)
+ self.assertEqual(Fruit.APPLE, f.value)
+
+ def test_accepts_string_default(self):
+ f = _flag.EnumClassFlag('fruit', 'ORANGE', 'A sample enum flag.', Fruit)
+ self.assertEqual(Fruit.ORANGE, f.value)
+
+ def test_case_sensitive_rejects_default_with_wrong_case(self):
+ with self.assertRaises(_exceptions.IllegalFlagValueError):
+ _flag.EnumClassFlag(
+ 'fruit', 'oranGe', 'A sample enum flag.', Fruit, case_sensitive=True)
+
+ def test_case_insensitive_accepts_string_default(self):
+ f = _flag.EnumClassFlag(
+ 'fruit', 'oranGe', 'A sample enum flag.', Fruit, case_sensitive=False)
+ self.assertEqual(Fruit.ORANGE, f.value)
+
+ def test_default_value_does_not_exist(self):
+ with self.assertRaises(_exceptions.IllegalFlagValueError):
+ _flag.EnumClassFlag('fruit', 'BANANA', 'help', Fruit)
+
+
+class MultiEnumClassFlagTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ('NoHelpSupplied', '', '<apple|orange>: (no help available);\n ' +
+ 'repeat this option to specify a list of values', False),
+ ('WithHelpSupplied', 'Type of fruit.',
+ '<APPLE|ORANGE>: Type of fruit.;\n ' +
+ 'repeat this option to specify a list of values', True))
+ def test_help_text(self, helptext_input, helptext_output, case_sensitive):
+ f = _flag.MultiEnumClassFlag(
+ 'fruit', None, helptext_input, Fruit, case_sensitive=case_sensitive)
+ self.assertEqual(helptext_output, f.help)
+
+ def test_requires_enum(self):
+ with self.assertRaises(TypeError):
+ _flag.MultiEnumClassFlag('fruit', None, 'help', ['apple', 'orange'])
+
+ def test_requires_non_empty_enum_class(self):
+ with self.assertRaises(ValueError):
+ _flag.MultiEnumClassFlag('empty', None, 'help', EmptyEnum)
+
+ def test_rejects_wrong_case_when_case_sensitive(self):
+ with self.assertRaisesRegex(_exceptions.IllegalFlagValueError,
+ '<APPLE|ORANGE>'):
+ _flag.MultiEnumClassFlag(
+ 'fruit', ['APPLE', 'Orange'],
+ 'A sample enum flag.',
+ Fruit,
+ case_sensitive=True)
+
+ def test_accepts_case_insensitive(self):
+ f = _flag.MultiEnumClassFlag('fruit', ['apple', 'APPLE'],
+ 'A sample enum flag.', Fruit)
+ self.assertListEqual([Fruit.APPLE, Fruit.APPLE], f.value)
+
+ def test_accepts_literal_default(self):
+ f = _flag.MultiEnumClassFlag('fruit', Fruit.APPLE, 'A sample enum flag.',
+ Fruit)
+ self.assertListEqual([Fruit.APPLE], f.value)
+
+ def test_accepts_list_of_literal_default(self):
+ f = _flag.MultiEnumClassFlag('fruit', [Fruit.APPLE, Fruit.ORANGE],
+ 'A sample enum flag.', Fruit)
+ self.assertListEqual([Fruit.APPLE, Fruit.ORANGE], f.value)
+
+ def test_accepts_string_default(self):
+ f = _flag.MultiEnumClassFlag('fruit', 'ORANGE', 'A sample enum flag.',
+ Fruit)
+ self.assertListEqual([Fruit.ORANGE], f.value)
+
+ def test_accepts_list_of_string_default(self):
+ f = _flag.MultiEnumClassFlag('fruit', ['ORANGE', 'APPLE'],
+ 'A sample enum flag.', Fruit)
+ self.assertListEqual([Fruit.ORANGE, Fruit.APPLE], f.value)
+
+ def test_default_value_does_not_exist(self):
+ with self.assertRaisesRegex(_exceptions.IllegalFlagValueError,
+ '<apple|banana>'):
+ _flag.MultiEnumClassFlag('fruit', 'BANANA', 'help', Fruit)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/flags/tests/_flagvalues_test.py b/absl/flags/tests/_flagvalues_test.py
new file mode 100644
index 0000000..46639f2
--- /dev/null
+++ b/absl/flags/tests/_flagvalues_test.py
@@ -0,0 +1,929 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for flags.FlagValues class."""
+
+import collections
+import copy
+import pickle
+import types
+from unittest import mock
+
+from absl import logging
+from absl.flags import _defines
+from absl.flags import _exceptions
+from absl.flags import _flagvalues
+from absl.flags import _helpers
+from absl.flags import _validators
+from absl.flags.tests import module_foo
+from absl.testing import absltest
+from absl.testing import parameterized
+
+
+class FlagValuesTest(absltest.TestCase):
+
+ def test_bool_flags(self):
+ for arg, expected in (('--nothing', True),
+ ('--nothing=true', True),
+ ('--nothing=false', False),
+ ('--nonothing', False)):
+ fv = _flagvalues.FlagValues()
+ _defines.DEFINE_boolean('nothing', None, '', flag_values=fv)
+ fv(('./program', arg))
+ self.assertIs(expected, fv.nothing)
+
+ for arg in ('--nonothing=true', '--nonothing=false'):
+ fv = _flagvalues.FlagValues()
+ _defines.DEFINE_boolean('nothing', None, '', flag_values=fv)
+ with self.assertRaises(ValueError):
+ fv(('./program', arg))
+
+ def test_boolean_flag_parser_gets_string_argument(self):
+ for arg, expected in (('--nothing', 'true'),
+ ('--nothing=true', 'true'),
+ ('--nothing=false', 'false'),
+ ('--nonothing', 'false')):
+ fv = _flagvalues.FlagValues()
+ _defines.DEFINE_boolean('nothing', None, '', flag_values=fv)
+ with mock.patch.object(fv['nothing'].parser, 'parse') as mock_parse:
+ fv(('./program', arg))
+ mock_parse.assert_called_once_with(expected)
+
+ def test_unregistered_flags_are_cleaned_up(self):
+ fv = _flagvalues.FlagValues()
+ module, module_name = _helpers.get_calling_module_object_and_name()
+
+ # Define first flag.
+ _defines.DEFINE_integer('cores', 4, '', flag_values=fv, short_name='c')
+ old_cores_flag = fv['cores']
+ fv.register_key_flag_for_module(module_name, old_cores_flag)
+ self.assertEqual(fv.flags_by_module_dict(),
+ {module_name: [old_cores_flag]})
+ self.assertEqual(fv.flags_by_module_id_dict(),
+ {id(module): [old_cores_flag]})
+ self.assertEqual(fv.key_flags_by_module_dict(),
+ {module_name: [old_cores_flag]})
+
+ # Redefine the same flag.
+ _defines.DEFINE_integer(
+ 'cores', 4, '', flag_values=fv, short_name='c', allow_override=True)
+ new_cores_flag = fv['cores']
+ self.assertNotEqual(old_cores_flag, new_cores_flag)
+ self.assertEqual(fv.flags_by_module_dict(),
+ {module_name: [new_cores_flag]})
+ self.assertEqual(fv.flags_by_module_id_dict(),
+ {id(module): [new_cores_flag]})
+ # old_cores_flag is removed from key flags, and the new_cores_flag is
+ # not automatically added because it must be registered explicitly.
+ self.assertEqual(fv.key_flags_by_module_dict(), {module_name: []})
+
+ # Define a new flag but with the same short_name.
+ _defines.DEFINE_integer(
+ 'changelist',
+ 0,
+ '',
+ flag_values=fv,
+ short_name='c',
+ allow_override=True)
+ old_changelist_flag = fv['changelist']
+ fv.register_key_flag_for_module(module_name, old_changelist_flag)
+ # The short named flag -c is overridden to be the old_changelist_flag.
+ self.assertEqual(fv['c'], old_changelist_flag)
+ self.assertNotEqual(fv['c'], new_cores_flag)
+ self.assertEqual(fv.flags_by_module_dict(),
+ {module_name: [new_cores_flag, old_changelist_flag]})
+ self.assertEqual(fv.flags_by_module_id_dict(),
+ {id(module): [new_cores_flag, old_changelist_flag]})
+ self.assertEqual(fv.key_flags_by_module_dict(),
+ {module_name: [old_changelist_flag]})
+
+ # Define a flag only with the same long name.
+ _defines.DEFINE_integer(
+ 'changelist',
+ 0,
+ '',
+ flag_values=fv,
+ short_name='l',
+ allow_override=True)
+ new_changelist_flag = fv['changelist']
+ self.assertNotEqual(old_changelist_flag, new_changelist_flag)
+ self.assertEqual(fv.flags_by_module_dict(),
+ {module_name: [new_cores_flag,
+ old_changelist_flag,
+ new_changelist_flag]})
+ self.assertEqual(fv.flags_by_module_id_dict(),
+ {id(module): [new_cores_flag,
+ old_changelist_flag,
+ new_changelist_flag]})
+ self.assertEqual(fv.key_flags_by_module_dict(),
+ {module_name: [old_changelist_flag]})
+
+ # Delete the new changelist's long name, it should still be registered
+ # because of its short name.
+ del fv.changelist
+ self.assertNotIn('changelist', fv)
+ self.assertEqual(fv.flags_by_module_dict(),
+ {module_name: [new_cores_flag,
+ old_changelist_flag,
+ new_changelist_flag]})
+ self.assertEqual(fv.flags_by_module_id_dict(),
+ {id(module): [new_cores_flag,
+ old_changelist_flag,
+ new_changelist_flag]})
+ self.assertEqual(fv.key_flags_by_module_dict(),
+ {module_name: [old_changelist_flag]})
+
+ # Delete the new changelist's short name, it should be removed.
+ del fv.l
+ self.assertNotIn('l', fv)
+ self.assertEqual(fv.flags_by_module_dict(),
+ {module_name: [new_cores_flag,
+ old_changelist_flag]})
+ self.assertEqual(fv.flags_by_module_id_dict(),
+ {id(module): [new_cores_flag,
+ old_changelist_flag]})
+ self.assertEqual(fv.key_flags_by_module_dict(),
+ {module_name: [old_changelist_flag]})
+
+ def _test_find_module_or_id_defining_flag(self, test_id):
+ """Tests for find_module_defining_flag and find_module_id_defining_flag.
+
+ Args:
+ test_id: True to test find_module_id_defining_flag, False to test
+ find_module_defining_flag.
+ """
+ fv = _flagvalues.FlagValues()
+ current_module, current_module_name = (
+ _helpers.get_calling_module_object_and_name())
+ alt_module_name = _flagvalues.__name__
+
+ if test_id:
+ current_module_or_id = id(current_module)
+ alt_module_or_id = id(_flagvalues)
+ testing_fn = fv.find_module_id_defining_flag
+ else:
+ current_module_or_id = current_module_name
+ alt_module_or_id = alt_module_name
+ testing_fn = fv.find_module_defining_flag
+
+ # Define first flag.
+ _defines.DEFINE_integer('cores', 4, '', flag_values=fv, short_name='c')
+ module_or_id_cores = testing_fn('cores')
+ self.assertEqual(module_or_id_cores, current_module_or_id)
+ module_or_id_c = testing_fn('c')
+ self.assertEqual(module_or_id_c, current_module_or_id)
+
+ # Redefine the same flag in another module.
+ _defines.DEFINE_integer(
+ 'cores',
+ 4,
+ '',
+ flag_values=fv,
+ module_name=alt_module_name,
+ short_name='c',
+ allow_override=True)
+ module_or_id_cores = testing_fn('cores')
+ self.assertEqual(module_or_id_cores, alt_module_or_id)
+ module_or_id_c = testing_fn('c')
+ self.assertEqual(module_or_id_c, alt_module_or_id)
+
+ # Define a new flag but with the same short_name.
+ _defines.DEFINE_integer(
+ 'changelist',
+ 0,
+ '',
+ flag_values=fv,
+ short_name='c',
+ allow_override=True)
+ module_or_id_cores = testing_fn('cores')
+ self.assertEqual(module_or_id_cores, alt_module_or_id)
+ module_or_id_changelist = testing_fn('changelist')
+ self.assertEqual(module_or_id_changelist, current_module_or_id)
+ module_or_id_c = testing_fn('c')
+ self.assertEqual(module_or_id_c, current_module_or_id)
+
+ # Define a flag in another module only with the same long name.
+ _defines.DEFINE_integer(
+ 'changelist',
+ 0,
+ '',
+ flag_values=fv,
+ module_name=alt_module_name,
+ short_name='l',
+ allow_override=True)
+ module_or_id_cores = testing_fn('cores')
+ self.assertEqual(module_or_id_cores, alt_module_or_id)
+ module_or_id_changelist = testing_fn('changelist')
+ self.assertEqual(module_or_id_changelist, alt_module_or_id)
+ module_or_id_c = testing_fn('c')
+ self.assertEqual(module_or_id_c, current_module_or_id)
+ module_or_id_l = testing_fn('l')
+ self.assertEqual(module_or_id_l, alt_module_or_id)
+
+ # Delete the changelist flag, its short name should still be registered.
+ del fv.changelist
+ module_or_id_changelist = testing_fn('changelist')
+ self.assertIsNone(module_or_id_changelist)
+ module_or_id_c = testing_fn('c')
+ self.assertEqual(module_or_id_c, current_module_or_id)
+ module_or_id_l = testing_fn('l')
+ self.assertEqual(module_or_id_l, alt_module_or_id)
+
+ def test_find_module_defining_flag(self):
+ self._test_find_module_or_id_defining_flag(test_id=False)
+
+ def test_find_module_id_defining_flag(self):
+ self._test_find_module_or_id_defining_flag(test_id=True)
+
+ def test_set_default(self):
+ fv = _flagvalues.FlagValues()
+ fv.mark_as_parsed()
+ with self.assertRaises(_exceptions.UnrecognizedFlagError):
+ fv.set_default('changelist', 1)
+ _defines.DEFINE_integer('changelist', 0, 'help', flag_values=fv)
+ self.assertEqual(0, fv.changelist)
+ fv.set_default('changelist', 2)
+ self.assertEqual(2, fv.changelist)
+
+ def test_default_gnu_getopt_value(self):
+ self.assertTrue(_flagvalues.FlagValues().is_gnu_getopt())
+
+ def test_known_only_flags_in_gnustyle(self):
+
+ def run_test(argv, defined_py_flags, expected_argv):
+ fv = _flagvalues.FlagValues()
+ fv.set_gnu_getopt(True)
+ for f in defined_py_flags:
+ if f.startswith('b'):
+ _defines.DEFINE_boolean(f, False, 'help', flag_values=fv)
+ else:
+ _defines.DEFINE_string(f, 'default', 'help', flag_values=fv)
+ output_argv = fv(argv, known_only=True)
+ self.assertEqual(expected_argv, output_argv)
+
+ run_test(
+ argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '),
+ defined_py_flags=[],
+ expected_argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '))
+ run_test(
+ argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '),
+ defined_py_flags=['f1'],
+ expected_argv='0 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '))
+ run_test(
+ argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '),
+ defined_py_flags=['f2'],
+ expected_argv='0 --f1=v1 cmd --b1 --f3 v3 --nob2'.split(' '))
+ run_test(
+ argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '),
+ defined_py_flags=['b1'],
+ expected_argv='0 --f1=v1 cmd --f2 v2 --f3 v3 --nob2'.split(' '))
+ run_test(
+ argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '),
+ defined_py_flags=['f3'],
+ expected_argv='0 --f1=v1 cmd --f2 v2 --b1 --nob2'.split(' '))
+ run_test(
+ argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '),
+ defined_py_flags=['b2'],
+ expected_argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3'.split(' '))
+ run_test(
+ argv=('0 --f1=v1 cmd --undefok=f1 --f2 v2 --b1 '
+ '--f3 v3 --nob2').split(' '),
+ defined_py_flags=['b2'],
+ expected_argv='0 cmd --f2 v2 --b1 --f3 v3'.split(' '))
+ run_test(
+ argv=('0 --f1=v1 cmd --undefok f1,f2 --f2 v2 --b1 '
+ '--f3 v3 --nob2').split(' '),
+ defined_py_flags=['b2'],
+ # Note v2 is preserved here, since undefok requires the flag being
+ # specified in the form of --flag=value.
+ expected_argv='0 cmd v2 --b1 --f3 v3'.split(' '))
+
+ def test_invalid_flag_name(self):
+ with self.assertRaises(_exceptions.Error):
+ _defines.DEFINE_boolean('test ', 0, '')
+
+ with self.assertRaises(_exceptions.Error):
+ _defines.DEFINE_boolean(' test', 0, '')
+
+ with self.assertRaises(_exceptions.Error):
+ _defines.DEFINE_boolean('te st', 0, '')
+
+ with self.assertRaises(_exceptions.Error):
+ _defines.DEFINE_boolean('', 0, '')
+
+ with self.assertRaises(_exceptions.Error):
+ _defines.DEFINE_boolean(1, 0, '')
+
+ def test_len(self):
+ fv = _flagvalues.FlagValues()
+ self.assertEmpty(fv)
+ self.assertFalse(fv)
+
+ _defines.DEFINE_boolean('boolean', False, 'help', flag_values=fv)
+ self.assertLen(fv, 1)
+ self.assertTrue(fv)
+
+ _defines.DEFINE_boolean(
+ 'bool', False, 'help', short_name='b', flag_values=fv)
+ self.assertLen(fv, 3)
+ self.assertTrue(fv)
+
+ def test_pickle(self):
+ fv = _flagvalues.FlagValues()
+ with self.assertRaisesRegex(TypeError, "can't pickle FlagValues"):
+ pickle.dumps(fv)
+
+ def test_copy(self):
+ fv = _flagvalues.FlagValues()
+ _defines.DEFINE_integer('answer', 0, 'help', flag_values=fv)
+ fv(['', '--answer=1'])
+
+ with self.assertRaisesRegex(TypeError,
+ 'FlagValues does not support shallow copies'):
+ copy.copy(fv)
+
+ fv2 = copy.deepcopy(fv)
+ self.assertEqual(fv2.answer, 1)
+
+ fv2.answer = 42
+ self.assertEqual(fv2.answer, 42)
+ self.assertEqual(fv.answer, 1)
+
+ def test_conflicting_flags(self):
+ fv = _flagvalues.FlagValues()
+ with self.assertRaises(_exceptions.FlagNameConflictsWithMethodError):
+ _defines.DEFINE_boolean('is_gnu_getopt', False, 'help', flag_values=fv)
+ _defines.DEFINE_boolean(
+ 'is_gnu_getopt',
+ False,
+ 'help',
+ flag_values=fv,
+ allow_using_method_names=True)
+ self.assertFalse(fv['is_gnu_getopt'].value)
+ self.assertIsInstance(fv.is_gnu_getopt, types.MethodType)
+
+ def test_get_flags_for_module(self):
+ fv = _flagvalues.FlagValues()
+ _defines.DEFINE_string('foo', None, 'help', flag_values=fv)
+ module_foo.define_flags(fv)
+ flags = fv.get_flags_for_module('__main__')
+
+ self.assertEqual({'foo'}, {flag.name for flag in flags})
+
+ flags = fv.get_flags_for_module(module_foo)
+ self.assertEqual({'tmod_foo_bool', 'tmod_foo_int', 'tmod_foo_str'},
+ {flag.name for flag in flags})
+
+ def test_get_help(self):
+ fv = _flagvalues.FlagValues()
+ self.assertMultiLineEqual('''\
+--flagfile: Insert flag definitions from the given file into the command line.
+ (default: '')
+--undefok: comma-separated list of flag names that it is okay to specify on the
+ command line even if the program does not define a flag with that name.
+ IMPORTANT: flags in this list that have arguments MUST use the --flag=value
+ format.
+ (default: '')''', fv.get_help())
+
+ module_foo.define_flags(fv)
+ self.assertMultiLineEqual('''
+absl.flags.tests.module_bar:
+ --tmod_bar_t: Sample int flag.
+ (default: '4')
+ (an integer)
+ --tmod_bar_u: Sample int flag.
+ (default: '5')
+ (an integer)
+ --tmod_bar_v: Sample int flag.
+ (default: '6')
+ (an integer)
+ --[no]tmod_bar_x: Boolean flag.
+ (default: 'true')
+ --tmod_bar_y: String flag.
+ (default: 'default')
+ --[no]tmod_bar_z: Another boolean flag from module bar.
+ (default: 'false')
+
+absl.flags.tests.module_foo:
+ --[no]tmod_foo_bool: Boolean flag from module foo.
+ (default: 'true')
+ --tmod_foo_int: Sample int flag.
+ (default: '3')
+ (an integer)
+ --tmod_foo_str: String flag.
+ (default: 'default')
+
+absl.flags:
+ --flagfile: Insert flag definitions from the given file into the command line.
+ (default: '')
+ --undefok: comma-separated list of flag names that it is okay to specify on
+ the command line even if the program does not define a flag with that name.
+ IMPORTANT: flags in this list that have arguments MUST use the --flag=value
+ format.
+ (default: '')''', fv.get_help())
+
+ self.assertMultiLineEqual('''
+xxxxabsl.flags.tests.module_bar:
+xxxx --tmod_bar_t: Sample int flag.
+xxxx (default: '4')
+xxxx (an integer)
+xxxx --tmod_bar_u: Sample int flag.
+xxxx (default: '5')
+xxxx (an integer)
+xxxx --tmod_bar_v: Sample int flag.
+xxxx (default: '6')
+xxxx (an integer)
+xxxx --[no]tmod_bar_x: Boolean flag.
+xxxx (default: 'true')
+xxxx --tmod_bar_y: String flag.
+xxxx (default: 'default')
+xxxx --[no]tmod_bar_z: Another boolean flag from module bar.
+xxxx (default: 'false')
+
+xxxxabsl.flags.tests.module_foo:
+xxxx --[no]tmod_foo_bool: Boolean flag from module foo.
+xxxx (default: 'true')
+xxxx --tmod_foo_int: Sample int flag.
+xxxx (default: '3')
+xxxx (an integer)
+xxxx --tmod_foo_str: String flag.
+xxxx (default: 'default')
+
+xxxxabsl.flags:
+xxxx --flagfile: Insert flag definitions from the given file into the command
+xxxx line.
+xxxx (default: '')
+xxxx --undefok: comma-separated list of flag names that it is okay to specify
+xxxx on the command line even if the program does not define a flag with that
+xxxx name. IMPORTANT: flags in this list that have arguments MUST use the
+xxxx --flag=value format.
+xxxx (default: '')''', fv.get_help(prefix='xxxx'))
+
+ self.assertMultiLineEqual('''
+absl.flags.tests.module_bar:
+ --tmod_bar_t: Sample int flag.
+ (default: '4')
+ (an integer)
+ --tmod_bar_u: Sample int flag.
+ (default: '5')
+ (an integer)
+ --tmod_bar_v: Sample int flag.
+ (default: '6')
+ (an integer)
+ --[no]tmod_bar_x: Boolean flag.
+ (default: 'true')
+ --tmod_bar_y: String flag.
+ (default: 'default')
+ --[no]tmod_bar_z: Another boolean flag from module bar.
+ (default: 'false')
+
+absl.flags.tests.module_foo:
+ --[no]tmod_foo_bool: Boolean flag from module foo.
+ (default: 'true')
+ --tmod_foo_int: Sample int flag.
+ (default: '3')
+ (an integer)
+ --tmod_foo_str: String flag.
+ (default: 'default')''', fv.get_help(include_special_flags=False))
+
+ def test_str(self):
+ fv = _flagvalues.FlagValues()
+ self.assertEqual(str(fv), fv.get_help())
+ module_foo.define_flags(fv)
+ self.assertEqual(str(fv), fv.get_help())
+
+ def test_empty_argv(self):
+ fv = _flagvalues.FlagValues()
+ with self.assertRaises(ValueError):
+ fv([])
+
+ def test_invalid_argv(self):
+ fv = _flagvalues.FlagValues()
+ with self.assertRaises(TypeError):
+ fv('./program')
+ with self.assertRaises(TypeError):
+ fv(b'./program')
+ with self.assertRaises(TypeError):
+ fv(u'./program')
+
+ def test_flags_dir(self):
+ flag_values = _flagvalues.FlagValues()
+ flag_name1 = 'bool_flag'
+ flag_name2 = 'string_flag'
+ flag_name3 = 'float_flag'
+ description = 'Description'
+ _defines.DEFINE_boolean(
+ flag_name1, None, description, flag_values=flag_values)
+ _defines.DEFINE_string(
+ flag_name2, None, description, flag_values=flag_values)
+ self.assertEqual(sorted([flag_name1, flag_name2]), dir(flag_values))
+
+ _defines.DEFINE_float(
+ flag_name3, None, description, flag_values=flag_values)
+ self.assertEqual(
+ sorted([flag_name1, flag_name2, flag_name3]), dir(flag_values))
+
+ def test_flags_into_string_deterministic(self):
+ flag_values = _flagvalues.FlagValues()
+ _defines.DEFINE_string(
+ 'fa', 'x', '', flag_values=flag_values, module_name='mb')
+ _defines.DEFINE_string(
+ 'fb', 'x', '', flag_values=flag_values, module_name='mb')
+ _defines.DEFINE_string(
+ 'fc', 'x', '', flag_values=flag_values, module_name='ma')
+ _defines.DEFINE_string(
+ 'fd', 'x', '', flag_values=flag_values, module_name='ma')
+
+ expected = ('--fc=x\n'
+ '--fd=x\n'
+ '--fa=x\n'
+ '--fb=x\n')
+
+ flags_by_module_items = sorted(
+ flag_values.flags_by_module_dict().items(), reverse=True)
+ for _, module_flags in flags_by_module_items:
+ module_flags.sort(reverse=True)
+
+ flag_values.__dict__['__flags_by_module'] = collections.OrderedDict(
+ flags_by_module_items)
+
+ actual = flag_values.flags_into_string()
+ self.assertEqual(expected, actual)
+
+ def test_validate_all_flags(self):
+ fv = _flagvalues.FlagValues()
+ _defines.DEFINE_string('name', None, '', flag_values=fv)
+ _validators.mark_flag_as_required('name', flag_values=fv)
+ with self.assertRaises(_exceptions.IllegalFlagValueError):
+ fv.validate_all_flags()
+ fv.name = 'test'
+ fv.validate_all_flags()
+
+
+class FlagValuesLoggingTest(absltest.TestCase):
+ """Test to make sure logging.* functions won't recurse.
+
+ Logging may and does happen before flags initialization. We need to make
+ sure that any warnings trown by flagvalues do not result in unlimited
+ recursion.
+ """
+
+ def test_logging_do_not_recurse(self):
+ logging.info('test info')
+ try:
+ raise ValueError('test exception')
+ except ValueError:
+ logging.exception('test message')
+
+
+class FlagSubstrMatchingTests(parameterized.TestCase):
+ """Tests related to flag substring matching."""
+
+ def _get_test_flag_values(self):
+ """Get a _flagvalues.FlagValues() instance, set up for tests."""
+ flag_values = _flagvalues.FlagValues()
+
+ _defines.DEFINE_string('strf', '', '', flag_values=flag_values)
+ _defines.DEFINE_boolean('boolf', 0, '', flag_values=flag_values)
+
+ return flag_values
+
+ # Test cases that should always make parsing raise an error.
+ # Tuples of strings with the argv to use.
+ FAIL_TEST_CASES = [
+ ('./program', '--boo', '0'),
+ ('./program', '--boo=true', '0'),
+ ('./program', '--boo=0'),
+ ('./program', '--noboo'),
+ ('./program', '--st=blah'),
+ ('./program', '--st=de'),
+ ('./program', '--st=blah', '--boo'),
+ ('./program', '--st=blah', 'unused'),
+ ('./program', '--st=--blah'),
+ ('./program', '--st', '--blah'),
+ ]
+
+ @parameterized.parameters(FAIL_TEST_CASES)
+ def test_raise(self, *argv):
+ """Test that raising works."""
+ fv = self._get_test_flag_values()
+ with self.assertRaises(_exceptions.UnrecognizedFlagError):
+ fv(argv)
+
+ @parameterized.parameters(
+ FAIL_TEST_CASES + [('./program', 'unused', '--st=blah')])
+ def test_gnu_getopt_raise(self, *argv):
+ """Test that raising works when combined with GNU-style getopt."""
+ fv = self._get_test_flag_values()
+ fv.set_gnu_getopt()
+ with self.assertRaises(_exceptions.UnrecognizedFlagError):
+ fv(argv)
+
+
+class SettingUnknownFlagTest(absltest.TestCase):
+
+ def setUp(self):
+ super(SettingUnknownFlagTest, self).setUp()
+ self.setter_called = 0
+
+ def set_undef(self, unused_name, unused_val):
+ self.setter_called += 1
+
+ def test_raise_on_undefined(self):
+ new_flags = _flagvalues.FlagValues()
+ with self.assertRaises(_exceptions.UnrecognizedFlagError):
+ new_flags.undefined_flag = 0
+
+ def test_not_raise(self):
+ new_flags = _flagvalues.FlagValues()
+ new_flags._register_unknown_flag_setter(self.set_undef)
+ new_flags.undefined_flag = 0
+ self.assertEqual(self.setter_called, 1)
+
+ def test_not_raise_on_undefined_if_undefok(self):
+ new_flags = _flagvalues.FlagValues()
+ args = ['0', '--foo', '--bar=1', '--undefok=foo,bar']
+ unparsed = new_flags(args, known_only=True)
+ self.assertEqual(['0'], unparsed)
+
+ def test_re_raise_undefined(self):
+ def setter(unused_name, unused_val):
+ raise NameError()
+ new_flags = _flagvalues.FlagValues()
+ new_flags._register_unknown_flag_setter(setter)
+ with self.assertRaises(_exceptions.UnrecognizedFlagError):
+ new_flags.undefined_flag = 0
+
+ def test_re_raise_invalid(self):
+ def setter(unused_name, unused_val):
+ raise ValueError()
+ new_flags = _flagvalues.FlagValues()
+ new_flags._register_unknown_flag_setter(setter)
+ with self.assertRaises(_exceptions.IllegalFlagValueError):
+ new_flags.undefined_flag = 0
+
+
+class SetAttributesTest(absltest.TestCase):
+
+ def setUp(self):
+ super(SetAttributesTest, self).setUp()
+ self.new_flags = _flagvalues.FlagValues()
+ _defines.DEFINE_boolean(
+ 'defined_flag', None, '', flag_values=self.new_flags)
+ _defines.DEFINE_boolean(
+ 'another_defined_flag', None, '', flag_values=self.new_flags)
+ self.setter_called = 0
+
+ def set_undef(self, unused_name, unused_val):
+ self.setter_called += 1
+
+ def test_two_defined_flags(self):
+ self.new_flags._set_attributes(
+ defined_flag=False, another_defined_flag=False)
+ self.assertEqual(self.setter_called, 0)
+
+ def test_one_defined_one_undefined_flag(self):
+ with self.assertRaises(_exceptions.UnrecognizedFlagError):
+ self.new_flags._set_attributes(defined_flag=False, undefined_flag=0)
+
+ def test_register_unknown_flag_setter(self):
+ self.new_flags._register_unknown_flag_setter(self.set_undef)
+ self.new_flags._set_attributes(defined_flag=False, undefined_flag=0)
+ self.assertEqual(self.setter_called, 1)
+
+
+class FlagsDashSyntaxTest(absltest.TestCase):
+
+ def setUp(self):
+ super(FlagsDashSyntaxTest, self).setUp()
+ self.fv = _flagvalues.FlagValues()
+ _defines.DEFINE_string(
+ 'long_name', 'default', 'help', flag_values=self.fv, short_name='s')
+
+ def test_long_name_one_dash(self):
+ self.fv(['./program', '-long_name=new'])
+ self.assertEqual('new', self.fv.long_name)
+
+ def test_long_name_two_dashes(self):
+ self.fv(['./program', '--long_name=new'])
+ self.assertEqual('new', self.fv.long_name)
+
+ def test_long_name_three_dashes(self):
+ with self.assertRaises(_exceptions.UnrecognizedFlagError):
+ self.fv(['./program', '---long_name=new'])
+
+ def test_short_name_one_dash(self):
+ self.fv(['./program', '-s=new'])
+ self.assertEqual('new', self.fv.s)
+
+ def test_short_name_two_dashes(self):
+ self.fv(['./program', '--s=new'])
+ self.assertEqual('new', self.fv.s)
+
+ def test_short_name_three_dashes(self):
+ with self.assertRaises(_exceptions.UnrecognizedFlagError):
+ self.fv(['./program', '---s=new'])
+
+
+class UnparseFlagsTest(absltest.TestCase):
+
+ def test_using_default_value_none(self):
+ fv = _flagvalues.FlagValues()
+ _defines.DEFINE_string('default_none', None, 'help', flag_values=fv)
+ self.assertTrue(fv['default_none'].using_default_value)
+ fv(['', '--default_none=notNone'])
+ self.assertFalse(fv['default_none'].using_default_value)
+ fv.unparse_flags()
+ self.assertTrue(fv['default_none'].using_default_value)
+ fv(['', '--default_none=alsoNotNone'])
+ self.assertFalse(fv['default_none'].using_default_value)
+ fv.unparse_flags()
+ self.assertTrue(fv['default_none'].using_default_value)
+
+ def test_using_default_value_not_none(self):
+ fv = _flagvalues.FlagValues()
+ _defines.DEFINE_string('default_foo', 'foo', 'help', flag_values=fv)
+
+ fv.mark_as_parsed()
+ self.assertTrue(fv['default_foo'].using_default_value)
+
+ fv(['', '--default_foo=foo'])
+ self.assertFalse(fv['default_foo'].using_default_value)
+
+ fv(['', '--default_foo=notFoo'])
+ self.assertFalse(fv['default_foo'].using_default_value)
+
+ fv.unparse_flags()
+ self.assertTrue(fv['default_foo'].using_default_value)
+
+ fv(['', '--default_foo=alsoNotFoo'])
+ self.assertFalse(fv['default_foo'].using_default_value)
+
+ def test_allow_overwrite_false(self):
+ fv = _flagvalues.FlagValues()
+ _defines.DEFINE_string(
+ 'default_none', None, 'help', allow_overwrite=False, flag_values=fv)
+ _defines.DEFINE_string(
+ 'default_foo', 'foo', 'help', allow_overwrite=False, flag_values=fv)
+
+ fv.mark_as_parsed()
+ self.assertEqual('foo', fv.default_foo)
+ self.assertIsNone(fv.default_none)
+
+ fv(['', '--default_foo=notFoo', '--default_none=notNone'])
+ self.assertEqual('notFoo', fv.default_foo)
+ self.assertEqual('notNone', fv.default_none)
+
+ fv.unparse_flags()
+ self.assertEqual('foo', fv['default_foo'].value)
+ self.assertIsNone(fv['default_none'].value)
+
+ fv(['', '--default_foo=alsoNotFoo', '--default_none=alsoNotNone'])
+ self.assertEqual('alsoNotFoo', fv.default_foo)
+ self.assertEqual('alsoNotNone', fv.default_none)
+
+ def test_multi_string_default_none(self):
+ fv = _flagvalues.FlagValues()
+ _defines.DEFINE_multi_string('foo', None, 'help', flag_values=fv)
+ fv.mark_as_parsed()
+ self.assertIsNone(fv.foo)
+ fv(['', '--foo=aa'])
+ self.assertEqual(['aa'], fv.foo)
+ fv.unparse_flags()
+ self.assertIsNone(fv['foo'].value)
+ fv(['', '--foo=bb', '--foo=cc'])
+ self.assertEqual(['bb', 'cc'], fv.foo)
+ fv.unparse_flags()
+ self.assertIsNone(fv['foo'].value)
+
+ def test_multi_string_default_string(self):
+ fv = _flagvalues.FlagValues()
+ _defines.DEFINE_multi_string('foo', 'xyz', 'help', flag_values=fv)
+ expected_default = ['xyz']
+ fv.mark_as_parsed()
+ self.assertEqual(expected_default, fv.foo)
+ fv(['', '--foo=aa'])
+ self.assertEqual(['aa'], fv.foo)
+ fv.unparse_flags()
+ self.assertEqual(expected_default, fv['foo'].value)
+ fv(['', '--foo=bb', '--foo=cc'])
+ self.assertEqual(['bb', 'cc'], fv['foo'].value)
+ fv.unparse_flags()
+ self.assertEqual(expected_default, fv['foo'].value)
+
+ def test_multi_string_default_list(self):
+ fv = _flagvalues.FlagValues()
+ _defines.DEFINE_multi_string(
+ 'foo', ['xx', 'yy', 'zz'], 'help', flag_values=fv)
+ expected_default = ['xx', 'yy', 'zz']
+ fv.mark_as_parsed()
+ self.assertEqual(expected_default, fv.foo)
+ fv(['', '--foo=aa'])
+ self.assertEqual(['aa'], fv.foo)
+ fv.unparse_flags()
+ self.assertEqual(expected_default, fv['foo'].value)
+ fv(['', '--foo=bb', '--foo=cc'])
+ self.assertEqual(['bb', 'cc'], fv.foo)
+ fv.unparse_flags()
+ self.assertEqual(expected_default, fv['foo'].value)
+
+
+class UnparsedFlagAccessTest(absltest.TestCase):
+
+ def test_unparsed_flag_access(self):
+ fv = _flagvalues.FlagValues()
+ _defines.DEFINE_string('name', 'default', 'help', flag_values=fv)
+ with self.assertRaises(_exceptions.UnparsedFlagAccessError):
+ _ = fv.name
+
+ def test_hasattr_raises_in_py3(self):
+ fv = _flagvalues.FlagValues()
+ _defines.DEFINE_string('name', 'default', 'help', flag_values=fv)
+ with self.assertRaises(_exceptions.UnparsedFlagAccessError):
+ _ = hasattr(fv, 'name')
+
+ def test_unparsed_flags_access_raises_after_unparse_flags(self):
+ fv = _flagvalues.FlagValues()
+ _defines.DEFINE_string('a_str', 'default_value', 'help', flag_values=fv)
+ fv.mark_as_parsed()
+ self.assertEqual(fv.a_str, 'default_value')
+ fv.unparse_flags()
+ with self.assertRaises(_exceptions.UnparsedFlagAccessError):
+ _ = fv.a_str
+
+
+class FlagHolderTest(absltest.TestCase):
+
+ def setUp(self):
+ super(FlagHolderTest, self).setUp()
+ self.fv = _flagvalues.FlagValues()
+ self.name_flag = _defines.DEFINE_string(
+ 'name', 'default', 'help', flag_values=self.fv)
+
+ def parse_flags(self, *argv):
+ self.fv.unparse_flags()
+ self.fv(['binary_name'] + list(argv))
+
+ def test_name(self):
+ self.assertEqual('name', self.name_flag.name)
+
+ def test_value_before_flag_parsing(self):
+ with self.assertRaises(_exceptions.UnparsedFlagAccessError):
+ _ = self.name_flag.value
+
+ def test_value_returns_default_value_if_not_explicitly_set(self):
+ self.parse_flags()
+ self.assertEqual('default', self.name_flag.value)
+
+ def test_value_returns_explicitly_set_value(self):
+ self.parse_flags('--name=new_value')
+ self.assertEqual('new_value', self.name_flag.value)
+
+ def test_present_returns_false_before_flag_parsing(self):
+ self.assertFalse(self.name_flag.present)
+
+ def test_present_returns_false_if_not_explicitly_set(self):
+ self.parse_flags()
+ self.assertFalse(self.name_flag.present)
+
+ def test_present_returns_true_if_explicitly_set(self):
+ self.parse_flags('--name=new_value')
+ self.assertTrue(self.name_flag.present)
+
+ def test_allow_override(self):
+ first = _defines.DEFINE_integer(
+ 'int_flag', 1, 'help', flag_values=self.fv, allow_override=1)
+ second = _defines.DEFINE_integer(
+ 'int_flag', 2, 'help', flag_values=self.fv, allow_override=1)
+ self.parse_flags('--int_flag=3')
+ self.assertEqual(3, first.value)
+ self.assertEqual(3, second.value)
+ self.assertTrue(first.present)
+ self.assertTrue(second.present)
+
+ def test_eq(self):
+ with self.assertRaises(TypeError):
+ self.name_flag == 'value' # pylint: disable=pointless-statement
+
+ def test_eq_reflection(self):
+ with self.assertRaises(TypeError):
+ 'value' == self.name_flag # pylint: disable=pointless-statement
+
+ def test_bool(self):
+ with self.assertRaises(TypeError):
+ bool(self.name_flag)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/flags/tests/_helpers_test.py b/absl/flags/tests/_helpers_test.py
new file mode 100644
index 0000000..4746a79
--- /dev/null
+++ b/absl/flags/tests/_helpers_test.py
@@ -0,0 +1,173 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unittests for helpers module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+from absl.flags import _helpers
+from absl.flags.tests import module_bar
+from absl.flags.tests import module_foo
+from absl.testing import absltest
+
+
+class FlagSuggestionTest(absltest.TestCase):
+
+ def setUp(self):
+ self.longopts = [
+ 'fsplit-ivs-in-unroller=',
+ 'fsplit-wide-types=',
+ 'fstack-protector=',
+ 'fstack-protector-all=',
+ 'fstrict-aliasing=',
+ 'fstrict-overflow=',
+ 'fthread-jumps=',
+ 'ftracer',
+ 'ftree-bit-ccp',
+ 'ftree-builtin-call-dce',
+ 'ftree-ccp',
+ 'ftree-ch']
+
+ def test_damerau_levenshtein_id(self):
+ self.assertEqual(0, _helpers._damerau_levenshtein('asdf', 'asdf'))
+
+ def test_damerau_levenshtein_empty(self):
+ self.assertEqual(5, _helpers._damerau_levenshtein('', 'kites'))
+ self.assertEqual(6, _helpers._damerau_levenshtein('kitten', ''))
+
+ def test_damerau_levenshtein_commutative(self):
+ self.assertEqual(2, _helpers._damerau_levenshtein('kitten', 'kites'))
+ self.assertEqual(2, _helpers._damerau_levenshtein('kites', 'kitten'))
+
+ def test_damerau_levenshtein_transposition(self):
+ self.assertEqual(1, _helpers._damerau_levenshtein('kitten', 'ktiten'))
+
+ def test_mispelled_suggestions(self):
+ suggestions = _helpers.get_flag_suggestions('fstack_protector_all',
+ self.longopts)
+ self.assertEqual(['fstack-protector-all'], suggestions)
+
+ def test_ambiguous_prefix_suggestion(self):
+ suggestions = _helpers.get_flag_suggestions('fstack', self.longopts)
+ self.assertEqual(['fstack-protector', 'fstack-protector-all'], suggestions)
+
+ def test_misspelled_ambiguous_prefix_suggestion(self):
+ suggestions = _helpers.get_flag_suggestions('stack', self.longopts)
+ self.assertEqual(['fstack-protector', 'fstack-protector-all'], suggestions)
+
+ def test_crazy_suggestion(self):
+ suggestions = _helpers.get_flag_suggestions('asdfasdgasdfa', self.longopts)
+ self.assertEqual([], suggestions)
+
+ def test_suggestions_are_sorted(self):
+ sorted_flags = sorted(['aab', 'aac', 'aad'])
+ misspelt_flag = 'aaa'
+ suggestions = _helpers.get_flag_suggestions(misspelt_flag,
+ reversed(sorted_flags))
+ self.assertEqual(sorted_flags, suggestions)
+
+
+class GetCallingModuleTest(absltest.TestCase):
+ """Test whether we correctly determine the module which defines the flag."""
+
+ def test_get_calling_module(self):
+ self.assertEqual(_helpers.get_calling_module(), sys.argv[0])
+ self.assertEqual(module_foo.get_module_name(),
+ 'absl.flags.tests.module_foo')
+ self.assertEqual(module_bar.get_module_name(),
+ 'absl.flags.tests.module_bar')
+
+ # We execute the following exec statements for their side-effect
+ # (i.e., not raising an error). They emphasize the case that not
+ # all code resides in one of the imported modules: Python is a
+ # really dynamic language, where we can dynamically construct some
+ # code and execute it.
+ code = ('from absl.flags import _helpers\n'
+ 'module_name = _helpers.get_calling_module()')
+ exec(code) # pylint: disable=exec-used
+
+ # Next two exec statements executes code with a global environment
+ # that is different from the global environment of any imported
+ # module.
+ exec(code, {}) # pylint: disable=exec-used
+ # vars(self) returns a dictionary corresponding to the symbol
+ # table of the self object. dict(...) makes a distinct copy of
+ # this dictionary, such that any new symbol definition by the
+ # exec-ed code (e.g., import flags, module_name = ...) does not
+ # affect the symbol table of self.
+ exec(code, dict(vars(self))) # pylint: disable=exec-used
+
+ # Next test is actually more involved: it checks not only that
+ # get_calling_module does not crash inside exec code, it also checks
+ # that it returns the expected value: the code executed via exec
+ # code is treated as being executed by the current module. We
+ # check it twice: first time by executing exec from the main
+ # module, second time by executing it from module_bar.
+ global_dict = {}
+ exec(code, global_dict) # pylint: disable=exec-used
+ self.assertEqual(global_dict['module_name'],
+ sys.argv[0])
+
+ global_dict = {}
+ module_bar.execute_code(code, global_dict)
+ self.assertEqual(global_dict['module_name'],
+ 'absl.flags.tests.module_bar')
+
+ def test_get_calling_module_with_iteritems_error(self):
+ # This test checks that get_calling_module is using
+ # sys.modules.items(), instead of .iteritems().
+ orig_sys_modules = sys.modules
+
+ # Mock sys.modules: simulates error produced by importing a module
+ # in parallel with our iteration over sys.modules.iteritems().
+ class SysModulesMock(dict):
+
+ def __init__(self, original_content):
+ dict.__init__(self, original_content)
+
+ def iteritems(self):
+ # Any dictionary method is fine, but not .iteritems().
+ raise RuntimeError('dictionary changed size during iteration')
+
+ sys.modules = SysModulesMock(orig_sys_modules)
+ try:
+ # _get_calling_module should still work as expected:
+ self.assertEqual(_helpers.get_calling_module(), sys.argv[0])
+ self.assertEqual(module_foo.get_module_name(),
+ 'absl.flags.tests.module_foo')
+ finally:
+ sys.modules = orig_sys_modules
+
+
+class IsBytesOrString(absltest.TestCase):
+
+ def test_bytes(self):
+ self.assertTrue(_helpers.is_bytes_or_string(b'bytes'))
+
+ def test_str(self):
+ self.assertTrue(_helpers.is_bytes_or_string('str'))
+
+ def test_unicode(self):
+ self.assertTrue(_helpers.is_bytes_or_string(u'unicode'))
+
+ def test_list(self):
+ self.assertFalse(_helpers.is_bytes_or_string(['str']))
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/flags/tests/_validators_test.py b/absl/flags/tests/_validators_test.py
new file mode 100644
index 0000000..f724813
--- /dev/null
+++ b/absl/flags/tests/_validators_test.py
@@ -0,0 +1,744 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Testing that flags validators framework does work.
+
+This file tests that each flag validator called when it should be, and that
+failed validator will throw an exception, etc.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import warnings
+
+
+from absl.flags import _defines
+from absl.flags import _exceptions
+from absl.flags import _flagvalues
+from absl.flags import _validators
+from absl.testing import absltest
+
+
+class SingleFlagValidatorTest(absltest.TestCase):
+ """Testing _validators.register_validator() method."""
+
+ def setUp(self):
+ super(SingleFlagValidatorTest, self).setUp()
+ self.flag_values = _flagvalues.FlagValues()
+ self.call_args = []
+
+ def test_success(self):
+ def checker(x):
+ self.call_args.append(x)
+ return True
+ _defines.DEFINE_integer(
+ 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
+ _validators.register_validator(
+ 'test_flag',
+ checker,
+ message='Errors happen',
+ flag_values=self.flag_values)
+
+ argv = ('./program',)
+ self.flag_values(argv)
+ self.assertIsNone(self.flag_values.test_flag)
+ self.flag_values.test_flag = 2
+ self.assertEqual(2, self.flag_values.test_flag)
+ self.assertEqual([None, 2], self.call_args)
+
+ def test_default_value_not_used_success(self):
+ def checker(x):
+ self.call_args.append(x)
+ return True
+ _defines.DEFINE_integer(
+ 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
+ _validators.register_validator(
+ 'test_flag',
+ checker,
+ message='Errors happen',
+ flag_values=self.flag_values)
+
+ argv = ('./program', '--test_flag=1')
+ self.flag_values(argv)
+ self.assertEqual(1, self.flag_values.test_flag)
+ self.assertEqual([1], self.call_args)
+
+ def test_validator_not_called_when_other_flag_is_changed(self):
+ def checker(x):
+ self.call_args.append(x)
+ return True
+ _defines.DEFINE_integer(
+ 'test_flag', 1, 'Usual integer flag', flag_values=self.flag_values)
+ _defines.DEFINE_integer(
+ 'other_flag', 2, 'Other integer flag', flag_values=self.flag_values)
+ _validators.register_validator(
+ 'test_flag',
+ checker,
+ message='Errors happen',
+ flag_values=self.flag_values)
+
+ argv = ('./program',)
+ self.flag_values(argv)
+ self.assertEqual(1, self.flag_values.test_flag)
+ self.flag_values.other_flag = 3
+ self.assertEqual([1], self.call_args)
+
+ def test_exception_raised_if_checker_fails(self):
+ def checker(x):
+ self.call_args.append(x)
+ return x == 1
+ _defines.DEFINE_integer(
+ 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
+ _validators.register_validator(
+ 'test_flag',
+ checker,
+ message='Errors happen',
+ flag_values=self.flag_values)
+
+ argv = ('./program', '--test_flag=1')
+ self.flag_values(argv)
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values.test_flag = 2
+ self.assertEqual('flag --test_flag=2: Errors happen', str(cm.exception))
+ self.assertEqual([1, 2], self.call_args)
+
+ def test_exception_raised_if_checker_raises_exception(self):
+ def checker(x):
+ self.call_args.append(x)
+ if x == 1:
+ return True
+ raise _exceptions.ValidationError('Specific message')
+
+ _defines.DEFINE_integer(
+ 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
+ _validators.register_validator(
+ 'test_flag',
+ checker,
+ message='Errors happen',
+ flag_values=self.flag_values)
+
+ argv = ('./program', '--test_flag=1')
+ self.flag_values(argv)
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values.test_flag = 2
+ self.assertEqual('flag --test_flag=2: Specific message', str(cm.exception))
+ self.assertEqual([1, 2], self.call_args)
+
+ def test_error_message_when_checker_returns_false_on_start(self):
+ def checker(x):
+ self.call_args.append(x)
+ return False
+ _defines.DEFINE_integer(
+ 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
+ _validators.register_validator(
+ 'test_flag',
+ checker,
+ message='Errors happen',
+ flag_values=self.flag_values)
+
+ argv = ('./program', '--test_flag=1')
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values(argv)
+ self.assertEqual('flag --test_flag=1: Errors happen', str(cm.exception))
+ self.assertEqual([1], self.call_args)
+
+ def test_error_message_when_checker_raises_exception_on_start(self):
+ def checker(x):
+ self.call_args.append(x)
+ raise _exceptions.ValidationError('Specific message')
+
+ _defines.DEFINE_integer(
+ 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
+ _validators.register_validator(
+ 'test_flag',
+ checker,
+ message='Errors happen',
+ flag_values=self.flag_values)
+
+ argv = ('./program', '--test_flag=1')
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values(argv)
+ self.assertEqual('flag --test_flag=1: Specific message', str(cm.exception))
+ self.assertEqual([1], self.call_args)
+
+ def test_validators_checked_in_order(self):
+
+ def required(x):
+ self.calls.append('required')
+ return x is not None
+
+ def even(x):
+ self.calls.append('even')
+ return x % 2 == 0
+
+ self.calls = []
+ self._define_flag_and_validators(required, even)
+ self.assertEqual(['required', 'even'], self.calls)
+
+ self.calls = []
+ self._define_flag_and_validators(even, required)
+ self.assertEqual(['even', 'required'], self.calls)
+
+ def _define_flag_and_validators(self, first_validator, second_validator):
+ local_flags = _flagvalues.FlagValues()
+ _defines.DEFINE_integer(
+ 'test_flag', 2, 'test flag', flag_values=local_flags)
+ _validators.register_validator(
+ 'test_flag', first_validator, message='', flag_values=local_flags)
+ _validators.register_validator(
+ 'test_flag', second_validator, message='', flag_values=local_flags)
+ argv = ('./program',)
+ local_flags(argv)
+
+ def test_validator_as_decorator(self):
+ _defines.DEFINE_integer(
+ 'test_flag', None, 'Simple integer flag', flag_values=self.flag_values)
+
+ @_validators.validator('test_flag', flag_values=self.flag_values)
+ def checker(x):
+ self.call_args.append(x)
+ return True
+
+ argv = ('./program',)
+ self.flag_values(argv)
+ self.assertIsNone(self.flag_values.test_flag)
+ self.flag_values.test_flag = 2
+ self.assertEqual(2, self.flag_values.test_flag)
+ self.assertEqual([None, 2], self.call_args)
+ # Check that 'Checker' is still a function and has not been replaced.
+ self.assertTrue(checker(3))
+ self.assertEqual([None, 2, 3], self.call_args)
+
+
+class MultiFlagsValidatorTest(absltest.TestCase):
+ """Test flags multi-flag validators."""
+
+ def setUp(self):
+ super(MultiFlagsValidatorTest, self).setUp()
+ self.flag_values = _flagvalues.FlagValues()
+ self.call_args = []
+ _defines.DEFINE_integer(
+ 'foo', 1, 'Usual integer flag', flag_values=self.flag_values)
+ _defines.DEFINE_integer(
+ 'bar', 2, 'Usual integer flag', flag_values=self.flag_values)
+
+ def test_success(self):
+ def checker(flags_dict):
+ self.call_args.append(flags_dict)
+ return True
+ _validators.register_multi_flags_validator(
+ ['foo', 'bar'], checker, flag_values=self.flag_values)
+
+ argv = ('./program', '--bar=2')
+ self.flag_values(argv)
+ self.assertEqual(1, self.flag_values.foo)
+ self.assertEqual(2, self.flag_values.bar)
+ self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
+ self.flag_values.foo = 3
+ self.assertEqual(3, self.flag_values.foo)
+ self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 2}],
+ self.call_args)
+
+ def test_validator_not_called_when_other_flag_is_changed(self):
+ def checker(flags_dict):
+ self.call_args.append(flags_dict)
+ return True
+ _defines.DEFINE_integer(
+ 'other_flag', 3, 'Other integer flag', flag_values=self.flag_values)
+ _validators.register_multi_flags_validator(
+ ['foo', 'bar'], checker, flag_values=self.flag_values)
+
+ argv = ('./program',)
+ self.flag_values(argv)
+ self.flag_values.other_flag = 3
+ self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
+
+ def test_exception_raised_if_checker_fails(self):
+ def checker(flags_dict):
+ self.call_args.append(flags_dict)
+ values = flags_dict.values()
+ # Make sure all the flags have different values.
+ return len(set(values)) == len(values)
+ _validators.register_multi_flags_validator(
+ ['foo', 'bar'],
+ checker,
+ message='Errors happen',
+ flag_values=self.flag_values)
+
+ argv = ('./program',)
+ self.flag_values(argv)
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values.bar = 1
+ self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception))
+ self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
+ self.call_args)
+
+ def test_exception_raised_if_checker_raises_exception(self):
+ def checker(flags_dict):
+ self.call_args.append(flags_dict)
+ values = flags_dict.values()
+ # Make sure all the flags have different values.
+ if len(set(values)) != len(values):
+ raise _exceptions.ValidationError('Specific message')
+ return True
+
+ _validators.register_multi_flags_validator(
+ ['foo', 'bar'],
+ checker,
+ message='Errors happen',
+ flag_values=self.flag_values)
+
+ argv = ('./program',)
+ self.flag_values(argv)
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values.bar = 1
+ self.assertEqual('flags foo=1, bar=1: Specific message', str(cm.exception))
+ self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
+ self.call_args)
+
+ def test_decorator(self):
+ @_validators.multi_flags_validator(
+ ['foo', 'bar'], message='Errors happen', flag_values=self.flag_values)
+ def checker(flags_dict): # pylint: disable=unused-variable
+ self.call_args.append(flags_dict)
+ values = flags_dict.values()
+ # Make sure all the flags have different values.
+ return len(set(values)) == len(values)
+
+ argv = ('./program',)
+ self.flag_values(argv)
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values.bar = 1
+ self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception))
+ self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
+ self.call_args)
+
+
+class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
+
+ def setUp(self):
+ super(MarkFlagsAsMutualExclusiveTest, self).setUp()
+ self.flag_values = _flagvalues.FlagValues()
+
+ _defines.DEFINE_string(
+ 'flag_one', None, 'flag one', flag_values=self.flag_values)
+ _defines.DEFINE_string(
+ 'flag_two', None, 'flag two', flag_values=self.flag_values)
+ _defines.DEFINE_string(
+ 'flag_three', None, 'flag three', flag_values=self.flag_values)
+ _defines.DEFINE_integer(
+ 'int_flag_one', None, 'int flag one', flag_values=self.flag_values)
+ _defines.DEFINE_integer(
+ 'int_flag_two', None, 'int flag two', flag_values=self.flag_values)
+ _defines.DEFINE_multi_string(
+ 'multi_flag_one', None, 'multi flag one', flag_values=self.flag_values)
+ _defines.DEFINE_multi_string(
+ 'multi_flag_two', None, 'multi flag two', flag_values=self.flag_values)
+ _defines.DEFINE_boolean(
+ 'flag_not_none', False, 'false default', flag_values=self.flag_values)
+
+ def _mark_flags_as_mutually_exclusive(self, flag_names, required):
+ _validators.mark_flags_as_mutual_exclusive(
+ flag_names, required=required, flag_values=self.flag_values)
+
+ def test_no_flags_present(self):
+ self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], False)
+ argv = ('./program',)
+
+ self.flag_values(argv)
+ self.assertIsNone(self.flag_values.flag_one)
+ self.assertIsNone(self.flag_values.flag_two)
+
+ def test_no_flags_present_required(self):
+ self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
+ argv = ('./program',)
+ expected = (
+ 'flags flag_one=None, flag_two=None: '
+ 'Exactly one of (flag_one, flag_two) must have a value other than '
+ 'None.')
+
+ self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
+ expected, self.flag_values, argv)
+
+ def test_one_flag_present(self):
+ self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], False)
+ self.flag_values(('./program', '--flag_one=1'))
+ self.assertEqual('1', self.flag_values.flag_one)
+
+ def test_one_flag_present_required(self):
+ self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
+ self.flag_values(('./program', '--flag_two=2'))
+ self.assertEqual('2', self.flag_values.flag_two)
+
+ def test_one_flag_zero_required(self):
+ self._mark_flags_as_mutually_exclusive(
+ ['int_flag_one', 'int_flag_two'], True)
+ self.flag_values(('./program', '--int_flag_one=0'))
+ self.assertEqual(0, self.flag_values.int_flag_one)
+
+ def test_mutual_exclusion_with_extra_flags(self):
+ self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
+ argv = ('./program', '--flag_two=2', '--flag_three=3')
+
+ self.flag_values(argv)
+ self.assertEqual('2', self.flag_values.flag_two)
+ self.assertEqual('3', self.flag_values.flag_three)
+
+ def test_mutual_exclusion_with_zero(self):
+ self._mark_flags_as_mutually_exclusive(
+ ['int_flag_one', 'int_flag_two'], False)
+ argv = ('./program', '--int_flag_one=0', '--int_flag_two=0')
+ expected = (
+ 'flags int_flag_one=0, int_flag_two=0: '
+ 'At most one of (int_flag_one, int_flag_two) must have a value other '
+ 'than None.')
+
+ self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
+ expected, self.flag_values, argv)
+
+ def test_multiple_flags_present(self):
+ self._mark_flags_as_mutually_exclusive(
+ ['flag_one', 'flag_two', 'flag_three'], False)
+ argv = ('./program', '--flag_one=1', '--flag_two=2', '--flag_three=3')
+ expected = (
+ 'flags flag_one=1, flag_two=2, flag_three=3: '
+ 'At most one of (flag_one, flag_two, flag_three) must have a value '
+ 'other than None.')
+
+ self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
+ expected, self.flag_values, argv)
+
+ def test_multiple_flags_present_required(self):
+ self._mark_flags_as_mutually_exclusive(
+ ['flag_one', 'flag_two', 'flag_three'], True)
+ argv = ('./program', '--flag_one=1', '--flag_two=2', '--flag_three=3')
+ expected = (
+ 'flags flag_one=1, flag_two=2, flag_three=3: '
+ 'Exactly one of (flag_one, flag_two, flag_three) must have a value '
+ 'other than None.')
+
+ self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
+ expected, self.flag_values, argv)
+
+ def test_no_multiflags_present(self):
+ self._mark_flags_as_mutually_exclusive(
+ ['multi_flag_one', 'multi_flag_two'], False)
+ argv = ('./program',)
+ self.flag_values(argv)
+ self.assertIsNone(self.flag_values.multi_flag_one)
+ self.assertIsNone(self.flag_values.multi_flag_two)
+
+ def test_no_multistring_flags_present_required(self):
+ self._mark_flags_as_mutually_exclusive(
+ ['multi_flag_one', 'multi_flag_two'], True)
+ argv = ('./program',)
+ expected = (
+ 'flags multi_flag_one=None, multi_flag_two=None: '
+ 'Exactly one of (multi_flag_one, multi_flag_two) must have a value '
+ 'other than None.')
+
+ self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
+ expected, self.flag_values, argv)
+
+ def test_one_multiflag_present(self):
+ self._mark_flags_as_mutually_exclusive(
+ ['multi_flag_one', 'multi_flag_two'], True)
+ self.flag_values(('./program', '--multi_flag_one=1'))
+ self.assertEqual(['1'], self.flag_values.multi_flag_one)
+
+ def test_one_multiflag_present_repeated(self):
+ self._mark_flags_as_mutually_exclusive(
+ ['multi_flag_one', 'multi_flag_two'], True)
+ self.flag_values(('./program', '--multi_flag_one=1', '--multi_flag_one=1b'))
+ self.assertEqual(['1', '1b'], self.flag_values.multi_flag_one)
+
+ def test_multiple_multiflags_present(self):
+ self._mark_flags_as_mutually_exclusive(
+ ['multi_flag_one', 'multi_flag_two'], False)
+ argv = ('./program', '--multi_flag_one=1', '--multi_flag_two=2')
+ expected = (
+ "flags multi_flag_one=['1'], multi_flag_two=['2']: "
+ 'At most one of (multi_flag_one, multi_flag_two) must have a value '
+ 'other than None.')
+
+ self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
+ expected, self.flag_values, argv)
+
+ def test_multiple_multiflags_present_required(self):
+ self._mark_flags_as_mutually_exclusive(
+ ['multi_flag_one', 'multi_flag_two'], True)
+ argv = ('./program', '--multi_flag_one=1', '--multi_flag_two=2')
+ expected = (
+ "flags multi_flag_one=['1'], multi_flag_two=['2']: "
+ 'Exactly one of (multi_flag_one, multi_flag_two) must have a value '
+ 'other than None.')
+
+ self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
+ expected, self.flag_values, argv)
+
+ def test_flag_default_not_none_warning(self):
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ warnings.simplefilter('always')
+ self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_not_none'],
+ False)
+ self.assertLen(caught_warnings, 1)
+ self.assertIn('--flag_not_none has a non-None default value',
+ str(caught_warnings[0].message))
+
+
+class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase):
+
+ def setUp(self):
+ super(MarkBoolFlagsAsMutualExclusiveTest, self).setUp()
+ self.flag_values = _flagvalues.FlagValues()
+
+ _defines.DEFINE_boolean(
+ 'false_1', False, 'default false 1', flag_values=self.flag_values)
+ _defines.DEFINE_boolean(
+ 'false_2', False, 'default false 2', flag_values=self.flag_values)
+ _defines.DEFINE_boolean(
+ 'true_1', True, 'default true 1', flag_values=self.flag_values)
+ _defines.DEFINE_integer(
+ 'non_bool', None, 'non bool', flag_values=self.flag_values)
+
+ def _mark_bool_flags_as_mutually_exclusive(self, flag_names, required):
+ _validators.mark_bool_flags_as_mutual_exclusive(
+ flag_names, required=required, flag_values=self.flag_values)
+
+ def test_no_flags_present(self):
+ self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], False)
+ self.flag_values(('./program',))
+ self.assertEqual(False, self.flag_values.false_1)
+ self.assertEqual(False, self.flag_values.false_2)
+
+ def test_no_flags_present_required(self):
+ self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], True)
+ argv = ('./program',)
+ expected = (
+ 'flags false_1=False, false_2=False: '
+ 'Exactly one of (false_1, false_2) must be True.')
+
+ self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
+ expected, self.flag_values, argv)
+
+ def test_no_flags_present_with_default_true_required(self):
+ self._mark_bool_flags_as_mutually_exclusive(['false_1', 'true_1'], True)
+ self.flag_values(('./program',))
+ self.assertEqual(False, self.flag_values.false_1)
+ self.assertEqual(True, self.flag_values.true_1)
+
+ def test_two_flags_true(self):
+ self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], False)
+ argv = ('./program', '--false_1', '--false_2')
+ expected = (
+ 'flags false_1=True, false_2=True: At most one of (false_1, '
+ 'false_2) must be True.')
+
+ self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
+ expected, self.flag_values, argv)
+
+ def test_non_bool_flag(self):
+ expected = ('Flag --non_bool is not Boolean, which is required for flags '
+ 'used in mark_bool_flags_as_mutual_exclusive.')
+ with self.assertRaisesWithLiteralMatch(_exceptions.ValidationError,
+ expected):
+ self._mark_bool_flags_as_mutually_exclusive(['false_1', 'non_bool'],
+ False)
+
+
+class MarkFlagAsRequiredTest(absltest.TestCase):
+
+ def setUp(self):
+ super(MarkFlagAsRequiredTest, self).setUp()
+ self.flag_values = _flagvalues.FlagValues()
+
+ def test_success(self):
+ _defines.DEFINE_string(
+ 'string_flag', None, 'string flag', flag_values=self.flag_values)
+ _validators.mark_flag_as_required(
+ 'string_flag', flag_values=self.flag_values)
+ argv = ('./program', '--string_flag=value')
+ self.flag_values(argv)
+ self.assertEqual('value', self.flag_values.string_flag)
+
+ def test_catch_none_as_default(self):
+ _defines.DEFINE_string(
+ 'string_flag', None, 'string flag', flag_values=self.flag_values)
+ _validators.mark_flag_as_required(
+ 'string_flag', flag_values=self.flag_values)
+ argv = ('./program',)
+ expected = (
+ r'flag --string_flag=None: Flag --string_flag must have a value other '
+ r'than None\.')
+ with self.assertRaisesRegex(_exceptions.IllegalFlagValueError, expected):
+ self.flag_values(argv)
+
+ def test_catch_setting_none_after_program_start(self):
+ _defines.DEFINE_string(
+ 'string_flag', 'value', 'string flag', flag_values=self.flag_values)
+ _validators.mark_flag_as_required(
+ 'string_flag', flag_values=self.flag_values)
+ argv = ('./program',)
+ self.flag_values(argv)
+ self.assertEqual('value', self.flag_values.string_flag)
+ expected = ('flag --string_flag=None: Flag --string_flag must have a value '
+ 'other than None.')
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values.string_flag = None
+ self.assertEqual(expected, str(cm.exception))
+
+ def test_flag_default_not_none_warning(self):
+ _defines.DEFINE_string(
+ 'flag_not_none', '', 'empty default', flag_values=self.flag_values)
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ warnings.simplefilter('always')
+ _validators.mark_flag_as_required(
+ 'flag_not_none', flag_values=self.flag_values)
+
+ self.assertLen(caught_warnings, 1)
+ self.assertIn('--flag_not_none has a non-None default value',
+ str(caught_warnings[0].message))
+
+
+class MarkFlagsAsRequiredTest(absltest.TestCase):
+
+ def setUp(self):
+ super(MarkFlagsAsRequiredTest, self).setUp()
+ self.flag_values = _flagvalues.FlagValues()
+
+ def test_success(self):
+ _defines.DEFINE_string(
+ 'string_flag_1', None, 'string flag 1', flag_values=self.flag_values)
+ _defines.DEFINE_string(
+ 'string_flag_2', None, 'string flag 2', flag_values=self.flag_values)
+ flag_names = ['string_flag_1', 'string_flag_2']
+ _validators.mark_flags_as_required(flag_names, flag_values=self.flag_values)
+ argv = ('./program', '--string_flag_1=value_1', '--string_flag_2=value_2')
+ self.flag_values(argv)
+ self.assertEqual('value_1', self.flag_values.string_flag_1)
+ self.assertEqual('value_2', self.flag_values.string_flag_2)
+
+ def test_catch_none_as_default(self):
+ _defines.DEFINE_string(
+ 'string_flag_1', None, 'string flag 1', flag_values=self.flag_values)
+ _defines.DEFINE_string(
+ 'string_flag_2', None, 'string flag 2', flag_values=self.flag_values)
+ _validators.mark_flags_as_required(
+ ['string_flag_1', 'string_flag_2'], flag_values=self.flag_values)
+ argv = ('./program', '--string_flag_1=value_1')
+ expected = (
+ r'flag --string_flag_2=None: Flag --string_flag_2 must have a value '
+ r'other than None\.')
+ with self.assertRaisesRegex(_exceptions.IllegalFlagValueError, expected):
+ self.flag_values(argv)
+
+ def test_catch_setting_none_after_program_start(self):
+ _defines.DEFINE_string(
+ 'string_flag_1',
+ 'value_1',
+ 'string flag 1',
+ flag_values=self.flag_values)
+ _defines.DEFINE_string(
+ 'string_flag_2',
+ 'value_2',
+ 'string flag 2',
+ flag_values=self.flag_values)
+ _validators.mark_flags_as_required(
+ ['string_flag_1', 'string_flag_2'], flag_values=self.flag_values)
+ argv = ('./program', '--string_flag_1=value_1')
+ self.flag_values(argv)
+ self.assertEqual('value_1', self.flag_values.string_flag_1)
+ expected = (
+ 'flag --string_flag_1=None: Flag --string_flag_1 must have a value '
+ 'other than None.')
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values.string_flag_1 = None
+ self.assertEqual(expected, str(cm.exception))
+
+ def test_catch_multiple_flags_as_none_at_program_start(self):
+ _defines.DEFINE_float(
+ 'float_flag_1',
+ None,
+ 'string flag 1',
+ flag_values=self.flag_values)
+ _defines.DEFINE_float(
+ 'float_flag_2',
+ None,
+ 'string flag 2',
+ flag_values=self.flag_values)
+ _validators.mark_flags_as_required(
+ ['float_flag_1', 'float_flag_2'], flag_values=self.flag_values)
+ argv = ('./program', '')
+ expected = (
+ 'flag --float_flag_1=None: Flag --float_flag_1 must have a value '
+ 'other than None.\n'
+ 'flag --float_flag_2=None: Flag --float_flag_2 must have a value '
+ 'other than None.')
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values(argv)
+ self.assertEqual(expected, str(cm.exception))
+
+ def test_fail_fast_single_flag_and_skip_remaining_validators(self):
+ def raise_unexpected_error(x):
+ del x
+ raise _exceptions.ValidationError('Should not be raised.')
+ _defines.DEFINE_float(
+ 'flag_1', None, 'flag 1', flag_values=self.flag_values)
+ _defines.DEFINE_float(
+ 'flag_2', 4.2, 'flag 2', flag_values=self.flag_values)
+ _validators.mark_flag_as_required('flag_1', flag_values=self.flag_values)
+ _validators.register_validator(
+ 'flag_1', raise_unexpected_error, flag_values=self.flag_values)
+ _validators.register_multi_flags_validator(['flag_2', 'flag_1'],
+ raise_unexpected_error,
+ flag_values=self.flag_values)
+ argv = ('./program', '')
+ expected = (
+ 'flag --flag_1=None: Flag --flag_1 must have a value other than None.')
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values(argv)
+ self.assertEqual(expected, str(cm.exception))
+
+ def test_fail_fast_multi_flag_and_skip_remaining_validators(self):
+ def raise_expected_error(x):
+ del x
+ raise _exceptions.ValidationError('Expected error.')
+ def raise_unexpected_error(x):
+ del x
+ raise _exceptions.ValidationError('Got unexpected error.')
+ _defines.DEFINE_float(
+ 'flag_1', 5.1, 'flag 1', flag_values=self.flag_values)
+ _defines.DEFINE_float(
+ 'flag_2', 10.0, 'flag 2', flag_values=self.flag_values)
+ _validators.register_multi_flags_validator(['flag_1', 'flag_2'],
+ raise_expected_error,
+ flag_values=self.flag_values)
+ _validators.register_multi_flags_validator(['flag_2', 'flag_1'],
+ raise_unexpected_error,
+ flag_values=self.flag_values)
+ _validators.register_validator(
+ 'flag_1', raise_unexpected_error, flag_values=self.flag_values)
+ _validators.register_validator(
+ 'flag_2', raise_unexpected_error, flag_values=self.flag_values)
+ argv = ('./program', '')
+ expected = ('flags flag_1=5.1, flag_2=10.0: Expected error.')
+ with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
+ self.flag_values(argv)
+ self.assertEqual(expected, str(cm.exception))
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/flags/tests/argparse_flags_test.py b/absl/flags/tests/argparse_flags_test.py
new file mode 100644
index 0000000..5e6f49a
--- /dev/null
+++ b/absl/flags/tests/argparse_flags_test.py
@@ -0,0 +1,447 @@
+# Copyright 2018 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for absl.flags.argparse_flags."""
+
+import io
+import os
+import subprocess
+import sys
+import tempfile
+from unittest import mock
+
+from absl import flags
+from absl import logging
+from absl.flags import argparse_flags
+from absl.testing import _bazelize_command
+from absl.testing import absltest
+from absl.testing import parameterized
+
+
+class ArgparseFlagsTest(parameterized.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self._absl_flags = flags.FlagValues()
+ flags.DEFINE_bool(
+ 'absl_bool', None, 'help for --absl_bool.',
+ short_name='b', flag_values=self._absl_flags)
+ # Add a boolean flag that starts with "no", to verify it can correctly
+ # handle the "no" prefixes in boolean flags.
+ flags.DEFINE_bool(
+ 'notice', None, 'help for --notice.',
+ flag_values=self._absl_flags)
+ flags.DEFINE_string(
+ 'absl_string', 'default', 'help for --absl_string=%.',
+ short_name='s', flag_values=self._absl_flags)
+ flags.DEFINE_integer(
+ 'absl_integer', 1, 'help for --absl_integer.',
+ flag_values=self._absl_flags)
+ flags.DEFINE_float(
+ 'absl_float', 1, 'help for --absl_integer.',
+ flag_values=self._absl_flags)
+ flags.DEFINE_enum(
+ 'absl_enum', 'apple', ['apple', 'orange'], 'help for --absl_enum.',
+ flag_values=self._absl_flags)
+
+ def test_dash_as_prefix_char_only(self):
+ with self.assertRaises(ValueError):
+ argparse_flags.ArgumentParser(prefix_chars='/')
+
+ def test_default_inherited_absl_flags_value(self):
+ parser = argparse_flags.ArgumentParser()
+ self.assertIs(parser._inherited_absl_flags, flags.FLAGS)
+
+ def test_parse_absl_flags(self):
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags)
+ self.assertFalse(self._absl_flags.is_parsed())
+ self.assertTrue(self._absl_flags['absl_string'].using_default_value)
+ self.assertTrue(self._absl_flags['absl_integer'].using_default_value)
+ self.assertTrue(self._absl_flags['absl_float'].using_default_value)
+ self.assertTrue(self._absl_flags['absl_enum'].using_default_value)
+
+ parser.parse_args(
+ ['--absl_string=new_string', '--absl_integer', '2'])
+ self.assertEqual(self._absl_flags.absl_string, 'new_string')
+ self.assertEqual(self._absl_flags.absl_integer, 2)
+ self.assertTrue(self._absl_flags.is_parsed())
+ self.assertFalse(self._absl_flags['absl_string'].using_default_value)
+ self.assertFalse(self._absl_flags['absl_integer'].using_default_value)
+ self.assertTrue(self._absl_flags['absl_float'].using_default_value)
+ self.assertTrue(self._absl_flags['absl_enum'].using_default_value)
+
+ @parameterized.named_parameters(
+ ('true', ['--absl_bool'], True),
+ ('false', ['--noabsl_bool'], False),
+ ('does_not_accept_equal_value', ['--absl_bool=true'], SystemExit),
+ ('does_not_accept_space_value', ['--absl_bool', 'true'], SystemExit),
+ ('long_name_single_dash', ['-absl_bool'], SystemExit),
+ ('short_name', ['-b'], True),
+ ('short_name_false', ['-nob'], SystemExit),
+ ('short_name_double_dash', ['--b'], SystemExit),
+ ('short_name_double_dash_false', ['--nob'], SystemExit),
+ )
+ def test_parse_boolean_flags(self, args, expected):
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags)
+ self.assertIsNone(self._absl_flags['absl_bool'].value)
+ self.assertIsNone(self._absl_flags['b'].value)
+ if isinstance(expected, bool):
+ parser.parse_args(args)
+ self.assertEqual(expected, self._absl_flags.absl_bool)
+ self.assertEqual(expected, self._absl_flags.b)
+ else:
+ with self.assertRaises(expected):
+ parser.parse_args(args)
+
+ @parameterized.named_parameters(
+ ('true', ['--notice'], True),
+ ('false', ['--nonotice'], False),
+ )
+ def test_parse_boolean_existing_no_prefix(self, args, expected):
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags)
+ self.assertIsNone(self._absl_flags['notice'].value)
+ parser.parse_args(args)
+ self.assertEqual(expected, self._absl_flags.notice)
+
+ def test_unrecognized_flag(self):
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags)
+ with self.assertRaises(SystemExit):
+ parser.parse_args(['--unknown_flag=what'])
+
+ def test_absl_validators(self):
+
+ @flags.validator('absl_integer', flag_values=self._absl_flags)
+ def ensure_positive(value):
+ return value > 0
+
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags)
+ with self.assertRaises(SystemExit):
+ parser.parse_args(['--absl_integer', '-2'])
+
+ del ensure_positive
+
+ @parameterized.named_parameters(
+ ('regular_name_double_dash', '--absl_string=new_string', 'new_string'),
+ ('regular_name_single_dash', '-absl_string=new_string', SystemExit),
+ ('short_name_double_dash', '--s=new_string', SystemExit),
+ ('short_name_single_dash', '-s=new_string', 'new_string'),
+ )
+ def test_dashes(self, argument, expected):
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags)
+ if isinstance(expected, str):
+ parser.parse_args([argument])
+ self.assertEqual(self._absl_flags.absl_string, expected)
+ else:
+ with self.assertRaises(expected):
+ parser.parse_args([argument])
+
+ def test_absl_flags_not_added_to_namespace(self):
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags)
+ args = parser.parse_args(['--absl_string=new_string'])
+ self.assertIsNone(getattr(args, 'absl_string', None))
+
+ def test_mixed_flags_and_positional(self):
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags)
+ parser.add_argument('--header', help='Header message to print.')
+ parser.add_argument('integers', metavar='N', type=int, nargs='+',
+ help='an integer for the accumulator')
+
+ args = parser.parse_args(
+ ['--absl_string=new_string', '--header=HEADER', '--absl_integer',
+ '2', '3', '4'])
+ self.assertEqual(self._absl_flags.absl_string, 'new_string')
+ self.assertEqual(self._absl_flags.absl_integer, 2)
+ self.assertEqual(args.header, 'HEADER')
+ self.assertListEqual(args.integers, [3, 4])
+
+ def test_subparsers(self):
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags)
+ parser.add_argument('--header', help='Header message to print.')
+ subparsers = parser.add_subparsers(help='The command to execute.')
+
+ sub_parser = subparsers.add_parser(
+ 'sub_cmd', help='Sub command.', inherited_absl_flags=self._absl_flags)
+ sub_parser.add_argument('--sub_flag', help='Sub command flag.')
+
+ def sub_command_func():
+ pass
+
+ sub_parser.set_defaults(command=sub_command_func)
+
+ args = parser.parse_args([
+ '--header=HEADER', '--absl_string=new_value', 'sub_cmd',
+ '--absl_integer=2', '--sub_flag=new_sub_flag_value'])
+
+ self.assertEqual(args.header, 'HEADER')
+ self.assertEqual(self._absl_flags.absl_string, 'new_value')
+ self.assertEqual(args.command, sub_command_func)
+ self.assertEqual(self._absl_flags.absl_integer, 2)
+ self.assertEqual(args.sub_flag, 'new_sub_flag_value')
+
+ def test_subparsers_no_inherit_in_subparser(self):
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags)
+ subparsers = parser.add_subparsers(help='The command to execute.')
+
+ subparsers.add_parser(
+ 'sub_cmd', help='Sub command.',
+ # Do not inherit absl flags in the subparser.
+ # This is the behavior that this test exercises.
+ inherited_absl_flags=None)
+
+ with self.assertRaises(SystemExit):
+ parser.parse_args(['sub_cmd', '--absl_string=new_value'])
+
+ def test_help_main_module_flags(self):
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags)
+ help_message = parser.format_help()
+
+ # Only the short name is shown in the usage string.
+ self.assertIn('[-s ABSL_STRING]', help_message)
+ # Both names are included in the options section.
+ self.assertIn('-s ABSL_STRING, --absl_string ABSL_STRING', help_message)
+ # Verify help messages.
+ self.assertIn('help for --absl_string=%.', help_message)
+ self.assertIn('<apple|orange>: help for --absl_enum.', help_message)
+
+ def test_help_non_main_module_flags(self):
+ flags.DEFINE_string(
+ 'non_main_module_flag', 'default', 'help',
+ module_name='other.module', flag_values=self._absl_flags)
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags)
+ help_message = parser.format_help()
+
+ # Non main module key flags are not printed in the help message.
+ self.assertNotIn('non_main_module_flag', help_message)
+
+ def test_help_non_main_module_key_flags(self):
+ flags.DEFINE_string(
+ 'non_main_module_flag', 'default', 'help',
+ module_name='other.module', flag_values=self._absl_flags)
+ flags.declare_key_flag('non_main_module_flag', flag_values=self._absl_flags)
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags)
+ help_message = parser.format_help()
+
+ # Main module key fags are printed in the help message, even if the flag
+ # is defined in another module.
+ self.assertIn('non_main_module_flag', help_message)
+
+ @parameterized.named_parameters(
+ ('h', ['-h']),
+ ('help', ['--help']),
+ ('helpshort', ['--helpshort']),
+ ('helpfull', ['--helpfull']),
+ )
+ def test_help_flags(self, args):
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags)
+ with self.assertRaises(SystemExit):
+ parser.parse_args(args)
+
+ @parameterized.named_parameters(
+ ('h', ['-h']),
+ ('help', ['--help']),
+ ('helpshort', ['--helpshort']),
+ ('helpfull', ['--helpfull']),
+ )
+ def test_no_help_flags(self, args):
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags, add_help=False)
+ with mock.patch.object(parser, 'print_help'):
+ with self.assertRaises(SystemExit):
+ parser.parse_args(args)
+ parser.print_help.assert_not_called()
+
+ def test_helpfull_message(self):
+ flags.DEFINE_string(
+ 'non_main_module_flag', 'default', 'help',
+ module_name='other.module', flag_values=self._absl_flags)
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags)
+ with self.assertRaises(SystemExit),\
+ mock.patch.object(sys, 'stdout', new=io.StringIO()) as mock_stdout:
+ parser.parse_args(['--helpfull'])
+ stdout_message = mock_stdout.getvalue()
+ logging.info('captured stdout message:\n%s', stdout_message)
+ self.assertIn('--non_main_module_flag', stdout_message)
+ self.assertIn('other.module', stdout_message)
+ # Make sure the main module is not included.
+ self.assertNotIn(sys.argv[0], stdout_message)
+ # Special flags defined in absl.flags.
+ self.assertIn('absl.flags:', stdout_message)
+ self.assertIn('--flagfile', stdout_message)
+ self.assertIn('--undefok', stdout_message)
+
+ @parameterized.named_parameters(
+ ('at_end',
+ ('1', '--absl_string=value_from_cmd', '--flagfile='),
+ 'value_from_file'),
+ ('at_beginning',
+ ('--flagfile=', '1', '--absl_string=value_from_cmd'),
+ 'value_from_cmd'),
+ )
+ def test_flagfile(self, cmd_args, expected_absl_string_value):
+ # Set gnu_getopt to False, to verify it's ignored by argparse_flags.
+ self._absl_flags.set_gnu_getopt(False)
+
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags)
+ parser.add_argument('--header', help='Header message to print.')
+ parser.add_argument('integers', metavar='N', type=int, nargs='+',
+ help='an integer for the accumulator')
+ flagfile = tempfile.NamedTemporaryFile(
+ dir=absltest.TEST_TMPDIR.value, delete=False)
+ self.addCleanup(os.unlink, flagfile.name)
+ with flagfile:
+ flagfile.write(b'''
+# The flag file.
+--absl_string=value_from_file
+--absl_integer=1
+--header=header_from_file
+''')
+
+ expand_flagfile = lambda x: x + flagfile.name if x == '--flagfile=' else x
+ cmd_args = [expand_flagfile(x) for x in cmd_args]
+ args = parser.parse_args(cmd_args)
+
+ self.assertEqual([1], args.integers)
+ self.assertEqual('header_from_file', args.header)
+ self.assertEqual(expected_absl_string_value, self._absl_flags.absl_string)
+
+ @parameterized.parameters(
+ ('positional', {'positional'}, False),
+ ('--not_existed', {'existed'}, False),
+ ('--empty', set(), False),
+ ('-single_dash', {'single_dash'}, True),
+ ('--double_dash', {'double_dash'}, True),
+ ('--with_value=value', {'with_value'}, True),
+ )
+ def test_is_undefok(self, arg, undefok_names, is_undefok):
+ self.assertEqual(is_undefok, argparse_flags._is_undefok(arg, undefok_names))
+
+ @parameterized.named_parameters(
+ ('single', 'single', ['--single'], []),
+ ('multiple', 'first,second', ['--first', '--second'], []),
+ ('single_dash', 'dash', ['-dash'], []),
+ ('mixed_dash', 'mixed', ['-mixed', '--mixed'], []),
+ ('value', 'name', ['--name=value'], []),
+ ('boolean_positive', 'bool', ['--bool'], []),
+ ('boolean_negative', 'bool', ['--nobool'], []),
+ ('left_over', 'strip', ['--first', '--strip', '--last'],
+ ['--first', '--last']),
+ )
+ def test_strip_undefok_args(self, undefok, args, expected_args):
+ actual_args = argparse_flags._strip_undefok_args(undefok, args)
+ self.assertListEqual(expected_args, actual_args)
+
+ @parameterized.named_parameters(
+ ('at_end', ['--unknown', '--undefok=unknown']),
+ ('at_beginning', ['--undefok=unknown', '--unknown']),
+ ('multiple', ['--unknown', '--undefok=unknown,another_unknown']),
+ ('with_value', ['--unknown=value', '--undefok=unknown']),
+ ('maybe_boolean', ['--nounknown', '--undefok=unknown']),
+ ('with_space', ['--unknown', '--undefok', 'unknown']),
+ )
+ def test_undefok_flag_correct_use(self, cmd_args):
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags)
+ args = parser.parse_args(cmd_args) # Make sure it doesn't raise.
+ # Make sure `undefok` is not exposed in namespace.
+ sentinel = object()
+ self.assertIs(sentinel, getattr(args, 'undefok', sentinel))
+
+ def test_undefok_flag_existing(self):
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags)
+ parser.parse_args(
+ ['--absl_string=new_value', '--undefok=absl_string'])
+ self.assertEqual('new_value', self._absl_flags.absl_string)
+
+ @parameterized.named_parameters(
+ ('no_equal', ['--unknown', 'value', '--undefok=unknown']),
+ ('single_dash', ['--unknown', '-undefok=unknown']),
+ )
+ def test_undefok_flag_incorrect_use(self, cmd_args):
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags)
+ with self.assertRaises(SystemExit):
+ parser.parse_args(cmd_args)
+
+ def test_argument_default(self):
+ # Regression test for https://github.com/abseil/abseil-py/issues/171.
+ parser = argparse_flags.ArgumentParser(
+ inherited_absl_flags=self._absl_flags, argument_default=23)
+ parser.add_argument(
+ '--magic_number', type=int, help='The magic number to use.')
+ args = parser.parse_args([])
+ self.assertEqual(args.magic_number, 23)
+
+
+class ArgparseWithAppRunTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ('simple',
+ 'main_simple', 'parse_flags_simple',
+ ['--argparse_echo=I am argparse.', '--absl_echo=I am absl.'],
+ ['I am argparse.', 'I am absl.']),
+ ('subcommand_roll_dice',
+ 'main_subcommands', 'parse_flags_subcommands',
+ ['--argparse_echo=I am argparse.', '--absl_echo=I am absl.',
+ 'roll_dice', '--num_faces=12'],
+ ['I am argparse.', 'I am absl.', 'Rolled a dice: ']),
+ ('subcommand_shuffle',
+ 'main_subcommands', 'parse_flags_subcommands',
+ ['--argparse_echo=I am argparse.', '--absl_echo=I am absl.',
+ 'shuffle', 'a', 'b', 'c'],
+ ['I am argparse.', 'I am absl.', 'Shuffled: ']),
+ )
+ def test_argparse_with_app_run(
+ self, main_func_name, flags_parser_func_name, args, output_strings):
+ env = os.environ.copy()
+ env['MAIN_FUNC'] = main_func_name
+ env['FLAGS_PARSER_FUNC'] = flags_parser_func_name
+ helper = _bazelize_command.get_executable_path(
+ 'absl/flags/tests/argparse_flags_test_helper')
+ try:
+ stdout = subprocess.check_output(
+ [helper] + args, env=env, universal_newlines=True)
+ except subprocess.CalledProcessError as e:
+ error_info = ('ERROR: argparse_helper failed\n'
+ 'Command: {}\n'
+ 'Exit code: {}\n'
+ '----- output -----\n{}'
+ '------------------')
+ error_info = error_info.format(e.cmd, e.returncode,
+ e.output + '\n' if e.output else '<empty>')
+ print(error_info, file=sys.stderr)
+ raise
+
+ for output_string in output_strings:
+ self.assertIn(output_string, stdout)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/flags/tests/argparse_flags_test_helper.py b/absl/flags/tests/argparse_flags_test_helper.py
new file mode 100644
index 0000000..8cf42e6
--- /dev/null
+++ b/absl/flags/tests/argparse_flags_test_helper.py
@@ -0,0 +1,89 @@
+# Copyright 2018 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test helper for argparse_flags_test."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import random
+
+from absl import app
+from absl import flags
+from absl.flags import argparse_flags
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('absl_echo', None, 'The echo message from absl.flags.')
+
+
+def parse_flags_simple(argv):
+ """Simple example for absl.flags + argparse."""
+ parser = argparse_flags.ArgumentParser(
+ description='A simple example of argparse_flags.')
+ parser.add_argument(
+ '--argparse_echo', help='The echo message from argparse_flags')
+ return parser.parse_args(argv[1:])
+
+
+def main_simple(args):
+ print('--absl_echo is', FLAGS.absl_echo)
+ print('--argparse_echo is', args.argparse_echo)
+
+
+def roll_dice(args):
+ print('Rolled a dice:', random.randint(1, args.num_faces))
+
+
+def shuffle(args):
+ inputs = list(args.inputs)
+ random.shuffle(inputs)
+ print('Shuffled:', ' '.join(inputs))
+
+
+def parse_flags_subcommands(argv):
+ """Subcommands example for absl.flags + argparse."""
+ parser = argparse_flags.ArgumentParser(
+ description='A subcommands example of argparse_flags.')
+ parser.add_argument('--argparse_echo',
+ help='The echo message from argparse_flags')
+
+ subparsers = parser.add_subparsers(help='The command to execute.')
+
+ roll_dice_parser = subparsers.add_parser(
+ 'roll_dice', help='Roll a dice.')
+ roll_dice_parser.add_argument('--num_faces', type=int, default=6)
+ roll_dice_parser.set_defaults(command=roll_dice)
+
+ shuffle_parser = subparsers.add_parser(
+ 'shuffle', help='Shuffle inputs.')
+ shuffle_parser.add_argument(
+ 'inputs', metavar='I', nargs='+', help='Inputs to shuffle.')
+ shuffle_parser.set_defaults(command=shuffle)
+
+ return parser.parse_args(argv[1:])
+
+
+def main_subcommands(args):
+ main_simple(args)
+ args.command(args)
+
+
+if __name__ == '__main__':
+ main_func_name = os.environ['MAIN_FUNC']
+ flags_parser_func_name = os.environ['FLAGS_PARSER_FUNC']
+ app.run(main=globals()[main_func_name],
+ flags_parser=globals()[flags_parser_func_name])
diff --git a/absl/flags/tests/flags_formatting_test.py b/absl/flags/tests/flags_formatting_test.py
new file mode 100644
index 0000000..bb547ce
--- /dev/null
+++ b/absl/flags/tests/flags_formatting_test.py
@@ -0,0 +1,217 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import flags
+from absl.flags import _helpers
+from absl.testing import absltest
+
+FLAGS = flags.FLAGS
+
+
+class FlagsUnitTest(absltest.TestCase):
+ """Flags formatting Unit Test."""
+
+ def test_get_help_width(self):
+ """Verify that get_help_width() reflects _help_width."""
+ default_help_width = _helpers._DEFAULT_HELP_WIDTH # Save.
+ self.assertEqual(80, _helpers._DEFAULT_HELP_WIDTH)
+ self.assertEqual(_helpers._DEFAULT_HELP_WIDTH, flags.get_help_width())
+ _helpers._DEFAULT_HELP_WIDTH = 10
+ self.assertEqual(_helpers._DEFAULT_HELP_WIDTH, flags.get_help_width())
+ _helpers._DEFAULT_HELP_WIDTH = default_help_width # restore
+
+ def test_text_wrap(self):
+ """Test that wrapping works as expected.
+
+ Also tests that it is using global flags._help_width by default.
+ """
+ default_help_width = _helpers._DEFAULT_HELP_WIDTH
+ _helpers._DEFAULT_HELP_WIDTH = 10
+
+ # Generate a string with length 40, no spaces
+ text = ''
+ expect = []
+ for n in range(4):
+ line = str(n)
+ line += '123456789'
+ text += line
+ expect.append(line)
+
+ # Verify we still break
+ wrapped = flags.text_wrap(text).split('\n')
+ self.assertEqual(4, len(wrapped))
+ self.assertEqual(expect, wrapped)
+
+ wrapped = flags.text_wrap(text, 80).split('\n')
+ self.assertEqual(1, len(wrapped))
+ self.assertEqual([text], wrapped)
+
+ # Normal case, breaking at word boundaries and rewriting new lines
+ input_value = 'a b c d e f g h'
+ expect = {1: ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'],
+ 2: ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'],
+ 3: ['a b', 'c d', 'e f', 'g h'],
+ 4: ['a b', 'c d', 'e f', 'g h'],
+ 5: ['a b c', 'd e f', 'g h'],
+ 6: ['a b c', 'd e f', 'g h'],
+ 7: ['a b c d', 'e f g h'],
+ 8: ['a b c d', 'e f g h'],
+ 9: ['a b c d e', 'f g h'],
+ 10: ['a b c d e', 'f g h'],
+ 11: ['a b c d e f', 'g h'],
+ 12: ['a b c d e f', 'g h'],
+ 13: ['a b c d e f g', 'h'],
+ 14: ['a b c d e f g', 'h'],
+ 15: ['a b c d e f g h']}
+ for width, exp in expect.items():
+ self.assertEqual(exp, flags.text_wrap(input_value, width).split('\n'))
+
+ # We turn lines with only whitespace into empty lines
+ # We strip from the right up to the first new line
+ self.assertEqual('', flags.text_wrap(' '))
+ self.assertEqual('\n', flags.text_wrap(' \n '))
+ self.assertEqual('\n', flags.text_wrap('\n\n'))
+ self.assertEqual('\n\n', flags.text_wrap('\n\n\n'))
+ self.assertEqual('\n', flags.text_wrap('\n '))
+ self.assertEqual('a\n\nb', flags.text_wrap('a\n \nb'))
+ self.assertEqual('a\n\n\nb', flags.text_wrap('a\n \n \nb'))
+ self.assertEqual('a\nb', flags.text_wrap(' a\nb '))
+ self.assertEqual('\na\nb', flags.text_wrap('\na\nb\n'))
+ self.assertEqual('\na\nb\n', flags.text_wrap(' \na\nb\n '))
+ self.assertEqual('\na\nb\n', flags.text_wrap(' \na\nb\n\n'))
+
+ # Double newline.
+ self.assertEqual('a\n\nb', flags.text_wrap(' a\n\n b'))
+
+ # We respect prefix
+ self.assertEqual(' a\n b\n c', flags.text_wrap('a\nb\nc', 80, ' '))
+ self.assertEqual('a\n b\n c', flags.text_wrap('a\nb\nc', 80, ' ', ''))
+
+ # tabs
+ self.assertEqual('a\n b c',
+ flags.text_wrap('a\nb\tc', 80, ' ', ''))
+ self.assertEqual('a\n bb c',
+ flags.text_wrap('a\nbb\tc', 80, ' ', ''))
+ self.assertEqual('a\n bbb c',
+ flags.text_wrap('a\nbbb\tc', 80, ' ', ''))
+ self.assertEqual('a\n bbbb c',
+ flags.text_wrap('a\nbbbb\tc', 80, ' ', ''))
+ self.assertEqual('a\n b\n c\n d',
+ flags.text_wrap('a\nb\tc\td', 3, ' ', ''))
+ self.assertEqual('a\n b\n c\n d',
+ flags.text_wrap('a\nb\tc\td', 4, ' ', ''))
+ self.assertEqual('a\n b\n c\n d',
+ flags.text_wrap('a\nb\tc\td', 5, ' ', ''))
+ self.assertEqual('a\n b c\n d',
+ flags.text_wrap('a\nb\tc\td', 6, ' ', ''))
+ self.assertEqual('a\n b c\n d',
+ flags.text_wrap('a\nb\tc\td', 7, ' ', ''))
+ self.assertEqual('a\n b c\n d',
+ flags.text_wrap('a\nb\tc\td', 8, ' ', ''))
+ self.assertEqual('a\n b c\n d',
+ flags.text_wrap('a\nb\tc\td', 9, ' ', ''))
+ self.assertEqual('a\n b c d',
+ flags.text_wrap('a\nb\tc\td', 10, ' ', ''))
+
+ # multiple tabs
+ self.assertEqual('a c',
+ flags.text_wrap('a\t\tc', 80, ' ', ''))
+
+ _helpers._DEFAULT_HELP_WIDTH = default_help_width # restore
+
+ def test_doc_to_help(self):
+ self.assertEqual('', flags.doc_to_help(' '))
+ self.assertEqual('', flags.doc_to_help(' \n '))
+ self.assertEqual('a\n\nb', flags.doc_to_help('a\n \nb'))
+ self.assertEqual('a\n\n\nb', flags.doc_to_help('a\n \n \nb'))
+ self.assertEqual('a b', flags.doc_to_help(' a\nb '))
+ self.assertEqual('a b', flags.doc_to_help('\na\nb\n'))
+ self.assertEqual('a\n\nb', flags.doc_to_help('\na\n\nb\n'))
+ self.assertEqual('a b', flags.doc_to_help(' \na\nb\n '))
+ # Different first line, one line empty - erm double new line.
+ self.assertEqual('a b c\n\nd', flags.doc_to_help('a\n b\n c\n\n d'))
+ self.assertEqual('a b\n c d', flags.doc_to_help('a\n b\n \tc\n d'))
+ self.assertEqual('a b\n c\n d',
+ flags.doc_to_help('a\n b\n \tc\n \td'))
+
+ def test_doc_to_help_flag_values(self):
+ # !!!!!!!!!!!!!!!!!!!!
+ # The following doc string is taken as is directly from flags.py:FlagValues
+ # The intention of this test is to verify 'live' performance
+ # !!!!!!!!!!!!!!!!!!!!
+ """Used as a registry for 'Flag' objects.
+
+ A 'FlagValues' can then scan command line arguments, passing flag
+ arguments through to the 'Flag' objects that it owns. It also
+ provides easy access to the flag values. Typically only one
+ 'FlagValues' object is needed by an application: flags.FLAGS
+
+ This class is heavily overloaded:
+
+ 'Flag' objects are registered via __setitem__:
+ FLAGS['longname'] = x # register a new flag
+
+ The .value member of the registered 'Flag' objects can be accessed as
+ members of this 'FlagValues' object, through __getattr__. Both the
+ long and short name of the original 'Flag' objects can be used to
+ access its value:
+ FLAGS.longname # parsed flag value
+ FLAGS.x # parsed flag value (short name)
+
+ Command line arguments are scanned and passed to the registered 'Flag'
+ objects through the __call__ method. Unparsed arguments, including
+ argv[0] (e.g. the program name) are returned.
+ argv = FLAGS(sys.argv) # scan command line arguments
+
+ The original registered Flag objects can be retrieved through the use
+ """
+ doc = flags.doc_to_help(self.test_doc_to_help_flag_values.__doc__)
+ # Test the general outline of the converted docs
+ lines = doc.splitlines()
+ self.assertEqual(17, len(lines))
+ empty_lines = [index for index in range(len(lines)) if not lines[index]]
+ self.assertEqual([1, 3, 5, 8, 12, 15], empty_lines)
+ # test that some starting prefix is kept
+ flags_lines = [index for index in range(len(lines))
+ if lines[index].startswith(' FLAGS')]
+ self.assertEqual([7, 10, 11], flags_lines)
+ # but other, especially common space has been removed
+ space_lines = [index for index in range(len(lines))
+ if lines[index] and lines[index][0].isspace()]
+ self.assertEqual([7, 10, 11, 14], space_lines)
+ # No right space was kept
+ rspace_lines = [index for index in range(len(lines))
+ if lines[index] != lines[index].rstrip()]
+ self.assertEqual([], rspace_lines)
+ # test double spaces are kept
+ self.assertEqual(True, lines[2].endswith('application: flags.FLAGS'))
+
+ def test_text_wrap_raises_on_excessive_indent(self):
+ """Ensure an indent longer than line length raises."""
+ self.assertRaises(ValueError,
+ flags.text_wrap, 'dummy', length=10, indent=' ' * 10)
+
+ def test_text_wrap_raises_on_excessive_first_line(self):
+ """Ensure a first line indent longer than line length raises."""
+ self.assertRaises(
+ ValueError,
+ flags.text_wrap, 'dummy', length=80, firstline_indent=' ' * 80)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/flags/tests/flags_helpxml_test.py b/absl/flags/tests/flags_helpxml_test.py
new file mode 100644
index 0000000..e2168ac
--- /dev/null
+++ b/absl/flags/tests/flags_helpxml_test.py
@@ -0,0 +1,659 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for the XML-format help generated by the flags.py module."""
+
+import enum
+import io
+import os
+import string
+import sys
+import xml.dom.minidom
+import xml.sax.saxutils
+
+from absl import flags
+from absl.flags import _helpers
+from absl.flags.tests import module_bar
+from absl.testing import absltest
+
+
+class CreateXMLDOMElement(absltest.TestCase):
+
+ def _check(self, name, value, expected_output):
+ doc = xml.dom.minidom.Document()
+ node = _helpers.create_xml_dom_element(doc, name, value)
+ output = node.toprettyxml(' ', encoding='utf-8')
+ self.assertEqual(expected_output, output)
+
+ def test_create_xml_dom_element(self):
+ self._check('tag', '', b'<tag></tag>\n')
+ self._check('tag', 'plain text', b'<tag>plain text</tag>\n')
+ self._check('tag', '(x < y) && (a >= b)',
+ b'<tag>(x &lt; y) &amp;&amp; (a &gt;= b)</tag>\n')
+
+ # If the value is bytes with invalid unicode:
+ bytes_with_invalid_unicodes = b'\x81\xff'
+ # In python 3 the string representation is "b'\x81\xff'" so they are kept
+ # as "b'\x81\xff'".
+ self._check('tag', bytes_with_invalid_unicodes,
+ b"<tag>b'\\x81\\xff'</tag>\n")
+
+ # Some unicode chars are illegal in xml
+ # (http://www.w3.org/TR/REC-xml/#charsets):
+ self._check('tag', u'\x0b\x02\x08\ufffe', b'<tag></tag>\n')
+
+ # Valid unicode will be encoded:
+ self._check('tag', u'\xff', b'<tag>\xc3\xbf</tag>\n')
+
+
+def _list_separators_in_xmlformat(separators, indent=''):
+ """Generates XML encoding of a list of list separators.
+
+ Args:
+ separators: A list of list separators. Usually, this should be a
+ string whose characters are the valid list separators, e.g., ','
+ means that both comma (',') and space (' ') are valid list
+ separators.
+ indent: A string that is added at the beginning of each generated
+ XML element.
+
+ Returns:
+ A string.
+ """
+ result = ''
+ separators = list(separators)
+ separators.sort()
+ for sep_char in separators:
+ result += ('%s<list_separator>%s</list_separator>\n' %
+ (indent, repr(sep_char)))
+ return result
+
+
+class FlagCreateXMLDOMElement(absltest.TestCase):
+ """Test the create_xml_dom_element method for a single flag at a time.
+
+ There is one test* method for each kind of DEFINE_* declaration.
+ """
+
+ def setUp(self):
+ # self.fv is a FlagValues object, just like flags.FLAGS. Each
+ # test registers one flag with this FlagValues.
+ self.fv = flags.FlagValues()
+
+ def _check_flag_help_in_xml(self, flag_name, module_name,
+ expected_output, is_key=False):
+ flag_obj = self.fv[flag_name]
+ doc = xml.dom.minidom.Document()
+ element = flag_obj._create_xml_dom_element(doc, module_name, is_key=is_key)
+ output = element.toprettyxml(indent=' ')
+ self.assertMultiLineEqual(expected_output, output)
+
+ def test_flag_help_in_xml_int(self):
+ flags.DEFINE_integer('index', 17, 'An integer flag', flag_values=self.fv)
+ expected_output_pattern = (
+ '<flag>\n'
+ ' <file>module.name</file>\n'
+ ' <name>index</name>\n'
+ ' <meaning>An integer flag</meaning>\n'
+ ' <default>17</default>\n'
+ ' <current>%d</current>\n'
+ ' <type>int</type>\n'
+ '</flag>\n')
+ self._check_flag_help_in_xml('index', 'module.name',
+ expected_output_pattern % 17)
+ # Check that the output is correct even when the current value of
+ # a flag is different from the default one.
+ self.fv['index'].value = 20
+ self._check_flag_help_in_xml('index', 'module.name',
+ expected_output_pattern % 20)
+
+ def test_flag_help_in_xml_int_with_bounds(self):
+ flags.DEFINE_integer('nb_iters', 17, 'An integer flag',
+ lower_bound=5, upper_bound=27,
+ flag_values=self.fv)
+ expected_output = (
+ '<flag>\n'
+ ' <key>yes</key>\n'
+ ' <file>module.name</file>\n'
+ ' <name>nb_iters</name>\n'
+ ' <meaning>An integer flag</meaning>\n'
+ ' <default>17</default>\n'
+ ' <current>17</current>\n'
+ ' <type>int</type>\n'
+ ' <lower_bound>5</lower_bound>\n'
+ ' <upper_bound>27</upper_bound>\n'
+ '</flag>\n')
+ self._check_flag_help_in_xml('nb_iters', 'module.name', expected_output,
+ is_key=True)
+
+ def test_flag_help_in_xml_string(self):
+ flags.DEFINE_string('file_path', '/path/to/my/dir', 'A test string flag.',
+ flag_values=self.fv)
+ expected_output = (
+ '<flag>\n'
+ ' <file>simple_module</file>\n'
+ ' <name>file_path</name>\n'
+ ' <meaning>A test string flag.</meaning>\n'
+ ' <default>/path/to/my/dir</default>\n'
+ ' <current>/path/to/my/dir</current>\n'
+ ' <type>string</type>\n'
+ '</flag>\n')
+ self._check_flag_help_in_xml('file_path', 'simple_module', expected_output)
+
+ def test_flag_help_in_xml_string_with_xmlillegal_chars(self):
+ flags.DEFINE_string('file_path', '/path/to/\x08my/dir',
+ 'A test string flag.', flag_values=self.fv)
+ # '\x08' is not a legal character in XML 1.0 documents. Our
+ # current code purges such characters from the generated XML.
+ expected_output = (
+ '<flag>\n'
+ ' <file>simple_module</file>\n'
+ ' <name>file_path</name>\n'
+ ' <meaning>A test string flag.</meaning>\n'
+ ' <default>/path/to/my/dir</default>\n'
+ ' <current>/path/to/my/dir</current>\n'
+ ' <type>string</type>\n'
+ '</flag>\n')
+ self._check_flag_help_in_xml('file_path', 'simple_module', expected_output)
+
+ def test_flag_help_in_xml_boolean(self):
+ flags.DEFINE_boolean('use_gpu', False, 'Use gpu for performance.',
+ flag_values=self.fv)
+ expected_output = (
+ '<flag>\n'
+ ' <key>yes</key>\n'
+ ' <file>a_module</file>\n'
+ ' <name>use_gpu</name>\n'
+ ' <meaning>Use gpu for performance.</meaning>\n'
+ ' <default>false</default>\n'
+ ' <current>false</current>\n'
+ ' <type>bool</type>\n'
+ '</flag>\n')
+ self._check_flag_help_in_xml('use_gpu', 'a_module', expected_output,
+ is_key=True)
+
+ def test_flag_help_in_xml_enum(self):
+ flags.DEFINE_enum('cc_version', 'stable', ['stable', 'experimental'],
+ 'Compiler version to use.', flag_values=self.fv)
+ expected_output = (
+ '<flag>\n'
+ ' <file>tool</file>\n'
+ ' <name>cc_version</name>\n'
+ ' <meaning>&lt;stable|experimental&gt;: '
+ 'Compiler version to use.</meaning>\n'
+ ' <default>stable</default>\n'
+ ' <current>stable</current>\n'
+ ' <type>string enum</type>\n'
+ ' <enum_value>stable</enum_value>\n'
+ ' <enum_value>experimental</enum_value>\n'
+ '</flag>\n')
+ self._check_flag_help_in_xml('cc_version', 'tool', expected_output)
+
+ def test_flag_help_in_xml_enum_class(self):
+ class Version(enum.Enum):
+ STABLE = 0
+ EXPERIMENTAL = 1
+
+ flags.DEFINE_enum_class('cc_version', 'STABLE', Version,
+ 'Compiler version to use.', flag_values=self.fv)
+ expected_output = ('<flag>\n'
+ ' <file>tool</file>\n'
+ ' <name>cc_version</name>\n'
+ ' <meaning>&lt;stable|experimental&gt;: '
+ 'Compiler version to use.</meaning>\n'
+ ' <default>stable</default>\n'
+ ' <current>Version.STABLE</current>\n'
+ ' <type>enum class</type>\n'
+ ' <enum_value>STABLE</enum_value>\n'
+ ' <enum_value>EXPERIMENTAL</enum_value>\n'
+ '</flag>\n')
+ self._check_flag_help_in_xml('cc_version', 'tool', expected_output)
+
+ def test_flag_help_in_xml_comma_separated_list(self):
+ flags.DEFINE_list('files', 'a.cc,a.h,archive/old.zip',
+ 'Files to process.', flag_values=self.fv)
+ expected_output = (
+ '<flag>\n'
+ ' <file>tool</file>\n'
+ ' <name>files</name>\n'
+ ' <meaning>Files to process.</meaning>\n'
+ ' <default>a.cc,a.h,archive/old.zip</default>\n'
+ ' <current>[\'a.cc\', \'a.h\', \'archive/old.zip\']</current>\n'
+ ' <type>comma separated list of strings</type>\n'
+ ' <list_separator>\',\'</list_separator>\n'
+ '</flag>\n')
+ self._check_flag_help_in_xml('files', 'tool', expected_output)
+
+ def test_list_as_default_argument_comma_separated_list(self):
+ flags.DEFINE_list('allow_users', ['alice', 'bob'],
+ 'Users with access.', flag_values=self.fv)
+ expected_output = (
+ '<flag>\n'
+ ' <file>tool</file>\n'
+ ' <name>allow_users</name>\n'
+ ' <meaning>Users with access.</meaning>\n'
+ ' <default>alice,bob</default>\n'
+ ' <current>[\'alice\', \'bob\']</current>\n'
+ ' <type>comma separated list of strings</type>\n'
+ ' <list_separator>\',\'</list_separator>\n'
+ '</flag>\n')
+ self._check_flag_help_in_xml('allow_users', 'tool', expected_output)
+
+ def test_none_as_default_arguments_comma_separated_list(self):
+ flags.DEFINE_list('allow_users', None,
+ 'Users with access.', flag_values=self.fv)
+ expected_output = (
+ '<flag>\n'
+ ' <file>tool</file>\n'
+ ' <name>allow_users</name>\n'
+ ' <meaning>Users with access.</meaning>\n'
+ ' <default></default>\n'
+ ' <current>None</current>\n'
+ ' <type>comma separated list of strings</type>\n'
+ ' <list_separator>\',\'</list_separator>\n'
+ '</flag>\n')
+ self._check_flag_help_in_xml('allow_users', 'tool', expected_output)
+
+ def test_flag_help_in_xml_space_separated_list(self):
+ flags.DEFINE_spaceseplist('dirs', 'src libs bin',
+ 'Directories to search.', flag_values=self.fv)
+ expected_separators = sorted(string.whitespace)
+ expected_output = (
+ '<flag>\n'
+ ' <file>tool</file>\n'
+ ' <name>dirs</name>\n'
+ ' <meaning>Directories to search.</meaning>\n'
+ ' <default>src libs bin</default>\n'
+ ' <current>[\'src\', \'libs\', \'bin\']</current>\n'
+ ' <type>whitespace separated list of strings</type>\n'
+ 'LIST_SEPARATORS'
+ '</flag>\n').replace('LIST_SEPARATORS',
+ _list_separators_in_xmlformat(expected_separators,
+ indent=' '))
+ self._check_flag_help_in_xml('dirs', 'tool', expected_output)
+
+ def test_flag_help_in_xml_space_separated_list_with_comma_compat(self):
+ flags.DEFINE_spaceseplist('dirs', 'src libs,bin',
+ 'Directories to search.', comma_compat=True,
+ flag_values=self.fv)
+ expected_separators = sorted(string.whitespace + ',')
+ expected_output = (
+ '<flag>\n'
+ ' <file>tool</file>\n'
+ ' <name>dirs</name>\n'
+ ' <meaning>Directories to search.</meaning>\n'
+ ' <default>src libs bin</default>\n'
+ ' <current>[\'src\', \'libs\', \'bin\']</current>\n'
+ ' <type>whitespace or comma separated list of strings</type>\n'
+ 'LIST_SEPARATORS'
+ '</flag>\n').replace('LIST_SEPARATORS',
+ _list_separators_in_xmlformat(expected_separators,
+ indent=' '))
+ self._check_flag_help_in_xml('dirs', 'tool', expected_output)
+
+ def test_flag_help_in_xml_multi_string(self):
+ flags.DEFINE_multi_string('to_delete', ['a.cc', 'b.h'],
+ 'Files to delete', flag_values=self.fv)
+ expected_output = (
+ '<flag>\n'
+ ' <file>tool</file>\n'
+ ' <name>to_delete</name>\n'
+ ' <meaning>Files to delete;\n'
+ ' repeat this option to specify a list of values</meaning>\n'
+ ' <default>[\'a.cc\', \'b.h\']</default>\n'
+ ' <current>[\'a.cc\', \'b.h\']</current>\n'
+ ' <type>multi string</type>\n'
+ '</flag>\n')
+ self._check_flag_help_in_xml('to_delete', 'tool', expected_output)
+
+ def test_flag_help_in_xml_multi_int(self):
+ flags.DEFINE_multi_integer('cols', [5, 7, 23],
+ 'Columns to select', flag_values=self.fv)
+ expected_output = (
+ '<flag>\n'
+ ' <file>tool</file>\n'
+ ' <name>cols</name>\n'
+ ' <meaning>Columns to select;\n '
+ 'repeat this option to specify a list of values</meaning>\n'
+ ' <default>[5, 7, 23]</default>\n'
+ ' <current>[5, 7, 23]</current>\n'
+ ' <type>multi int</type>\n'
+ '</flag>\n')
+ self._check_flag_help_in_xml('cols', 'tool', expected_output)
+
+ def test_flag_help_in_xml_multi_enum(self):
+ flags.DEFINE_multi_enum('flavours', ['APPLE', 'BANANA'],
+ ['APPLE', 'BANANA', 'CHERRY'],
+ 'Compilation flavour.', flag_values=self.fv)
+ expected_output = (
+ '<flag>\n'
+ ' <file>tool</file>\n'
+ ' <name>flavours</name>\n'
+ ' <meaning>&lt;APPLE|BANANA|CHERRY&gt;: Compilation flavour.;\n'
+ ' repeat this option to specify a list of values</meaning>\n'
+ ' <default>[\'APPLE\', \'BANANA\']</default>\n'
+ ' <current>[\'APPLE\', \'BANANA\']</current>\n'
+ ' <type>multi string enum</type>\n'
+ ' <enum_value>APPLE</enum_value>\n'
+ ' <enum_value>BANANA</enum_value>\n'
+ ' <enum_value>CHERRY</enum_value>\n'
+ '</flag>\n')
+ self._check_flag_help_in_xml('flavours', 'tool', expected_output)
+
+ def test_flag_help_in_xml_multi_enum_class_singleton_default(self):
+ class Fruit(enum.Enum):
+ ORANGE = 0
+ BANANA = 1
+
+ flags.DEFINE_multi_enum_class('fruit', ['ORANGE'],
+ Fruit,
+ 'The fruit flag.', flag_values=self.fv)
+ expected_output = (
+ '<flag>\n'
+ ' <file>tool</file>\n'
+ ' <name>fruit</name>\n'
+ ' <meaning>&lt;orange|banana&gt;: The fruit flag.;\n'
+ ' repeat this option to specify a list of values</meaning>\n'
+ ' <default>orange</default>\n'
+ ' <current>orange</current>\n'
+ ' <type>multi enum class</type>\n'
+ ' <enum_value>ORANGE</enum_value>\n'
+ ' <enum_value>BANANA</enum_value>\n'
+ '</flag>\n')
+ self._check_flag_help_in_xml('fruit', 'tool', expected_output)
+
+ def test_flag_help_in_xml_multi_enum_class_list_default(self):
+ class Fruit(enum.Enum):
+ ORANGE = 0
+ BANANA = 1
+
+ flags.DEFINE_multi_enum_class('fruit', ['ORANGE', 'BANANA'],
+ Fruit,
+ 'The fruit flag.', flag_values=self.fv)
+ expected_output = (
+ '<flag>\n'
+ ' <file>tool</file>\n'
+ ' <name>fruit</name>\n'
+ ' <meaning>&lt;orange|banana&gt;: The fruit flag.;\n'
+ ' repeat this option to specify a list of values</meaning>\n'
+ ' <default>orange,banana</default>\n'
+ ' <current>orange,banana</current>\n'
+ ' <type>multi enum class</type>\n'
+ ' <enum_value>ORANGE</enum_value>\n'
+ ' <enum_value>BANANA</enum_value>\n'
+ '</flag>\n')
+ self._check_flag_help_in_xml('fruit', 'tool', expected_output)
+
+# The next EXPECTED_HELP_XML_* constants are parts of a template for
+# the expected XML output from WriteHelpInXMLFormatTest below. When
+# we assemble these parts into a single big string, we'll take into
+# account the ordering between the name of the main module and the
+# name of module_bar. Next, we'll fill in the docstring for this
+# module (%(usage_doc)s), the name of the main module
+# (%(main_module_name)s) and the name of the module module_bar
+# (%(module_bar_name)s). See WriteHelpInXMLFormatTest below.
+EXPECTED_HELP_XML_START = """\
+<?xml version="1.0" encoding="utf-8"?>
+<AllFlags>
+ <program>%(basename_of_argv0)s</program>
+ <usage>%(usage_doc)s</usage>
+"""
+
+EXPECTED_HELP_XML_FOR_FLAGS_FROM_MAIN_MODULE = """\
+ <flag>
+ <key>yes</key>
+ <file>%(main_module_name)s</file>
+ <name>allow_users</name>
+ <meaning>Users with access.</meaning>
+ <default>alice,bob</default>
+ <current>['alice', 'bob']</current>
+ <type>comma separated list of strings</type>
+ <list_separator>','</list_separator>
+ </flag>
+ <flag>
+ <key>yes</key>
+ <file>%(main_module_name)s</file>
+ <name>cc_version</name>
+ <meaning>&lt;stable|experimental&gt;: Compiler version to use.</meaning>
+ <default>stable</default>
+ <current>stable</current>
+ <type>string enum</type>
+ <enum_value>stable</enum_value>
+ <enum_value>experimental</enum_value>
+ </flag>
+ <flag>
+ <key>yes</key>
+ <file>%(main_module_name)s</file>
+ <name>cols</name>
+ <meaning>Columns to select;
+ repeat this option to specify a list of values</meaning>
+ <default>[5, 7, 23]</default>
+ <current>[5, 7, 23]</current>
+ <type>multi int</type>
+ </flag>
+ <flag>
+ <key>yes</key>
+ <file>%(main_module_name)s</file>
+ <name>dirs</name>
+ <meaning>Directories to create.</meaning>
+ <default>src libs bins</default>
+ <current>['src', 'libs', 'bins']</current>
+ <type>whitespace separated list of strings</type>
+%(whitespace_separators)s </flag>
+ <flag>
+ <key>yes</key>
+ <file>%(main_module_name)s</file>
+ <name>file_path</name>
+ <meaning>A test string flag.</meaning>
+ <default>/path/to/my/dir</default>
+ <current>/path/to/my/dir</current>
+ <type>string</type>
+ </flag>
+ <flag>
+ <key>yes</key>
+ <file>%(main_module_name)s</file>
+ <name>files</name>
+ <meaning>Files to process.</meaning>
+ <default>a.cc,a.h,archive/old.zip</default>
+ <current>['a.cc', 'a.h', 'archive/old.zip']</current>
+ <type>comma separated list of strings</type>
+ <list_separator>\',\'</list_separator>
+ </flag>
+ <flag>
+ <key>yes</key>
+ <file>%(main_module_name)s</file>
+ <name>flavours</name>
+ <meaning>&lt;APPLE|BANANA|CHERRY&gt;: Compilation flavour.;
+ repeat this option to specify a list of values</meaning>
+ <default>['APPLE', 'BANANA']</default>
+ <current>['APPLE', 'BANANA']</current>
+ <type>multi string enum</type>
+ <enum_value>APPLE</enum_value>
+ <enum_value>BANANA</enum_value>
+ <enum_value>CHERRY</enum_value>
+ </flag>
+ <flag>
+ <key>yes</key>
+ <file>%(main_module_name)s</file>
+ <name>index</name>
+ <meaning>An integer flag</meaning>
+ <default>17</default>
+ <current>17</current>
+ <type>int</type>
+ </flag>
+ <flag>
+ <key>yes</key>
+ <file>%(main_module_name)s</file>
+ <name>nb_iters</name>
+ <meaning>An integer flag</meaning>
+ <default>17</default>
+ <current>17</current>
+ <type>int</type>
+ <lower_bound>5</lower_bound>
+ <upper_bound>27</upper_bound>
+ </flag>
+ <flag>
+ <key>yes</key>
+ <file>%(main_module_name)s</file>
+ <name>to_delete</name>
+ <meaning>Files to delete;
+ repeat this option to specify a list of values</meaning>
+ <default>['a.cc', 'b.h']</default>
+ <current>['a.cc', 'b.h']</current>
+ <type>multi string</type>
+ </flag>
+ <flag>
+ <key>yes</key>
+ <file>%(main_module_name)s</file>
+ <name>use_gpu</name>
+ <meaning>Use gpu for performance.</meaning>
+ <default>false</default>
+ <current>false</current>
+ <type>bool</type>
+ </flag>
+"""
+
+EXPECTED_HELP_XML_FOR_FLAGS_FROM_MODULE_BAR = """\
+ <flag>
+ <file>%(module_bar_name)s</file>
+ <name>tmod_bar_t</name>
+ <meaning>Sample int flag.</meaning>
+ <default>4</default>
+ <current>4</current>
+ <type>int</type>
+ </flag>
+ <flag>
+ <key>yes</key>
+ <file>%(module_bar_name)s</file>
+ <name>tmod_bar_u</name>
+ <meaning>Sample int flag.</meaning>
+ <default>5</default>
+ <current>5</current>
+ <type>int</type>
+ </flag>
+ <flag>
+ <file>%(module_bar_name)s</file>
+ <name>tmod_bar_v</name>
+ <meaning>Sample int flag.</meaning>
+ <default>6</default>
+ <current>6</current>
+ <type>int</type>
+ </flag>
+ <flag>
+ <file>%(module_bar_name)s</file>
+ <name>tmod_bar_x</name>
+ <meaning>Boolean flag.</meaning>
+ <default>true</default>
+ <current>true</current>
+ <type>bool</type>
+ </flag>
+ <flag>
+ <file>%(module_bar_name)s</file>
+ <name>tmod_bar_y</name>
+ <meaning>String flag.</meaning>
+ <default>default</default>
+ <current>default</current>
+ <type>string</type>
+ </flag>
+ <flag>
+ <key>yes</key>
+ <file>%(module_bar_name)s</file>
+ <name>tmod_bar_z</name>
+ <meaning>Another boolean flag from module bar.</meaning>
+ <default>false</default>
+ <current>false</current>
+ <type>bool</type>
+ </flag>
+"""
+
+EXPECTED_HELP_XML_END = """\
+</AllFlags>
+"""
+
+
+class WriteHelpInXMLFormatTest(absltest.TestCase):
+ """Big test of FlagValues.write_help_in_xml_format, with several flags."""
+
+ def test_write_help_in_xmlformat(self):
+ fv = flags.FlagValues()
+ # Since these flags are defined by the top module, they are all key.
+ flags.DEFINE_integer('index', 17, 'An integer flag', flag_values=fv)
+ flags.DEFINE_integer('nb_iters', 17, 'An integer flag',
+ lower_bound=5, upper_bound=27, flag_values=fv)
+ flags.DEFINE_string('file_path', '/path/to/my/dir', 'A test string flag.',
+ flag_values=fv)
+ flags.DEFINE_boolean('use_gpu', False, 'Use gpu for performance.',
+ flag_values=fv)
+ flags.DEFINE_enum('cc_version', 'stable', ['stable', 'experimental'],
+ 'Compiler version to use.', flag_values=fv)
+ flags.DEFINE_list('files', 'a.cc,a.h,archive/old.zip',
+ 'Files to process.', flag_values=fv)
+ flags.DEFINE_list('allow_users', ['alice', 'bob'],
+ 'Users with access.', flag_values=fv)
+ flags.DEFINE_spaceseplist('dirs', 'src libs bins',
+ 'Directories to create.', flag_values=fv)
+ flags.DEFINE_multi_string('to_delete', ['a.cc', 'b.h'],
+ 'Files to delete', flag_values=fv)
+ flags.DEFINE_multi_integer('cols', [5, 7, 23],
+ 'Columns to select', flag_values=fv)
+ flags.DEFINE_multi_enum('flavours', ['APPLE', 'BANANA'],
+ ['APPLE', 'BANANA', 'CHERRY'],
+ 'Compilation flavour.', flag_values=fv)
+ # Define a few flags in a different module.
+ module_bar.define_flags(flag_values=fv)
+ # And declare only a few of them to be key. This way, we have
+ # different kinds of flags, defined in different modules, and not
+ # all of them are key flags.
+ flags.declare_key_flag('tmod_bar_z', flag_values=fv)
+ flags.declare_key_flag('tmod_bar_u', flag_values=fv)
+
+ # Generate flag help in XML format in the StringIO sio.
+ sio = io.StringIO()
+ fv.write_help_in_xml_format(sio)
+
+ # Check that we got the expected result.
+ expected_output_template = EXPECTED_HELP_XML_START
+ main_module_name = sys.argv[0]
+ module_bar_name = module_bar.__name__
+
+ if main_module_name < module_bar_name:
+ expected_output_template += EXPECTED_HELP_XML_FOR_FLAGS_FROM_MAIN_MODULE
+ expected_output_template += EXPECTED_HELP_XML_FOR_FLAGS_FROM_MODULE_BAR
+ else:
+ expected_output_template += EXPECTED_HELP_XML_FOR_FLAGS_FROM_MODULE_BAR
+ expected_output_template += EXPECTED_HELP_XML_FOR_FLAGS_FROM_MAIN_MODULE
+
+ expected_output_template += EXPECTED_HELP_XML_END
+
+ # XML representation of the whitespace list separators.
+ whitespace_separators = _list_separators_in_xmlformat(string.whitespace,
+ indent=' ')
+ expected_output = (
+ expected_output_template %
+ {'basename_of_argv0': os.path.basename(sys.argv[0]),
+ 'usage_doc': sys.modules['__main__'].__doc__,
+ 'main_module_name': main_module_name,
+ 'module_bar_name': module_bar_name,
+ 'whitespace_separators': whitespace_separators})
+
+ actual_output = sio.getvalue()
+ self.assertMultiLineEqual(expected_output, actual_output)
+
+ # Also check that our result is valid XML. minidom.parseString
+ # throws an xml.parsers.expat.ExpatError in case of an error.
+ xml.dom.minidom.parseString(actual_output)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/flags/tests/flags_numeric_bounds_test.py b/absl/flags/tests/flags_numeric_bounds_test.py
new file mode 100644
index 0000000..d3c2a95
--- /dev/null
+++ b/absl/flags/tests/flags_numeric_bounds_test.py
@@ -0,0 +1,105 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for lower/upper bounds validators for numeric flags."""
+
+from unittest import mock
+from absl import flags
+from absl.flags import _validators
+from absl.testing import absltest
+
+
+class NumericFlagBoundsTest(absltest.TestCase):
+
+ def setUp(self):
+ super(NumericFlagBoundsTest, self).setUp()
+ self.flag_values = flags.FlagValues()
+
+ def test_no_validator_if_no_bounds(self):
+ """Validator is not registered if lower and upper bound are None."""
+ with mock.patch.object(_validators, 'register_validator'
+ ) as register_validator:
+ flags.DEFINE_integer('positive_flag', None, 'positive int',
+ lower_bound=0, flag_values=self.flag_values)
+ register_validator.assert_called_once_with(
+ 'positive_flag', mock.ANY, flag_values=self.flag_values)
+ with mock.patch.object(_validators, 'register_validator'
+ ) as register_validator:
+ flags.DEFINE_integer('int_flag', None, 'just int',
+ flag_values=self.flag_values)
+ register_validator.assert_not_called()
+
+ def test_success(self):
+ flags.DEFINE_integer('int_flag', 5, 'Just integer',
+ flag_values=self.flag_values)
+ argv = ('./program', '--int_flag=13')
+ self.flag_values(argv)
+ self.assertEqual(13, self.flag_values.int_flag)
+ self.flag_values.int_flag = 25
+ self.assertEqual(25, self.flag_values.int_flag)
+
+ def test_success_if_none(self):
+ flags.DEFINE_integer('int_flag', None, '',
+ lower_bound=0, upper_bound=5,
+ flag_values=self.flag_values)
+ argv = ('./program',)
+ self.flag_values(argv)
+ self.assertIsNone(self.flag_values.int_flag)
+
+ def test_success_if_exactly_equals(self):
+ flags.DEFINE_float('float_flag', None, '',
+ lower_bound=1, upper_bound=1,
+ flag_values=self.flag_values)
+ argv = ('./program', '--float_flag=1')
+ self.flag_values(argv)
+ self.assertEqual(1, self.flag_values.float_flag)
+
+ def test_exception_if_smaller(self):
+ flags.DEFINE_integer('int_flag', None, '',
+ lower_bound=0, upper_bound=5,
+ flag_values=self.flag_values)
+ argv = ('./program', '--int_flag=-1')
+ try:
+ self.flag_values(argv)
+ except flags.IllegalFlagValueError as e:
+ text = 'flag --int_flag=-1: -1 is not an integer in the range [0, 5]'
+ self.assertEqual(text, str(e))
+
+
+class SettingFlagAfterStartTest(absltest.TestCase):
+
+ def setUp(self):
+ self.flag_values = flags.FlagValues()
+
+ def test_success(self):
+ flags.DEFINE_integer('int_flag', None, 'Just integer',
+ flag_values=self.flag_values)
+ argv = ('./program', '--int_flag=13')
+ self.flag_values(argv)
+ self.assertEqual(13, self.flag_values.int_flag)
+ self.flag_values.int_flag = 25
+ self.assertEqual(25, self.flag_values.int_flag)
+
+ def test_exception_if_setting_integer_flag_outside_bounds(self):
+ flags.DEFINE_integer('int_flag', None, 'Just integer', lower_bound=0,
+ flag_values=self.flag_values)
+ argv = ('./program', '--int_flag=13')
+ self.flag_values(argv)
+ self.assertEqual(13, self.flag_values.int_flag)
+ with self.assertRaises(flags.IllegalFlagValueError):
+ self.flag_values.int_flag = -2
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/flags/tests/flags_test.py b/absl/flags/tests/flags_test.py
new file mode 100644
index 0000000..8a42bc9
--- /dev/null
+++ b/absl/flags/tests/flags_test.py
@@ -0,0 +1,2922 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for absl.flags used as a package."""
+
+import contextlib
+import enum
+import io
+import os
+import shutil
+import sys
+import tempfile
+import unittest
+
+from absl import flags
+from absl.flags import _exceptions
+from absl.flags import _helpers
+from absl.flags.tests import module_bar
+from absl.flags.tests import module_baz
+from absl.flags.tests import module_foo
+from absl.testing import absltest
+
+FLAGS = flags.FLAGS
+
+
+@contextlib.contextmanager
+def _use_gnu_getopt(flag_values, use_gnu_get_opt):
+ old_use_gnu_get_opt = flag_values.is_gnu_getopt()
+ flag_values.set_gnu_getopt(use_gnu_get_opt)
+ yield
+ flag_values.set_gnu_getopt(old_use_gnu_get_opt)
+
+
+class FlagDictToArgsTest(absltest.TestCase):
+
+ def test_flatten_google_flag_map(self):
+ arg_dict = {
+ 'week-end': None,
+ 'estudia': False,
+ 'trabaja': False,
+ 'party': True,
+ 'monday': 'party',
+ 'score': 42,
+ 'loadthatstuff': [42, 'hello', 'goodbye'],
+ }
+ self.assertSameElements(
+ ('--week-end', '--noestudia', '--notrabaja', '--party',
+ '--monday=party', '--score=42', '--loadthatstuff=42,hello,goodbye'),
+ flags.flag_dict_to_args(arg_dict))
+
+ def test_flatten_google_flag_map_with_multi_flag(self):
+ arg_dict = {
+ 'some_list': ['value1', 'value2'],
+ 'some_multi_string': ['value3', 'value4'],
+ }
+ self.assertSameElements(
+ ('--some_list=value1,value2', '--some_multi_string=value3',
+ '--some_multi_string=value4'),
+ flags.flag_dict_to_args(arg_dict, multi_flags={'some_multi_string'}))
+
+
+class Fruit(enum.Enum):
+ APPLE = object()
+ ORANGE = object()
+
+
+class CaseSensitiveFruit(enum.Enum):
+ apple = 1
+ orange = 2
+ APPLE = 3
+
+
+class EmptyEnum(enum.Enum):
+ pass
+
+
+class AliasFlagsTest(absltest.TestCase):
+
+ def setUp(self):
+ super(AliasFlagsTest, self).setUp()
+ self.flags = flags.FlagValues()
+
+ @property
+ def alias(self):
+ return self.flags['alias']
+
+ @property
+ def aliased(self):
+ return self.flags['aliased']
+
+ def define_alias(self, *args, **kwargs):
+ flags.DEFINE_alias(*args, flag_values=self.flags, **kwargs)
+
+ def define_integer(self, *args, **kwargs):
+ flags.DEFINE_integer(*args, flag_values=self.flags, **kwargs)
+
+ def define_multi_integer(self, *args, **kwargs):
+ flags.DEFINE_multi_integer(*args, flag_values=self.flags, **kwargs)
+
+ def define_string(self, *args, **kwargs):
+ flags.DEFINE_string(*args, flag_values=self.flags, **kwargs)
+
+ def assert_alias_mirrors_aliased(self, alias, aliased, ignore_due_to_bug=()):
+ # A few sanity checks to avoid false success
+ self.assertIn('FlagAlias', alias.__class__.__qualname__)
+ self.assertIsNot(alias, aliased)
+ self.assertNotEqual(aliased.name, alias.name)
+
+ alias_state = {}
+ aliased_state = {}
+ attrs = {
+ 'allow_hide_cpp',
+ 'allow_override',
+ 'allow_override_cpp',
+ 'allow_overwrite',
+ 'allow_using_method_names',
+ 'boolean',
+ 'default',
+ 'default_as_str',
+ 'default_unparsed',
+ # TODO(rlevasseur): This should match, but a bug prevents it from being
+ # in sync.
+ # 'using_default_value',
+ 'value',
+ }
+ attrs.difference_update(ignore_due_to_bug)
+
+ for attr in attrs:
+ alias_state[attr] = getattr(alias, attr)
+ aliased_state[attr] = getattr(aliased, attr)
+
+ self.assertEqual(aliased_state, alias_state, 'LHS is aliased; RHS is alias')
+
+ def test_serialize_multi(self):
+ self.define_multi_integer('aliased', [0, 1], '')
+ self.define_alias('alias', 'aliased')
+
+ actual = self.alias.serialize()
+ # TODO(rlevasseur): This should check for --alias=0\n--alias=1, but
+ # a bug causes it to serialize incorrectly.
+ self.assertEqual('--alias=[0, 1]', actual)
+
+ def test_allow_overwrite_false(self):
+ self.define_integer('aliased', None, 'help', allow_overwrite=False)
+ self.define_alias('alias', 'aliased')
+
+ with self.assertRaisesRegex(flags.IllegalFlagValueError, 'already defined'):
+ self.flags(['./program', '--alias=1', '--aliased=2'])
+
+ self.assertEqual(1, self.alias.value)
+ self.assertEqual(1, self.aliased.value)
+
+ def test_aliasing_multi_no_default(self):
+
+ def define_flags():
+ self.flags = flags.FlagValues()
+ self.define_multi_integer('aliased', None, 'help')
+ self.define_alias('alias', 'aliased')
+
+ with self.subTest('after defining'):
+ define_flags()
+ self.assert_alias_mirrors_aliased(self.alias, self.aliased)
+ self.assertIsNone(self.alias.value)
+
+ with self.subTest('set alias'):
+ define_flags()
+ self.flags(['./program', '--alias=1', '--alias=2'])
+ self.assertEqual([1, 2], self.alias.value)
+ self.assert_alias_mirrors_aliased(self.alias, self.aliased)
+
+ with self.subTest('set aliased'):
+ define_flags()
+ self.flags(['./program', '--aliased=1', '--aliased=2'])
+ self.assertEqual([1, 2], self.alias.value)
+ self.assert_alias_mirrors_aliased(self.alias, self.aliased)
+
+ with self.subTest('not setting anything'):
+ define_flags()
+ self.flags(['./program'])
+ self.assertEqual(None, self.alias.value)
+ self.assert_alias_mirrors_aliased(self.alias, self.aliased)
+
+ def test_aliasing_multi_with_default(self):
+
+ def define_flags():
+ self.flags = flags.FlagValues()
+ self.define_multi_integer('aliased', [0], 'help')
+ self.define_alias('alias', 'aliased')
+
+ with self.subTest('after defining'):
+ define_flags()
+ self.assertEqual([0], self.alias.default)
+ self.assert_alias_mirrors_aliased(self.alias, self.aliased)
+
+ with self.subTest('set alias'):
+ define_flags()
+ self.flags(['./program', '--alias=1', '--alias=2'])
+ self.assertEqual([1, 2], self.alias.value)
+ self.assert_alias_mirrors_aliased(self.alias, self.aliased)
+
+ self.assertEqual(2, self.alias.present)
+ # TODO(rlevasseur): This should assert 0, but a bug with aliases and
+ # MultiFlag causes the alias to increment aliased's present counter.
+ self.assertEqual(2, self.aliased.present)
+
+ with self.subTest('set aliased'):
+ define_flags()
+ self.flags(['./program', '--aliased=1', '--aliased=2'])
+ self.assertEqual([1, 2], self.alias.value)
+ self.assert_alias_mirrors_aliased(self.alias, self.aliased)
+ self.assertEqual(0, self.alias.present)
+
+ # TODO(rlevasseur): This should assert 0, but a bug with aliases and
+ # MultiFlag causes the alias to increment aliased present counter.
+ self.assertEqual(2, self.aliased.present)
+
+ with self.subTest('not setting anything'):
+ define_flags()
+ self.flags(['./program'])
+ self.assertEqual([0], self.alias.value)
+ self.assert_alias_mirrors_aliased(self.alias, self.aliased)
+ self.assertEqual(0, self.alias.present)
+ self.assertEqual(0, self.aliased.present)
+
+ def test_aliasing_regular(self):
+
+ def define_flags():
+ self.flags = flags.FlagValues()
+ self.define_string('aliased', '', 'help')
+ self.define_alias('alias', 'aliased')
+
+ define_flags()
+ self.assert_alias_mirrors_aliased(self.alias, self.aliased)
+
+ self.flags(['./program', '--alias=1'])
+ self.assertEqual('1', self.alias.value)
+ self.assert_alias_mirrors_aliased(self.alias, self.aliased)
+ self.assertEqual(1, self.alias.present)
+ self.assertEqual('--alias=1', self.alias.serialize())
+ self.assertEqual(1, self.aliased.present)
+
+ define_flags()
+ self.flags(['./program', '--aliased=2'])
+ self.assertEqual('2', self.alias.value)
+ self.assert_alias_mirrors_aliased(self.alias, self.aliased)
+ self.assertEqual(0, self.alias.present)
+ self.assertEqual('--alias=2', self.alias.serialize())
+ self.assertEqual(1, self.aliased.present)
+
+ def test_defining_alias_doesnt_affect_aliased_state_regular(self):
+ self.define_string('aliased', 'default', 'help')
+ self.define_alias('alias', 'aliased')
+
+ self.assertEqual(0, self.aliased.present)
+ self.assertEqual(0, self.alias.present)
+
+ def test_defining_alias_doesnt_affect_aliased_state_multi(self):
+ self.define_multi_integer('aliased', [0], 'help')
+ self.define_alias('alias', 'aliased')
+
+ self.assertEqual([0], self.aliased.value)
+ self.assertEqual([0], self.aliased.default)
+ self.assertEqual(0, self.aliased.present)
+
+ self.assertEqual([0], self.aliased.value)
+ self.assertEqual([0], self.aliased.default)
+ self.assertEqual(0, self.alias.present)
+
+
+class FlagsUnitTest(absltest.TestCase):
+ """Flags Unit Test."""
+
+ maxDiff = None
+
+ def test_flags(self):
+ """Test normal usage with no (expected) errors."""
+ # Define flags
+ number_test_framework_flags = len(FLAGS)
+ repeat_help = 'how many times to repeat (0-5)'
+ flags.DEFINE_integer(
+ 'repeat', 4, repeat_help, lower_bound=0, short_name='r')
+ flags.DEFINE_string('name', 'Bob', 'namehelp')
+ flags.DEFINE_boolean('debug', 0, 'debughelp')
+ flags.DEFINE_boolean('q', 1, 'quiet mode')
+ flags.DEFINE_boolean('quack', 0, "superstring of 'q'")
+ flags.DEFINE_boolean('noexec', 1, 'boolean flag with no as prefix')
+ flags.DEFINE_float('float', 3.14, 'using floats')
+ flags.DEFINE_integer('octal', '0o666', 'using octals')
+ flags.DEFINE_integer('decimal', '666', 'using decimals')
+ flags.DEFINE_integer('hexadecimal', '0x666', 'using hexadecimals')
+ flags.DEFINE_integer('x', 3, 'how eXtreme to be')
+ flags.DEFINE_integer('l', 0x7fffffff00000000, 'how long to be')
+ flags.DEFINE_list('args', 'v=1,"vmodule=a=0,b=2"', 'a list of arguments')
+ flags.DEFINE_list('letters', 'a,b,c', 'a list of letters')
+ flags.DEFINE_list('numbers', [1, 2, 3], 'a list of numbers')
+ flags.DEFINE_enum('kwery', None, ['who', 'what', 'Why', 'where', 'when'],
+ '?')
+ flags.DEFINE_enum(
+ 'sense', None, ['Case', 'case', 'CASE'], '?', case_sensitive=True)
+ flags.DEFINE_enum(
+ 'cases',
+ None, ['UPPER', 'lower', 'Initial', 'Ot_HeR'],
+ '?',
+ case_sensitive=False)
+ flags.DEFINE_enum(
+ 'funny',
+ None, ['Joke', 'ha', 'ha', 'ha', 'ha'],
+ '?',
+ case_sensitive=True)
+ flags.DEFINE_enum(
+ 'blah',
+ None, ['bla', 'Blah', 'BLAH', 'blah'],
+ '?',
+ case_sensitive=False)
+ flags.DEFINE_string(
+ 'only_once', None, 'test only sets this once', allow_overwrite=False)
+ flags.DEFINE_string(
+ 'universe',
+ None,
+ 'test tries to set this three times',
+ allow_overwrite=False)
+
+ # Specify number of flags defined above. The short_name defined
+ # for 'repeat' counts as an extra flag.
+ number_defined_flags = 22 + 1
+ self.assertLen(FLAGS, number_defined_flags + number_test_framework_flags)
+
+ self.assertEqual(FLAGS.repeat, 4)
+ self.assertEqual(FLAGS.name, 'Bob')
+ self.assertEqual(FLAGS.debug, 0)
+ self.assertEqual(FLAGS.q, 1)
+ self.assertEqual(FLAGS.octal, 0o666)
+ self.assertEqual(FLAGS.decimal, 666)
+ self.assertEqual(FLAGS.hexadecimal, 0x666)
+ self.assertEqual(FLAGS.x, 3)
+ self.assertEqual(FLAGS.l, 0x7fffffff00000000)
+ self.assertEqual(FLAGS.args, ['v=1', 'vmodule=a=0,b=2'])
+ self.assertEqual(FLAGS.letters, ['a', 'b', 'c'])
+ self.assertEqual(FLAGS.numbers, [1, 2, 3])
+ self.assertIsNone(FLAGS.kwery)
+ self.assertIsNone(FLAGS.sense)
+ self.assertIsNone(FLAGS.cases)
+ self.assertIsNone(FLAGS.funny)
+ self.assertIsNone(FLAGS.blah)
+
+ flag_values = FLAGS.flag_values_dict()
+ self.assertEqual(flag_values['repeat'], 4)
+ self.assertEqual(flag_values['name'], 'Bob')
+ self.assertEqual(flag_values['debug'], 0)
+ self.assertEqual(flag_values['r'], 4) # Short for repeat.
+ self.assertEqual(flag_values['q'], 1)
+ self.assertEqual(flag_values['quack'], 0)
+ self.assertEqual(flag_values['x'], 3)
+ self.assertEqual(flag_values['l'], 0x7fffffff00000000)
+ self.assertEqual(flag_values['args'], ['v=1', 'vmodule=a=0,b=2'])
+ self.assertEqual(flag_values['letters'], ['a', 'b', 'c'])
+ self.assertEqual(flag_values['numbers'], [1, 2, 3])
+ self.assertIsNone(flag_values['kwery'])
+ self.assertIsNone(flag_values['sense'])
+ self.assertIsNone(flag_values['cases'])
+ self.assertIsNone(flag_values['funny'])
+ self.assertIsNone(flag_values['blah'])
+
+ # Verify string form of defaults
+ self.assertEqual(FLAGS['repeat'].default_as_str, "'4'")
+ self.assertEqual(FLAGS['name'].default_as_str, "'Bob'")
+ self.assertEqual(FLAGS['debug'].default_as_str, "'false'")
+ self.assertEqual(FLAGS['q'].default_as_str, "'true'")
+ self.assertEqual(FLAGS['quack'].default_as_str, "'false'")
+ self.assertEqual(FLAGS['noexec'].default_as_str, "'true'")
+ self.assertEqual(FLAGS['x'].default_as_str, "'3'")
+ self.assertEqual(FLAGS['l'].default_as_str, "'9223372032559808512'")
+ self.assertEqual(FLAGS['args'].default_as_str, '\'v=1,"vmodule=a=0,b=2"\'')
+ self.assertEqual(FLAGS['letters'].default_as_str, "'a,b,c'")
+ self.assertEqual(FLAGS['numbers'].default_as_str, "'1,2,3'")
+
+ # Verify that the iterator for flags yields all the keys
+ keys = list(FLAGS)
+ keys.sort()
+ reg_flags = list(FLAGS._flags())
+ reg_flags.sort()
+ self.assertEqual(keys, reg_flags)
+
+ # Parse flags
+ # .. empty command line
+ argv = ('./program',)
+ argv = FLAGS(argv)
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+
+ # .. non-empty command line
+ argv = ('./program', '--debug', '--name=Bob', '-q', '--x=8')
+ argv = FLAGS(argv)
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+ self.assertEqual(FLAGS['debug'].present, 1)
+ FLAGS['debug'].present = 0 # Reset
+ self.assertEqual(FLAGS['name'].present, 1)
+ FLAGS['name'].present = 0 # Reset
+ self.assertEqual(FLAGS['q'].present, 1)
+ FLAGS['q'].present = 0 # Reset
+ self.assertEqual(FLAGS['x'].present, 1)
+ FLAGS['x'].present = 0 # Reset
+
+ # Flags list.
+ self.assertLen(FLAGS, number_defined_flags + number_test_framework_flags)
+ self.assertIn('name', FLAGS)
+ self.assertIn('debug', FLAGS)
+ self.assertIn('repeat', FLAGS)
+ self.assertIn('r', FLAGS)
+ self.assertIn('q', FLAGS)
+ self.assertIn('quack', FLAGS)
+ self.assertIn('x', FLAGS)
+ self.assertIn('l', FLAGS)
+ self.assertIn('args', FLAGS)
+ self.assertIn('letters', FLAGS)
+ self.assertIn('numbers', FLAGS)
+
+ # __contains__
+ self.assertIn('name', FLAGS)
+ self.assertNotIn('name2', FLAGS)
+
+ # try deleting a flag
+ del FLAGS.r
+ self.assertLen(FLAGS,
+ number_defined_flags - 1 + number_test_framework_flags)
+ self.assertNotIn('r', FLAGS)
+
+ # .. command line with extra stuff
+ argv = ('./program', '--debug', '--name=Bob', 'extra')
+ argv = FLAGS(argv)
+ self.assertLen(argv, 2, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+ self.assertEqual(argv[1], 'extra', 'extra argument not preserved')
+ self.assertEqual(FLAGS['debug'].present, 1)
+ FLAGS['debug'].present = 0 # Reset
+ self.assertEqual(FLAGS['name'].present, 1)
+ FLAGS['name'].present = 0 # Reset
+
+ # Test reset
+ argv = ('./program', '--debug')
+ argv = FLAGS(argv)
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+ self.assertEqual(FLAGS['debug'].present, 1)
+ self.assertTrue(FLAGS['debug'].value)
+ FLAGS.unparse_flags()
+ self.assertEqual(FLAGS['debug'].present, 0)
+ self.assertFalse(FLAGS['debug'].value)
+
+ # Test that reset restores default value when default value is None.
+ argv = ('./program', '--kwery=who')
+ argv = FLAGS(argv)
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+ self.assertEqual(FLAGS['kwery'].present, 1)
+ self.assertEqual(FLAGS['kwery'].value, 'who')
+ FLAGS.unparse_flags()
+ argv = ('./program', '--kwery=Why')
+ argv = FLAGS(argv)
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+ self.assertEqual(FLAGS['kwery'].present, 1)
+ self.assertEqual(FLAGS['kwery'].value, 'Why')
+ FLAGS.unparse_flags()
+ self.assertEqual(FLAGS['kwery'].present, 0)
+ self.assertIsNone(FLAGS['kwery'].value)
+
+ # Test case sensitive enum.
+ argv = ('./program', '--sense=CASE')
+ argv = FLAGS(argv)
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+ self.assertEqual(FLAGS['sense'].present, 1)
+ self.assertEqual(FLAGS['sense'].value, 'CASE')
+ FLAGS.unparse_flags()
+ argv = ('./program', '--sense=Case')
+ argv = FLAGS(argv)
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+ self.assertEqual(FLAGS['sense'].present, 1)
+ self.assertEqual(FLAGS['sense'].value, 'Case')
+ FLAGS.unparse_flags()
+
+ # Test case insensitive enum.
+ argv = ('./program', '--cases=upper')
+ argv = FLAGS(argv)
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+ self.assertEqual(FLAGS['cases'].present, 1)
+ self.assertEqual(FLAGS['cases'].value, 'UPPER')
+ FLAGS.unparse_flags()
+
+ # Test case sensitive enum with duplicates.
+ argv = ('./program', '--funny=ha')
+ argv = FLAGS(argv)
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+ self.assertEqual(FLAGS['funny'].present, 1)
+ self.assertEqual(FLAGS['funny'].value, 'ha')
+ FLAGS.unparse_flags()
+
+ # Test case insensitive enum with duplicates.
+ argv = ('./program', '--blah=bLah')
+ argv = FLAGS(argv)
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+ self.assertEqual(FLAGS['blah'].present, 1)
+ self.assertEqual(FLAGS['blah'].value, 'Blah')
+ FLAGS.unparse_flags()
+ argv = ('./program', '--blah=BLAH')
+ argv = FLAGS(argv)
+ self.assertLen(argv, 1, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+ self.assertEqual(FLAGS['blah'].present, 1)
+ self.assertEqual(FLAGS['blah'].value, 'Blah')
+ FLAGS.unparse_flags()
+
+ # Test integer argument passing
+ argv = ('./program', '--x', '0x12345')
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.x, 0x12345)
+ self.assertEqual(type(FLAGS.x), int)
+
+ argv = ('./program', '--x', '0x1234567890ABCDEF1234567890ABCDEF')
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.x, 0x1234567890ABCDEF1234567890ABCDEF)
+ self.assertIsInstance(FLAGS.x, int)
+
+ argv = ('./program', '--x', '0o12345')
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.x, 0o12345)
+ self.assertEqual(type(FLAGS.x), int)
+
+ # Treat 0-prefixed parameters as base-10, not base-8
+ argv = ('./program', '--x', '012345')
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.x, 12345)
+ self.assertEqual(type(FLAGS.x), int)
+
+ argv = ('./program', '--x', '0123459')
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.x, 123459)
+ self.assertEqual(type(FLAGS.x), int)
+
+ argv = ('./program', '--x', '0x123efg')
+ with self.assertRaises(flags.IllegalFlagValueError):
+ argv = FLAGS(argv)
+
+ # Test boolean argument parsing
+ flags.DEFINE_boolean('test0', None, 'test boolean parsing')
+ argv = ('./program', '--notest0')
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.test0, 0)
+
+ flags.DEFINE_boolean('test1', None, 'test boolean parsing')
+ argv = ('./program', '--test1')
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.test1, 1)
+
+ FLAGS.test0 = None
+ argv = ('./program', '--test0=false')
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.test0, 0)
+
+ FLAGS.test1 = None
+ argv = ('./program', '--test1=true')
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.test1, 1)
+
+ FLAGS.test0 = None
+ argv = ('./program', '--test0=0')
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.test0, 0)
+
+ FLAGS.test1 = None
+ argv = ('./program', '--test1=1')
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.test1, 1)
+
+ # Test booleans that already have 'no' as a prefix
+ FLAGS.noexec = None
+ argv = ('./program', '--nonoexec', '--name', 'Bob')
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.noexec, 0)
+
+ FLAGS.noexec = None
+ argv = ('./program', '--name', 'Bob', '--noexec')
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.noexec, 1)
+
+ # Test unassigned booleans
+ flags.DEFINE_boolean('testnone', None, 'test boolean parsing')
+ argv = ('./program',)
+ argv = FLAGS(argv)
+ self.assertIsNone(FLAGS.testnone)
+
+ # Test get with default
+ flags.DEFINE_boolean('testget1', None, 'test parsing with defaults')
+ flags.DEFINE_boolean('testget2', None, 'test parsing with defaults')
+ flags.DEFINE_boolean('testget3', None, 'test parsing with defaults')
+ flags.DEFINE_integer('testget4', None, 'test parsing with defaults')
+ argv = ('./program', '--testget1', '--notestget2')
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.get_flag_value('testget1', 'foo'), 1)
+ self.assertEqual(FLAGS.get_flag_value('testget2', 'foo'), 0)
+ self.assertEqual(FLAGS.get_flag_value('testget3', 'foo'), 'foo')
+ self.assertEqual(FLAGS.get_flag_value('testget4', 'foo'), 'foo')
+
+ # test list code
+ lists = [['hello', 'moo', 'boo', '1'], []]
+
+ flags.DEFINE_list('testcomma_list', '', 'test comma list parsing')
+ flags.DEFINE_spaceseplist('testspace_list', '', 'tests space list parsing')
+ flags.DEFINE_spaceseplist(
+ 'testspace_or_comma_list',
+ '',
+ 'tests space list parsing with comma compatibility',
+ comma_compat=True)
+
+ for name, sep in (('testcomma_list', ','), ('testspace_list',
+ ' '), ('testspace_list', '\n'),
+ ('testspace_or_comma_list',
+ ' '), ('testspace_or_comma_list',
+ '\n'), ('testspace_or_comma_list', ',')):
+ for lst in lists:
+ argv = ('./program', '--%s=%s' % (name, sep.join(lst)))
+ argv = FLAGS(argv)
+ self.assertEqual(getattr(FLAGS, name), lst)
+
+ # Test help text
+ flags_help = str(FLAGS)
+ self.assertNotEqual(
+ flags_help.find('repeat'), -1, 'cannot find flag in help')
+ self.assertNotEqual(
+ flags_help.find(repeat_help), -1, 'cannot find help string in help')
+
+ # Test flag specified twice
+ argv = ('./program', '--repeat=4', '--repeat=2', '--debug', '--nodebug')
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.get_flag_value('repeat', None), 2)
+ self.assertEqual(FLAGS.get_flag_value('debug', None), 0)
+
+ # Test MultiFlag with single default value
+ flags.DEFINE_multi_string(
+ 's_str',
+ 'sing1',
+ 'string option that can occur multiple times',
+ short_name='s')
+ self.assertEqual(FLAGS.get_flag_value('s_str', None), ['sing1'])
+
+ # Test MultiFlag with list of default values
+ multi_string_defs = ['def1', 'def2']
+ flags.DEFINE_multi_string(
+ 'm_str',
+ multi_string_defs,
+ 'string option that can occur multiple times',
+ short_name='m')
+ self.assertEqual(FLAGS.get_flag_value('m_str', None), multi_string_defs)
+
+ # Test flag specified multiple times with a MultiFlag
+ argv = ('./program', '--m_str=str1', '-m', 'str2')
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.get_flag_value('m_str', None), ['str1', 'str2'])
+
+ # A flag with allow_overwrite set to False should behave normally when it
+ # is only specified once
+ argv = ('./program', '--only_once=singlevalue')
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.get_flag_value('only_once', None), 'singlevalue')
+
+ # A flag with allow_overwrite set to False should complain when it is
+ # specified more than once
+ argv = ('./program', '--universe=ptolemaic', '--universe=copernicean',
+ '--universe=euclidean')
+ self.assertRaisesWithLiteralMatch(
+ flags.IllegalFlagValueError,
+ 'flag --universe=copernicean: already defined as ptolemaic', FLAGS,
+ argv)
+
+ # Test single-letter flags; should support both single and double dash
+ argv = ('./program', '-q')
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.get_flag_value('q', None), 1)
+
+ argv = ('./program', '--q', '--x', '9', '--noquack')
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.get_flag_value('q', None), 1)
+ self.assertEqual(FLAGS.get_flag_value('x', None), 9)
+ self.assertEqual(FLAGS.get_flag_value('quack', None), 0)
+
+ argv = ('./program', '--noq', '--x=10', '--quack')
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.get_flag_value('q', None), 0)
+ self.assertEqual(FLAGS.get_flag_value('x', None), 10)
+ self.assertEqual(FLAGS.get_flag_value('quack', None), 1)
+
+ ####################################
+ # Test flag serialization code:
+
+ old_testcomma_list = FLAGS.testcomma_list
+ old_testspace_list = FLAGS.testspace_list
+ old_testspace_or_comma_list = FLAGS.testspace_or_comma_list
+
+ argv = ('./program', FLAGS['test0'].serialize(), FLAGS['test1'].serialize(),
+ FLAGS['s_str'].serialize())
+
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS['test0'].serialize(), '--notest0')
+ self.assertEqual(FLAGS['test1'].serialize(), '--test1')
+ self.assertEqual(FLAGS['s_str'].serialize(), '--s_str=sing1')
+
+ self.assertEqual(FLAGS['testnone'].serialize(), '')
+
+ testcomma_list1 = ['aa', 'bb']
+ testspace_list1 = ['aa', 'bb', 'cc']
+ testspace_or_comma_list1 = ['aa', 'bb', 'cc', 'dd']
+ FLAGS.testcomma_list = list(testcomma_list1)
+ FLAGS.testspace_list = list(testspace_list1)
+ FLAGS.testspace_or_comma_list = list(testspace_or_comma_list1)
+ argv = ('./program', FLAGS['testcomma_list'].serialize(),
+ FLAGS['testspace_list'].serialize(),
+ FLAGS['testspace_or_comma_list'].serialize())
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.testcomma_list, testcomma_list1)
+ self.assertEqual(FLAGS.testspace_list, testspace_list1)
+ self.assertEqual(FLAGS.testspace_or_comma_list, testspace_or_comma_list1)
+
+ testcomma_list1 = ['aa some spaces', 'bb']
+ testspace_list1 = ['aa', 'bb,some,commas,', 'cc']
+ testspace_or_comma_list1 = ['aa', 'bb,some,commas,', 'cc']
+ FLAGS.testcomma_list = list(testcomma_list1)
+ FLAGS.testspace_list = list(testspace_list1)
+ FLAGS.testspace_or_comma_list = list(testspace_or_comma_list1)
+ argv = ('./program', FLAGS['testcomma_list'].serialize(),
+ FLAGS['testspace_list'].serialize(),
+ FLAGS['testspace_or_comma_list'].serialize())
+ argv = FLAGS(argv)
+ self.assertEqual(FLAGS.testcomma_list, testcomma_list1)
+ self.assertEqual(FLAGS.testspace_list, testspace_list1)
+ # We don't expect idempotency when commas are placed in an item value and
+ # comma_compat is enabled.
+ self.assertEqual(FLAGS.testspace_or_comma_list,
+ ['aa', 'bb', 'some', 'commas', 'cc'])
+
+ FLAGS.testcomma_list = old_testcomma_list
+ FLAGS.testspace_list = old_testspace_list
+ FLAGS.testspace_or_comma_list = old_testspace_or_comma_list
+
+ ####################################
+ # Test flag-update:
+
+ def args_list():
+ # Exclude flags that have different default values based on the
+ # environment.
+ flags_to_exclude = {'log_dir', 'test_srcdir', 'test_tmpdir'}
+ flagnames = set(FLAGS) - flags_to_exclude
+
+ nonbool_flags = []
+ truebool_flags = []
+ falsebool_flags = []
+ for name in flagnames:
+ flag_value = FLAGS.get_flag_value(name, None)
+ if not isinstance(FLAGS[name], flags.BooleanFlag):
+ nonbool_flags.append('--%s %s' % (name, flag_value))
+ elif flag_value:
+ truebool_flags.append('--%s' % name)
+ else:
+ falsebool_flags.append('--no%s' % name)
+ all_flags = nonbool_flags + truebool_flags + falsebool_flags
+ all_flags.sort()
+ return all_flags
+
+ argv = ('./program', '--repeat=3', '--name=giants', '--nodebug')
+
+ FLAGS(argv)
+ self.assertEqual(FLAGS.get_flag_value('repeat', None), 3)
+ self.assertEqual(FLAGS.get_flag_value('name', None), 'giants')
+ self.assertEqual(FLAGS.get_flag_value('debug', None), 0)
+ self.assertListEqual([
+ '--alsologtostderr',
+ "--args ['v=1', 'vmodule=a=0,b=2']",
+ '--blah None',
+ '--cases None',
+ '--decimal 666',
+ '--float 3.14',
+ '--funny None',
+ '--hexadecimal 1638',
+ '--kwery None',
+ '--l 9223372032559808512',
+ "--letters ['a', 'b', 'c']",
+ '--logger_levels {}',
+ "--m ['str1', 'str2']",
+ "--m_str ['str1', 'str2']",
+ '--name giants',
+ '--no?',
+ '--nodebug',
+ '--noexec',
+ '--nohelp',
+ '--nohelpfull',
+ '--nohelpshort',
+ '--nohelpxml',
+ '--nologtostderr',
+ '--noonly_check_args',
+ '--nopdb_post_mortem',
+ '--noq',
+ '--norun_with_pdb',
+ '--norun_with_profiling',
+ '--notest0',
+ '--notestget2',
+ '--notestget3',
+ '--notestnone',
+ '--numbers [1, 2, 3]',
+ '--octal 438',
+ '--only_once singlevalue',
+ '--pdb False',
+ '--profile_file None',
+ '--quack',
+ '--repeat 3',
+ "--s ['sing1']",
+ "--s_str ['sing1']",
+ '--sense None',
+ '--showprefixforinfo',
+ '--stderrthreshold fatal',
+ '--test1',
+ '--test_random_seed 301',
+ '--test_randomize_ordering_seed ',
+ '--testcomma_list []',
+ '--testget1',
+ '--testget4 None',
+ '--testspace_list []',
+ '--testspace_or_comma_list []',
+ '--tmod_baz_x',
+ '--universe ptolemaic',
+ '--use_cprofile_for_profiling',
+ '--v -1',
+ '--verbosity -1',
+ '--x 10',
+ '--xml_output_file ',
+ ], args_list())
+
+ argv = ('./program', '--debug', '--m_str=upd1', '-s', 'upd2')
+ FLAGS(argv)
+ self.assertEqual(FLAGS.get_flag_value('repeat', None), 3)
+ self.assertEqual(FLAGS.get_flag_value('name', None), 'giants')
+ self.assertEqual(FLAGS.get_flag_value('debug', None), 1)
+
+ # items appended to existing non-default value lists for --m/--m_str
+ # new value overwrites default value (not appended to it) for --s/--s_str
+ self.assertListEqual([
+ '--alsologtostderr',
+ "--args ['v=1', 'vmodule=a=0,b=2']",
+ '--blah None',
+ '--cases None',
+ '--debug',
+ '--decimal 666',
+ '--float 3.14',
+ '--funny None',
+ '--hexadecimal 1638',
+ '--kwery None',
+ '--l 9223372032559808512',
+ "--letters ['a', 'b', 'c']",
+ '--logger_levels {}',
+ "--m ['str1', 'str2', 'upd1']",
+ "--m_str ['str1', 'str2', 'upd1']",
+ '--name giants',
+ '--no?',
+ '--noexec',
+ '--nohelp',
+ '--nohelpfull',
+ '--nohelpshort',
+ '--nohelpxml',
+ '--nologtostderr',
+ '--noonly_check_args',
+ '--nopdb_post_mortem',
+ '--noq',
+ '--norun_with_pdb',
+ '--norun_with_profiling',
+ '--notest0',
+ '--notestget2',
+ '--notestget3',
+ '--notestnone',
+ '--numbers [1, 2, 3]',
+ '--octal 438',
+ '--only_once singlevalue',
+ '--pdb False',
+ '--profile_file None',
+ '--quack',
+ '--repeat 3',
+ "--s ['sing1', 'upd2']",
+ "--s_str ['sing1', 'upd2']",
+ '--sense None',
+ '--showprefixforinfo',
+ '--stderrthreshold fatal',
+ '--test1',
+ '--test_random_seed 301',
+ '--test_randomize_ordering_seed ',
+ '--testcomma_list []',
+ '--testget1',
+ '--testget4 None',
+ '--testspace_list []',
+ '--testspace_or_comma_list []',
+ '--tmod_baz_x',
+ '--universe ptolemaic',
+ '--use_cprofile_for_profiling',
+ '--v -1',
+ '--verbosity -1',
+ '--x 10',
+ '--xml_output_file ',
+ ], args_list())
+
+ ####################################
+ # Test all kind of error conditions.
+
+ # Argument not in enum exception
+ argv = ('./program', '--kwery=WHEN')
+ self.assertRaises(flags.IllegalFlagValueError, FLAGS, argv)
+ argv = ('./program', '--kwery=why')
+ self.assertRaises(flags.IllegalFlagValueError, FLAGS, argv)
+
+ # Duplicate flag detection
+ with self.assertRaises(flags.DuplicateFlagError):
+ flags.DEFINE_boolean('run', 0, 'runhelp', short_name='q')
+
+ # Duplicate short flag detection
+ with self.assertRaisesRegex(
+ flags.DuplicateFlagError,
+ r"The flag 'z' is defined twice\. .*First from.*, Second from"):
+ flags.DEFINE_boolean('zoom1', 0, 'runhelp z1', short_name='z')
+ flags.DEFINE_boolean('zoom2', 0, 'runhelp z2', short_name='z')
+ raise AssertionError('duplicate short flag detection failed')
+
+ # Duplicate mixed flag detection
+ with self.assertRaisesRegex(
+ flags.DuplicateFlagError,
+ r"The flag 's' is defined twice\. .*First from.*, Second from"):
+ flags.DEFINE_boolean('short1', 0, 'runhelp s1', short_name='s')
+ flags.DEFINE_boolean('s', 0, 'runhelp s2')
+
+ # Check that duplicate flag detection detects definition sites
+ # correctly.
+ flagnames = ['repeated']
+ original_flags = flags.FlagValues()
+ flags.DEFINE_boolean(
+ flagnames[0],
+ False,
+ 'Flag about to be repeated.',
+ flag_values=original_flags)
+ duplicate_flags = module_foo.duplicate_flags(flagnames)
+ with self.assertRaisesRegex(flags.DuplicateFlagError,
+ 'flags_test.*module_foo'):
+ original_flags.append_flag_values(duplicate_flags)
+
+ # Make sure allow_override works
+ try:
+ flags.DEFINE_boolean(
+ 'dup1', 0, 'runhelp d11', short_name='u', allow_override=0)
+ flag = FLAGS._flags()['dup1']
+ self.assertEqual(flag.default, 0)
+
+ flags.DEFINE_boolean(
+ 'dup1', 1, 'runhelp d12', short_name='u', allow_override=1)
+ flag = FLAGS._flags()['dup1']
+ self.assertEqual(flag.default, 1)
+ except flags.DuplicateFlagError:
+ raise AssertionError('allow_override did not permit a flag duplication')
+
+ # Make sure allow_override works
+ try:
+ flags.DEFINE_boolean(
+ 'dup2', 0, 'runhelp d21', short_name='u', allow_override=1)
+ flag = FLAGS._flags()['dup2']
+ self.assertEqual(flag.default, 0)
+
+ flags.DEFINE_boolean(
+ 'dup2', 1, 'runhelp d22', short_name='u', allow_override=0)
+ flag = FLAGS._flags()['dup2']
+ self.assertEqual(flag.default, 1)
+ except flags.DuplicateFlagError:
+ raise AssertionError('allow_override did not permit a flag duplication')
+
+ # Make sure that re-importing a module does not cause a DuplicateFlagError
+ # to be raised.
+ try:
+ sys.modules.pop('absl.flags.tests.module_baz')
+ import absl.flags.tests.module_baz
+ del absl
+ except flags.DuplicateFlagError:
+ raise AssertionError('Module reimport caused flag duplication error')
+
+ # Make sure that when we override, the help string gets updated correctly
+ flags.DEFINE_boolean(
+ 'dup3', 0, 'runhelp d31', short_name='u', allow_override=1)
+ flags.DEFINE_boolean(
+ 'dup3', 1, 'runhelp d32', short_name='u', allow_override=1)
+ self.assertEqual(str(FLAGS).find('runhelp d31'), -1)
+ self.assertNotEqual(str(FLAGS).find('runhelp d32'), -1)
+
+ # Make sure append_flag_values works
+ new_flags = flags.FlagValues()
+ flags.DEFINE_boolean('new1', 0, 'runhelp n1', flag_values=new_flags)
+ flags.DEFINE_boolean('new2', 0, 'runhelp n2', flag_values=new_flags)
+ self.assertEqual(len(new_flags._flags()), 2)
+ old_len = len(FLAGS._flags())
+ FLAGS.append_flag_values(new_flags)
+ self.assertEqual(len(FLAGS._flags()) - old_len, 2)
+ self.assertEqual('new1' in FLAGS._flags(), True)
+ self.assertEqual('new2' in FLAGS._flags(), True)
+
+ # Then test that removing those flags works
+ FLAGS.remove_flag_values(new_flags)
+ self.assertEqual(len(FLAGS._flags()), old_len)
+ self.assertFalse('new1' in FLAGS._flags())
+ self.assertFalse('new2' in FLAGS._flags())
+
+ # Make sure append_flag_values works with flags with shortnames.
+ new_flags = flags.FlagValues()
+ flags.DEFINE_boolean('new3', 0, 'runhelp n3', flag_values=new_flags)
+ flags.DEFINE_boolean(
+ 'new4', 0, 'runhelp n4', flag_values=new_flags, short_name='n4')
+ self.assertEqual(len(new_flags._flags()), 3)
+ old_len = len(FLAGS._flags())
+ FLAGS.append_flag_values(new_flags)
+ self.assertEqual(len(FLAGS._flags()) - old_len, 3)
+ self.assertIn('new3', FLAGS._flags())
+ self.assertIn('new4', FLAGS._flags())
+ self.assertIn('n4', FLAGS._flags())
+ self.assertEqual(FLAGS._flags()['n4'], FLAGS._flags()['new4'])
+
+ # Then test removing them
+ FLAGS.remove_flag_values(new_flags)
+ self.assertEqual(len(FLAGS._flags()), old_len)
+ self.assertFalse('new3' in FLAGS._flags())
+ self.assertFalse('new4' in FLAGS._flags())
+ self.assertFalse('n4' in FLAGS._flags())
+
+ # Make sure append_flag_values fails on duplicates
+ flags.DEFINE_boolean('dup4', 0, 'runhelp d41')
+ new_flags = flags.FlagValues()
+ flags.DEFINE_boolean('dup4', 0, 'runhelp d42', flag_values=new_flags)
+ with self.assertRaises(flags.DuplicateFlagError):
+ FLAGS.append_flag_values(new_flags)
+
+ # Integer out of bounds
+ with self.assertRaises(flags.IllegalFlagValueError):
+ argv = ('./program', '--repeat=-4')
+ FLAGS(argv)
+
+ # Non-integer
+ with self.assertRaises(flags.IllegalFlagValueError):
+ argv = ('./program', '--repeat=2.5')
+ FLAGS(argv)
+
+ # Missing required argument
+ with self.assertRaises(flags.Error):
+ argv = ('./program', '--name')
+ FLAGS(argv)
+
+ # Non-boolean arguments for boolean
+ with self.assertRaises(flags.IllegalFlagValueError):
+ argv = ('./program', '--debug=goofup')
+ FLAGS(argv)
+
+ with self.assertRaises(flags.IllegalFlagValueError):
+ argv = ('./program', '--debug=42')
+ FLAGS(argv)
+
+ # Non-numeric argument for integer flag --repeat
+ with self.assertRaises(flags.IllegalFlagValueError):
+ argv = ('./program', '--repeat', 'Bob', 'extra')
+ FLAGS(argv)
+
+ # Aliases of existing flags
+ with self.assertRaises(flags.UnrecognizedFlagError):
+ flags.DEFINE_alias('alias_not_a_flag', 'not_a_flag')
+
+ # Programmtically modify alias and aliased flag
+ flags.DEFINE_alias('alias_octal', 'octal')
+ FLAGS.octal = 0o2222
+ self.assertEqual(0o2222, FLAGS.octal)
+ self.assertEqual(0o2222, FLAGS.alias_octal)
+ FLAGS.alias_octal = 0o4444
+ self.assertEqual(0o4444, FLAGS.octal)
+ self.assertEqual(0o4444, FLAGS.alias_octal)
+
+ # Setting alias preserves the default of the original
+ flags.DEFINE_alias('alias_name', 'name')
+ flags.DEFINE_alias('alias_debug', 'debug')
+ flags.DEFINE_alias('alias_decimal', 'decimal')
+ flags.DEFINE_alias('alias_float', 'float')
+ flags.DEFINE_alias('alias_letters', 'letters')
+ self.assertEqual(FLAGS['name'].default, FLAGS.alias_name)
+ self.assertEqual(FLAGS['debug'].default, FLAGS.alias_debug)
+ self.assertEqual(int(FLAGS['decimal'].default), FLAGS.alias_decimal)
+ self.assertEqual(float(FLAGS['float'].default), FLAGS.alias_float)
+ self.assertSameElements(FLAGS['letters'].default, FLAGS.alias_letters)
+
+ # Original flags set on command line
+ argv = ('./program', '--name=Martin', '--debug=True', '--decimal=777',
+ '--letters=x,y,z')
+ FLAGS(argv)
+ self.assertEqual('Martin', FLAGS.name)
+ self.assertEqual('Martin', FLAGS.alias_name)
+ self.assertTrue(FLAGS.debug)
+ self.assertTrue(FLAGS.alias_debug)
+ self.assertEqual(777, FLAGS.decimal)
+ self.assertEqual(777, FLAGS.alias_decimal)
+ self.assertSameElements(['x', 'y', 'z'], FLAGS.letters)
+ self.assertSameElements(['x', 'y', 'z'], FLAGS.alias_letters)
+
+ # Alias flags set on command line
+ argv = ('./program', '--alias_name=Auston', '--alias_debug=False',
+ '--alias_decimal=888', '--alias_letters=l,m,n')
+ FLAGS(argv)
+ self.assertEqual('Auston', FLAGS.name)
+ self.assertEqual('Auston', FLAGS.alias_name)
+ self.assertFalse(FLAGS.debug)
+ self.assertFalse(FLAGS.alias_debug)
+ self.assertEqual(888, FLAGS.decimal)
+ self.assertEqual(888, FLAGS.alias_decimal)
+ self.assertSameElements(['l', 'm', 'n'], FLAGS.letters)
+ self.assertSameElements(['l', 'm', 'n'], FLAGS.alias_letters)
+
+ # Make sure importing a module does not change flag value parsed
+ # from commandline.
+ flags.DEFINE_integer(
+ 'dup5', 1, 'runhelp d51', short_name='u5', allow_override=0)
+ self.assertEqual(FLAGS.dup5, 1)
+ self.assertEqual(FLAGS.dup5, 1)
+ argv = ('./program', '--dup5=3')
+ FLAGS(argv)
+ self.assertEqual(FLAGS.dup5, 3)
+ flags.DEFINE_integer(
+ 'dup5', 2, 'runhelp d52', short_name='u5', allow_override=1)
+ self.assertEqual(FLAGS.dup5, 3)
+
+ # Make sure importing a module does not change user defined flag value.
+ flags.DEFINE_integer(
+ 'dup6', 1, 'runhelp d61', short_name='u6', allow_override=0)
+ self.assertEqual(FLAGS.dup6, 1)
+ FLAGS.dup6 = 3
+ self.assertEqual(FLAGS.dup6, 3)
+ flags.DEFINE_integer(
+ 'dup6', 2, 'runhelp d62', short_name='u6', allow_override=1)
+ self.assertEqual(FLAGS.dup6, 3)
+
+ # Make sure importing a module does not change user defined flag value
+ # even if it is the 'default' value.
+ flags.DEFINE_integer(
+ 'dup7', 1, 'runhelp d71', short_name='u7', allow_override=0)
+ self.assertEqual(FLAGS.dup7, 1)
+ FLAGS.dup7 = 1
+ self.assertEqual(FLAGS.dup7, 1)
+ flags.DEFINE_integer(
+ 'dup7', 2, 'runhelp d72', short_name='u7', allow_override=1)
+ self.assertEqual(FLAGS.dup7, 1)
+
+ # Test module_help().
+ helpstr = FLAGS.module_help(module_baz)
+
+ expected_help = '\n' + module_baz.__name__ + ':' + """
+ --[no]tmod_baz_x: Boolean flag.
+ (default: 'true')"""
+
+ self.assertMultiLineEqual(expected_help, helpstr)
+
+ # Test main_module_help(). This must be part of test_flags because
+ # it depends on dup1/2/3/etc being introduced first.
+ helpstr = FLAGS.main_module_help()
+
+ expected_help = '\n' + sys.argv[0] + ':' + """
+ --[no]alias_debug: Alias for --debug.
+ (default: 'false')
+ --alias_decimal: Alias for --decimal.
+ (default: '666')
+ (an integer)
+ --alias_float: Alias for --float.
+ (default: '3.14')
+ (a number)
+ --alias_letters: Alias for --letters.
+ (default: 'a,b,c')
+ (a comma separated list)
+ --alias_name: Alias for --name.
+ (default: 'Bob')
+ --alias_octal: Alias for --octal.
+ (default: '438')
+ (an integer)
+ --args: a list of arguments
+ (default: 'v=1,"vmodule=a=0,b=2"')
+ (a comma separated list)
+ --blah: <bla|Blah|BLAH|blah>: ?
+ --cases: <UPPER|lower|Initial|Ot_HeR>: ?
+ --[no]debug: debughelp
+ (default: 'false')
+ --decimal: using decimals
+ (default: '666')
+ (an integer)
+ -u,--[no]dup1: runhelp d12
+ (default: 'true')
+ -u,--[no]dup2: runhelp d22
+ (default: 'true')
+ -u,--[no]dup3: runhelp d32
+ (default: 'true')
+ --[no]dup4: runhelp d41
+ (default: 'false')
+ -u5,--dup5: runhelp d51
+ (default: '1')
+ (an integer)
+ -u6,--dup6: runhelp d61
+ (default: '1')
+ (an integer)
+ -u7,--dup7: runhelp d71
+ (default: '1')
+ (an integer)
+ --float: using floats
+ (default: '3.14')
+ (a number)
+ --funny: <Joke|ha|ha|ha|ha>: ?
+ --hexadecimal: using hexadecimals
+ (default: '1638')
+ (an integer)
+ --kwery: <who|what|Why|where|when>: ?
+ --l: how long to be
+ (default: '9223372032559808512')
+ (an integer)
+ --letters: a list of letters
+ (default: 'a,b,c')
+ (a comma separated list)
+ -m,--m_str: string option that can occur multiple times;
+ repeat this option to specify a list of values
+ (default: "['def1', 'def2']")
+ --name: namehelp
+ (default: 'Bob')
+ --[no]noexec: boolean flag with no as prefix
+ (default: 'true')
+ --numbers: a list of numbers
+ (default: '1,2,3')
+ (a comma separated list)
+ --octal: using octals
+ (default: '438')
+ (an integer)
+ --only_once: test only sets this once
+ --[no]q: quiet mode
+ (default: 'true')
+ --[no]quack: superstring of 'q'
+ (default: 'false')
+ -r,--repeat: how many times to repeat (0-5)
+ (default: '4')
+ (a non-negative integer)
+ -s,--s_str: string option that can occur multiple times;
+ repeat this option to specify a list of values
+ (default: "['sing1']")
+ --sense: <Case|case|CASE>: ?
+ --[no]test0: test boolean parsing
+ --[no]test1: test boolean parsing
+ --testcomma_list: test comma list parsing
+ (default: '')
+ (a comma separated list)
+ --[no]testget1: test parsing with defaults
+ --[no]testget2: test parsing with defaults
+ --[no]testget3: test parsing with defaults
+ --testget4: test parsing with defaults
+ (an integer)
+ --[no]testnone: test boolean parsing
+ --testspace_list: tests space list parsing
+ (default: '')
+ (a whitespace separated list)
+ --testspace_or_comma_list: tests space list parsing with comma compatibility
+ (default: '')
+ (a whitespace or comma separated list)
+ --universe: test tries to set this three times
+ --x: how eXtreme to be
+ (default: '3')
+ (an integer)
+ -z,--[no]zoom1: runhelp z1
+ (default: 'false')"""
+
+ self.assertMultiLineEqual(expected_help, helpstr)
+
+ def test_string_flag_with_wrong_type(self):
+ fv = flags.FlagValues()
+ with self.assertRaises(flags.IllegalFlagValueError):
+ flags.DEFINE_string('name', False, 'help', flag_values=fv)
+ with self.assertRaises(flags.IllegalFlagValueError):
+ flags.DEFINE_string('name2', 0, 'help', flag_values=fv)
+
+ def test_integer_flag_with_wrong_type(self):
+ fv = flags.FlagValues()
+ with self.assertRaises(flags.IllegalFlagValueError):
+ flags.DEFINE_integer('name', 1e2, 'help', flag_values=fv)
+ with self.assertRaises(flags.IllegalFlagValueError):
+ flags.DEFINE_integer('name', [], 'help', flag_values=fv)
+ with self.assertRaises(flags.IllegalFlagValueError):
+ flags.DEFINE_integer('name', False, 'help', flag_values=fv)
+
+ def test_float_flag_with_wrong_type(self):
+ fv = flags.FlagValues()
+ with self.assertRaises(flags.IllegalFlagValueError):
+ flags.DEFINE_float('name', False, 'help', flag_values=fv)
+
+ def test_enum_flag_with_empty_values(self):
+ fv = flags.FlagValues()
+ with self.assertRaises(ValueError):
+ flags.DEFINE_enum('fruit', None, [], 'help', flag_values=fv)
+
+ def test_define_enum_class_flag(self):
+ fv = flags.FlagValues()
+ flags.DEFINE_enum_class('fruit', None, Fruit, '?', flag_values=fv)
+ fv.mark_as_parsed()
+
+ self.assertIsNone(fv.fruit)
+
+ def test_parse_enum_class_flag(self):
+ fv = flags.FlagValues()
+ flags.DEFINE_enum_class('fruit', None, Fruit, '?', flag_values=fv)
+
+ argv = ('./program', '--fruit=orange')
+ argv = fv(argv)
+ self.assertEqual(len(argv), 1, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+ self.assertEqual(fv['fruit'].present, 1)
+ self.assertEqual(fv['fruit'].value, Fruit.ORANGE)
+ fv.unparse_flags()
+ argv = ('./program', '--fruit=APPLE')
+ argv = fv(argv)
+ self.assertEqual(len(argv), 1, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+ self.assertEqual(fv['fruit'].present, 1)
+ self.assertEqual(fv['fruit'].value, Fruit.APPLE)
+ fv.unparse_flags()
+
+ def test_enum_class_flag_help_message(self):
+ fv = flags.FlagValues()
+ flags.DEFINE_enum_class('fruit', None, Fruit, '?', flag_values=fv)
+
+ helpstr = fv.main_module_help()
+ expected_help = '\n%s:\n --fruit: <apple|orange>: ?' % sys.argv[0]
+
+ self.assertEqual(helpstr, expected_help)
+
+ def test_enum_class_flag_with_wrong_default_value_type(self):
+ fv = flags.FlagValues()
+ with self.assertRaises(_exceptions.IllegalFlagValueError):
+ flags.DEFINE_enum_class('fruit', 1, Fruit, 'help', flag_values=fv)
+
+ def test_enum_class_flag_requires_enum_class(self):
+ fv = flags.FlagValues()
+ with self.assertRaises(TypeError):
+ flags.DEFINE_enum_class(
+ 'fruit', None, ['apple', 'orange'], 'help', flag_values=fv)
+
+ def test_enum_class_flag_requires_non_empty_enum_class(self):
+ fv = flags.FlagValues()
+ with self.assertRaises(ValueError):
+ flags.DEFINE_enum_class('empty', None, EmptyEnum, 'help', flag_values=fv)
+
+ def test_required_flag(self):
+ fv = flags.FlagValues()
+ fl = flags.DEFINE_integer(
+ name='int_flag',
+ default=None,
+ help='help',
+ required=True,
+ flag_values=fv)
+ # Since the flag is required, the FlagHolder should ensure value returned
+ # is not None.
+ self.assertTrue(fl._ensure_non_none_value)
+
+ def test_illegal_required_flag(self):
+ fv = flags.FlagValues()
+ with self.assertRaises(ValueError):
+ flags.DEFINE_integer(
+ name='int_flag',
+ default=3,
+ help='help',
+ required=True,
+ flag_values=fv)
+
+
+class MultiNumericalFlagsTest(absltest.TestCase):
+
+ def test_multi_numerical_flags(self):
+ """Test multi_int and multi_float flags."""
+ fv = flags.FlagValues()
+ int_defaults = [77, 88]
+ flags.DEFINE_multi_integer(
+ 'm_int',
+ int_defaults,
+ 'integer option that can occur multiple times',
+ short_name='mi',
+ flag_values=fv)
+ self.assertListEqual(fv['m_int'].default, int_defaults)
+ argv = ('./program', '--m_int=-99', '--mi=101')
+ fv(argv)
+ self.assertListEqual(fv.get_flag_value('m_int', None), [-99, 101])
+
+ float_defaults = [2.2, 3]
+ flags.DEFINE_multi_float(
+ 'm_float',
+ float_defaults,
+ 'float option that can occur multiple times',
+ short_name='mf',
+ flag_values=fv)
+ for (expected, actual) in zip(float_defaults,
+ fv.get_flag_value('m_float', None)):
+ self.assertAlmostEqual(expected, actual)
+ argv = ('./program', '--m_float=-17', '--mf=2.78e9')
+ fv(argv)
+ expected_floats = [-17.0, 2.78e9]
+ for (expected, actual) in zip(expected_floats,
+ fv.get_flag_value('m_float', None)):
+ self.assertAlmostEqual(expected, actual)
+
+ def test_multi_numerical_with_tuples(self):
+ """Verify multi_int/float accept tuples as default values."""
+ flags.DEFINE_multi_integer(
+ 'm_int_tuple', (77, 88),
+ 'integer option that can occur multiple times',
+ short_name='mi_tuple')
+ self.assertListEqual(FLAGS.get_flag_value('m_int_tuple', None), [77, 88])
+
+ dict_with_float_keys = {2.2: 'hello', 3: 'happy'}
+ float_defaults = dict_with_float_keys.keys()
+ flags.DEFINE_multi_float(
+ 'm_float_tuple',
+ float_defaults,
+ 'float option that can occur multiple times',
+ short_name='mf_tuple')
+ for (expected, actual) in zip(float_defaults,
+ FLAGS.get_flag_value('m_float_tuple', None)):
+ self.assertAlmostEqual(expected, actual)
+
+ def test_single_value_default(self):
+ """Test multi_int and multi_float flags with a single default value."""
+ int_default = 77
+ flags.DEFINE_multi_integer('m_int1', int_default,
+ 'integer option that can occur multiple times')
+ self.assertListEqual(FLAGS.get_flag_value('m_int1', None), [int_default])
+
+ float_default = 2.2
+ flags.DEFINE_multi_float('m_float1', float_default,
+ 'float option that can occur multiple times')
+ actual = FLAGS.get_flag_value('m_float1', None)
+ self.assertEqual(1, len(actual))
+ self.assertAlmostEqual(actual[0], float_default)
+
+ def test_bad_multi_numerical_flags(self):
+ """Test multi_int and multi_float flags with non-parseable values."""
+
+ # Test non-parseable defaults.
+ self.assertRaisesRegex(
+ flags.IllegalFlagValueError,
+ r"flag --m_int2=abc: invalid literal for int\(\) with base 10: 'abc'",
+ flags.DEFINE_multi_integer, 'm_int2', ['abc'], 'desc')
+
+ self.assertRaisesRegex(
+ flags.IllegalFlagValueError, r'flag --m_float2=abc: '
+ r'(invalid literal for float\(\)|could not convert string to float): '
+ r"'?abc'?", flags.DEFINE_multi_float, 'm_float2', ['abc'], 'desc')
+
+ # Test non-parseable command line values.
+ fv = flags.FlagValues()
+ flags.DEFINE_multi_integer(
+ 'm_int2',
+ '77',
+ 'integer option that can occur multiple times',
+ flag_values=fv)
+ argv = ('./program', '--m_int2=def')
+ self.assertRaisesRegex(
+ flags.IllegalFlagValueError,
+ r"flag --m_int2=def: invalid literal for int\(\) with base 10: 'def'",
+ fv, argv)
+
+ flags.DEFINE_multi_float(
+ 'm_float2',
+ 2.2,
+ 'float option that can occur multiple times',
+ flag_values=fv)
+ argv = ('./program', '--m_float2=def')
+ self.assertRaisesRegex(
+ flags.IllegalFlagValueError, r'flag --m_float2=def: '
+ r'(invalid literal for float\(\)|could not convert string to float): '
+ r"'?def'?", fv, argv)
+
+
+class MultiEnumFlagsTest(absltest.TestCase):
+
+ def test_multi_enum_flags(self):
+ """Test multi_enum flags."""
+ fv = flags.FlagValues()
+
+ enum_defaults = ['FOO', 'BAZ']
+ flags.DEFINE_multi_enum(
+ 'm_enum',
+ enum_defaults, ['FOO', 'BAR', 'BAZ', 'WHOOSH'],
+ 'Enum option that can occur multiple times',
+ short_name='me',
+ flag_values=fv)
+ self.assertListEqual(fv['m_enum'].default, enum_defaults)
+ argv = ('./program', '--m_enum=WHOOSH', '--me=FOO')
+ fv(argv)
+ self.assertListEqual(fv.get_flag_value('m_enum', None), ['WHOOSH', 'FOO'])
+
+ def test_help_text(self):
+ """Test multi_enum flag's help text."""
+ fv = flags.FlagValues()
+
+ flags.DEFINE_multi_enum(
+ 'm_enum',
+ None, ['FOO', 'BAR'],
+ 'Enum option that can occur multiple times',
+ flag_values=fv)
+ self.assertRegex(
+ fv['m_enum'].help,
+ r'<FOO\|BAR>: Enum option that can occur multiple times;\s+'
+ 'repeat this option to specify a list of values')
+
+ def test_single_value_default(self):
+ """Test multi_enum flags with a single default value."""
+ fv = flags.FlagValues()
+ enum_default = 'FOO'
+ flags.DEFINE_multi_enum(
+ 'm_enum1',
+ enum_default, ['FOO', 'BAR', 'BAZ', 'WHOOSH'],
+ 'enum option that can occur multiple times',
+ flag_values=fv)
+ self.assertListEqual(fv['m_enum1'].default, [enum_default])
+
+ def test_case_sensitivity(self):
+ """Test case sensitivity of multi_enum flag."""
+ fv = flags.FlagValues()
+ # Test case insensitive enum.
+ flags.DEFINE_multi_enum(
+ 'm_enum2', ['whoosh'], ['FOO', 'BAR', 'BAZ', 'WHOOSH'],
+ 'Enum option that can occur multiple times',
+ short_name='me2',
+ case_sensitive=False,
+ flag_values=fv)
+ argv = ('./program', '--m_enum2=bar', '--me2=fOo')
+ fv(argv)
+ self.assertListEqual(fv.get_flag_value('m_enum2', None), ['BAR', 'FOO'])
+
+ # Test case sensitive enum.
+ flags.DEFINE_multi_enum(
+ 'm_enum3', ['BAR'], ['FOO', 'BAR', 'BAZ', 'WHOOSH'],
+ 'Enum option that can occur multiple times',
+ short_name='me3',
+ case_sensitive=True,
+ flag_values=fv)
+ argv = ('./program', '--m_enum3=bar', '--me3=fOo')
+ self.assertRaisesRegex(
+ flags.IllegalFlagValueError,
+ r'flag --m_enum3=invalid: value should be one of <FOO|BAR|BAZ|WHOOSH>',
+ fv, argv)
+
+ def test_bad_multi_enum_flags(self):
+ """Test multi_enum with invalid values."""
+
+ # Test defaults that are not in the permitted list of enums.
+ self.assertRaisesRegex(
+ flags.IllegalFlagValueError,
+ r'flag --m_enum=INVALID: value should be one of <FOO|BAR|BAZ>',
+ flags.DEFINE_multi_enum, 'm_enum', ['INVALID'], ['FOO', 'BAR', 'BAZ'],
+ 'desc')
+
+ self.assertRaisesRegex(
+ flags.IllegalFlagValueError,
+ r'flag --m_enum=1234: value should be one of <FOO|BAR|BAZ>',
+ flags.DEFINE_multi_enum, 'm_enum2', [1234], ['FOO', 'BAR', 'BAZ'],
+ 'desc')
+
+ # Test command-line values that are not in the permitted list of enums.
+ flags.DEFINE_multi_enum('m_enum4', 'FOO', ['FOO', 'BAR', 'BAZ'],
+ 'enum option that can occur multiple times')
+ argv = ('./program', '--m_enum4=INVALID')
+ self.assertRaisesRegex(
+ flags.IllegalFlagValueError,
+ r'flag --m_enum4=invalid: value should be one of <FOO|BAR|BAZ>', FLAGS,
+ argv)
+
+
+class MultiEnumClassFlagsTest(absltest.TestCase):
+
+ def test_define_results_in_registered_flag_with_none(self):
+ fv = flags.FlagValues()
+ enum_defaults = None
+ flags.DEFINE_multi_enum_class(
+ 'fruit',
+ enum_defaults,
+ Fruit,
+ 'Enum option that can occur multiple times',
+ flag_values=fv)
+ fv.mark_as_parsed()
+
+ self.assertIsNone(fv.fruit)
+
+ def test_help_text(self):
+ fv = flags.FlagValues()
+ enum_defaults = None
+ flags.DEFINE_multi_enum_class(
+ 'fruit',
+ enum_defaults,
+ Fruit,
+ 'Enum option that can occur multiple times',
+ flag_values=fv)
+
+ self.assertRegex(
+ fv['fruit'].help,
+ r'<apple\|orange>: Enum option that can occur multiple times;\s+'
+ 'repeat this option to specify a list of values')
+
+ def test_define_results_in_registered_flag_with_string(self):
+ fv = flags.FlagValues()
+ enum_defaults = 'apple'
+ flags.DEFINE_multi_enum_class(
+ 'fruit',
+ enum_defaults,
+ Fruit,
+ 'Enum option that can occur multiple times',
+ flag_values=fv)
+ fv.mark_as_parsed()
+
+ self.assertListEqual(fv.fruit, [Fruit.APPLE])
+
+ def test_define_results_in_registered_flag_with_enum(self):
+ fv = flags.FlagValues()
+ enum_defaults = Fruit.APPLE
+ flags.DEFINE_multi_enum_class(
+ 'fruit',
+ enum_defaults,
+ Fruit,
+ 'Enum option that can occur multiple times',
+ flag_values=fv)
+ fv.mark_as_parsed()
+
+ self.assertListEqual(fv.fruit, [Fruit.APPLE])
+
+ def test_define_results_in_registered_flag_with_string_list(self):
+ fv = flags.FlagValues()
+ enum_defaults = ['apple', 'APPLE']
+ flags.DEFINE_multi_enum_class(
+ 'fruit',
+ enum_defaults,
+ CaseSensitiveFruit,
+ 'Enum option that can occur multiple times',
+ flag_values=fv,
+ case_sensitive=True)
+ fv.mark_as_parsed()
+
+ self.assertListEqual(fv.fruit,
+ [CaseSensitiveFruit.apple, CaseSensitiveFruit.APPLE])
+
+ def test_define_results_in_registered_flag_with_enum_list(self):
+ fv = flags.FlagValues()
+ enum_defaults = [Fruit.APPLE, Fruit.ORANGE]
+ flags.DEFINE_multi_enum_class(
+ 'fruit',
+ enum_defaults,
+ Fruit,
+ 'Enum option that can occur multiple times',
+ flag_values=fv)
+ fv.mark_as_parsed()
+
+ self.assertListEqual(fv.fruit, [Fruit.APPLE, Fruit.ORANGE])
+
+ def test_from_command_line_returns_multiple(self):
+ fv = flags.FlagValues()
+ enum_defaults = [Fruit.APPLE]
+ flags.DEFINE_multi_enum_class(
+ 'fruit',
+ enum_defaults,
+ Fruit,
+ 'Enum option that can occur multiple times',
+ flag_values=fv)
+ argv = ('./program', '--fruit=Apple', '--fruit=orange')
+ fv(argv)
+ self.assertListEqual(fv.fruit, [Fruit.APPLE, Fruit.ORANGE])
+
+ def test_bad_multi_enum_class_flags_from_definition(self):
+ with self.assertRaisesRegex(
+ flags.IllegalFlagValueError,
+ 'flag --fruit=INVALID: value should be one of <apple|orange|APPLE>'):
+ flags.DEFINE_multi_enum_class('fruit', ['INVALID'], Fruit, 'desc')
+
+ def test_bad_multi_enum_class_flags_from_commandline(self):
+ fv = flags.FlagValues()
+ enum_defaults = [Fruit.APPLE]
+ flags.DEFINE_multi_enum_class(
+ 'fruit', enum_defaults, Fruit, 'desc', flag_values=fv)
+ argv = ('./program', '--fruit=INVALID')
+ with self.assertRaisesRegex(
+ flags.IllegalFlagValueError,
+ 'flag --fruit=INVALID: value should be one of <apple|orange|APPLE>'):
+ fv(argv)
+
+
+class UnicodeFlagsTest(absltest.TestCase):
+ """Testing proper unicode support for flags."""
+
+ def test_unicode_default_and_helpstring(self):
+ fv = flags.FlagValues()
+ flags.DEFINE_string(
+ 'unicode_str',
+ b'\xC3\x80\xC3\xBD'.decode('utf-8'),
+ b'help:\xC3\xAA'.decode('utf-8'),
+ flag_values=fv)
+ argv = ('./program',)
+ fv(argv) # should not raise any exceptions
+
+ argv = ('./program', '--unicode_str=foo')
+ fv(argv) # should not raise any exceptions
+
+ def test_unicode_in_list(self):
+ fv = flags.FlagValues()
+ flags.DEFINE_list(
+ 'unicode_list',
+ ['abc', b'\xC3\x80'.decode('utf-8'), b'\xC3\xBD'.decode('utf-8')],
+ b'help:\xC3\xAB'.decode('utf-8'),
+ flag_values=fv)
+ argv = ('./program',)
+ fv(argv) # should not raise any exceptions
+
+ argv = ('./program', '--unicode_list=hello,there')
+ fv(argv) # should not raise any exceptions
+
+ def test_xmloutput(self):
+ fv = flags.FlagValues()
+ flags.DEFINE_string(
+ 'unicode1',
+ b'\xC3\x80\xC3\xBD'.decode('utf-8'),
+ b'help:\xC3\xAC'.decode('utf-8'),
+ flag_values=fv)
+ flags.DEFINE_list(
+ 'unicode2',
+ ['abc', b'\xC3\x80'.decode('utf-8'), b'\xC3\xBD'.decode('utf-8')],
+ b'help:\xC3\xAD'.decode('utf-8'),
+ flag_values=fv)
+ flags.DEFINE_list(
+ 'non_unicode', ['abc', 'def', 'ghi'],
+ b'help:\xC3\xAD'.decode('utf-8'),
+ flag_values=fv)
+
+ outfile = io.StringIO()
+ fv.write_help_in_xml_format(outfile)
+ actual_output = outfile.getvalue()
+
+ # The xml output is large, so we just check parts of it.
+ self.assertIn(
+ b'<name>unicode1</name>\n'
+ b' <meaning>help:\xc3\xac</meaning>\n'
+ b' <default>\xc3\x80\xc3\xbd</default>\n'
+ b' <current>\xc3\x80\xc3\xbd</current>'.decode('utf-8'),
+ actual_output)
+ self.assertIn(
+ b'<name>unicode2</name>\n'
+ b' <meaning>help:\xc3\xad</meaning>\n'
+ b' <default>abc,\xc3\x80,\xc3\xbd</default>\n'
+ b" <current>['abc', '\xc3\x80', '\xc3\xbd']"
+ b'</current>'.decode('utf-8'), actual_output)
+ self.assertIn(
+ b'<name>non_unicode</name>\n'
+ b' <meaning>help:\xc3\xad</meaning>\n'
+ b' <default>abc,def,ghi</default>\n'
+ b" <current>['abc', 'def', 'ghi']"
+ b'</current>'.decode('utf-8'), actual_output)
+
+
+class LoadFromFlagFileTest(absltest.TestCase):
+ """Testing loading flags from a file and parsing them."""
+
+ def setUp(self):
+ self.flag_values = flags.FlagValues()
+ flags.DEFINE_string(
+ 'unittest_message1',
+ 'Foo!',
+ 'You Add Here.',
+ flag_values=self.flag_values)
+ flags.DEFINE_string(
+ 'unittest_message2',
+ 'Bar!',
+ 'Hello, Sailor!',
+ flag_values=self.flag_values)
+ flags.DEFINE_boolean(
+ 'unittest_boolflag',
+ 0,
+ 'Some Boolean thing',
+ flag_values=self.flag_values)
+ flags.DEFINE_integer(
+ 'unittest_number',
+ 12345,
+ 'Some integer',
+ lower_bound=0,
+ flag_values=self.flag_values)
+ flags.DEFINE_list(
+ 'UnitTestList', '1,2,3', 'Some list', flag_values=self.flag_values)
+ self.tmp_path = None
+ self.flag_values.mark_as_parsed()
+
+ def tearDown(self):
+ self._remove_test_files()
+
+ def _setup_test_files(self):
+ """Creates and sets up some dummy flagfile files with bogus flags."""
+
+ # Figure out where to create temporary files
+ self.assertFalse(self.tmp_path)
+ self.tmp_path = tempfile.mkdtemp()
+
+ tmp_flag_file_1 = open(self.tmp_path + '/UnitTestFile1.tst', 'w')
+ tmp_flag_file_2 = open(self.tmp_path + '/UnitTestFile2.tst', 'w')
+ tmp_flag_file_3 = open(self.tmp_path + '/UnitTestFile3.tst', 'w')
+ tmp_flag_file_4 = open(self.tmp_path + '/UnitTestFile4.tst', 'w')
+
+ # put some dummy flags in our test files
+ tmp_flag_file_1.write('#A Fake Comment\n')
+ tmp_flag_file_1.write('--unittest_message1=tempFile1!\n')
+ tmp_flag_file_1.write('\n')
+ tmp_flag_file_1.write('--unittest_number=54321\n')
+ tmp_flag_file_1.write('--nounittest_boolflag\n')
+ file_list = [tmp_flag_file_1.name]
+ # this one includes test file 1
+ tmp_flag_file_2.write('//A Different Fake Comment\n')
+ tmp_flag_file_2.write('--flagfile=%s\n' % tmp_flag_file_1.name)
+ tmp_flag_file_2.write('--unittest_message2=setFromTempFile2\n')
+ tmp_flag_file_2.write('\t\t\n')
+ tmp_flag_file_2.write('--unittest_number=6789a\n')
+ file_list.append(tmp_flag_file_2.name)
+ # this file points to itself
+ tmp_flag_file_3.write('--flagfile=%s\n' % tmp_flag_file_3.name)
+ tmp_flag_file_3.write('--unittest_message1=setFromTempFile3\n')
+ tmp_flag_file_3.write('#YAFC\n')
+ tmp_flag_file_3.write('--unittest_boolflag\n')
+ file_list.append(tmp_flag_file_3.name)
+ # this file is unreadable
+ tmp_flag_file_4.write('--flagfile=%s\n' % tmp_flag_file_3.name)
+ tmp_flag_file_4.write('--unittest_message1=setFromTempFile4\n')
+ tmp_flag_file_4.write('--unittest_message1=setFromTempFile4\n')
+ os.chmod(self.tmp_path + '/UnitTestFile4.tst', 0)
+ file_list.append(tmp_flag_file_4.name)
+
+ tmp_flag_file_1.close()
+ tmp_flag_file_2.close()
+ tmp_flag_file_3.close()
+ tmp_flag_file_4.close()
+
+ return file_list # these are just the file names
+
+ def _remove_test_files(self):
+ """Removes the files we just created."""
+ if self.tmp_path:
+ shutil.rmtree(self.tmp_path, ignore_errors=True)
+ self.tmp_path = None
+
+ def _read_flags_from_files(self, argv, force_gnu):
+ return argv[:1] + self.flag_values.read_flags_from_files(
+ argv[1:], force_gnu=force_gnu)
+
+ #### Flagfile Unit Tests ####
+ def test_method_flagfiles_1(self):
+ """Test trivial case with no flagfile based options."""
+ fake_cmd_line = 'fooScript --unittest_boolflag'
+ fake_argv = fake_cmd_line.split(' ')
+ self.flag_values(fake_argv)
+ self.assertEqual(self.flag_values.unittest_boolflag, 1)
+ self.assertListEqual(fake_argv,
+ self._read_flags_from_files(fake_argv, False))
+
+ def test_method_flagfiles_2(self):
+ """Tests parsing one file + arguments off simulated argv."""
+ tmp_files = self._setup_test_files()
+ # specify our temp file on the fake cmd line
+ fake_cmd_line = 'fooScript --q --flagfile=%s' % tmp_files[0]
+ fake_argv = fake_cmd_line.split(' ')
+
+ # We should see the original cmd line with the file's contents spliced in.
+ # Flags from the file will appear in the order order they are specified
+ # in the file, in the same position as the flagfile argument.
+ expected_results = [
+ 'fooScript', '--q', '--unittest_message1=tempFile1!',
+ '--unittest_number=54321', '--nounittest_boolflag'
+ ]
+ test_results = self._read_flags_from_files(fake_argv, False)
+ self.assertListEqual(expected_results, test_results)
+
+ # end testTwo def
+
+ def test_method_flagfiles_3(self):
+ """Tests parsing nested files + arguments of simulated argv."""
+ tmp_files = self._setup_test_files()
+ # specify our temp file on the fake cmd line
+ fake_cmd_line = ('fooScript --unittest_number=77 --flagfile=%s' %
+ tmp_files[1])
+ fake_argv = fake_cmd_line.split(' ')
+
+ expected_results = [
+ 'fooScript', '--unittest_number=77', '--unittest_message1=tempFile1!',
+ '--unittest_number=54321', '--nounittest_boolflag',
+ '--unittest_message2=setFromTempFile2', '--unittest_number=6789a'
+ ]
+ test_results = self._read_flags_from_files(fake_argv, False)
+ self.assertListEqual(expected_results, test_results)
+
+ # end testThree def
+
+ def test_method_flagfiles_3_spaces(self):
+ """Tests parsing nested files + arguments of simulated argv.
+
+ The arguments include a pair that is actually an arg with a value, so it
+ doesn't stop processing.
+ """
+ tmp_files = self._setup_test_files()
+ # specify our temp file on the fake cmd line
+ fake_cmd_line = ('fooScript --unittest_number 77 --flagfile=%s' %
+ tmp_files[1])
+ fake_argv = fake_cmd_line.split(' ')
+
+ expected_results = [
+ 'fooScript', '--unittest_number', '77',
+ '--unittest_message1=tempFile1!', '--unittest_number=54321',
+ '--nounittest_boolflag', '--unittest_message2=setFromTempFile2',
+ '--unittest_number=6789a'
+ ]
+ test_results = self._read_flags_from_files(fake_argv, False)
+ self.assertListEqual(expected_results, test_results)
+
+ def test_method_flagfiles_3_spaces_boolean(self):
+ """Tests parsing nested files + arguments of simulated argv.
+
+ The arguments include a pair that looks like a --x y arg with value, but
+ since the flag is a boolean it's actually not.
+ """
+ tmp_files = self._setup_test_files()
+ # specify our temp file on the fake cmd line
+ fake_cmd_line = ('fooScript --unittest_boolflag 77 --flagfile=%s' %
+ tmp_files[1])
+ fake_argv = fake_cmd_line.split(' ')
+
+ expected_results = [
+ 'fooScript', '--unittest_boolflag', '77',
+ '--flagfile=%s' % tmp_files[1]
+ ]
+ with _use_gnu_getopt(self.flag_values, False):
+ test_results = self._read_flags_from_files(fake_argv, False)
+ self.assertListEqual(expected_results, test_results)
+
+ def test_method_flagfiles_4(self):
+ """Tests parsing self-referential files + arguments of simulated argv.
+
+ This test should print a warning to stderr of some sort.
+ """
+ tmp_files = self._setup_test_files()
+ # specify our temp file on the fake cmd line
+ fake_cmd_line = ('fooScript --flagfile=%s --nounittest_boolflag' %
+ tmp_files[2])
+ fake_argv = fake_cmd_line.split(' ')
+ expected_results = [
+ 'fooScript', '--unittest_message1=setFromTempFile3',
+ '--unittest_boolflag', '--nounittest_boolflag'
+ ]
+
+ test_results = self._read_flags_from_files(fake_argv, False)
+ self.assertListEqual(expected_results, test_results)
+
+ def test_method_flagfiles_5(self):
+ """Test that --flagfile parsing respects the '--' end-of-options marker."""
+ tmp_files = self._setup_test_files()
+ # specify our temp file on the fake cmd line
+ fake_cmd_line = 'fooScript --some_flag -- --flagfile=%s' % tmp_files[0]
+ fake_argv = fake_cmd_line.split(' ')
+ expected_results = [
+ 'fooScript', '--some_flag', '--',
+ '--flagfile=%s' % tmp_files[0]
+ ]
+
+ test_results = self._read_flags_from_files(fake_argv, False)
+ self.assertListEqual(expected_results, test_results)
+
+ def test_method_flagfiles_6(self):
+ """Test that --flagfile parsing stops at non-options (non-GNU behavior)."""
+ tmp_files = self._setup_test_files()
+ # specify our temp file on the fake cmd line
+ fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' %
+ tmp_files[0])
+ fake_argv = fake_cmd_line.split(' ')
+ expected_results = [
+ 'fooScript', '--some_flag', 'some_arg',
+ '--flagfile=%s' % tmp_files[0]
+ ]
+
+ with _use_gnu_getopt(self.flag_values, False):
+ test_results = self._read_flags_from_files(fake_argv, False)
+ self.assertListEqual(expected_results, test_results)
+
+ def test_method_flagfiles_7(self):
+ """Test that --flagfile parsing skips over a non-option (GNU behavior)."""
+ self.flag_values.set_gnu_getopt()
+ tmp_files = self._setup_test_files()
+ # specify our temp file on the fake cmd line
+ fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' %
+ tmp_files[0])
+ fake_argv = fake_cmd_line.split(' ')
+ expected_results = [
+ 'fooScript', '--some_flag', 'some_arg',
+ '--unittest_message1=tempFile1!', '--unittest_number=54321',
+ '--nounittest_boolflag'
+ ]
+
+ test_results = self._read_flags_from_files(fake_argv, False)
+ self.assertListEqual(expected_results, test_results)
+
+ def test_method_flagfiles_8(self):
+ """Test that --flagfile parsing respects force_gnu=True."""
+ tmp_files = self._setup_test_files()
+ # specify our temp file on the fake cmd line
+ fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' %
+ tmp_files[0])
+ fake_argv = fake_cmd_line.split(' ')
+ expected_results = [
+ 'fooScript', '--some_flag', 'some_arg',
+ '--unittest_message1=tempFile1!', '--unittest_number=54321',
+ '--nounittest_boolflag'
+ ]
+
+ test_results = self._read_flags_from_files(fake_argv, True)
+ self.assertListEqual(expected_results, test_results)
+
+ def test_method_flagfiles_repeated_non_circular(self):
+ """Tests that parsing repeated non-circular flagfiles works."""
+ tmp_files = self._setup_test_files()
+ # specify our temp files on the fake cmd line
+ fake_cmd_line = ('fooScript --flagfile=%s --flagfile=%s' %
+ (tmp_files[1], tmp_files[0]))
+ fake_argv = fake_cmd_line.split(' ')
+ expected_results = [
+ 'fooScript', '--unittest_message1=tempFile1!',
+ '--unittest_number=54321', '--nounittest_boolflag',
+ '--unittest_message2=setFromTempFile2', '--unittest_number=6789a',
+ '--unittest_message1=tempFile1!', '--unittest_number=54321',
+ '--nounittest_boolflag'
+ ]
+
+ test_results = self._read_flags_from_files(fake_argv, False)
+ self.assertListEqual(expected_results, test_results)
+
+ @unittest.skipIf(
+ os.name == 'nt',
+ 'There is no good way to create an unreadable file on Windows.')
+ def test_method_flagfiles_no_permissions(self):
+ """Test that --flagfile raises except on file that is unreadable."""
+ tmp_files = self._setup_test_files()
+ # specify our temp file on the fake cmd line
+ fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' %
+ tmp_files[3])
+ fake_argv = fake_cmd_line.split(' ')
+ self.assertRaises(flags.CantOpenFlagFileError, self._read_flags_from_files,
+ fake_argv, True)
+
+ def test_method_flagfiles_not_found(self):
+ """Test that --flagfile raises except on file that does not exist."""
+ tmp_files = self._setup_test_files()
+ # specify our temp file on the fake cmd line
+ fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%sNOTEXIST' %
+ tmp_files[3])
+ fake_argv = fake_cmd_line.split(' ')
+ self.assertRaises(flags.CantOpenFlagFileError, self._read_flags_from_files,
+ fake_argv, True)
+
+ def test_flagfiles_user_path_expansion(self):
+ """Test that user directory referenced paths are correctly expanded.
+
+ Test paths like ~/foo. This test depends on whatever account's running
+ the unit test to have read/write access to their own home directory,
+ otherwise it'll FAIL.
+ """
+ fake_flagfile_item_style_1 = '--flagfile=~/foo.file'
+ fake_flagfile_item_style_2 = '-flagfile=~/foo.file'
+
+ expected_results = os.path.expanduser('~/foo.file')
+
+ test_results = self.flag_values._extract_filename(
+ fake_flagfile_item_style_1)
+ self.assertEqual(expected_results, test_results)
+
+ test_results = self.flag_values._extract_filename(
+ fake_flagfile_item_style_2)
+ self.assertEqual(expected_results, test_results)
+
+ def test_no_touchy_non_flags(self):
+ """Test that the flags parser does not mutilate arguments.
+
+ The arguments are not supposed to be flags
+ """
+ fake_argv = [
+ 'fooScript', '--unittest_boolflag', 'command', '--command_arg1',
+ '--UnitTestBoom', '--UnitTestB'
+ ]
+ with _use_gnu_getopt(self.flag_values, False):
+ argv = self.flag_values(fake_argv)
+ self.assertListEqual(argv, fake_argv[:1] + fake_argv[2:])
+
+ def test_parse_flags_after_args_if_using_gnugetopt(self):
+ """Test that flags given after arguments are parsed if using gnu_getopt."""
+ self.flag_values.set_gnu_getopt()
+ fake_argv = [
+ 'fooScript', '--unittest_boolflag', 'command', '--unittest_number=54321'
+ ]
+ argv = self.flag_values(fake_argv)
+ self.assertListEqual(argv, ['fooScript', 'command'])
+
+ def test_set_default(self):
+ """Test changing flag defaults."""
+ # Test that set_default changes both the default and the value,
+ # and that the value is changed when one is given as an option.
+ self.flag_values.set_default('unittest_message1', 'New value')
+ self.assertEqual(self.flag_values.unittest_message1, 'New value')
+ self.assertEqual(self.flag_values['unittest_message1'].default_as_str,
+ "'New value'")
+ self.flag_values(['dummyscript', '--unittest_message1=Newer value'])
+ self.assertEqual(self.flag_values.unittest_message1, 'Newer value')
+
+ # Test that setting the default to None works correctly.
+ self.flag_values.set_default('unittest_number', None)
+ self.assertEqual(self.flag_values.unittest_number, None)
+ self.assertEqual(self.flag_values['unittest_number'].default_as_str, None)
+ self.flag_values(['dummyscript', '--unittest_number=56'])
+ self.assertEqual(self.flag_values.unittest_number, 56)
+
+ # Test that setting the default to zero works correctly.
+ self.flag_values.set_default('unittest_number', 0)
+ self.assertEqual(self.flag_values['unittest_number'].default, 0)
+ self.assertEqual(self.flag_values.unittest_number, 56)
+ self.assertEqual(self.flag_values['unittest_number'].default_as_str, "'0'")
+ self.flag_values(['dummyscript', '--unittest_number=56'])
+ self.assertEqual(self.flag_values.unittest_number, 56)
+
+ # Test that setting the default to '' works correctly.
+ self.flag_values.set_default('unittest_message1', '')
+ self.assertEqual(self.flag_values['unittest_message1'].default, '')
+ self.assertEqual(self.flag_values.unittest_message1, 'Newer value')
+ self.assertEqual(self.flag_values['unittest_message1'].default_as_str, "''")
+ self.flag_values(['dummyscript', '--unittest_message1=fifty-six'])
+ self.assertEqual(self.flag_values.unittest_message1, 'fifty-six')
+
+ # Test that setting the default to false works correctly.
+ self.flag_values.set_default('unittest_boolflag', False)
+ self.assertEqual(self.flag_values.unittest_boolflag, False)
+ self.assertEqual(self.flag_values['unittest_boolflag'].default_as_str,
+ "'false'")
+ self.flag_values(['dummyscript', '--unittest_boolflag=true'])
+ self.assertEqual(self.flag_values.unittest_boolflag, True)
+
+ # Test that setting a list default works correctly.
+ self.flag_values.set_default('UnitTestList', '4,5,6')
+ self.assertListEqual(self.flag_values.UnitTestList, ['4', '5', '6'])
+ self.assertEqual(self.flag_values['UnitTestList'].default_as_str, "'4,5,6'")
+ self.flag_values(['dummyscript', '--UnitTestList=7,8,9'])
+ self.assertListEqual(self.flag_values.UnitTestList, ['7', '8', '9'])
+
+ # Test that setting invalid defaults raises exceptions
+ with self.assertRaises(flags.IllegalFlagValueError):
+ self.flag_values.set_default('unittest_number', 'oops')
+ with self.assertRaises(flags.IllegalFlagValueError):
+ self.flag_values.set_default('unittest_number', -1)
+
+
+class FlagsParsingTest(absltest.TestCase):
+ """Testing different aspects of parsing: '-f' vs '--flag', etc."""
+
+ def setUp(self):
+ self.flag_values = flags.FlagValues()
+
+ def test_two_dash_arg_first(self):
+ flags.DEFINE_string(
+ 'twodash_name', 'Bob', 'namehelp', flag_values=self.flag_values)
+ flags.DEFINE_string(
+ 'twodash_blame', 'Rob', 'blamehelp', flag_values=self.flag_values)
+ argv = ('./program', '--', '--twodash_name=Harry')
+ argv = self.flag_values(argv)
+ self.assertEqual('Bob', self.flag_values.twodash_name)
+ self.assertEqual(argv[1], '--twodash_name=Harry')
+
+ def test_two_dash_arg_middle(self):
+ flags.DEFINE_string(
+ 'twodash2_name', 'Bob', 'namehelp', flag_values=self.flag_values)
+ flags.DEFINE_string(
+ 'twodash2_blame', 'Rob', 'blamehelp', flag_values=self.flag_values)
+ argv = ('./program', '--twodash2_blame=Larry', '--',
+ '--twodash2_name=Harry')
+ argv = self.flag_values(argv)
+ self.assertEqual('Bob', self.flag_values.twodash2_name)
+ self.assertEqual('Larry', self.flag_values.twodash2_blame)
+ self.assertEqual(argv[1], '--twodash2_name=Harry')
+
+ def test_one_dash_arg_first(self):
+ flags.DEFINE_string(
+ 'onedash_name', 'Bob', 'namehelp', flag_values=self.flag_values)
+ flags.DEFINE_string(
+ 'onedash_blame', 'Rob', 'blamehelp', flag_values=self.flag_values)
+ argv = ('./program', '-', '--onedash_name=Harry')
+ with _use_gnu_getopt(self.flag_values, False):
+ argv = self.flag_values(argv)
+ self.assertEqual(len(argv), 3)
+ self.assertEqual(argv[1], '-')
+ self.assertEqual(argv[2], '--onedash_name=Harry')
+
+ def test_required_flag_not_specified(self):
+ flags.DEFINE_string(
+ 'str_flag',
+ default=None,
+ help='help',
+ required=True,
+ flag_values=self.flag_values)
+ argv = ('./program',)
+ with _use_gnu_getopt(self.flag_values, False):
+ with self.assertRaises(flags.IllegalFlagValueError):
+ self.flag_values(argv)
+
+ def test_required_arg_works_with_other_validators(self):
+ flags.DEFINE_integer(
+ 'int_flag',
+ default=None,
+ help='help',
+ required=True,
+ lower_bound=4,
+ flag_values=self.flag_values)
+ argv = ('./program', '--int_flag=2')
+ with _use_gnu_getopt(self.flag_values, False):
+ with self.assertRaises(flags.IllegalFlagValueError):
+ self.flag_values(argv)
+
+ def test_unrecognized_flags(self):
+ flags.DEFINE_string('name', 'Bob', 'namehelp', flag_values=self.flag_values)
+ # Unknown flag --nosuchflag
+ try:
+ argv = ('./program', '--nosuchflag', '--name=Bob', 'extra')
+ self.flag_values(argv)
+ raise AssertionError('Unknown flag exception not raised')
+ except flags.UnrecognizedFlagError as e:
+ self.assertEqual(e.flagname, 'nosuchflag')
+ self.assertEqual(e.flagvalue, '--nosuchflag')
+
+ # Unknown flag -w (short option)
+ try:
+ argv = ('./program', '-w', '--name=Bob', 'extra')
+ self.flag_values(argv)
+ raise AssertionError('Unknown flag exception not raised')
+ except flags.UnrecognizedFlagError as e:
+ self.assertEqual(e.flagname, 'w')
+ self.assertEqual(e.flagvalue, '-w')
+
+ # Unknown flag --nosuchflagwithparam=foo
+ try:
+ argv = ('./program', '--nosuchflagwithparam=foo', '--name=Bob', 'extra')
+ self.flag_values(argv)
+ raise AssertionError('Unknown flag exception not raised')
+ except flags.UnrecognizedFlagError as e:
+ self.assertEqual(e.flagname, 'nosuchflagwithparam')
+ self.assertEqual(e.flagvalue, '--nosuchflagwithparam=foo')
+
+ # Allow unknown flag --nosuchflag if specified with undefok
+ argv = ('./program', '--nosuchflag', '--name=Bob', '--undefok=nosuchflag',
+ 'extra')
+ argv = self.flag_values(argv)
+ self.assertEqual(len(argv), 2, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+ self.assertEqual(argv[1], 'extra', 'extra argument not preserved')
+
+ # Allow unknown flag --noboolflag if undefok=boolflag is specified
+ argv = ('./program', '--noboolflag', '--name=Bob', '--undefok=boolflag',
+ 'extra')
+ argv = self.flag_values(argv)
+ self.assertEqual(len(argv), 2, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+ self.assertEqual(argv[1], 'extra', 'extra argument not preserved')
+
+ # But not if the flagname is misspelled:
+ try:
+ argv = ('./program', '--nosuchflag', '--name=Bob', '--undefok=nosuchfla',
+ 'extra')
+ self.flag_values(argv)
+ raise AssertionError('Unknown flag exception not raised')
+ except flags.UnrecognizedFlagError as e:
+ self.assertEqual(e.flagname, 'nosuchflag')
+
+ try:
+ argv = ('./program', '--nosuchflag', '--name=Bob',
+ '--undefok=nosuchflagg', 'extra')
+ self.flag_values(argv)
+ raise AssertionError('Unknown flag exception not raised')
+ except flags.UnrecognizedFlagError as e:
+ self.assertEqual(e.flagname, 'nosuchflag')
+
+ # Allow unknown short flag -w if specified with undefok
+ argv = ('./program', '-w', '--name=Bob', '--undefok=w', 'extra')
+ argv = self.flag_values(argv)
+ self.assertEqual(len(argv), 2, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+ self.assertEqual(argv[1], 'extra', 'extra argument not preserved')
+
+ # Allow unknown flag --nosuchflagwithparam=foo if specified
+ # with undefok
+ argv = ('./program', '--nosuchflagwithparam=foo', '--name=Bob',
+ '--undefok=nosuchflagwithparam', 'extra')
+ argv = self.flag_values(argv)
+ self.assertEqual(len(argv), 2, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+ self.assertEqual(argv[1], 'extra', 'extra argument not preserved')
+
+ # Even if undefok specifies multiple flags
+ argv = ('./program', '--nosuchflag', '-w', '--nosuchflagwithparam=foo',
+ '--name=Bob', '--undefok=nosuchflag,w,nosuchflagwithparam', 'extra')
+ argv = self.flag_values(argv)
+ self.assertEqual(len(argv), 2, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+ self.assertEqual(argv[1], 'extra', 'extra argument not preserved')
+
+ # However, not if undefok doesn't specify the flag
+ try:
+ argv = ('./program', '--nosuchflag', '--name=Bob',
+ '--undefok=another_such', 'extra')
+ self.flag_values(argv)
+ raise AssertionError('Unknown flag exception not raised')
+ except flags.UnrecognizedFlagError as e:
+ self.assertEqual(e.flagname, 'nosuchflag')
+
+ # Make sure --undefok doesn't mask other option errors.
+ try:
+ # Provide an option requiring a parameter but not giving it one.
+ argv = ('./program', '--undefok=name', '--name')
+ self.flag_values(argv)
+ raise AssertionError('Missing option parameter exception not raised')
+ except flags.UnrecognizedFlagError:
+ raise AssertionError('Wrong kind of error exception raised')
+ except flags.Error:
+ pass
+
+ # Test --undefok <list>
+ argv = ('./program', '--nosuchflag', '-w', '--nosuchflagwithparam=foo',
+ '--name=Bob', '--undefok', 'nosuchflag,w,nosuchflagwithparam',
+ 'extra')
+ argv = self.flag_values(argv)
+ self.assertEqual(len(argv), 2, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+ self.assertEqual(argv[1], 'extra', 'extra argument not preserved')
+
+ # Test incorrect --undefok with no value.
+ argv = ('./program', '--name=Bob', '--undefok')
+ with self.assertRaises(flags.Error):
+ self.flag_values(argv)
+
+
+class NonGlobalFlagsTest(absltest.TestCase):
+
+ def test_nonglobal_flags(self):
+ """Test use of non-global FlagValues."""
+ nonglobal_flags = flags.FlagValues()
+ flags.DEFINE_string('nonglobal_flag', 'Bob', 'flaghelp', nonglobal_flags)
+ argv = ('./program', '--nonglobal_flag=Mary', 'extra')
+ argv = nonglobal_flags(argv)
+ self.assertEqual(len(argv), 2, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+ self.assertEqual(argv[1], 'extra', 'extra argument not preserved')
+ self.assertEqual(nonglobal_flags['nonglobal_flag'].value, 'Mary')
+
+ def test_unrecognized_nonglobal_flags(self):
+ """Test unrecognized non-global flags."""
+ nonglobal_flags = flags.FlagValues()
+ argv = ('./program', '--nosuchflag')
+ try:
+ argv = nonglobal_flags(argv)
+ raise AssertionError('Unknown flag exception not raised')
+ except flags.UnrecognizedFlagError as e:
+ self.assertEqual(e.flagname, 'nosuchflag')
+
+ argv = ('./program', '--nosuchflag', '--undefok=nosuchflag')
+
+ argv = nonglobal_flags(argv)
+ self.assertEqual(len(argv), 1, 'wrong number of arguments pulled')
+ self.assertEqual(argv[0], './program', 'program name not preserved')
+
+ def test_create_flag_errors(self):
+ # Since the exception classes are exposed, nothing stops users
+ # from creating their own instances. This test makes sure that
+ # people modifying the flags module understand that the external
+ # mechanisms for creating the exceptions should continue to work.
+ _ = flags.Error()
+ _ = flags.Error('message')
+ _ = flags.DuplicateFlagError()
+ _ = flags.DuplicateFlagError('message')
+ _ = flags.IllegalFlagValueError()
+ _ = flags.IllegalFlagValueError('message')
+
+ def test_flag_values_del_attr(self):
+ """Checks that del self.flag_values.flag_id works."""
+ default_value = 'default value for test_flag_values_del_attr'
+ # 1. Declare and delete a flag with no short name.
+ flag_values = flags.FlagValues()
+ flags.DEFINE_string(
+ 'delattr_foo', default_value, 'A simple flag.', flag_values=flag_values)
+
+ flag_values.mark_as_parsed()
+ self.assertEqual(flag_values.delattr_foo, default_value)
+ flag_obj = flag_values['delattr_foo']
+ # We also check that _FlagIsRegistered works as expected :)
+ self.assertTrue(flag_values._flag_is_registered(flag_obj))
+ del flag_values.delattr_foo
+ self.assertFalse('delattr_foo' in flag_values._flags())
+ self.assertFalse(flag_values._flag_is_registered(flag_obj))
+ # If the previous del FLAGS.delattr_foo did not work properly, the
+ # next definition will trigger a redefinition error.
+ flags.DEFINE_integer(
+ 'delattr_foo', 3, 'A simple flag.', flag_values=flag_values)
+ del flag_values.delattr_foo
+
+ self.assertFalse('delattr_foo' in flag_values)
+
+ # 2. Declare and delete a flag with a short name.
+ flags.DEFINE_string(
+ 'delattr_bar',
+ default_value,
+ 'flag with short name',
+ short_name='x5',
+ flag_values=flag_values)
+ flag_obj = flag_values['delattr_bar']
+ self.assertTrue(flag_values._flag_is_registered(flag_obj))
+ del flag_values.x5
+ self.assertTrue(flag_values._flag_is_registered(flag_obj))
+ del flag_values.delattr_bar
+ self.assertFalse(flag_values._flag_is_registered(flag_obj))
+
+ # 3. Just like 2, but del flag_values.name last
+ flags.DEFINE_string(
+ 'delattr_bar',
+ default_value,
+ 'flag with short name',
+ short_name='x5',
+ flag_values=flag_values)
+ flag_obj = flag_values['delattr_bar']
+ self.assertTrue(flag_values._flag_is_registered(flag_obj))
+ del flag_values.delattr_bar
+ self.assertTrue(flag_values._flag_is_registered(flag_obj))
+ del flag_values.x5
+ self.assertFalse(flag_values._flag_is_registered(flag_obj))
+
+ self.assertFalse('delattr_bar' in flag_values)
+ self.assertFalse('x5' in flag_values)
+
+ def test_list_flag_format(self):
+ """Tests for correctly-formatted list flags."""
+ fv = flags.FlagValues()
+ flags.DEFINE_list('listflag', '', 'A list of arguments', flag_values=fv)
+
+ def _check_parsing(listval):
+ """Parse a particular value for our test flag, --listflag."""
+ argv = fv(['./program', '--listflag=' + listval, 'plain-arg'])
+ self.assertEqual(['./program', 'plain-arg'], argv)
+ return fv.listflag
+
+ # Basic success case
+ self.assertEqual(_check_parsing('foo,bar'), ['foo', 'bar'])
+ # Success case: newline in argument is quoted.
+ self.assertEqual(_check_parsing('"foo","bar\nbar"'), ['foo', 'bar\nbar'])
+ # Failure case: newline in argument is unquoted.
+ self.assertRaises(flags.IllegalFlagValueError, _check_parsing,
+ '"foo",bar\nbar')
+ # Failure case: unmatched ".
+ self.assertRaises(flags.IllegalFlagValueError, _check_parsing,
+ '"foo,barbar')
+
+ def test_flag_definition_via_setitem(self):
+ with self.assertRaises(flags.IllegalFlagValueError):
+ flag_values = flags.FlagValues()
+ flag_values['flag_name'] = 'flag_value'
+
+
+class KeyFlagsTest(absltest.TestCase):
+
+ def setUp(self):
+ self.flag_values = flags.FlagValues()
+
+ def _get_names_of_defined_flags(self, module, flag_values):
+ """Returns the list of names of flags defined by a module.
+
+ Auxiliary for the test_key_flags* methods.
+
+ Args:
+ module: A module object or a string module name.
+ flag_values: A FlagValues object.
+
+ Returns:
+ A list of strings.
+ """
+ return [f.name for f in flag_values.get_flags_for_module(module)]
+
+ def _get_names_of_key_flags(self, module, flag_values):
+ """Returns the list of names of key flags for a module.
+
+ Auxiliary for the test_key_flags* methods.
+
+ Args:
+ module: A module object or a string module name.
+ flag_values: A FlagValues object.
+
+ Returns:
+ A list of strings.
+ """
+ return [f.name for f in flag_values.get_key_flags_for_module(module)]
+
+ def _assert_lists_have_same_elements(self, list_1, list_2):
+ # Checks that two lists have the same elements with the same
+ # multiplicity, in possibly different order.
+ list_1 = list(list_1)
+ list_1.sort()
+ list_2 = list(list_2)
+ list_2.sort()
+ self.assertListEqual(list_1, list_2)
+
+ def test_key_flags(self):
+ flag_values = flags.FlagValues()
+ # Before starting any testing, make sure no flags are already
+ # defined for module_foo and module_bar.
+ self.assertListEqual(
+ self._get_names_of_key_flags(module_foo, flag_values), [])
+ self.assertListEqual(
+ self._get_names_of_key_flags(module_bar, flag_values), [])
+ self.assertListEqual(
+ self._get_names_of_defined_flags(module_foo, flag_values), [])
+ self.assertListEqual(
+ self._get_names_of_defined_flags(module_bar, flag_values), [])
+
+ # Defines a few flags in module_foo and module_bar.
+ module_foo.define_flags(flag_values=flag_values)
+
+ try:
+ # Part 1. Check that all flags defined by module_foo are key for
+ # that module, and similarly for module_bar.
+ for module in [module_foo, module_bar]:
+ self._assert_lists_have_same_elements(
+ flag_values.get_flags_for_module(module),
+ flag_values.get_key_flags_for_module(module))
+ # Also check that each module defined the expected flags.
+ self._assert_lists_have_same_elements(
+ self._get_names_of_defined_flags(module, flag_values),
+ module.names_of_defined_flags())
+
+ # Part 2. Check that flags.declare_key_flag works fine.
+ # Declare that some flags from module_bar are key for
+ # module_foo.
+ module_foo.declare_key_flags(flag_values=flag_values)
+
+ # Check that module_foo has the expected list of defined flags.
+ self._assert_lists_have_same_elements(
+ self._get_names_of_defined_flags(module_foo, flag_values),
+ module_foo.names_of_defined_flags())
+
+ # Check that module_foo has the expected list of key flags.
+ self._assert_lists_have_same_elements(
+ self._get_names_of_key_flags(module_foo, flag_values),
+ module_foo.names_of_declared_key_flags())
+
+ # Part 3. Check that flags.adopt_module_key_flags works fine.
+ # Trigger a call to flags.adopt_module_key_flags(module_bar)
+ # inside module_foo. This should declare a few more key
+ # flags in module_foo.
+ module_foo.declare_extra_key_flags(flag_values=flag_values)
+
+ # Check that module_foo has the expected list of key flags.
+ self._assert_lists_have_same_elements(
+ self._get_names_of_key_flags(module_foo, flag_values),
+ module_foo.names_of_declared_key_flags() +
+ module_foo.names_of_declared_extra_key_flags())
+ finally:
+ module_foo.remove_flags(flag_values=flag_values)
+
+ def test_key_flags_with_non_default_flag_values_object(self):
+ # Check that key flags work even when we use a FlagValues object
+ # that is not the default flags.self.flag_values object. Otherwise, this
+ # test is similar to test_key_flags, but it uses only module_bar.
+ # The other test module (module_foo) uses only the default values
+ # for the flag_values keyword arguments. This way, test_key_flags
+ # and this method test both the default FlagValues, the explicitly
+ # specified one, and a mixed usage of the two.
+
+ # A brand-new FlagValues object, to use instead of flags.self.flag_values.
+ fv = flags.FlagValues()
+
+ # Before starting any testing, make sure no flags are already
+ # defined for module_foo and module_bar.
+ self.assertListEqual(self._get_names_of_key_flags(module_bar, fv), [])
+ self.assertListEqual(self._get_names_of_defined_flags(module_bar, fv), [])
+
+ module_bar.define_flags(flag_values=fv)
+
+ # Check that all flags defined by module_bar are key for that
+ # module, and that module_bar defined the expected flags.
+ self._assert_lists_have_same_elements(
+ fv.get_flags_for_module(module_bar),
+ fv.get_key_flags_for_module(module_bar))
+ self._assert_lists_have_same_elements(
+ self._get_names_of_defined_flags(module_bar, fv),
+ module_bar.names_of_defined_flags())
+
+ # Pick two flags from module_bar, declare them as key for the
+ # current (i.e., main) module (via flags.declare_key_flag), and
+ # check that we get the expected effect. The important thing is
+ # that we always use flags_values=fv (instead of the default
+ # self.flag_values).
+ main_module = sys.argv[0]
+ names_of_flags_defined_by_bar = module_bar.names_of_defined_flags()
+ flag_name_0 = names_of_flags_defined_by_bar[0]
+ flag_name_2 = names_of_flags_defined_by_bar[2]
+
+ flags.declare_key_flag(flag_name_0, flag_values=fv)
+ self._assert_lists_have_same_elements(
+ self._get_names_of_key_flags(main_module, fv), [flag_name_0])
+
+ flags.declare_key_flag(flag_name_2, flag_values=fv)
+ self._assert_lists_have_same_elements(
+ self._get_names_of_key_flags(main_module, fv),
+ [flag_name_0, flag_name_2])
+
+ # Try with a special (not user-defined) flag too:
+ flags.declare_key_flag('undefok', flag_values=fv)
+ self._assert_lists_have_same_elements(
+ self._get_names_of_key_flags(main_module, fv),
+ [flag_name_0, flag_name_2, 'undefok'])
+
+ flags.adopt_module_key_flags(module_bar, fv)
+ self._assert_lists_have_same_elements(
+ self._get_names_of_key_flags(main_module, fv),
+ names_of_flags_defined_by_bar + ['undefok'])
+
+ # Adopt key flags from the flags module itself.
+ flags.adopt_module_key_flags(flags, flag_values=fv)
+ self._assert_lists_have_same_elements(
+ self._get_names_of_key_flags(main_module, fv),
+ names_of_flags_defined_by_bar + ['flagfile', 'undefok'])
+
+ def test_main_module_help_with_key_flags(self):
+ # Similar to test_main_module_help, but this time we make sure to
+ # declare some key flags.
+
+ # Safety check that the main module does not declare any flags
+ # at the beginning of this test.
+ expected_help = ''
+ self.assertMultiLineEqual(expected_help,
+ self.flag_values.main_module_help())
+
+ # Define one flag in this main module and some flags in modules
+ # a and b. Also declare one flag from module a and one flag
+ # from module b as key flags for the main module.
+ flags.DEFINE_integer(
+ 'main_module_int_fg',
+ 1,
+ 'Integer flag in the main module.',
+ flag_values=self.flag_values)
+
+ try:
+ main_module_int_fg_help = (
+ ' --main_module_int_fg: Integer flag in the main module.\n'
+ " (default: '1')\n"
+ ' (an integer)')
+
+ expected_help += '\n%s:\n%s' % (sys.argv[0], main_module_int_fg_help)
+ self.assertMultiLineEqual(expected_help,
+ self.flag_values.main_module_help())
+
+ # The following call should be a no-op: any flag declared by a
+ # module is automatically key for that module.
+ flags.declare_key_flag('main_module_int_fg', flag_values=self.flag_values)
+ self.assertMultiLineEqual(expected_help,
+ self.flag_values.main_module_help())
+
+ # The definition of a few flags in an imported module should not
+ # change the main module help.
+ module_foo.define_flags(flag_values=self.flag_values)
+ self.assertMultiLineEqual(expected_help,
+ self.flag_values.main_module_help())
+
+ flags.declare_key_flag('tmod_foo_bool', flag_values=self.flag_values)
+ tmod_foo_bool_help = (
+ ' --[no]tmod_foo_bool: Boolean flag from module foo.\n'
+ " (default: 'true')")
+ expected_help += '\n' + tmod_foo_bool_help
+ self.assertMultiLineEqual(expected_help,
+ self.flag_values.main_module_help())
+
+ flags.declare_key_flag('tmod_bar_z', flag_values=self.flag_values)
+ tmod_bar_z_help = (
+ ' --[no]tmod_bar_z: Another boolean flag from module bar.\n'
+ " (default: 'false')")
+ # Unfortunately, there is some flag sorting inside
+ # main_module_help, so we can't keep incrementally extending
+ # the expected_help string ...
+ expected_help = ('\n%s:\n%s\n%s\n%s' %
+ (sys.argv[0], main_module_int_fg_help, tmod_bar_z_help,
+ tmod_foo_bool_help))
+ self.assertMultiLineEqual(self.flag_values.main_module_help(),
+ expected_help)
+
+ finally:
+ # At the end, delete all the flag information we created.
+ self.flag_values.__delattr__('main_module_int_fg')
+ module_foo.remove_flags(flag_values=self.flag_values)
+
+ def test_adoptmodule_key_flags(self):
+ # Check that adopt_module_key_flags raises an exception when
+ # called with a module name (as opposed to a module object).
+ self.assertRaises(flags.Error, flags.adopt_module_key_flags, 'pyglib.app')
+
+ def test_disclaimkey_flags(self):
+ original_disclaim_module_ids = _helpers.disclaim_module_ids
+ _helpers.disclaim_module_ids = set(_helpers.disclaim_module_ids)
+ try:
+ module_bar.disclaim_key_flags()
+ module_foo.define_bar_flags(flag_values=self.flag_values)
+ module_name = self.flag_values.find_module_defining_flag('tmod_bar_x')
+ self.assertEqual(module_foo.__name__, module_name)
+ finally:
+ _helpers.disclaim_module_ids = original_disclaim_module_ids
+
+
+class FindModuleTest(absltest.TestCase):
+ """Testing methods that find a module that defines a given flag."""
+
+ def test_find_module_defining_flag(self):
+ self.assertEqual(
+ 'default',
+ FLAGS.find_module_defining_flag('__NON_EXISTENT_FLAG__', 'default'))
+ self.assertEqual(module_baz.__name__,
+ FLAGS.find_module_defining_flag('tmod_baz_x'))
+
+ def test_find_module_id_defining_flag(self):
+ self.assertEqual(
+ 'default',
+ FLAGS.find_module_id_defining_flag('__NON_EXISTENT_FLAG__', 'default'))
+ self.assertEqual(
+ id(module_baz), FLAGS.find_module_id_defining_flag('tmod_baz_x'))
+
+ def test_find_module_defining_flag_passing_module_name(self):
+ my_flags = flags.FlagValues()
+ module_name = sys.__name__ # Must use an existing module.
+ flags.DEFINE_boolean(
+ 'flag_name',
+ True,
+ 'Flag with a different module name.',
+ flag_values=my_flags,
+ module_name=module_name)
+ self.assertEqual(module_name,
+ my_flags.find_module_defining_flag('flag_name'))
+
+ def test_find_module_id_defining_flag_passing_module_name(self):
+ my_flags = flags.FlagValues()
+ module_name = sys.__name__ # Must use an existing module.
+ flags.DEFINE_boolean(
+ 'flag_name',
+ True,
+ 'Flag with a different module name.',
+ flag_values=my_flags,
+ module_name=module_name)
+ self.assertEqual(
+ id(sys), my_flags.find_module_id_defining_flag('flag_name'))
+
+
+class FlagsErrorMessagesTest(absltest.TestCase):
+ """Testing special cases for integer and float flags error messages."""
+
+ def setUp(self):
+ self.flag_values = flags.FlagValues()
+
+ def test_integer_error_text(self):
+ # Make sure we get proper error text
+ flags.DEFINE_integer(
+ 'positive',
+ 4,
+ 'non-negative flag',
+ lower_bound=1,
+ flag_values=self.flag_values)
+ flags.DEFINE_integer(
+ 'non_negative',
+ 4,
+ 'positive flag',
+ lower_bound=0,
+ flag_values=self.flag_values)
+ flags.DEFINE_integer(
+ 'negative',
+ -4,
+ 'negative flag',
+ upper_bound=-1,
+ flag_values=self.flag_values)
+ flags.DEFINE_integer(
+ 'non_positive',
+ -4,
+ 'non-positive flag',
+ upper_bound=0,
+ flag_values=self.flag_values)
+ flags.DEFINE_integer(
+ 'greater',
+ 19,
+ 'greater-than flag',
+ lower_bound=4,
+ flag_values=self.flag_values)
+ flags.DEFINE_integer(
+ 'smaller',
+ -19,
+ 'smaller-than flag',
+ upper_bound=4,
+ flag_values=self.flag_values)
+ flags.DEFINE_integer(
+ 'usual',
+ 4,
+ 'usual flag',
+ lower_bound=0,
+ upper_bound=10000,
+ flag_values=self.flag_values)
+ flags.DEFINE_integer(
+ 'another_usual',
+ 0,
+ 'usual flag',
+ lower_bound=-1,
+ upper_bound=1,
+ flag_values=self.flag_values)
+
+ self._check_error_message('positive', -4, 'a positive integer')
+ self._check_error_message('non_negative', -4, 'a non-negative integer')
+ self._check_error_message('negative', 0, 'a negative integer')
+ self._check_error_message('non_positive', 4, 'a non-positive integer')
+ self._check_error_message('usual', -4, 'an integer in the range [0, 10000]')
+ self._check_error_message('another_usual', 4,
+ 'an integer in the range [-1, 1]')
+ self._check_error_message('greater', -5, 'integer >= 4')
+ self._check_error_message('smaller', 5, 'integer <= 4')
+
+ def test_float_error_text(self):
+ flags.DEFINE_float(
+ 'positive',
+ 4,
+ 'non-negative flag',
+ lower_bound=1,
+ flag_values=self.flag_values)
+ flags.DEFINE_float(
+ 'non_negative',
+ 4,
+ 'positive flag',
+ lower_bound=0,
+ flag_values=self.flag_values)
+ flags.DEFINE_float(
+ 'negative',
+ -4,
+ 'negative flag',
+ upper_bound=-1,
+ flag_values=self.flag_values)
+ flags.DEFINE_float(
+ 'non_positive',
+ -4,
+ 'non-positive flag',
+ upper_bound=0,
+ flag_values=self.flag_values)
+ flags.DEFINE_float(
+ 'greater',
+ 19,
+ 'greater-than flag',
+ lower_bound=4,
+ flag_values=self.flag_values)
+ flags.DEFINE_float(
+ 'smaller',
+ -19,
+ 'smaller-than flag',
+ upper_bound=4,
+ flag_values=self.flag_values)
+ flags.DEFINE_float(
+ 'usual',
+ 4,
+ 'usual flag',
+ lower_bound=0,
+ upper_bound=10000,
+ flag_values=self.flag_values)
+ flags.DEFINE_float(
+ 'another_usual',
+ 0,
+ 'usual flag',
+ lower_bound=-1,
+ upper_bound=1,
+ flag_values=self.flag_values)
+
+ self._check_error_message('positive', 0.5, 'number >= 1')
+ self._check_error_message('non_negative', -4.0, 'a non-negative number')
+ self._check_error_message('negative', 0.5, 'number <= -1')
+ self._check_error_message('non_positive', 4.0, 'a non-positive number')
+ self._check_error_message('usual', -4.0, 'a number in the range [0, 10000]')
+ self._check_error_message('another_usual', 4.0,
+ 'a number in the range [-1, 1]')
+ self._check_error_message('smaller', 5.0, 'number <= 4')
+
+ def _check_error_message(self, flag_name, flag_value,
+ expected_message_suffix):
+ """Set a flag to a given value and make sure we get expected message."""
+
+ try:
+ self.flag_values.__setattr__(flag_name, flag_value)
+ raise AssertionError('Bounds exception not raised!')
+ except flags.IllegalFlagValueError as e:
+ expected = ('flag --%(name)s=%(value)s: %(value)s is not %(suffix)s' % {
+ 'name': flag_name,
+ 'value': flag_value,
+ 'suffix': expected_message_suffix
+ })
+ self.assertEqual(str(e), expected)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/flags/tests/flags_unicode_literals_test.py b/absl/flags/tests/flags_unicode_literals_test.py
new file mode 100644
index 0000000..e8ed5bf
--- /dev/null
+++ b/absl/flags/tests/flags_unicode_literals_test.py
@@ -0,0 +1,42 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test the use of flags when from __future__ import unicode_literals is on."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from absl import flags
+from absl.testing import absltest
+
+
+flags.DEFINE_string('seen_in_crittenden', 'alleged mountain lion',
+ 'This tests if unicode input to these functions works.')
+
+
+class FlagsUnicodeLiteralsTest(absltest.TestCase):
+
+ def testUnicodeFlagNameAndValueAreGood(self):
+ alleged_mountain_lion = flags.FLAGS.seen_in_crittenden
+ self.assertTrue(
+ isinstance(alleged_mountain_lion, type(u'')),
+ msg='expected flag value to be a {} not {}'.format(
+ type(u''), type(alleged_mountain_lion)))
+ self.assertEqual(alleged_mountain_lion, u'alleged mountain lion')
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/flags/tests/module_bar.py b/absl/flags/tests/module_bar.py
new file mode 100644
index 0000000..8714d2e
--- /dev/null
+++ b/absl/flags/tests/module_bar.py
@@ -0,0 +1,121 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Auxiliary module for testing flags.py.
+
+The purpose of this module is to define a few flags. We want to make
+sure the unit tests for flags.py involve more than one module.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import flags
+from absl.flags import _helpers
+
+FLAGS = flags.FLAGS
+
+
+def define_flags(flag_values=FLAGS):
+ """Defines some flags.
+
+ Args:
+ flag_values: The FlagValues object we want to register the flags
+ with.
+ """
+ # The 'tmod_bar_' prefix (short for 'test_module_bar') ensures there
+ # is no name clash with the existing flags.
+ flags.DEFINE_boolean('tmod_bar_x', True, 'Boolean flag.',
+ flag_values=flag_values)
+ flags.DEFINE_string('tmod_bar_y', 'default', 'String flag.',
+ flag_values=flag_values)
+ flags.DEFINE_boolean('tmod_bar_z', False,
+ 'Another boolean flag from module bar.',
+ flag_values=flag_values)
+ flags.DEFINE_integer('tmod_bar_t', 4, 'Sample int flag.',
+ flag_values=flag_values)
+ flags.DEFINE_integer('tmod_bar_u', 5, 'Sample int flag.',
+ flag_values=flag_values)
+ flags.DEFINE_integer('tmod_bar_v', 6, 'Sample int flag.',
+ flag_values=flag_values)
+
+
+def remove_one_flag(flag_name, flag_values=FLAGS):
+ """Removes the definition of one flag from flags.FLAGS.
+
+ Note: if the flag is not defined in flags.FLAGS, this function does
+ not do anything (in particular, it does not raise any exception).
+
+ Motivation: We use this function for cleanup *after* a test: if
+ there was a failure during a test and not all flags were declared,
+ we do not want the cleanup code to crash.
+
+ Args:
+ flag_name: A string, the name of the flag to delete.
+ flag_values: The FlagValues object we remove the flag from.
+ """
+ if flag_name in flag_values:
+ flag_values.__delattr__(flag_name)
+
+
+def names_of_defined_flags():
+ """Returns: List of names of the flags declared in this module."""
+ return ['tmod_bar_x',
+ 'tmod_bar_y',
+ 'tmod_bar_z',
+ 'tmod_bar_t',
+ 'tmod_bar_u',
+ 'tmod_bar_v']
+
+
+def remove_flags(flag_values=FLAGS):
+ """Deletes the flag definitions done by the above define_flags().
+
+ Args:
+ flag_values: The FlagValues object we remove the flags from.
+ """
+ for flag_name in names_of_defined_flags():
+ remove_one_flag(flag_name, flag_values=flag_values)
+
+
+def get_module_name():
+ """Uses get_calling_module() to return the name of this module.
+
+ For checking that get_calling_module works as expected.
+
+ Returns:
+ A string, the name of this module.
+ """
+ return _helpers.get_calling_module()
+
+
+def execute_code(code, global_dict):
+ """Executes some code in a given global environment.
+
+ For testing of get_calling_module.
+
+ Args:
+ code: A string, the code to be executed.
+ global_dict: A dictionary, the global environment that code should
+ be executed in.
+ """
+ # Indeed, using exec generates a lint warning. But some user code
+ # actually uses exec, and we have to test for it ...
+ exec(code, global_dict) # pylint: disable=exec-used
+
+
+def disclaim_key_flags():
+ """Disclaims flags declared in this module."""
+ flags.disclaim_key_flags()
diff --git a/absl/flags/tests/module_baz.py b/absl/flags/tests/module_baz.py
new file mode 100644
index 0000000..7199516
--- /dev/null
+++ b/absl/flags/tests/module_baz.py
@@ -0,0 +1,29 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Auxiliary module for testing flags.py.
+
+The purpose of this module is to test the behavior of flags that are defined
+before main() executes.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import flags
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_boolean('tmod_baz_x', True, 'Boolean flag.')
diff --git a/absl/flags/tests/module_foo.py b/absl/flags/tests/module_foo.py
new file mode 100644
index 0000000..a1a2573
--- /dev/null
+++ b/absl/flags/tests/module_foo.py
@@ -0,0 +1,128 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Auxiliary module for testing flags.py.
+
+The purpose of this module is to define a few flags, and declare some
+other flags as being important. We want to make sure the unit tests
+for flags.py involve more than one module.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import flags
+from absl.flags import _helpers
+from absl.flags.tests import module_bar
+
+FLAGS = flags.FLAGS
+
+
+DECLARED_KEY_FLAGS = ['tmod_bar_x', 'tmod_bar_z', 'tmod_bar_t',
+ # Special (not user-defined) flag:
+ 'flagfile']
+
+
+def define_flags(flag_values=FLAGS):
+ """Defines a few flags."""
+ module_bar.define_flags(flag_values=flag_values)
+ # The 'tmod_foo_' prefix (short for 'test_module_foo') ensures that we
+ # have no name clash with existing flags.
+ flags.DEFINE_boolean('tmod_foo_bool', True, 'Boolean flag from module foo.',
+ flag_values=flag_values)
+ flags.DEFINE_string('tmod_foo_str', 'default', 'String flag.',
+ flag_values=flag_values)
+ flags.DEFINE_integer('tmod_foo_int', 3, 'Sample int flag.',
+ flag_values=flag_values)
+
+
+def declare_key_flags(flag_values=FLAGS):
+ """Declares a few key flags."""
+ for flag_name in DECLARED_KEY_FLAGS:
+ flags.declare_key_flag(flag_name, flag_values=flag_values)
+
+
+def declare_extra_key_flags(flag_values=FLAGS):
+ """Declares some extra key flags."""
+ flags.adopt_module_key_flags(module_bar, flag_values=flag_values)
+
+
+def names_of_defined_flags():
+ """Returns: list of names of flags defined by this module."""
+ return ['tmod_foo_bool', 'tmod_foo_str', 'tmod_foo_int']
+
+
+def names_of_declared_key_flags():
+ """Returns: list of names of key flags for this module."""
+ return names_of_defined_flags() + DECLARED_KEY_FLAGS
+
+
+def names_of_declared_extra_key_flags():
+ """Returns the list of names of additional key flags for this module.
+
+ These are the flags that became key for this module only as a result
+ of a call to declare_extra_key_flags() above. I.e., the flags declared
+ by module_bar, that were not already declared as key for this
+ module.
+
+ Returns:
+ The list of names of additional key flags for this module.
+ """
+ names_of_extra_key_flags = list(module_bar.names_of_defined_flags())
+ for flag_name in names_of_declared_key_flags():
+ while flag_name in names_of_extra_key_flags:
+ names_of_extra_key_flags.remove(flag_name)
+ return names_of_extra_key_flags
+
+
+def remove_flags(flag_values=FLAGS):
+ """Deletes the flag definitions done by the above define_flags()."""
+ for flag_name in names_of_defined_flags():
+ module_bar.remove_one_flag(flag_name, flag_values=flag_values)
+ module_bar.remove_flags(flag_values=flag_values)
+
+
+def get_module_name():
+ """Uses get_calling_module() to return the name of this module.
+
+ For checking that _get_calling_module works as expected.
+
+ Returns:
+ A string, the name of this module.
+ """
+ return _helpers.get_calling_module()
+
+
+def duplicate_flags(flagnames=None):
+ """Returns a new FlagValues object with the requested flagnames.
+
+ Used to test DuplicateFlagError detection.
+
+ Args:
+ flagnames: str, A list of flag names to create.
+
+ Returns:
+ A FlagValues object with one boolean flag for each name in flagnames.
+ """
+ flag_values = flags.FlagValues()
+ for name in flagnames:
+ flags.DEFINE_boolean(name, False, 'Flag named %s' % (name,),
+ flag_values=flag_values)
+ return flag_values
+
+
+def define_bar_flags(flag_values=FLAGS):
+ """Defines flags from module_bar."""
+ module_bar.define_flags(flag_values)
diff --git a/absl/logging/BUILD b/absl/logging/BUILD
new file mode 100644
index 0000000..6c0d1bc
--- /dev/null
+++ b/absl/logging/BUILD
@@ -0,0 +1,100 @@
+licenses(["notice"])
+
+py_library(
+ name = "logging",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":converter",
+ "//absl/flags",
+ ],
+)
+
+py_library(
+ name = "converter",
+ srcs = ["converter.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+)
+
+py_test(
+ name = "tests/converter_test",
+ size = "small",
+ srcs = ["tests/converter_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":converter",
+ ":logging",
+ "//absl/testing:absltest",
+ ],
+)
+
+py_test(
+ name = "tests/logging_test",
+ size = "small",
+ srcs = ["tests/logging_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":logging",
+ "//absl/flags",
+ "//absl/testing:absltest",
+ "//absl/testing:flagsaver",
+ "//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "tests/log_before_import_test",
+ srcs = ["tests/log_before_import_test.py"],
+ main = "tests/log_before_import_test.py",
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":logging",
+ "//absl/testing:absltest",
+ ],
+)
+
+py_test(
+ name = "tests/verbosity_flag_test",
+ srcs = ["tests/verbosity_flag_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":logging",
+ "//absl/flags",
+ "//absl/testing:absltest",
+ ],
+)
+
+py_binary(
+ name = "tests/logging_functional_test_helper",
+ testonly = 1,
+ srcs = ["tests/logging_functional_test_helper.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":logging",
+ "//absl:app",
+ "//absl/flags",
+ ],
+)
+
+py_test(
+ name = "tests/logging_functional_test",
+ size = "large",
+ srcs = ["tests/logging_functional_test.py"],
+ data = [":tests/logging_functional_test_helper"],
+ python_version = "PY3",
+ shard_count = 50,
+ srcs_version = "PY3",
+ deps = [
+ ":logging",
+ "//absl/testing:_bazelize_command",
+ "//absl/testing:absltest",
+ "//absl/testing:parameterized",
+ ],
+)
diff --git a/absl/logging/__init__.py b/absl/logging/__init__.py
new file mode 100644
index 0000000..8804490
--- /dev/null
+++ b/absl/logging/__init__.py
@@ -0,0 +1,1234 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Abseil Python logging module implemented on top of standard logging.
+
+Simple usage:
+
+ from absl import logging
+
+ logging.info('Interesting Stuff')
+ logging.info('Interesting Stuff with Arguments: %d', 42)
+
+ logging.set_verbosity(logging.INFO)
+ logging.log(logging.DEBUG, 'This will *not* be printed')
+ logging.set_verbosity(logging.DEBUG)
+ logging.log(logging.DEBUG, 'This will be printed')
+
+ logging.warning('Worrying Stuff')
+ logging.error('Alarming Stuff')
+ logging.fatal('AAAAHHHHH!!!!') # Process exits.
+
+Usage note: Do not pre-format the strings in your program code.
+Instead, let the logging module perform argument interpolation.
+This saves cycles because strings that don't need to be printed
+are never formatted. Note that this module does not attempt to
+interpolate arguments when no arguments are given. In other words
+
+ logging.info('Interesting Stuff: %s')
+
+does not raise an exception because logging.info() has only one
+argument, the message string.
+
+"Lazy" evaluation for debugging:
+
+If you do something like this:
+ logging.debug('Thing: %s', thing.ExpensiveOp())
+then the ExpensiveOp will be evaluated even if nothing
+is printed to the log. To avoid this, use the level_debug() function:
+ if logging.level_debug():
+ logging.debug('Thing: %s', thing.ExpensiveOp())
+
+Per file level logging is supported by logging.vlog() and
+logging.vlog_is_on(). For example:
+
+ if logging.vlog_is_on(2):
+ logging.vlog(2, very_expensive_debug_message())
+
+Notes on Unicode:
+
+The log output is encoded as UTF-8. Don't pass data in other encodings in
+bytes() instances -- instead pass unicode string instances when you need to
+(for both the format string and arguments).
+
+Note on critical and fatal:
+Standard logging module defines fatal as an alias to critical, but it's not
+documented, and it does NOT actually terminate the program.
+This module only defines fatal but not critical, and it DOES terminate the
+program.
+
+The differences in behavior are historical and unfortunate.
+"""
+
+import collections
+from collections import abc
+import getpass
+import io
+import itertools
+import logging
+import os
+import socket
+import struct
+import sys
+import threading
+import time
+import timeit
+import traceback
+import types
+import warnings
+
+from absl import flags
+from absl.logging import converter
+
+try:
+ from typing import NoReturn
+except ImportError:
+ pass
+
+
+FLAGS = flags.FLAGS
+
+
+# Logging levels.
+FATAL = converter.ABSL_FATAL
+ERROR = converter.ABSL_ERROR
+WARNING = converter.ABSL_WARNING
+WARN = converter.ABSL_WARNING # Deprecated name.
+INFO = converter.ABSL_INFO
+DEBUG = converter.ABSL_DEBUG
+
+# Regex to match/parse log line prefixes.
+ABSL_LOGGING_PREFIX_REGEX = (
+ r'^(?P<severity>[IWEF])'
+ r'(?P<month>\d\d)(?P<day>\d\d) '
+ r'(?P<hour>\d\d):(?P<minute>\d\d):(?P<second>\d\d)'
+ r'\.(?P<microsecond>\d\d\d\d\d\d) +'
+ r'(?P<thread_id>-?\d+) '
+ r'(?P<filename>[a-zA-Z<][\w._<>-]+):(?P<line>\d+)')
+
+
+# Mask to convert integer thread ids to unsigned quantities for logging purposes
+_THREAD_ID_MASK = 2 ** (struct.calcsize('L') * 8) - 1
+
+# Extra property set on the LogRecord created by ABSLLogger when its level is
+# CRITICAL/FATAL.
+_ABSL_LOG_FATAL = '_absl_log_fatal'
+# Extra prefix added to the log message when a non-absl logger logs a
+# CRITICAL/FATAL message.
+_CRITICAL_PREFIX = 'CRITICAL - '
+
+# Used by findCaller to skip callers from */logging/__init__.py.
+_LOGGING_FILE_PREFIX = os.path.join('logging', '__init__.')
+
+# The ABSL logger instance, initialized in _initialize().
+_absl_logger = None
+# The ABSL handler instance, initialized in _initialize().
+_absl_handler = None
+
+
+_CPP_NAME_TO_LEVELS = {
+ 'debug': '0', # Abseil C++ has no DEBUG level, mapping it to INFO here.
+ 'info': '0',
+ 'warning': '1',
+ 'warn': '1',
+ 'error': '2',
+ 'fatal': '3'
+}
+
+_CPP_LEVEL_TO_NAMES = {
+ '0': 'info',
+ '1': 'warning',
+ '2': 'error',
+ '3': 'fatal',
+}
+
+
+class _VerbosityFlag(flags.Flag):
+ """Flag class for -v/--verbosity."""
+
+ def __init__(self, *args, **kwargs):
+ super(_VerbosityFlag, self).__init__(
+ flags.IntegerParser(),
+ flags.ArgumentSerializer(),
+ *args, **kwargs)
+
+ @property
+ def value(self):
+ return self._value
+
+ @value.setter
+ def value(self, v):
+ self._value = v
+ self._update_logging_levels()
+
+ def _update_logging_levels(self):
+ """Updates absl logging levels to the current verbosity.
+
+ Visibility: module-private
+ """
+ if not _absl_logger:
+ return
+
+ if self._value <= converter.ABSL_DEBUG:
+ standard_verbosity = converter.absl_to_standard(self._value)
+ else:
+ # --verbosity is set to higher than 1 for vlog.
+ standard_verbosity = logging.DEBUG - (self._value - 1)
+
+ # Also update root level when absl_handler is used.
+ if _absl_handler in logging.root.handlers:
+ # Make absl logger inherit from the root logger. absl logger might have
+ # a non-NOTSET value if logging.set_verbosity() is called at import time.
+ _absl_logger.setLevel(logging.NOTSET)
+ logging.root.setLevel(standard_verbosity)
+ else:
+ _absl_logger.setLevel(standard_verbosity)
+
+
+class _LoggerLevelsFlag(flags.Flag):
+ """Flag class for --logger_levels."""
+
+ def __init__(self, *args, **kwargs):
+ super(_LoggerLevelsFlag, self).__init__(
+ _LoggerLevelsParser(),
+ _LoggerLevelsSerializer(),
+ *args, **kwargs)
+
+ @property
+ def value(self):
+ # For lack of an immutable type, be defensive and return a copy.
+ # Modifications to the dict aren't supported and won't have any affect.
+ # While Py3 could use MappingProxyType, that isn't deepcopy friendly, so
+ # just return a copy.
+ return self._value.copy()
+
+ @value.setter
+ def value(self, v):
+ self._value = {} if v is None else v
+ self._update_logger_levels()
+
+ def _update_logger_levels(self):
+ # Visibility: module-private.
+ # This is called by absl.app.run() during initialization.
+ for name, level in self._value.items():
+ logging.getLogger(name).setLevel(level)
+
+
+class _LoggerLevelsParser(flags.ArgumentParser):
+ """Parser for --logger_levels flag."""
+
+ def parse(self, value):
+ if isinstance(value, abc.Mapping):
+ return value
+
+ pairs = [pair.strip() for pair in value.split(',') if pair.strip()]
+
+ # Preserve the order so that serialization is deterministic.
+ levels = collections.OrderedDict()
+ for name_level in pairs:
+ name, level = name_level.split(':', 1)
+ name = name.strip()
+ level = level.strip()
+ levels[name] = level
+ return levels
+
+
+class _LoggerLevelsSerializer(object):
+ """Serializer for --logger_levels flag."""
+
+ def serialize(self, value):
+ if isinstance(value, str):
+ return value
+ return ','.join(
+ '{}:{}'.format(name, level) for name, level in value.items())
+
+
+class _StderrthresholdFlag(flags.Flag):
+ """Flag class for --stderrthreshold."""
+
+ def __init__(self, *args, **kwargs):
+ super(_StderrthresholdFlag, self).__init__(
+ flags.ArgumentParser(),
+ flags.ArgumentSerializer(),
+ *args, **kwargs)
+
+ @property
+ def value(self):
+ return self._value
+
+ @value.setter
+ def value(self, v):
+ if v in _CPP_LEVEL_TO_NAMES:
+ # --stderrthreshold also accepts numeric strings whose values are
+ # Abseil C++ log levels.
+ cpp_value = int(v)
+ v = _CPP_LEVEL_TO_NAMES[v] # Normalize to strings.
+ elif v.lower() in _CPP_NAME_TO_LEVELS:
+ v = v.lower()
+ if v == 'warn':
+ v = 'warning' # Use 'warning' as the canonical name.
+ cpp_value = int(_CPP_NAME_TO_LEVELS[v])
+ else:
+ raise ValueError(
+ '--stderrthreshold must be one of (case-insensitive) '
+ "'debug', 'info', 'warning', 'error', 'fatal', "
+ "or '0', '1', '2', '3', not '%s'" % v)
+
+ self._value = v
+
+
+flags.DEFINE_boolean('logtostderr',
+ False,
+ 'Should only log to stderr?', allow_override_cpp=True)
+flags.DEFINE_boolean('alsologtostderr',
+ False,
+ 'also log to stderr?', allow_override_cpp=True)
+flags.DEFINE_string('log_dir',
+ os.getenv('TEST_TMPDIR', ''),
+ 'directory to write logfiles into',
+ allow_override_cpp=True)
+flags.DEFINE_flag(_VerbosityFlag(
+ 'verbosity', -1,
+ 'Logging verbosity level. Messages logged at this level or lower will '
+ 'be included. Set to 1 for debug logging. If the flag was not set or '
+ 'supplied, the value will be changed from the default of -1 (warning) to '
+ '0 (info) after flags are parsed.',
+ short_name='v', allow_hide_cpp=True))
+flags.DEFINE_flag(
+ _LoggerLevelsFlag(
+ 'logger_levels', {},
+ 'Specify log level of loggers. The format is a CSV list of '
+ '`name:level`. Where `name` is the logger name used with '
+ '`logging.getLogger()`, and `level` is a level name (INFO, DEBUG, '
+ 'etc). e.g. `myapp.foo:INFO,other.logger:DEBUG`'))
+flags.DEFINE_flag(_StderrthresholdFlag(
+ 'stderrthreshold', 'fatal',
+ 'log messages at this level, or more severe, to stderr in '
+ 'addition to the logfile. Possible values are '
+ "'debug', 'info', 'warning', 'error', and 'fatal'. "
+ 'Obsoletes --alsologtostderr. Using --alsologtostderr '
+ 'cancels the effect of this flag. Please also note that '
+ 'this flag is subject to --verbosity and requires logfile '
+ 'not be stderr.', allow_hide_cpp=True))
+flags.DEFINE_boolean('showprefixforinfo', True,
+ 'If False, do not prepend prefix to info messages '
+ 'when it\'s logged to stderr, '
+ '--verbosity is set to INFO level, '
+ 'and python logging is used.')
+
+
+def get_verbosity():
+ """Returns the logging verbosity."""
+ return FLAGS['verbosity'].value
+
+
+def set_verbosity(v):
+ """Sets the logging verbosity.
+
+ Causes all messages of level <= v to be logged,
+ and all messages of level > v to be silently discarded.
+
+ Args:
+ v: int|str, the verbosity level as an integer or string. Legal string values
+ are those that can be coerced to an integer as well as case-insensitive
+ 'debug', 'info', 'warning', 'error', and 'fatal'.
+ """
+ try:
+ new_level = int(v)
+ except ValueError:
+ new_level = converter.ABSL_NAMES[v.upper()]
+ FLAGS.verbosity = new_level
+
+
+def set_stderrthreshold(s):
+ """Sets the stderr threshold to the value passed in.
+
+ Args:
+ s: str|int, valid strings values are case-insensitive 'debug',
+ 'info', 'warning', 'error', and 'fatal'; valid integer values are
+ logging.DEBUG|INFO|WARNING|ERROR|FATAL.
+
+ Raises:
+ ValueError: Raised when s is an invalid value.
+ """
+ if s in converter.ABSL_LEVELS:
+ FLAGS.stderrthreshold = converter.ABSL_LEVELS[s]
+ elif isinstance(s, str) and s.upper() in converter.ABSL_NAMES:
+ FLAGS.stderrthreshold = s
+ else:
+ raise ValueError(
+ 'set_stderrthreshold only accepts integer absl logging level '
+ 'from -3 to 1, or case-insensitive string values '
+ "'debug', 'info', 'warning', 'error', and 'fatal'. "
+ 'But found "{}" ({}).'.format(s, type(s)))
+
+
+def fatal(msg, *args, **kwargs):
+ # type: (Any, Any, Any) -> NoReturn
+ """Logs a fatal message."""
+ log(FATAL, msg, *args, **kwargs)
+
+
+def error(msg, *args, **kwargs):
+ """Logs an error message."""
+ log(ERROR, msg, *args, **kwargs)
+
+
+def warning(msg, *args, **kwargs):
+ """Logs a warning message."""
+ log(WARNING, msg, *args, **kwargs)
+
+
+def warn(msg, *args, **kwargs):
+ """Deprecated, use 'warning' instead."""
+ warnings.warn("The 'warn' function is deprecated, use 'warning' instead",
+ DeprecationWarning, 2)
+ log(WARNING, msg, *args, **kwargs)
+
+
+def info(msg, *args, **kwargs):
+ """Logs an info message."""
+ log(INFO, msg, *args, **kwargs)
+
+
+def debug(msg, *args, **kwargs):
+ """Logs a debug message."""
+ log(DEBUG, msg, *args, **kwargs)
+
+
+def exception(msg, *args):
+ """Logs an exception, with traceback and message."""
+ error(msg, *args, exc_info=True)
+
+
+# Counter to keep track of number of log entries per token.
+_log_counter_per_token = {}
+
+
+def _get_next_log_count_per_token(token):
+ """Wrapper for _log_counter_per_token. Thread-safe.
+
+ Args:
+ token: The token for which to look up the count.
+
+ Returns:
+ The number of times this function has been called with
+ *token* as an argument (starting at 0).
+ """
+ # Can't use a defaultdict because defaultdict isn't atomic, whereas
+ # setdefault is.
+ return next(_log_counter_per_token.setdefault(token, itertools.count()))
+
+
+def log_every_n(level, msg, n, *args):
+ """Logs 'msg % args' at level 'level' once per 'n' times.
+
+ Logs the 1st call, (N+1)st call, (2N+1)st call, etc.
+ Not threadsafe.
+
+ Args:
+ level: int, the absl logging level at which to log.
+ msg: str, the message to be logged.
+ n: int, the number of times this should be called before it is logged.
+ *args: The args to be substituted into the msg.
+ """
+ count = _get_next_log_count_per_token(get_absl_logger().findCaller())
+ log_if(level, msg, not (count % n), *args)
+
+
+# Keeps track of the last log time of the given token.
+# Note: must be a dict since set/get is atomic in CPython.
+# Note: entries are never released as their number is expected to be low.
+_log_timer_per_token = {}
+
+
+def _seconds_have_elapsed(token, num_seconds):
+ """Tests if 'num_seconds' have passed since 'token' was requested.
+
+ Not strictly thread-safe - may log with the wrong frequency if called
+ concurrently from multiple threads. Accuracy depends on resolution of
+ 'timeit.default_timer()'.
+
+ Always returns True on the first call for a given 'token'.
+
+ Args:
+ token: The token for which to look up the count.
+ num_seconds: The number of seconds to test for.
+
+ Returns:
+ Whether it has been >= 'num_seconds' since 'token' was last requested.
+ """
+ now = timeit.default_timer()
+ then = _log_timer_per_token.get(token, None)
+ if then is None or (now - then) >= num_seconds:
+ _log_timer_per_token[token] = now
+ return True
+ else:
+ return False
+
+
+def log_every_n_seconds(level, msg, n_seconds, *args):
+ """Logs 'msg % args' at level 'level' iff 'n_seconds' elapsed since last call.
+
+ Logs the first call, logs subsequent calls if 'n' seconds have elapsed since
+ the last logging call from the same call site (file + line). Not thread-safe.
+
+ Args:
+ level: int, the absl logging level at which to log.
+ msg: str, the message to be logged.
+ n_seconds: float or int, seconds which should elapse before logging again.
+ *args: The args to be substituted into the msg.
+ """
+ should_log = _seconds_have_elapsed(get_absl_logger().findCaller(), n_seconds)
+ log_if(level, msg, should_log, *args)
+
+
+def log_first_n(level, msg, n, *args):
+ """Logs 'msg % args' at level 'level' only first 'n' times.
+
+ Not threadsafe.
+
+ Args:
+ level: int, the absl logging level at which to log.
+ msg: str, the message to be logged.
+ n: int, the maximal number of times the message is logged.
+ *args: The args to be substituted into the msg.
+ """
+ count = _get_next_log_count_per_token(get_absl_logger().findCaller())
+ log_if(level, msg, count < n, *args)
+
+
+def log_if(level, msg, condition, *args):
+ """Logs 'msg % args' at level 'level' only if condition is fulfilled."""
+ if condition:
+ log(level, msg, *args)
+
+
+def log(level, msg, *args, **kwargs):
+ """Logs 'msg % args' at absl logging level 'level'.
+
+ If no args are given just print msg, ignoring any interpolation specifiers.
+
+ Args:
+ level: int, the absl logging level at which to log the message
+ (logging.DEBUG|INFO|WARNING|ERROR|FATAL). While some C++ verbose logging
+ level constants are also supported, callers should prefer explicit
+ logging.vlog() calls for such purpose.
+
+ msg: str, the message to be logged.
+ *args: The args to be substituted into the msg.
+ **kwargs: May contain exc_info to add exception traceback to message.
+ """
+ if level > converter.ABSL_DEBUG:
+ # Even though this function supports level that is greater than 1, users
+ # should use logging.vlog instead for such cases.
+ # Treat this as vlog, 1 is equivalent to DEBUG.
+ standard_level = converter.STANDARD_DEBUG - (level - 1)
+ else:
+ if level < converter.ABSL_FATAL:
+ level = converter.ABSL_FATAL
+ standard_level = converter.absl_to_standard(level)
+
+ # Match standard logging's behavior. Before use_absl_handler() and
+ # logging is configured, there is no handler attached on _absl_logger nor
+ # logging.root. So logs go no where.
+ if not logging.root.handlers:
+ logging.basicConfig()
+
+ _absl_logger.log(standard_level, msg, *args, **kwargs)
+
+
+def vlog(level, msg, *args, **kwargs):
+ """Log 'msg % args' at C++ vlog level 'level'.
+
+ Args:
+ level: int, the C++ verbose logging level at which to log the message,
+ e.g. 1, 2, 3, 4... While absl level constants are also supported,
+ callers should prefer logging.log|debug|info|... calls for such purpose.
+ msg: str, the message to be logged.
+ *args: The args to be substituted into the msg.
+ **kwargs: May contain exc_info to add exception traceback to message.
+ """
+ log(level, msg, *args, **kwargs)
+
+
+def vlog_is_on(level):
+ """Checks if vlog is enabled for the given level in caller's source file.
+
+ Args:
+ level: int, the C++ verbose logging level at which to log the message,
+ e.g. 1, 2, 3, 4... While absl level constants are also supported,
+ callers should prefer level_debug|level_info|... calls for
+ checking those.
+
+ Returns:
+ True if logging is turned on for that level.
+ """
+
+ if level > converter.ABSL_DEBUG:
+ # Even though this function supports level that is greater than 1, users
+ # should use logging.vlog instead for such cases.
+ # Treat this as vlog, 1 is equivalent to DEBUG.
+ standard_level = converter.STANDARD_DEBUG - (level - 1)
+ else:
+ if level < converter.ABSL_FATAL:
+ level = converter.ABSL_FATAL
+ standard_level = converter.absl_to_standard(level)
+ return _absl_logger.isEnabledFor(standard_level)
+
+
+def flush():
+ """Flushes all log files."""
+ get_absl_handler().flush()
+
+
+def level_debug():
+ """Returns True if debug logging is turned on."""
+ return get_verbosity() >= DEBUG
+
+
+def level_info():
+ """Returns True if info logging is turned on."""
+ return get_verbosity() >= INFO
+
+
+def level_warning():
+ """Returns True if warning logging is turned on."""
+ return get_verbosity() >= WARNING
+
+
+level_warn = level_warning # Deprecated function.
+
+
+def level_error():
+ """Returns True if error logging is turned on."""
+ return get_verbosity() >= ERROR
+
+
+def get_log_file_name(level=INFO):
+ """Returns the name of the log file.
+
+ For Python logging, only one file is used and level is ignored. And it returns
+ empty string if it logs to stderr/stdout or the log stream has no `name`
+ attribute.
+
+ Args:
+ level: int, the absl.logging level.
+
+ Raises:
+ ValueError: Raised when `level` has an invalid value.
+ """
+ if level not in converter.ABSL_LEVELS:
+ raise ValueError('Invalid absl.logging level {}'.format(level))
+ stream = get_absl_handler().python_handler.stream
+ if (stream == sys.stderr or stream == sys.stdout or
+ not hasattr(stream, 'name')):
+ return ''
+ else:
+ return stream.name
+
+
+def find_log_dir_and_names(program_name=None, log_dir=None):
+ """Computes the directory and filename prefix for log file.
+
+ Args:
+ program_name: str|None, the filename part of the path to the program that
+ is running without its extension. e.g: if your program is called
+ 'usr/bin/foobar.py' this method should probably be called with
+ program_name='foobar' However, this is just a convention, you can
+ pass in any string you want, and it will be used as part of the
+ log filename. If you don't pass in anything, the default behavior
+ is as described in the example. In python standard logging mode,
+ the program_name will be prepended with py_ if it is the program_name
+ argument is omitted.
+ log_dir: str|None, the desired log directory.
+
+ Returns:
+ (log_dir, file_prefix, symlink_prefix)
+
+ Raises:
+ FileNotFoundError: raised in Python 3 when it cannot find a log directory.
+ OSError: raised in Python 2 when it cannot find a log directory.
+ """
+ if not program_name:
+ # Strip the extension (foobar.par becomes foobar, and
+ # fubar.py becomes fubar). We do this so that the log
+ # file names are similar to C++ log file names.
+ program_name = os.path.splitext(os.path.basename(sys.argv[0]))[0]
+
+ # Prepend py_ to files so that python code gets a unique file, and
+ # so that C++ libraries do not try to write to the same log files as us.
+ program_name = 'py_%s' % program_name
+
+ actual_log_dir = find_log_dir(log_dir=log_dir)
+
+ try:
+ username = getpass.getuser()
+ except KeyError:
+ # This can happen, e.g. when running under docker w/o passwd file.
+ if hasattr(os, 'getuid'):
+ # Windows doesn't have os.getuid
+ username = str(os.getuid())
+ else:
+ username = 'unknown'
+ hostname = socket.gethostname()
+ file_prefix = '%s.%s.%s.log' % (program_name, hostname, username)
+
+ return actual_log_dir, file_prefix, program_name
+
+
+def find_log_dir(log_dir=None):
+ """Returns the most suitable directory to put log files into.
+
+ Args:
+ log_dir: str|None, if specified, the logfile(s) will be created in that
+ directory. Otherwise if the --log_dir command-line flag is provided,
+ the logfile will be created in that directory. Otherwise the logfile
+ will be created in a standard location.
+
+ Raises:
+ FileNotFoundError: raised in Python 3 when it cannot find a log directory.
+ OSError: raised in Python 2 when it cannot find a log directory.
+ """
+ # Get a list of possible log dirs (will try to use them in order).
+ if log_dir:
+ # log_dir was explicitly specified as an arg, so use it and it alone.
+ dirs = [log_dir]
+ elif FLAGS['log_dir'].value:
+ # log_dir flag was provided, so use it and it alone (this mimics the
+ # behavior of the same flag in logging.cc).
+ dirs = [FLAGS['log_dir'].value]
+ else:
+ dirs = ['/tmp/', './']
+
+ # Find the first usable log dir.
+ for d in dirs:
+ if os.path.isdir(d) and os.access(d, os.W_OK):
+ return d
+ raise FileNotFoundError(
+ "Can't find a writable directory for logs, tried %s" % dirs)
+
+
+def get_absl_log_prefix(record):
+ """Returns the absl log prefix for the log record.
+
+ Args:
+ record: logging.LogRecord, the record to get prefix for.
+ """
+ created_tuple = time.localtime(record.created)
+ created_microsecond = int(record.created % 1.0 * 1e6)
+
+ critical_prefix = ''
+ level = record.levelno
+ if _is_non_absl_fatal_record(record):
+ # When the level is FATAL, but not logged from absl, lower the level so
+ # it's treated as ERROR.
+ level = logging.ERROR
+ critical_prefix = _CRITICAL_PREFIX
+ severity = converter.get_initial_for_level(level)
+
+ return '%c%02d%02d %02d:%02d:%02d.%06d %5d %s:%d] %s' % (
+ severity,
+ created_tuple.tm_mon,
+ created_tuple.tm_mday,
+ created_tuple.tm_hour,
+ created_tuple.tm_min,
+ created_tuple.tm_sec,
+ created_microsecond,
+ _get_thread_id(),
+ record.filename,
+ record.lineno,
+ critical_prefix)
+
+
+def skip_log_prefix(func):
+ """Skips reporting the prefix of a given function or name by ABSLLogger.
+
+ This is a convenience wrapper function / decorator for
+ `ABSLLogger.register_frame_to_skip`.
+
+ If a callable function is provided, only that function will be skipped.
+ If a function name is provided, all functions with the same name in the
+ file that this is called in will be skipped.
+
+ This can be used as a decorator of the intended function to be skipped.
+
+ Args:
+ func: Callable function or its name as a string.
+
+ Returns:
+ func (the input, unchanged).
+
+ Raises:
+ ValueError: The input is callable but does not have a function code object.
+ TypeError: The input is neither callable nor a string.
+ """
+ if callable(func):
+ func_code = getattr(func, '__code__', None)
+ if func_code is None:
+ raise ValueError('Input callable does not have a function code object.')
+ file_name = func_code.co_filename
+ func_name = func_code.co_name
+ func_lineno = func_code.co_firstlineno
+ elif isinstance(func, str):
+ file_name = get_absl_logger().findCaller()[0]
+ func_name = func
+ func_lineno = None
+ else:
+ raise TypeError('Input is neither callable nor a string.')
+ ABSLLogger.register_frame_to_skip(file_name, func_name, func_lineno)
+ return func
+
+
+def _is_non_absl_fatal_record(log_record):
+ return (log_record.levelno >= logging.FATAL and
+ not log_record.__dict__.get(_ABSL_LOG_FATAL, False))
+
+
+def _is_absl_fatal_record(log_record):
+ return (log_record.levelno >= logging.FATAL and
+ log_record.__dict__.get(_ABSL_LOG_FATAL, False))
+
+
+# Indicates if we still need to warn about pre-init logs going to stderr.
+_warn_preinit_stderr = True
+
+
+class PythonHandler(logging.StreamHandler):
+ """The handler class used by Abseil Python logging implementation."""
+
+ def __init__(self, stream=None, formatter=None):
+ super(PythonHandler, self).__init__(stream)
+ self.setFormatter(formatter or PythonFormatter())
+
+ def start_logging_to_file(self, program_name=None, log_dir=None):
+ """Starts logging messages to files instead of standard error."""
+ FLAGS.logtostderr = False
+
+ actual_log_dir, file_prefix, symlink_prefix = find_log_dir_and_names(
+ program_name=program_name, log_dir=log_dir)
+
+ basename = '%s.INFO.%s.%d' % (
+ file_prefix,
+ time.strftime('%Y%m%d-%H%M%S', time.localtime(time.time())),
+ os.getpid())
+ filename = os.path.join(actual_log_dir, basename)
+
+ self.stream = open(filename, 'a', encoding='utf-8')
+
+ # os.symlink is not available on Windows Python 2.
+ if getattr(os, 'symlink', None):
+ # Create a symlink to the log file with a canonical name.
+ symlink = os.path.join(actual_log_dir, symlink_prefix + '.INFO')
+ try:
+ if os.path.islink(symlink):
+ os.unlink(symlink)
+ os.symlink(os.path.basename(filename), symlink)
+ except EnvironmentError:
+ # If it fails, we're sad but it's no error. Commonly, this
+ # fails because the symlink was created by another user and so
+ # we can't modify it
+ pass
+
+ def use_absl_log_file(self, program_name=None, log_dir=None):
+ """Conditionally logs to files, based on --logtostderr."""
+ if FLAGS['logtostderr'].value:
+ self.stream = sys.stderr
+ else:
+ self.start_logging_to_file(program_name=program_name, log_dir=log_dir)
+
+ def flush(self):
+ """Flushes all log files."""
+ self.acquire()
+ try:
+ self.stream.flush()
+ except (EnvironmentError, ValueError):
+ # A ValueError is thrown if we try to flush a closed file.
+ pass
+ finally:
+ self.release()
+
+ def _log_to_stderr(self, record):
+ """Emits the record to stderr.
+
+ This temporarily sets the handler stream to stderr, calls
+ StreamHandler.emit, then reverts the stream back.
+
+ Args:
+ record: logging.LogRecord, the record to log.
+ """
+ # emit() is protected by a lock in logging.Handler, so we don't need to
+ # protect here again.
+ old_stream = self.stream
+ self.stream = sys.stderr
+ try:
+ super(PythonHandler, self).emit(record)
+ finally:
+ self.stream = old_stream
+
+ def emit(self, record):
+ """Prints a record out to some streams.
+
+ If FLAGS.logtostderr is set, it will print to sys.stderr ONLY.
+ If FLAGS.alsologtostderr is set, it will print to sys.stderr.
+ If FLAGS.logtostderr is not set, it will log to the stream
+ associated with the current thread.
+
+ Args:
+ record: logging.LogRecord, the record to emit.
+ """
+ # People occasionally call logging functions at import time before
+ # our flags may have even been defined yet, let alone even parsed, as we
+ # rely on the C++ side to define some flags for us and app init to
+ # deal with parsing. Match the C++ library behavior of notify and emit
+ # such messages to stderr. It encourages people to clean-up and does
+ # not hide the message.
+ level = record.levelno
+ if not FLAGS.is_parsed(): # Also implies "before flag has been defined".
+ global _warn_preinit_stderr
+ if _warn_preinit_stderr:
+ sys.stderr.write(
+ 'WARNING: Logging before flag parsing goes to stderr.\n')
+ _warn_preinit_stderr = False
+ self._log_to_stderr(record)
+ elif FLAGS['logtostderr'].value:
+ self._log_to_stderr(record)
+ else:
+ super(PythonHandler, self).emit(record)
+ stderr_threshold = converter.string_to_standard(
+ FLAGS['stderrthreshold'].value)
+ if ((FLAGS['alsologtostderr'].value or level >= stderr_threshold) and
+ self.stream != sys.stderr):
+ self._log_to_stderr(record)
+ # Die when the record is created from ABSLLogger and level is FATAL.
+ if _is_absl_fatal_record(record):
+ self.flush() # Flush the log before dying.
+
+ # In threaded python, sys.exit() from a non-main thread only
+ # exits the thread in question.
+ os.abort()
+
+ def close(self):
+ """Closes the stream to which we are writing."""
+ self.acquire()
+ try:
+ self.flush()
+ try:
+ # Do not close the stream if it's sys.stderr|stdout. They may be
+ # redirected or overridden to files, which should be managed by users
+ # explicitly.
+ user_managed = sys.stderr, sys.stdout, sys.__stderr__, sys.__stdout__
+ if self.stream not in user_managed and (
+ not hasattr(self.stream, 'isatty') or not self.stream.isatty()):
+ self.stream.close()
+ except ValueError:
+ # A ValueError is thrown if we try to run isatty() on a closed file.
+ pass
+ super(PythonHandler, self).close()
+ finally:
+ self.release()
+
+
+class ABSLHandler(logging.Handler):
+ """Abseil Python logging module's log handler."""
+
+ def __init__(self, python_logging_formatter):
+ super(ABSLHandler, self).__init__()
+
+ self._python_handler = PythonHandler(formatter=python_logging_formatter)
+ self.activate_python_handler()
+
+ def format(self, record):
+ return self._current_handler.format(record)
+
+ def setFormatter(self, fmt):
+ self._current_handler.setFormatter(fmt)
+
+ def emit(self, record):
+ self._current_handler.emit(record)
+
+ def flush(self):
+ self._current_handler.flush()
+
+ def close(self):
+ super(ABSLHandler, self).close()
+ self._current_handler.close()
+
+ def handle(self, record):
+ rv = self.filter(record)
+ if rv:
+ return self._current_handler.handle(record)
+ return rv
+
+ @property
+ def python_handler(self):
+ return self._python_handler
+
+ def activate_python_handler(self):
+ """Uses the Python logging handler as the current logging handler."""
+ self._current_handler = self._python_handler
+
+ def use_absl_log_file(self, program_name=None, log_dir=None):
+ self._current_handler.use_absl_log_file(program_name, log_dir)
+
+ def start_logging_to_file(self, program_name=None, log_dir=None):
+ self._current_handler.start_logging_to_file(program_name, log_dir)
+
+
+class PythonFormatter(logging.Formatter):
+ """Formatter class used by PythonHandler."""
+
+ def format(self, record):
+ """Appends the message from the record to the results of the prefix.
+
+ Args:
+ record: logging.LogRecord, the record to be formatted.
+
+ Returns:
+ The formatted string representing the record.
+ """
+ if (not FLAGS['showprefixforinfo'].value and
+ FLAGS['verbosity'].value == converter.ABSL_INFO and
+ record.levelno == logging.INFO and
+ _absl_handler.python_handler.stream == sys.stderr):
+ prefix = ''
+ else:
+ prefix = get_absl_log_prefix(record)
+ return prefix + super(PythonFormatter, self).format(record)
+
+
+class ABSLLogger(logging.getLoggerClass()):
+ """A logger that will create LogRecords while skipping some stack frames.
+
+ This class maintains an internal list of filenames and method names
+ for use when determining who called the currently executing stack
+ frame. Any method names from specific source files are skipped when
+ walking backwards through the stack.
+
+ Client code should use the register_frame_to_skip method to let the
+ ABSLLogger know which method from which file should be
+ excluded from the walk backwards through the stack.
+ """
+ _frames_to_skip = set()
+
+ def findCaller(self, stack_info=False, stacklevel=1):
+ """Finds the frame of the calling method on the stack.
+
+ This method skips any frames registered with the
+ ABSLLogger and any methods from this file, and whatever
+ method is currently being used to generate the prefix for the log
+ line. Then it returns the file name, line number, and method name
+ of the calling method. An optional fourth item may be returned,
+ callers who only need things from the first three are advised to
+ always slice or index the result rather than using direct unpacking
+ assignment.
+
+ Args:
+ stack_info: bool, when True, include the stack trace as a fourth item
+ returned. On Python 3 there are always four items returned - the
+ fourth will be None when this is False. On Python 2 the stdlib
+ base class API only returns three items. We do the same when this
+ new parameter is unspecified or False for compatibility.
+
+ Returns:
+ (filename, lineno, methodname[, sinfo]) of the calling method.
+ """
+ f_to_skip = ABSLLogger._frames_to_skip
+ # Use sys._getframe(2) instead of logging.currentframe(), it's slightly
+ # faster because there is one less frame to traverse.
+ frame = sys._getframe(2) # pylint: disable=protected-access
+
+ while frame:
+ code = frame.f_code
+ if (_LOGGING_FILE_PREFIX not in code.co_filename and
+ (code.co_filename, code.co_name,
+ code.co_firstlineno) not in f_to_skip and
+ (code.co_filename, code.co_name) not in f_to_skip):
+ sinfo = None
+ if stack_info:
+ out = io.StringIO()
+ out.write(u'Stack (most recent call last):\n')
+ traceback.print_stack(frame, file=out)
+ sinfo = out.getvalue().rstrip(u'\n')
+ return (code.co_filename, frame.f_lineno, code.co_name, sinfo)
+ frame = frame.f_back
+
+ def critical(self, msg, *args, **kwargs):
+ """Logs 'msg % args' with severity 'CRITICAL'."""
+ self.log(logging.CRITICAL, msg, *args, **kwargs)
+
+ def fatal(self, msg, *args, **kwargs):
+ """Logs 'msg % args' with severity 'FATAL'."""
+ self.log(logging.FATAL, msg, *args, **kwargs)
+
+ def error(self, msg, *args, **kwargs):
+ """Logs 'msg % args' with severity 'ERROR'."""
+ self.log(logging.ERROR, msg, *args, **kwargs)
+
+ def warn(self, msg, *args, **kwargs):
+ """Logs 'msg % args' with severity 'WARN'."""
+ warnings.warn("The 'warn' method is deprecated, use 'warning' instead",
+ DeprecationWarning, 2)
+ self.log(logging.WARN, msg, *args, **kwargs)
+
+ def warning(self, msg, *args, **kwargs):
+ """Logs 'msg % args' with severity 'WARNING'."""
+ self.log(logging.WARNING, msg, *args, **kwargs)
+
+ def info(self, msg, *args, **kwargs):
+ """Logs 'msg % args' with severity 'INFO'."""
+ self.log(logging.INFO, msg, *args, **kwargs)
+
+ def debug(self, msg, *args, **kwargs):
+ """Logs 'msg % args' with severity 'DEBUG'."""
+ self.log(logging.DEBUG, msg, *args, **kwargs)
+
+ def log(self, level, msg, *args, **kwargs):
+ """Logs a message at a cetain level substituting in the supplied arguments.
+
+ This method behaves differently in python and c++ modes.
+
+ Args:
+ level: int, the standard logging level at which to log the message.
+ msg: str, the text of the message to log.
+ *args: The arguments to substitute in the message.
+ **kwargs: The keyword arguments to substitute in the message.
+ """
+ if level >= logging.FATAL:
+ # Add property to the LogRecord created by this logger.
+ # This will be used by the ABSLHandler to determine whether it should
+ # treat CRITICAL/FATAL logs as really FATAL.
+ extra = kwargs.setdefault('extra', {})
+ extra[_ABSL_LOG_FATAL] = True
+ super(ABSLLogger, self).log(level, msg, *args, **kwargs)
+
+ def handle(self, record):
+ """Calls handlers without checking Logger.disabled.
+
+ Non-root loggers are set to disabled after setup with logging.config if
+ it's not explicitly specified. Historically, absl logging will not be
+ disabled by that. To maintaining this behavior, this function skips
+ checking the Logger.disabled bit.
+
+ This logger can still be disabled by adding a filter that filters out
+ everything.
+
+ Args:
+ record: logging.LogRecord, the record to handle.
+ """
+ if self.filter(record):
+ self.callHandlers(record)
+
+ @classmethod
+ def register_frame_to_skip(cls, file_name, function_name, line_number=None):
+ """Registers a function name to skip when walking the stack.
+
+ The ABSLLogger sometimes skips method calls on the stack
+ to make the log messages meaningful in their appropriate context.
+ This method registers a function from a particular file as one
+ which should be skipped.
+
+ Args:
+ file_name: str, the name of the file that contains the function.
+ function_name: str, the name of the function to skip.
+ line_number: int, if provided, only the function with this starting line
+ number will be skipped. Otherwise, all functions with the same name
+ in the file will be skipped.
+ """
+ if line_number is not None:
+ cls._frames_to_skip.add((file_name, function_name, line_number))
+ else:
+ cls._frames_to_skip.add((file_name, function_name))
+
+
+def _get_thread_id():
+ """Gets id of current thread, suitable for logging as an unsigned quantity.
+
+ If pywrapbase is linked, returns GetTID() for the thread ID to be
+ consistent with C++ logging. Otherwise, returns the numeric thread id.
+ The quantities are made unsigned by masking with 2*sys.maxint + 1.
+
+ Returns:
+ Thread ID unique to this process (unsigned)
+ """
+ thread_id = threading.get_ident()
+ return thread_id & _THREAD_ID_MASK
+
+
+def get_absl_logger():
+ """Returns the absl logger instance."""
+ return _absl_logger
+
+
+def get_absl_handler():
+ """Returns the absl handler instance."""
+ return _absl_handler
+
+
+def use_python_logging(quiet=False):
+ """Uses the python implementation of the logging code.
+
+ Args:
+ quiet: No logging message about switching logging type.
+ """
+ get_absl_handler().activate_python_handler()
+ if not quiet:
+ info('Restoring pure python logging')
+
+
+_attempted_to_remove_stderr_stream_handlers = False
+
+
+def use_absl_handler():
+ """Uses the ABSL logging handler for logging.
+
+ This method is called in app.run() so the absl handler is used in absl apps.
+ """
+ global _attempted_to_remove_stderr_stream_handlers
+ if not _attempted_to_remove_stderr_stream_handlers:
+ # The absl handler logs to stderr by default. To prevent double logging to
+ # stderr, the following code tries its best to remove other handlers that
+ # emit to stderr. Those handlers are most commonly added when
+ # logging.info/debug is called before calling use_absl_handler().
+ handlers = [
+ h for h in logging.root.handlers
+ if isinstance(h, logging.StreamHandler) and h.stream == sys.stderr]
+ for h in handlers:
+ logging.root.removeHandler(h)
+ _attempted_to_remove_stderr_stream_handlers = True
+
+ absl_handler = get_absl_handler()
+ if absl_handler not in logging.root.handlers:
+ logging.root.addHandler(absl_handler)
+ FLAGS['verbosity']._update_logging_levels() # pylint: disable=protected-access
+ FLAGS['logger_levels']._update_logger_levels() # pylint: disable=protected-access
+
+
+def _initialize():
+ """Initializes loggers and handlers."""
+ global _absl_logger, _absl_handler
+
+ if _absl_logger:
+ return
+
+ original_logger_class = logging.getLoggerClass()
+ logging.setLoggerClass(ABSLLogger)
+ _absl_logger = logging.getLogger('absl')
+ logging.setLoggerClass(original_logger_class)
+
+ python_logging_formatter = PythonFormatter()
+ _absl_handler = ABSLHandler(python_logging_formatter)
+
+
+_initialize()
diff --git a/absl/logging/converter.py b/absl/logging/converter.py
new file mode 100644
index 0000000..53dd46d
--- /dev/null
+++ b/absl/logging/converter.py
@@ -0,0 +1,211 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Module to convert log levels between Abseil Python, C++, and Python standard.
+
+This converter has to convert (best effort) between three different
+logging level schemes:
+ cpp = The C++ logging level scheme used in Abseil C++.
+ absl = The absl.logging level scheme used in Abseil Python.
+ standard = The python standard library logging level scheme.
+
+Here is a handy ascii chart for easy mental mapping.
+
+ LEVEL | cpp | absl | standard |
+ ---------+-----+--------+----------+
+ DEBUG | 0 | 1 | 10 |
+ INFO | 0 | 0 | 20 |
+ WARNING | 1 | -1 | 30 |
+ ERROR | 2 | -2 | 40 |
+ CRITICAL | 3 | -3 | 50 |
+ FATAL | 3 | -3 | 50 |
+
+Note: standard logging CRITICAL is mapped to absl/cpp FATAL.
+However, only CRITICAL logs from the absl logger (or absl.logging.fatal) will
+terminate the program. CRITICAL logs from non-absl loggers are treated as
+error logs with a message prefix "CRITICAL - ".
+
+Converting from standard to absl or cpp is a lossy conversion.
+Converting back to standard will lose granularity. For this reason,
+users should always try to convert to standard, the richest
+representation, before manipulating the levels, and then only to cpp
+or absl if those level schemes are absolutely necessary.
+"""
+
+import logging
+
+STANDARD_CRITICAL = logging.CRITICAL
+STANDARD_ERROR = logging.ERROR
+STANDARD_WARNING = logging.WARNING
+STANDARD_INFO = logging.INFO
+STANDARD_DEBUG = logging.DEBUG
+
+# These levels are also used to define the constants
+# FATAL, ERROR, WARNING, INFO, and DEBUG in the
+# absl.logging module.
+ABSL_FATAL = -3
+ABSL_ERROR = -2
+ABSL_WARNING = -1
+ABSL_WARN = -1 # Deprecated name.
+ABSL_INFO = 0
+ABSL_DEBUG = 1
+
+ABSL_LEVELS = {ABSL_FATAL: 'FATAL',
+ ABSL_ERROR: 'ERROR',
+ ABSL_WARNING: 'WARNING',
+ ABSL_INFO: 'INFO',
+ ABSL_DEBUG: 'DEBUG'}
+
+# Inverts the ABSL_LEVELS dictionary
+ABSL_NAMES = {'FATAL': ABSL_FATAL,
+ 'ERROR': ABSL_ERROR,
+ 'WARNING': ABSL_WARNING,
+ 'WARN': ABSL_WARNING, # Deprecated name.
+ 'INFO': ABSL_INFO,
+ 'DEBUG': ABSL_DEBUG}
+
+ABSL_TO_STANDARD = {ABSL_FATAL: STANDARD_CRITICAL,
+ ABSL_ERROR: STANDARD_ERROR,
+ ABSL_WARNING: STANDARD_WARNING,
+ ABSL_INFO: STANDARD_INFO,
+ ABSL_DEBUG: STANDARD_DEBUG}
+
+# Inverts the ABSL_TO_STANDARD
+STANDARD_TO_ABSL = dict((v, k) for (k, v) in ABSL_TO_STANDARD.items())
+
+
+def get_initial_for_level(level):
+ """Gets the initial that should start the log line for the given level.
+
+ It returns:
+ - 'I' when: level < STANDARD_WARNING.
+ - 'W' when: STANDARD_WARNING <= level < STANDARD_ERROR.
+ - 'E' when: STANDARD_ERROR <= level < STANDARD_CRITICAL.
+ - 'F' when: level >= STANDARD_CRITICAL.
+
+ Args:
+ level: int, a Python standard logging level.
+
+ Returns:
+ The first initial as it would be logged by the C++ logging module.
+ """
+ if level < STANDARD_WARNING:
+ return 'I'
+ elif level < STANDARD_ERROR:
+ return 'W'
+ elif level < STANDARD_CRITICAL:
+ return 'E'
+ else:
+ return 'F'
+
+
+def absl_to_cpp(level):
+ """Converts an absl log level to a cpp log level.
+
+ Args:
+ level: int, an absl.logging level.
+
+ Raises:
+ TypeError: Raised when level is not an integer.
+
+ Returns:
+ The corresponding integer level for use in Abseil C++.
+ """
+ if not isinstance(level, int):
+ raise TypeError('Expect an int level, found {}'.format(type(level)))
+ if level >= 0:
+ # C++ log levels must be >= 0
+ return 0
+ else:
+ return -level
+
+
+def absl_to_standard(level):
+ """Converts an integer level from the absl value to the standard value.
+
+ Args:
+ level: int, an absl.logging level.
+
+ Raises:
+ TypeError: Raised when level is not an integer.
+
+ Returns:
+ The corresponding integer level for use in standard logging.
+ """
+ if not isinstance(level, int):
+ raise TypeError('Expect an int level, found {}'.format(type(level)))
+ if level < ABSL_FATAL:
+ level = ABSL_FATAL
+ if level <= ABSL_DEBUG:
+ return ABSL_TO_STANDARD[level]
+ # Maps to vlog levels.
+ return STANDARD_DEBUG - level + 1
+
+
+def string_to_standard(level):
+ """Converts a string level to standard logging level value.
+
+ Args:
+ level: str, case-insensitive 'debug', 'info', 'warning', 'error', 'fatal'.
+
+ Returns:
+ The corresponding integer level for use in standard logging.
+ """
+ return absl_to_standard(ABSL_NAMES.get(level.upper()))
+
+
+def standard_to_absl(level):
+ """Converts an integer level from the standard value to the absl value.
+
+ Args:
+ level: int, a Python standard logging level.
+
+ Raises:
+ TypeError: Raised when level is not an integer.
+
+ Returns:
+ The corresponding integer level for use in absl logging.
+ """
+ if not isinstance(level, int):
+ raise TypeError('Expect an int level, found {}'.format(type(level)))
+ if level < 0:
+ level = 0
+ if level < STANDARD_DEBUG:
+ # Maps to vlog levels.
+ return STANDARD_DEBUG - level + 1
+ elif level < STANDARD_INFO:
+ return ABSL_DEBUG
+ elif level < STANDARD_WARNING:
+ return ABSL_INFO
+ elif level < STANDARD_ERROR:
+ return ABSL_WARNING
+ elif level < STANDARD_CRITICAL:
+ return ABSL_ERROR
+ else:
+ return ABSL_FATAL
+
+
+def standard_to_cpp(level):
+ """Converts an integer level from the standard value to the cpp value.
+
+ Args:
+ level: int, a Python standard logging level.
+
+ Raises:
+ TypeError: Raised when level is not an integer.
+
+ Returns:
+ The corresponding integer level for use in cpp logging.
+ """
+ return absl_to_cpp(standard_to_absl(level))
diff --git a/absl/logging/tests/__init__.py b/absl/logging/tests/__init__.py
new file mode 100644
index 0000000..a3bd1cd
--- /dev/null
+++ b/absl/logging/tests/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/absl/logging/tests/converter_test.py b/absl/logging/tests/converter_test.py
new file mode 100644
index 0000000..bdc893a
--- /dev/null
+++ b/absl/logging/tests/converter_test.py
@@ -0,0 +1,135 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for converter.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+
+from absl import logging as absl_logging
+from absl.logging import converter
+from absl.testing import absltest
+
+
+class ConverterTest(absltest.TestCase):
+ """Tests the converter module."""
+
+ def test_absl_to_cpp(self):
+ self.assertEqual(0, converter.absl_to_cpp(absl_logging.DEBUG))
+ self.assertEqual(0, converter.absl_to_cpp(absl_logging.INFO))
+ self.assertEqual(1, converter.absl_to_cpp(absl_logging.WARN))
+ self.assertEqual(2, converter.absl_to_cpp(absl_logging.ERROR))
+ self.assertEqual(3, converter.absl_to_cpp(absl_logging.FATAL))
+
+ with self.assertRaises(TypeError):
+ converter.absl_to_cpp('')
+
+ def test_absl_to_standard(self):
+ self.assertEqual(
+ logging.DEBUG, converter.absl_to_standard(absl_logging.DEBUG))
+ self.assertEqual(
+ logging.INFO, converter.absl_to_standard(absl_logging.INFO))
+ self.assertEqual(
+ logging.WARNING, converter.absl_to_standard(absl_logging.WARN))
+ self.assertEqual(
+ logging.WARN, converter.absl_to_standard(absl_logging.WARN))
+ self.assertEqual(
+ logging.ERROR, converter.absl_to_standard(absl_logging.ERROR))
+ self.assertEqual(
+ logging.FATAL, converter.absl_to_standard(absl_logging.FATAL))
+ self.assertEqual(
+ logging.CRITICAL, converter.absl_to_standard(absl_logging.FATAL))
+ # vlog levels.
+ self.assertEqual(9, converter.absl_to_standard(2))
+ self.assertEqual(8, converter.absl_to_standard(3))
+
+ with self.assertRaises(TypeError):
+ converter.absl_to_standard('')
+
+ def test_standard_to_absl(self):
+ self.assertEqual(
+ absl_logging.DEBUG, converter.standard_to_absl(logging.DEBUG))
+ self.assertEqual(
+ absl_logging.INFO, converter.standard_to_absl(logging.INFO))
+ self.assertEqual(
+ absl_logging.WARN, converter.standard_to_absl(logging.WARN))
+ self.assertEqual(
+ absl_logging.WARN, converter.standard_to_absl(logging.WARNING))
+ self.assertEqual(
+ absl_logging.ERROR, converter.standard_to_absl(logging.ERROR))
+ self.assertEqual(
+ absl_logging.FATAL, converter.standard_to_absl(logging.FATAL))
+ self.assertEqual(
+ absl_logging.FATAL, converter.standard_to_absl(logging.CRITICAL))
+ # vlog levels.
+ self.assertEqual(2, converter.standard_to_absl(logging.DEBUG - 1))
+ self.assertEqual(3, converter.standard_to_absl(logging.DEBUG - 2))
+
+ with self.assertRaises(TypeError):
+ converter.standard_to_absl('')
+
+ def test_standard_to_cpp(self):
+ self.assertEqual(0, converter.standard_to_cpp(logging.DEBUG))
+ self.assertEqual(0, converter.standard_to_cpp(logging.INFO))
+ self.assertEqual(1, converter.standard_to_cpp(logging.WARN))
+ self.assertEqual(1, converter.standard_to_cpp(logging.WARNING))
+ self.assertEqual(2, converter.standard_to_cpp(logging.ERROR))
+ self.assertEqual(3, converter.standard_to_cpp(logging.FATAL))
+ self.assertEqual(3, converter.standard_to_cpp(logging.CRITICAL))
+
+ with self.assertRaises(TypeError):
+ converter.standard_to_cpp('')
+
+ def test_get_initial_for_level(self):
+ self.assertEqual('F', converter.get_initial_for_level(logging.CRITICAL))
+ self.assertEqual('E', converter.get_initial_for_level(logging.ERROR))
+ self.assertEqual('W', converter.get_initial_for_level(logging.WARNING))
+ self.assertEqual('I', converter.get_initial_for_level(logging.INFO))
+ self.assertEqual('I', converter.get_initial_for_level(logging.DEBUG))
+ self.assertEqual('I', converter.get_initial_for_level(logging.NOTSET))
+
+ self.assertEqual('F', converter.get_initial_for_level(51))
+ self.assertEqual('E', converter.get_initial_for_level(49))
+ self.assertEqual('E', converter.get_initial_for_level(41))
+ self.assertEqual('W', converter.get_initial_for_level(39))
+ self.assertEqual('W', converter.get_initial_for_level(31))
+ self.assertEqual('I', converter.get_initial_for_level(29))
+ self.assertEqual('I', converter.get_initial_for_level(21))
+ self.assertEqual('I', converter.get_initial_for_level(19))
+ self.assertEqual('I', converter.get_initial_for_level(11))
+ self.assertEqual('I', converter.get_initial_for_level(9))
+ self.assertEqual('I', converter.get_initial_for_level(1))
+ self.assertEqual('I', converter.get_initial_for_level(-1))
+
+ def test_string_to_standard(self):
+ self.assertEqual(logging.DEBUG, converter.string_to_standard('debug'))
+ self.assertEqual(logging.INFO, converter.string_to_standard('info'))
+ self.assertEqual(logging.WARNING, converter.string_to_standard('warn'))
+ self.assertEqual(logging.WARNING, converter.string_to_standard('warning'))
+ self.assertEqual(logging.ERROR, converter.string_to_standard('error'))
+ self.assertEqual(logging.CRITICAL, converter.string_to_standard('fatal'))
+
+ self.assertEqual(logging.DEBUG, converter.string_to_standard('DEBUG'))
+ self.assertEqual(logging.INFO, converter.string_to_standard('INFO'))
+ self.assertEqual(logging.WARNING, converter.string_to_standard('WARN'))
+ self.assertEqual(logging.WARNING, converter.string_to_standard('WARNING'))
+ self.assertEqual(logging.ERROR, converter.string_to_standard('ERROR'))
+ self.assertEqual(logging.CRITICAL, converter.string_to_standard('FATAL'))
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/logging/tests/log_before_import_test.py b/absl/logging/tests/log_before_import_test.py
new file mode 100644
index 0000000..903af16
--- /dev/null
+++ b/absl/logging/tests/log_before_import_test.py
@@ -0,0 +1,127 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test of logging behavior before app.run(), aka flag and logging init()."""
+
+import contextlib
+import io
+import os
+import re
+import sys
+import tempfile
+from unittest import mock
+
+from absl import logging
+from absl.testing import absltest
+
+logging.get_verbosity() # Access --verbosity before flag parsing.
+# Access --logtostderr before flag parsing.
+logging.get_absl_handler().use_absl_log_file()
+
+
+class Error(Exception):
+ pass
+
+
+@contextlib.contextmanager
+def captured_stderr_filename():
+ """Captures stderr and writes them to a temporary file.
+
+ This uses os.dup/os.dup2 to redirect the stderr fd for capturing standard
+ error of logging at import-time. We cannot mock sys.stderr because on the
+ first log call, a default log handler writing to the mock sys.stderr is
+ registered, and it will never be removed and subsequent logs go to the mock
+ in addition to the real stder.
+
+ Yields:
+ The filename of captured stderr.
+ """
+ stderr_capture_file_fd, stderr_capture_file_name = tempfile.mkstemp()
+ original_stderr_fd = os.dup(sys.stderr.fileno())
+ os.dup2(stderr_capture_file_fd, sys.stderr.fileno())
+ try:
+ yield stderr_capture_file_name
+ finally:
+ os.close(stderr_capture_file_fd)
+ os.dup2(original_stderr_fd, sys.stderr.fileno())
+
+
+# Pre-initialization (aka "import" / __main__ time) test.
+with captured_stderr_filename() as before_set_verbosity_filename:
+ # Warnings and above go to stderr.
+ logging.debug('Debug message at parse time.')
+ logging.info('Info message at parse time.')
+ logging.error('Error message at parse time.')
+ logging.warning('Warning message at parse time.')
+ try:
+ raise Error('Exception reason.')
+ except Error:
+ logging.exception('Exception message at parse time.')
+
+
+logging.set_verbosity(logging.ERROR)
+with captured_stderr_filename() as after_set_verbosity_filename:
+ # Verbosity is set to ERROR, errors and above go to stderr.
+ logging.debug('Debug message at parse time.')
+ logging.info('Info message at parse time.')
+ logging.warning('Warning message at parse time.')
+ logging.error('Error message at parse time.')
+
+
+class LoggingInitWarningTest(absltest.TestCase):
+
+ def test_captured_pre_init_warnings(self):
+ with open(before_set_verbosity_filename) as stderr_capture_file:
+ captured_stderr = stderr_capture_file.read()
+ self.assertNotIn('Debug message at parse time.', captured_stderr)
+ self.assertNotIn('Info message at parse time.', captured_stderr)
+
+ traceback_re = re.compile(
+ r'\nTraceback \(most recent call last\):.*?Error: Exception reason.',
+ re.MULTILINE | re.DOTALL)
+ if not traceback_re.search(captured_stderr):
+ self.fail(
+ 'Cannot find traceback message from logging.exception '
+ 'in stderr:\n{}'.format(captured_stderr))
+ # Remove the traceback so the rest of the stderr is deterministic.
+ captured_stderr = traceback_re.sub('', captured_stderr)
+ captured_stderr_lines = captured_stderr.splitlines()
+ self.assertLen(captured_stderr_lines, 3)
+ self.assertIn('Error message at parse time.', captured_stderr_lines[0])
+ self.assertIn('Warning message at parse time.', captured_stderr_lines[1])
+ self.assertIn('Exception message at parse time.', captured_stderr_lines[2])
+
+ def test_set_verbosity_pre_init(self):
+ with open(after_set_verbosity_filename) as stderr_capture_file:
+ captured_stderr = stderr_capture_file.read()
+ captured_stderr_lines = captured_stderr.splitlines()
+
+ self.assertNotIn('Debug message at parse time.', captured_stderr)
+ self.assertNotIn('Info message at parse time.', captured_stderr)
+ self.assertNotIn('Warning message at parse time.', captured_stderr)
+ self.assertLen(captured_stderr_lines, 1)
+ self.assertIn('Error message at parse time.', captured_stderr_lines[0])
+
+ def test_no_more_warnings(self):
+ fake_stderr_type = io.BytesIO if bytes is str else io.StringIO
+ with mock.patch('sys.stderr', new=fake_stderr_type()) as mock_stderr:
+ self.assertMultiLineEqual('', mock_stderr.getvalue())
+ logging.warning('Hello. hello. hello. Is there anybody out there?')
+ self.assertNotIn('Logging before flag parsing goes to stderr',
+ mock_stderr.getvalue())
+ logging.info('A major purpose of this executable is merely not to crash.')
+
+
+if __name__ == '__main__':
+ absltest.main() # This calls the app.run() init equivalent.
diff --git a/absl/logging/tests/logging_functional_test.py b/absl/logging/tests/logging_functional_test.py
new file mode 100644
index 0000000..98a2fab
--- /dev/null
+++ b/absl/logging/tests/logging_functional_test.py
@@ -0,0 +1,732 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Functional tests for absl.logging."""
+
+import fnmatch
+import os
+import re
+import shutil
+import subprocess
+import sys
+import tempfile
+
+from absl import logging
+from absl.testing import _bazelize_command
+from absl.testing import absltest
+from absl.testing import parameterized
+
+
+_PY_VLOG3_LOG_MESSAGE = """\
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:62] This line is VLOG level 3
+"""
+
+_PY_VLOG2_LOG_MESSAGE = """\
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:64] This line is VLOG level 2
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:64] This line is log level 2
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:64] VLOG level 1, but only if VLOG level 2 is active
+"""
+
+# VLOG1 is the same as DEBUG logs.
+_PY_DEBUG_LOG_MESSAGE = """\
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is VLOG level 1
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is log level 1
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:66] This line is DEBUG
+"""
+
+_PY_INFO_LOG_MESSAGE = """\
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is VLOG level 0
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is log level 0
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:70] Interesting Stuff\0
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:71] Interesting Stuff with Arguments: 42
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:73] Interesting Stuff with Dictionary
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:123] This should appear 5 times.
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:123] This should appear 5 times.
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:123] This should appear 5 times.
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:123] This should appear 5 times.
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:123] This should appear 5 times.
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:76] Info first 1 of 2
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:77] Info 1 (every 3)
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:76] Info first 2 of 2
+I1231 23:59:59.000000 12345 logging_functional_test_helper.py:77] Info 4 (every 3)
+"""
+
+_PY_INFO_LOG_MESSAGE_NOPREFIX = """\
+This line is VLOG level 0
+This line is log level 0
+Interesting Stuff\0
+Interesting Stuff with Arguments: 42
+Interesting Stuff with Dictionary
+This should appear 5 times.
+This should appear 5 times.
+This should appear 5 times.
+This should appear 5 times.
+This should appear 5 times.
+Info first 1 of 2
+Info 1 (every 3)
+Info first 2 of 2
+Info 4 (every 3)
+"""
+
+_PY_WARNING_LOG_MESSAGE = """\
+W1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is VLOG level -1
+W1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is log level -1
+W1231 23:59:59.000000 12345 logging_functional_test_helper.py:79] Worrying Stuff
+W0000 23:59:59.000000 12345 logging_functional_test_helper.py:81] Warn first 1 of 2
+W0000 23:59:59.000000 12345 logging_functional_test_helper.py:82] Warn 1 (every 3)
+W0000 23:59:59.000000 12345 logging_functional_test_helper.py:81] Warn first 2 of 2
+W0000 23:59:59.000000 12345 logging_functional_test_helper.py:82] Warn 4 (every 3)
+"""
+
+if sys.version_info[0:2] == (3, 4):
+ _FAKE_ERROR_EXTRA_MESSAGE = """\
+Traceback (most recent call last):
+ File "logging_functional_test_helper.py", line 456, in _test_do_logging
+ raise OSError('Fake Error')
+"""
+else:
+ _FAKE_ERROR_EXTRA_MESSAGE = ''
+
+_PY_ERROR_LOG_MESSAGE = """\
+E1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is VLOG level -2
+E1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is log level -2
+E1231 23:59:59.000000 12345 logging_functional_test_helper.py:87] An Exception %s
+Traceback (most recent call last):
+ File "logging_functional_test_helper.py", line 456, in _test_do_logging
+ raise OSError('Fake Error')
+OSError: Fake Error
+E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] Once more, just because
+Traceback (most recent call last):
+ File "./logging_functional_test_helper.py", line 78, in _test_do_logging
+ raise OSError('Fake Error')
+OSError: Fake Error
+E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] Exception 2 %s
+Traceback (most recent call last):
+ File "logging_functional_test_helper.py", line 456, in _test_do_logging
+ raise OSError('Fake Error')
+OSError: Fake Error
+E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] Non-exception
+E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] Exception 3
+Traceback (most recent call last):
+ File "logging_functional_test_helper.py", line 456, in _test_do_logging
+ raise OSError('Fake Error')
+OSError: Fake Error
+E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] No traceback
+{fake_error_extra}OSError: Fake Error
+E1231 23:59:59.000000 12345 logging_functional_test_helper.py:90] Alarming Stuff
+E0000 23:59:59.000000 12345 logging_functional_test_helper.py:92] Error first 1 of 2
+E0000 23:59:59.000000 12345 logging_functional_test_helper.py:93] Error 1 (every 3)
+E0000 23:59:59.000000 12345 logging_functional_test_helper.py:92] Error first 2 of 2
+E0000 23:59:59.000000 12345 logging_functional_test_helper.py:93] Error 4 (every 3)
+""".format(fake_error_extra=_FAKE_ERROR_EXTRA_MESSAGE)
+
+
+_CRITICAL_DOWNGRADE_TO_ERROR_MESSAGE = """\
+E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] CRITICAL - A critical message
+"""
+
+
+_VERBOSITY_FLAG_TEST_PARAMETERS = (
+ ('fatal', logging.FATAL),
+ ('error', logging.ERROR),
+ ('warning', logging.WARN),
+ ('info', logging.INFO),
+ ('debug', logging.DEBUG),
+ ('vlog1', 1),
+ ('vlog2', 2),
+ ('vlog3', 3))
+
+
+def _get_fatal_log_expectation(testcase, message, include_stacktrace):
+ """Returns the expectation for fatal logging tests.
+
+ Args:
+ testcase: The TestCase instance.
+ message: The extra fatal logging message.
+ include_stacktrace: Whether or not to include stacktrace.
+
+ Returns:
+ A callable, the expectation for fatal logging tests. It will be passed to
+ FunctionalTest._exec_test as third items in the expected_logs list.
+ See _exec_test's docstring for more information.
+ """
+ def assert_logs(logs):
+ if os.name == 'nt':
+ # On Windows, it also dumps extra information at the end, something like:
+ # This application has requested the Runtime to terminate it in an
+ # unusual way. Please contact the application's support team for more
+ # information.
+ logs = '\n'.join(logs.split('\n')[:-3])
+ format_string = (
+ 'F1231 23:59:59.000000 12345 logging_functional_test_helper.py:175] '
+ '%s message\n')
+ expected_logs = format_string % message
+ if include_stacktrace:
+ expected_logs += 'Stack trace:\n'
+ faulthandler_start = 'Fatal Python error: Aborted'
+ testcase.assertIn(faulthandler_start, logs)
+ log_message = logs.split(faulthandler_start)[0]
+ testcase.assertEqual(_munge_log(expected_logs), _munge_log(log_message))
+
+ return assert_logs
+
+
+def _munge_log(buf):
+ """Remove timestamps, thread ids, filenames and line numbers from logs."""
+
+ # Remove all messages produced before the output to be tested.
+ buf = re.sub(r'(?:.|\n)*START OF TEST HELPER LOGS: IGNORE PREVIOUS.\n',
+ r'',
+ buf)
+
+ # Greeting
+ buf = re.sub(r'(?m)^Log file created at: .*\n',
+ '',
+ buf)
+ buf = re.sub(r'(?m)^Running on machine: .*\n',
+ '',
+ buf)
+ buf = re.sub(r'(?m)^Binary: .*\n',
+ '',
+ buf)
+ buf = re.sub(r'(?m)^Log line format: .*\n',
+ '',
+ buf)
+
+ # Verify thread id is logged as a non-negative quantity.
+ matched = re.match(r'(?m)^(\w)(\d\d\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d) '
+ r'([ ]*-?[0-9a-fA-f]+ )?([a-zA-Z<][\w._<>-]+):(\d+)',
+ buf)
+ if matched:
+ threadid = matched.group(3)
+ if int(threadid) < 0:
+ raise AssertionError("Negative threadid '%s' in '%s'" % (threadid, buf))
+
+ # Timestamp
+ buf = re.sub(r'(?m)' + logging.ABSL_LOGGING_PREFIX_REGEX,
+ r'\g<severity>0000 00:00:00.000000 12345 \g<filename>:123',
+ buf)
+
+ # Traceback
+ buf = re.sub(r'(?m)^ File "(.*/)?([^"/]+)", line (\d+),',
+ r' File "\g<2>", line 456,',
+ buf)
+
+ # Stack trace is too complicated for re, just assume it extends to end of
+ # output
+ buf = re.sub(r'(?sm)^Stack trace:\n.*',
+ r'Stack trace:\n',
+ buf)
+ buf = re.sub(r'(?sm)^\*\*\* Signal 6 received by PID.*\n.*',
+ r'Stack trace:\n',
+ buf)
+ buf = re.sub((r'(?sm)^\*\*\* ([A-Z]+) received by PID (\d+) '
+ r'\(TID 0x([0-9a-f]+)\)'
+ r'( from PID \d+)?; stack trace: \*\*\*\n.*'),
+ r'Stack trace:\n',
+ buf)
+ buf = re.sub(r'(?sm)^\*\*\* Check failure stack trace: \*\*\*\n.*',
+ r'Stack trace:\n',
+ buf)
+
+ if os.name == 'nt':
+ # On windows, we calls Python interpreter explicitly, so the file names
+ # include the full path. Strip them.
+ buf = re.sub(r'( File ").*(logging_functional_test_helper\.py", line )',
+ r'\1\2',
+ buf)
+
+ return buf
+
+
+def _verify_status(expected, actual, output):
+ if expected != actual:
+ raise AssertionError(
+ 'Test exited with unexpected status code %d (expected %d). '
+ 'Output was:\n%s' % (actual, expected, output))
+
+
+def _verify_ok(status, output):
+ """Check that helper exited with no errors."""
+ _verify_status(0, status, output)
+
+
+def _verify_fatal(status, output):
+ """Check that helper died as expected."""
+ # os.abort generates a SIGABRT signal (-6). On Windows, the process
+ # immediately returns an exit code of 3.
+ # See https://docs.python.org/3.6/library/os.html#os.abort.
+ expected_exit_code = 3 if os.name == 'nt' else -6
+ _verify_status(expected_exit_code, status, output)
+
+
+def _verify_assert(status, output):
+ """.Check that helper failed with assertion."""
+ _verify_status(1, status, output)
+
+
+class FunctionalTest(parameterized.TestCase):
+ """Functional tests using the logging_functional_test_helper script."""
+
+ def _get_helper(self):
+ helper_name = 'absl/logging/tests/logging_functional_test_helper'
+ return _bazelize_command.get_executable_path(helper_name)
+
+ def _get_logs(self,
+ verbosity,
+ include_info_prefix=True):
+ logs = []
+ if verbosity >= 3:
+ logs.append(_PY_VLOG3_LOG_MESSAGE)
+ if verbosity >= 2:
+ logs.append(_PY_VLOG2_LOG_MESSAGE)
+ if verbosity >= logging.DEBUG:
+ logs.append(_PY_DEBUG_LOG_MESSAGE)
+
+ if verbosity >= logging.INFO:
+ if include_info_prefix:
+ logs.append(_PY_INFO_LOG_MESSAGE)
+ else:
+ logs.append(_PY_INFO_LOG_MESSAGE_NOPREFIX)
+ if verbosity >= logging.WARN:
+ logs.append(_PY_WARNING_LOG_MESSAGE)
+ if verbosity >= logging.ERROR:
+ logs.append(_PY_ERROR_LOG_MESSAGE)
+
+ expected_logs = ''.join(logs)
+ expected_logs = expected_logs.replace(
+ "<type 'exceptions.OSError'>", "<class 'OSError'>")
+ return expected_logs
+
+ def setUp(self):
+ super(FunctionalTest, self).setUp()
+ self._log_dir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
+
+ def tearDown(self):
+ shutil.rmtree(self._log_dir)
+ super(FunctionalTest, self).tearDown()
+
+ def _exec_test(self,
+ verify_exit_fn,
+ expected_logs,
+ test_name='do_logging',
+ pass_logtostderr=False,
+ use_absl_log_file=False,
+ show_info_prefix=1,
+ call_dict_config=False,
+ extra_args=()):
+ """Execute the helper script and verify its output.
+
+ Args:
+ verify_exit_fn: A function taking (status, output).
+ expected_logs: List of tuples, or None if output shouldn't be checked.
+ Tuple is (log prefix, log type, expected contents):
+ - log prefix: A program name, or 'stderr'.
+ - log type: 'INFO', 'ERROR', etc.
+ - expected: Can be the following:
+ - A string
+ - A callable, called with the logs as a single argument
+ - None, means don't check contents of log file
+ test_name: Name to pass to helper.
+ pass_logtostderr: Pass --logtostderr to the helper script if True.
+ use_absl_log_file: If True, call
+ logging.get_absl_handler().use_absl_log_file() before test_fn in
+ logging_functional_test_helper.
+ show_info_prefix: --showprefixforinfo value passed to the helper script.
+ call_dict_config: True if helper script should call
+ logging.config.dictConfig.
+ extra_args: Iterable of str (optional, defaults to ()) - extra arguments
+ to pass to the helper script.
+
+ Raises:
+ AssertionError: Assertion error when test fails.
+ """
+ args = ['--log_dir=%s' % self._log_dir]
+ if pass_logtostderr:
+ args.append('--logtostderr')
+ if not show_info_prefix:
+ args.append('--noshowprefixforinfo')
+ args += extra_args
+
+ # Execute helper in subprocess.
+ env = os.environ.copy()
+ env.update({
+ 'TEST_NAME': test_name,
+ 'USE_ABSL_LOG_FILE': '%d' % (use_absl_log_file,),
+ 'CALL_DICT_CONFIG': '%d' % (call_dict_config,),
+ })
+ cmd = [self._get_helper()] + args
+
+ print('env: %s' % env, file=sys.stderr)
+ print('cmd: %s' % cmd, file=sys.stderr)
+ process = subprocess.Popen(
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env,
+ universal_newlines=True)
+ output, _ = process.communicate()
+ status = process.returncode
+
+ # Verify exit status.
+ verify_exit_fn(status, output)
+
+ # Check outputs?
+ if expected_logs is None:
+ return
+
+ # Get list of log files.
+ logs = os.listdir(self._log_dir)
+ logs = fnmatch.filter(logs, '*.log.*')
+ logs.append('stderr')
+
+ # Look for a log matching each expected pattern.
+ matched = []
+ unmatched = []
+ unexpected = logs[:]
+ for log_prefix, log_type, expected in expected_logs:
+ # What pattern?
+ if log_prefix == 'stderr':
+ assert log_type is None
+ pattern = 'stderr'
+ else:
+ pattern = r'%s[.].*[.]log[.]%s[.][\d.-]*$' % (log_prefix, log_type)
+
+ # Is it there
+ for basename in logs:
+ if re.match(pattern, basename):
+ matched.append([expected, basename])
+ unexpected.remove(basename)
+ break
+ else:
+ unmatched.append(pattern)
+
+ # Mismatch?
+ errors = ''
+ if unmatched:
+ errors += 'The following log files were expected but not found: %s' % (
+ '\n '.join(unmatched))
+ if unexpected:
+ if errors:
+ errors += '\n'
+ errors += 'The following log files were not expected: %s' % (
+ '\n '.join(unexpected))
+ if errors:
+ raise AssertionError(errors)
+
+ # Compare contents of matches.
+ for (expected, basename) in matched:
+ if expected is None:
+ continue
+
+ if basename == 'stderr':
+ actual = output
+ else:
+ path = os.path.join(self._log_dir, basename)
+ with open(path, encoding='utf-8') as f:
+ actual = f.read()
+
+ if callable(expected):
+ try:
+ expected(actual)
+ except AssertionError:
+ print('expected_logs assertion failed, actual {} log:\n{}'.format(
+ basename, actual), file=sys.stderr)
+ raise
+ elif isinstance(expected, str):
+ self.assertMultiLineEqual(_munge_log(expected), _munge_log(actual),
+ '%s differs' % basename)
+ else:
+ self.fail(
+ 'Invalid value found for expected logs: {}, type: {}'.format(
+ expected, type(expected)))
+
+ @parameterized.named_parameters(
+ ('', False),
+ ('logtostderr', True))
+ def test_py_logging(self, logtostderr):
+ # Python logging by default logs to stderr.
+ self._exec_test(
+ _verify_ok,
+ [['stderr', None, self._get_logs(logging.INFO)]],
+ pass_logtostderr=logtostderr)
+
+ def test_py_logging_use_absl_log_file(self):
+ # Python logging calling use_absl_log_file causes also log to files.
+ self._exec_test(
+ _verify_ok,
+ [['stderr', None, ''],
+ ['absl_log_file', 'INFO', self._get_logs(logging.INFO)]],
+ use_absl_log_file=True)
+
+ def test_py_logging_use_absl_log_file_logtostderr(self):
+ # Python logging asked to log to stderr even though use_absl_log_file
+ # is called.
+ self._exec_test(
+ _verify_ok,
+ [['stderr', None, self._get_logs(logging.INFO)]],
+ pass_logtostderr=True,
+ use_absl_log_file=True)
+
+ @parameterized.named_parameters(
+ ('', False),
+ ('logtostderr', True))
+ def test_py_logging_noshowprefixforinfo(self, logtostderr):
+ self._exec_test(
+ _verify_ok,
+ [['stderr', None, self._get_logs(logging.INFO,
+ include_info_prefix=False)]],
+ pass_logtostderr=logtostderr,
+ show_info_prefix=0)
+
+ def test_py_logging_noshowprefixforinfo_use_absl_log_file(self):
+ self._exec_test(
+ _verify_ok,
+ [['stderr', None, ''],
+ ['absl_log_file', 'INFO', self._get_logs(logging.INFO)]],
+ show_info_prefix=0,
+ use_absl_log_file=True)
+
+ def test_py_logging_noshowprefixforinfo_use_absl_log_file_logtostderr(self):
+ self._exec_test(
+ _verify_ok,
+ [['stderr', None, self._get_logs(logging.INFO,
+ include_info_prefix=False)]],
+ pass_logtostderr=True,
+ show_info_prefix=0,
+ use_absl_log_file=True)
+
+ def test_py_logging_noshowprefixforinfo_verbosity(self):
+ self._exec_test(
+ _verify_ok,
+ [['stderr', None, self._get_logs(logging.DEBUG)]],
+ pass_logtostderr=True,
+ show_info_prefix=0,
+ use_absl_log_file=True,
+ extra_args=['-v=1'])
+
+ def test_py_logging_fatal_main_thread_only(self):
+ self._exec_test(
+ _verify_fatal,
+ [['stderr', None, _get_fatal_log_expectation(
+ self, 'fatal_main_thread_only', False)]],
+ test_name='fatal_main_thread_only')
+
+ def test_py_logging_fatal_with_other_threads(self):
+ self._exec_test(
+ _verify_fatal,
+ [['stderr', None, _get_fatal_log_expectation(
+ self, 'fatal_with_other_threads', False)]],
+ test_name='fatal_with_other_threads')
+
+ def test_py_logging_fatal_non_main_thread(self):
+ self._exec_test(
+ _verify_fatal,
+ [['stderr', None, _get_fatal_log_expectation(
+ self, 'fatal_non_main_thread', False)]],
+ test_name='fatal_non_main_thread')
+
+ def test_py_logging_critical_non_absl(self):
+ self._exec_test(
+ _verify_ok,
+ [['stderr', None, _CRITICAL_DOWNGRADE_TO_ERROR_MESSAGE]],
+ test_name='critical_from_non_absl_logger')
+
+ def test_py_logging_skip_log_prefix(self):
+ self._exec_test(
+ _verify_ok,
+ [['stderr', None, '']],
+ test_name='register_frame_to_skip')
+
+ def test_py_logging_flush(self):
+ self._exec_test(
+ _verify_ok,
+ [['stderr', None, '']],
+ test_name='flush')
+
+ @parameterized.named_parameters(*_VERBOSITY_FLAG_TEST_PARAMETERS)
+ def test_py_logging_verbosity_stderr(self, verbosity):
+ """Tests -v/--verbosity flag with python logging to stderr."""
+ v_flag = '-v=%d' % verbosity
+ self._exec_test(
+ _verify_ok,
+ [['stderr', None, self._get_logs(verbosity)]],
+ extra_args=[v_flag])
+
+ @parameterized.named_parameters(*_VERBOSITY_FLAG_TEST_PARAMETERS)
+ def test_py_logging_verbosity_file(self, verbosity):
+ """Tests -v/--verbosity flag with Python logging to stderr."""
+ v_flag = '-v=%d' % verbosity
+ self._exec_test(
+ _verify_ok,
+ [['stderr', None, ''],
+ # When using python logging, it only creates a file named INFO,
+ # unlike C++ it also creates WARNING and ERROR files.
+ ['absl_log_file', 'INFO', self._get_logs(verbosity)]],
+ use_absl_log_file=True,
+ extra_args=[v_flag])
+
+ def test_stderrthreshold_py_logging(self):
+ """Tests --stderrthreshold."""
+
+ stderr_logs = '''\
+I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=debug, debug log
+I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=debug, info log
+W0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=debug, warning log
+E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=debug, error log
+I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=info, info log
+W0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=info, warning log
+E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=info, error log
+W0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=warning, warning log
+E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=warning, error log
+E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=error, error log
+'''
+
+ expected_logs = [
+ ['stderr', None, stderr_logs],
+ ['absl_log_file', 'INFO', None],
+ ]
+ # Set verbosity to debug to test stderrthreshold == debug.
+ extra_args = ['-v=1']
+
+ self._exec_test(
+ _verify_ok,
+ expected_logs,
+ test_name='stderrthreshold',
+ extra_args=extra_args,
+ use_absl_log_file=True)
+
+ def test_std_logging_py_logging(self):
+ """Tests logs from std logging."""
+ stderr_logs = '''\
+I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] std debug log
+I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] std info log
+W0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] std warning log
+E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] std error log
+'''
+ expected_logs = [['stderr', None, stderr_logs]]
+
+ extra_args = ['-v=1', '--logtostderr']
+ self._exec_test(
+ _verify_ok,
+ expected_logs,
+ test_name='std_logging',
+ extra_args=extra_args)
+
+ def test_bad_exc_info_py_logging(self):
+
+ def assert_stderr(stderr):
+ # The exact message differs among different Python versions. So it just
+ # asserts some certain information is there.
+ self.assertIn('Traceback (most recent call last):', stderr)
+ self.assertIn('IndexError', stderr)
+
+ expected_logs = [
+ ['stderr', None, assert_stderr],
+ ['absl_log_file', 'INFO', '']]
+
+ self._exec_test(
+ _verify_ok,
+ expected_logs,
+ test_name='bad_exc_info',
+ use_absl_log_file=True)
+
+ def test_verbosity_logger_levels_flag_ordering(self):
+ """Make sure last-specified flag wins."""
+
+ def assert_error_level_logged(stderr):
+ lines = stderr.splitlines()
+ for line in lines:
+ self.assertIn('std error log', line)
+
+ self._exec_test(
+ _verify_ok,
+ test_name='std_logging',
+ expected_logs=[('stderr', None, assert_error_level_logged)],
+ extra_args=['-v=1', '--logger_levels=:ERROR'])
+
+ def assert_debug_level_logged(stderr):
+ lines = stderr.splitlines()
+ for line in lines:
+ self.assertRegex(line, 'std (debug|info|warning|error) log')
+
+ self._exec_test(
+ _verify_ok,
+ test_name='std_logging',
+ expected_logs=[('stderr', None, assert_debug_level_logged)],
+ extra_args=['--logger_levels=:ERROR', '-v=1'])
+
+ def test_none_exc_info_py_logging(self):
+
+ expected_stderr = ''
+ expected_info = '''\
+I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] None exc_info
+'''
+ expected_info += 'NoneType: None\n'
+
+ expected_logs = [
+ ['stderr', None, expected_stderr],
+ ['absl_log_file', 'INFO', expected_info]]
+
+ self._exec_test(
+ _verify_ok,
+ expected_logs,
+ test_name='none_exc_info',
+ use_absl_log_file=True)
+
+ def test_unicode_py_logging(self):
+
+ def get_stderr_message(stderr, name):
+ match = re.search(
+ '-- begin {} --\n(.*)-- end {} --'.format(name, name),
+ stderr, re.MULTILINE | re.DOTALL)
+ self.assertTrue(
+ match, 'Cannot find stderr message for test {}'.format(name))
+ return match.group(1)
+
+ def assert_stderr(stderr):
+ """Verifies that it writes correct information to stderr for Python 3.
+
+ There are no unicode errors in Python 3.
+
+ Args:
+ stderr: the message from stderr.
+ """
+ # Successful logs:
+ for name in (
+ 'unicode', 'unicode % unicode', 'bytes % bytes', 'unicode % bytes',
+ 'bytes % unicode', 'unicode % iso8859-15', 'str % exception',
+ 'str % exception'):
+ logging.info('name = %s', name)
+ self.assertEqual('', get_stderr_message(stderr, name))
+
+ expected_logs = [['stderr', None, assert_stderr]]
+
+ info_log = u'''\
+I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] G\u00eete: Ch\u00e2tonnaye
+I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] G\u00eete: Ch\u00e2tonnaye
+I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] b'G\\xc3\\xaete: b'Ch\\xc3\\xa2tonnaye''
+I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] G\u00eete: b'Ch\\xc3\\xa2tonnaye'
+I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] b'G\\xc3\\xaete: Ch\u00e2tonnaye'
+I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] G\u00eete: b'Ch\\xe2tonnaye'
+I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] exception: Ch\u00e2tonnaye
+'''
+ expected_logs.append(['absl_log_file', 'INFO', info_log])
+
+ self._exec_test(
+ _verify_ok,
+ expected_logs,
+ test_name='unicode',
+ use_absl_log_file=True)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/logging/tests/logging_functional_test_helper.py b/absl/logging/tests/logging_functional_test_helper.py
new file mode 100644
index 0000000..b95647b
--- /dev/null
+++ b/absl/logging/tests/logging_functional_test_helper.py
@@ -0,0 +1,312 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Helper script for logging_functional_test."""
+
+import logging as std_logging
+import logging.config as std_logging_config
+import os
+import sys
+import threading
+import time
+import timeit
+from unittest import mock
+
+from absl import app
+from absl import flags
+from absl import logging
+
+FLAGS = flags.FLAGS
+
+
+class VerboseDel(object):
+ """Dummy class to test __del__ running."""
+
+ def __init__(self, msg):
+ self._msg = msg
+
+ def __del__(self):
+ sys.stderr.write(self._msg)
+ sys.stderr.flush()
+
+
+def _test_do_logging():
+ """Do some log operations."""
+ logging.vlog(3, 'This line is VLOG level 3')
+ logging.vlog(2, 'This line is VLOG level 2')
+ logging.log(2, 'This line is log level 2')
+ if logging.vlog_is_on(2):
+ logging.log(1, 'VLOG level 1, but only if VLOG level 2 is active')
+
+ logging.vlog(1, 'This line is VLOG level 1')
+ logging.log(1, 'This line is log level 1')
+ logging.debug('This line is DEBUG')
+
+ logging.vlog(0, 'This line is VLOG level 0')
+ logging.log(0, 'This line is log level 0')
+ logging.info('Interesting Stuff\0')
+ logging.info('Interesting Stuff with Arguments: %d', 42)
+ logging.info('%(a)s Stuff with %(b)s',
+ {'a': 'Interesting', 'b': 'Dictionary'})
+
+ with mock.patch.object(timeit, 'default_timer') as mock_timer:
+ mock_timer.return_value = 0
+ while timeit.default_timer() < 9:
+ logging.log_every_n_seconds(logging.INFO, 'This should appear 5 times.',
+ 2)
+ mock_timer.return_value = mock_timer() + .2
+
+ for i in range(1, 5):
+ logging.log_first_n(logging.INFO, 'Info first %d of %d', 2, i, 2)
+ logging.log_every_n(logging.INFO, 'Info %d (every %d)', 3, i, 3)
+
+ logging.vlog(-1, 'This line is VLOG level -1')
+ logging.log(-1, 'This line is log level -1')
+ logging.warning('Worrying Stuff')
+ for i in range(1, 5):
+ logging.log_first_n(logging.WARNING, 'Warn first %d of %d', 2, i, 2)
+ logging.log_every_n(logging.WARNING, 'Warn %d (every %d)', 3, i, 3)
+
+ logging.vlog(-2, 'This line is VLOG level -2')
+ logging.log(-2, 'This line is log level -2')
+ try:
+ raise OSError('Fake Error')
+ except OSError:
+ saved_exc_info = sys.exc_info()
+ logging.exception('An Exception %s')
+ logging.exception('Once more, %(reason)s', {'reason': 'just because'})
+ logging.error('Exception 2 %s', exc_info=True)
+ logging.error('Non-exception', exc_info=False)
+
+ try:
+ sys.exc_clear()
+ except AttributeError:
+ # No sys.exc_clear() in Python 3, but this will clear sys.exc_info() too.
+ pass
+
+ logging.error('Exception %s', '3', exc_info=saved_exc_info)
+ logging.error('No traceback', exc_info=saved_exc_info[:2] + (None,))
+
+ logging.error('Alarming Stuff')
+ for i in range(1, 5):
+ logging.log_first_n(logging.ERROR, 'Error first %d of %d', 2, i, 2)
+ logging.log_every_n(logging.ERROR, 'Error %d (every %d)', 3, i, 3)
+ logging.flush()
+
+
+def _test_fatal_main_thread_only():
+ """Test logging.fatal from main thread, no other threads running."""
+ v = VerboseDel('fatal_main_thread_only main del called\n')
+ try:
+ logging.fatal('fatal_main_thread_only message')
+ finally:
+ del v
+
+
+def _test_fatal_with_other_threads():
+ """Test logging.fatal from main thread, other threads running."""
+
+ lock = threading.Lock()
+ lock.acquire()
+
+ def sleep_forever(lock=lock):
+ v = VerboseDel('fatal_with_other_threads non-main del called\n')
+ try:
+ lock.release()
+ while True:
+ time.sleep(10000)
+ finally:
+ del v
+
+ v = VerboseDel('fatal_with_other_threads main del called\n')
+ try:
+ # Start new thread
+ t = threading.Thread(target=sleep_forever)
+ t.start()
+
+ # Wait for other thread
+ lock.acquire()
+ lock.release()
+
+ # Die
+ logging.fatal('fatal_with_other_threads message')
+ while True:
+ time.sleep(10000)
+ finally:
+ del v
+
+
+def _test_fatal_non_main_thread():
+ """Test logging.fatal from non main thread."""
+
+ lock = threading.Lock()
+ lock.acquire()
+
+ def die_soon(lock=lock):
+ v = VerboseDel('fatal_non_main_thread non-main del called\n')
+ try:
+ # Wait for signal from other thread
+ lock.acquire()
+ lock.release()
+ logging.fatal('fatal_non_main_thread message')
+ while True:
+ time.sleep(10000)
+ finally:
+ del v
+
+ v = VerboseDel('fatal_non_main_thread main del called\n')
+ try:
+ # Start new thread
+ t = threading.Thread(target=die_soon)
+ t.start()
+
+ # Signal other thread
+ lock.release()
+
+ # Wait for it to die
+ while True:
+ time.sleep(10000)
+ finally:
+ del v
+
+
+def _test_critical_from_non_absl_logger():
+ """Test CRITICAL logs from non-absl loggers."""
+
+ std_logging.critical('A critical message')
+
+
+def _test_register_frame_to_skip():
+ """Test skipping frames for line number reporting."""
+
+ def _getline():
+
+ def _getline_inner():
+ return logging.get_absl_logger().findCaller()[1]
+
+ return _getline_inner()
+
+ # Check register_frame_to_skip function to see if log frame skipping works.
+ line1 = _getline()
+ line2 = _getline()
+ logging.get_absl_logger().register_frame_to_skip(__file__, '_getline')
+ line3 = _getline()
+ # Both should be line number of the _getline_inner() call.
+ assert (line1 == line2), (line1, line2)
+ # line3 should be a line number in this function.
+ assert (line2 != line3), (line2, line3)
+
+
+def _test_flush():
+ """Test flush in various difficult cases."""
+ # Flush, but one of the logfiles is closed
+ log_filename = os.path.join(FLAGS.log_dir, 'a_thread_with_logfile.txt')
+ with open(log_filename, 'w') as log_file:
+ logging.get_absl_handler().python_handler.stream = log_file
+ logging.flush()
+
+
+def _test_stderrthreshold():
+ """Tests modifying --stderrthreshold after flag parsing will work."""
+
+ def log_things():
+ logging.debug('FLAGS.stderrthreshold=%s, debug log', FLAGS.stderrthreshold)
+ logging.info('FLAGS.stderrthreshold=%s, info log', FLAGS.stderrthreshold)
+ logging.warning('FLAGS.stderrthreshold=%s, warning log',
+ FLAGS.stderrthreshold)
+ logging.error('FLAGS.stderrthreshold=%s, error log', FLAGS.stderrthreshold)
+
+ FLAGS.stderrthreshold = 'debug'
+ log_things()
+ FLAGS.stderrthreshold = 'info'
+ log_things()
+ FLAGS.stderrthreshold = 'warning'
+ log_things()
+ FLAGS.stderrthreshold = 'error'
+ log_things()
+
+
+def _test_std_logging():
+ """Tests logs from std logging."""
+ std_logging.debug('std debug log')
+ std_logging.info('std info log')
+ std_logging.warning('std warning log')
+ std_logging.error('std error log')
+
+
+def _test_bad_exc_info():
+ """Tests when a bad exc_info valud is provided."""
+ logging.info('Bad exc_info', exc_info=(None, None))
+
+
+def _test_none_exc_info():
+ """Tests when exc_info is requested but not available."""
+ # Clear exc_info first.
+ try:
+ sys.exc_clear()
+ except AttributeError:
+ # No sys.exc_clear() in Python 3, but this will clear sys.exc_info() too.
+ pass
+ logging.info('None exc_info', exc_info=True)
+
+
+def _test_unicode():
+ """Tests unicode handling."""
+
+ test_names = []
+
+ def log(name, msg, *args):
+ """Logs the message, and ensures the same name is not logged again."""
+ assert name not in test_names, ('test_unicode expects unique names to work,'
+ ' found existing name {}').format(name)
+ test_names.append(name)
+
+ # Add line separators so that tests can verify the output for each log
+ # message.
+ sys.stderr.write('-- begin {} --\n'.format(name))
+ logging.info(msg, *args)
+ sys.stderr.write('-- end {} --\n'.format(name))
+
+ log('unicode', u'G\u00eete: Ch\u00e2tonnaye')
+ log('unicode % unicode', u'G\u00eete: %s', u'Ch\u00e2tonnaye')
+ log('bytes % bytes', u'G\u00eete: %s'.encode('utf-8'),
+ u'Ch\u00e2tonnaye'.encode('utf-8'))
+ log('unicode % bytes', u'G\u00eete: %s', u'Ch\u00e2tonnaye'.encode('utf-8'))
+ log('bytes % unicode', u'G\u00eete: %s'.encode('utf-8'), u'Ch\u00e2tonnaye')
+ log('unicode % iso8859-15', u'G\u00eete: %s',
+ u'Ch\u00e2tonnaye'.encode('iso-8859-15'))
+ log('str % exception', 'exception: %s', Exception(u'Ch\u00e2tonnaye'))
+
+
+def main(argv):
+ del argv # Unused.
+
+ test_name = os.environ.get('TEST_NAME', None)
+ test_fn = globals().get('_test_%s' % test_name)
+ if test_fn is None:
+ raise AssertionError('TEST_NAME must be set to a valid value')
+ # Flush so previous messages are written to file before we switch to a new
+ # file with use_absl_log_file.
+ logging.flush()
+ if os.environ.get('USE_ABSL_LOG_FILE') == '1':
+ logging.get_absl_handler().use_absl_log_file('absl_log_file', FLAGS.log_dir)
+
+ test_fn()
+
+
+if __name__ == '__main__':
+ sys.argv[0] = 'py_argv_0'
+ if os.environ.get('CALL_DICT_CONFIG') == '1':
+ std_logging_config.dictConfig({'version': 1})
+ app.run(main)
diff --git a/absl/logging/tests/logging_test.py b/absl/logging/tests/logging_test.py
new file mode 100644
index 0000000..e5c4fcc
--- /dev/null
+++ b/absl/logging/tests/logging_test.py
@@ -0,0 +1,1002 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for absl.logging."""
+
+import contextlib
+import functools
+import getpass
+import io
+import logging as std_logging
+import os
+import re
+import socket
+import sys
+import tempfile
+import threading
+import time
+import traceback
+import unittest
+from unittest import mock
+
+from absl import flags
+from absl import logging
+from absl.testing import absltest
+from absl.testing import flagsaver
+from absl.testing import parameterized
+
+FLAGS = flags.FLAGS
+
+
+class ConfigurationTest(absltest.TestCase):
+ """Tests the initial logging configuration."""
+
+ def test_logger_and_handler(self):
+ absl_logger = std_logging.getLogger('absl')
+ self.assertIs(absl_logger, logging.get_absl_logger())
+ self.assertIsInstance(absl_logger, logging.ABSLLogger)
+ self.assertIsInstance(
+ logging.get_absl_handler().python_handler.formatter,
+ logging.PythonFormatter)
+
+
+class LoggerLevelsTest(parameterized.TestCase):
+
+ def setUp(self):
+ super(LoggerLevelsTest, self).setUp()
+ # Since these tests muck with the flag, always save/restore in case the
+ # tests forget to clean up properly.
+ # enter_context() is py3-only, but manually enter/exit should suffice.
+ cm = self.set_logger_levels({})
+ cm.__enter__()
+ self.addCleanup(lambda: cm.__exit__(None, None, None))
+
+ @contextlib.contextmanager
+ def set_logger_levels(self, levels):
+ original_levels = {
+ name: std_logging.getLogger(name).level for name in levels
+ }
+
+ try:
+ with flagsaver.flagsaver(logger_levels=levels):
+ yield
+ finally:
+ for name, level in original_levels.items():
+ std_logging.getLogger(name).setLevel(level)
+
+ def assert_logger_level(self, name, expected_level):
+ logger = std_logging.getLogger(name)
+ self.assertEqual(logger.level, expected_level)
+
+ def assert_logged(self, logger_name, expected_msgs):
+ logger = std_logging.getLogger(logger_name)
+ # NOTE: assertLogs() sets the logger to INFO if not specified.
+ with self.assertLogs(logger, logger.level) as cm:
+ logger.debug('debug')
+ logger.info('info')
+ logger.warning('warning')
+ logger.error('error')
+ logger.critical('critical')
+
+ actual = {r.getMessage() for r in cm.records}
+ self.assertEqual(set(expected_msgs), actual)
+
+ def test_setting_levels(self):
+ # Other tests change the root logging level, so we can't
+ # assume it's the default.
+ orig_root_level = std_logging.root.getEffectiveLevel()
+ with self.set_logger_levels({'foo': 'ERROR', 'bar': 'DEBUG'}):
+
+ self.assert_logger_level('foo', std_logging.ERROR)
+ self.assert_logger_level('bar', std_logging.DEBUG)
+ self.assert_logger_level('', orig_root_level)
+
+ self.assert_logged('foo', {'error', 'critical'})
+ self.assert_logged('bar',
+ {'debug', 'info', 'warning', 'error', 'critical'})
+
+ @parameterized.named_parameters(
+ ('empty', ''),
+ ('one_value', 'one:INFO'),
+ ('two_values', 'one.a:INFO,two.b:ERROR'),
+ ('whitespace_ignored', ' one : DEBUG , two : INFO'),
+ )
+ def test_serialize_parse(self, levels_str):
+ fl = FLAGS['logger_levels']
+ fl.parse(levels_str)
+ expected = levels_str.replace(' ', '')
+ actual = fl.serialize()
+ self.assertEqual('--logger_levels={}'.format(expected), actual)
+
+ def test_invalid_value(self):
+ with self.assertRaisesRegex(ValueError, 'Unknown level.*10'):
+ FLAGS['logger_levels'].parse('foo:10')
+
+
+class PythonHandlerTest(absltest.TestCase):
+ """Tests the PythonHandler class."""
+
+ def setUp(self):
+ super().setUp()
+ (year, month, day, hour, minute, sec,
+ dunno, dayofyear, dst_flag) = (1979, 10, 21, 18, 17, 16, 3, 15, 0)
+ self.now_tuple = (year, month, day, hour, minute, sec,
+ dunno, dayofyear, dst_flag)
+ self.python_handler = logging.PythonHandler()
+
+ def tearDown(self):
+ mock.patch.stopall()
+ super().tearDown()
+
+ @flagsaver.flagsaver(logtostderr=False)
+ def test_set_google_log_file_no_log_to_stderr(self):
+ with mock.patch.object(self.python_handler, 'start_logging_to_file'):
+ self.python_handler.use_absl_log_file()
+ self.python_handler.start_logging_to_file.assert_called_once_with(
+ program_name=None, log_dir=None)
+
+ @flagsaver.flagsaver(logtostderr=True)
+ def test_set_google_log_file_with_log_to_stderr(self):
+ self.python_handler.stream = None
+ self.python_handler.use_absl_log_file()
+ self.assertEqual(sys.stderr, self.python_handler.stream)
+
+ @mock.patch.object(logging, 'find_log_dir_and_names')
+ @mock.patch.object(logging.time, 'localtime')
+ @mock.patch.object(logging.time, 'time')
+ @mock.patch.object(os.path, 'islink')
+ @mock.patch.object(os, 'unlink')
+ @mock.patch.object(os, 'getpid')
+ def test_start_logging_to_file(
+ self, mock_getpid, mock_unlink, mock_islink, mock_time,
+ mock_localtime, mock_find_log_dir_and_names):
+ mock_find_log_dir_and_names.return_value = ('here', 'prog1', 'prog1')
+ mock_time.return_value = '12345'
+ mock_localtime.return_value = self.now_tuple
+ mock_getpid.return_value = 4321
+ symlink = os.path.join('here', 'prog1.INFO')
+ mock_islink.return_value = True
+ with mock.patch.object(
+ logging, 'open', return_value=sys.stdout, create=True):
+ if getattr(os, 'symlink', None):
+ with mock.patch.object(os, 'symlink'):
+ self.python_handler.start_logging_to_file()
+ mock_unlink.assert_called_once_with(symlink)
+ os.symlink.assert_called_once_with(
+ 'prog1.INFO.19791021-181716.4321', symlink)
+ else:
+ self.python_handler.start_logging_to_file()
+
+ def test_log_file(self):
+ handler = logging.PythonHandler()
+ self.assertEqual(sys.stderr, handler.stream)
+
+ stream = mock.Mock()
+ handler = logging.PythonHandler(stream)
+ self.assertEqual(stream, handler.stream)
+
+ def test_flush(self):
+ stream = mock.Mock()
+ handler = logging.PythonHandler(stream)
+ handler.flush()
+ stream.flush.assert_called_once()
+
+ def test_flush_with_value_error(self):
+ stream = mock.Mock()
+ stream.flush.side_effect = ValueError
+ handler = logging.PythonHandler(stream)
+ handler.flush()
+ stream.flush.assert_called_once()
+
+ def test_flush_with_environment_error(self):
+ stream = mock.Mock()
+ stream.flush.side_effect = EnvironmentError
+ handler = logging.PythonHandler(stream)
+ handler.flush()
+ stream.flush.assert_called_once()
+
+ def test_flush_with_assertion_error(self):
+ stream = mock.Mock()
+ stream.flush.side_effect = AssertionError
+ handler = logging.PythonHandler(stream)
+ with self.assertRaises(AssertionError):
+ handler.flush()
+
+ def test_log_to_std_err(self):
+ record = std_logging.LogRecord(
+ 'name', std_logging.INFO, 'path', 12, 'logging_msg', [], False)
+ with mock.patch.object(std_logging.StreamHandler, 'emit'):
+ self.python_handler._log_to_stderr(record)
+ std_logging.StreamHandler.emit.assert_called_once_with(record)
+
+ @flagsaver.flagsaver(logtostderr=True)
+ def test_emit_log_to_stderr(self):
+ record = std_logging.LogRecord(
+ 'name', std_logging.INFO, 'path', 12, 'logging_msg', [], False)
+ with mock.patch.object(self.python_handler, '_log_to_stderr'):
+ self.python_handler.emit(record)
+ self.python_handler._log_to_stderr.assert_called_once_with(record)
+
+ def test_emit(self):
+ stream = io.StringIO()
+ handler = logging.PythonHandler(stream)
+ handler.stderr_threshold = std_logging.FATAL
+ record = std_logging.LogRecord(
+ 'name', std_logging.INFO, 'path', 12, 'logging_msg', [], False)
+ handler.emit(record)
+ self.assertEqual(1, stream.getvalue().count('logging_msg'))
+
+ @flagsaver.flagsaver(stderrthreshold='debug')
+ def test_emit_and_stderr_threshold(self):
+ mock_stderr = io.StringIO()
+ stream = io.StringIO()
+ handler = logging.PythonHandler(stream)
+ record = std_logging.LogRecord(
+ 'name', std_logging.INFO, 'path', 12, 'logging_msg', [], False)
+ with mock.patch.object(sys, 'stderr', new=mock_stderr) as mock_stderr:
+ handler.emit(record)
+ self.assertEqual(1, stream.getvalue().count('logging_msg'))
+ self.assertEqual(1, mock_stderr.getvalue().count('logging_msg'))
+
+ @flagsaver.flagsaver(alsologtostderr=True)
+ def test_emit_also_log_to_stderr(self):
+ mock_stderr = io.StringIO()
+ stream = io.StringIO()
+ handler = logging.PythonHandler(stream)
+ handler.stderr_threshold = std_logging.FATAL
+ record = std_logging.LogRecord(
+ 'name', std_logging.INFO, 'path', 12, 'logging_msg', [], False)
+ with mock.patch.object(sys, 'stderr', new=mock_stderr) as mock_stderr:
+ handler.emit(record)
+ self.assertEqual(1, stream.getvalue().count('logging_msg'))
+ self.assertEqual(1, mock_stderr.getvalue().count('logging_msg'))
+
+ def test_emit_on_stderr(self):
+ mock_stderr = io.StringIO()
+ with mock.patch.object(sys, 'stderr', new=mock_stderr) as mock_stderr:
+ handler = logging.PythonHandler()
+ handler.stderr_threshold = std_logging.INFO
+ record = std_logging.LogRecord(
+ 'name', std_logging.INFO, 'path', 12, 'logging_msg', [], False)
+ handler.emit(record)
+ self.assertEqual(1, mock_stderr.getvalue().count('logging_msg'))
+
+ def test_emit_fatal_absl(self):
+ stream = io.StringIO()
+ handler = logging.PythonHandler(stream)
+ record = std_logging.LogRecord(
+ 'name', std_logging.FATAL, 'path', 12, 'logging_msg', [], False)
+ record.__dict__[logging._ABSL_LOG_FATAL] = True
+ with mock.patch.object(handler, 'flush') as mock_flush:
+ with mock.patch.object(os, 'abort') as mock_abort:
+ handler.emit(record)
+ mock_abort.assert_called_once()
+ mock_flush.assert_called() # flush is also called by super class.
+
+ def test_emit_fatal_non_absl(self):
+ stream = io.StringIO()
+ handler = logging.PythonHandler(stream)
+ record = std_logging.LogRecord(
+ 'name', std_logging.FATAL, 'path', 12, 'logging_msg', [], False)
+ with mock.patch.object(os, 'abort') as mock_abort:
+ handler.emit(record)
+ mock_abort.assert_not_called()
+
+ def test_close(self):
+ stream = mock.Mock()
+ stream.isatty.return_value = True
+ handler = logging.PythonHandler(stream)
+ with mock.patch.object(handler, 'flush') as mock_flush:
+ with mock.patch.object(std_logging.StreamHandler, 'close') as super_close:
+ handler.close()
+ mock_flush.assert_called_once()
+ super_close.assert_called_once()
+ stream.close.assert_not_called()
+
+ def test_close_afile(self):
+ stream = mock.Mock()
+ stream.isatty.return_value = False
+ stream.close.side_effect = ValueError
+ handler = logging.PythonHandler(stream)
+ with mock.patch.object(handler, 'flush') as mock_flush:
+ with mock.patch.object(std_logging.StreamHandler, 'close') as super_close:
+ handler.close()
+ mock_flush.assert_called_once()
+ super_close.assert_called_once()
+
+ def test_close_stderr(self):
+ with mock.patch.object(sys, 'stderr') as mock_stderr:
+ mock_stderr.isatty.return_value = False
+ handler = logging.PythonHandler(sys.stderr)
+ handler.close()
+ mock_stderr.close.assert_not_called()
+
+ def test_close_stdout(self):
+ with mock.patch.object(sys, 'stdout') as mock_stdout:
+ mock_stdout.isatty.return_value = False
+ handler = logging.PythonHandler(sys.stdout)
+ handler.close()
+ mock_stdout.close.assert_not_called()
+
+ def test_close_original_stderr(self):
+ with mock.patch.object(sys, '__stderr__') as mock_original_stderr:
+ mock_original_stderr.isatty.return_value = False
+ handler = logging.PythonHandler(sys.__stderr__)
+ handler.close()
+ mock_original_stderr.close.assert_not_called()
+
+ def test_close_original_stdout(self):
+ with mock.patch.object(sys, '__stdout__') as mock_original_stdout:
+ mock_original_stdout.isatty.return_value = False
+ handler = logging.PythonHandler(sys.__stdout__)
+ handler.close()
+ mock_original_stdout.close.assert_not_called()
+
+ def test_close_fake_file(self):
+
+ class FakeFile(object):
+ """A file-like object that does not implement "isatty"."""
+
+ def __init__(self):
+ self.closed = False
+
+ def close(self):
+ self.closed = True
+
+ def flush(self):
+ pass
+
+ fake_file = FakeFile()
+ handler = logging.PythonHandler(fake_file)
+ handler.close()
+ self.assertTrue(fake_file.closed)
+
+
+class ABSLHandlerTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ formatter = logging.PythonFormatter()
+ self.absl_handler = logging.ABSLHandler(formatter)
+
+ def test_activate_python_handler(self):
+ self.absl_handler.activate_python_handler()
+ self.assertEqual(
+ self.absl_handler._current_handler, self.absl_handler.python_handler)
+
+
+class ABSLLoggerTest(absltest.TestCase):
+ """Tests the ABSLLogger class."""
+
+ def set_up_mock_frames(self):
+ """Sets up mock frames for use with the testFindCaller methods."""
+ logging_file = os.path.join('absl', 'logging', '__init__.py')
+
+ # Set up mock frame 0
+ mock_frame_0 = mock.Mock()
+ mock_code_0 = mock.Mock()
+ mock_code_0.co_filename = logging_file
+ mock_code_0.co_name = 'LoggingLog'
+ mock_code_0.co_firstlineno = 124
+ mock_frame_0.f_code = mock_code_0
+ mock_frame_0.f_lineno = 125
+
+ # Set up mock frame 1
+ mock_frame_1 = mock.Mock()
+ mock_code_1 = mock.Mock()
+ mock_code_1.co_filename = 'myfile.py'
+ mock_code_1.co_name = 'Method1'
+ mock_code_1.co_firstlineno = 124
+ mock_frame_1.f_code = mock_code_1
+ mock_frame_1.f_lineno = 125
+
+ # Set up mock frame 2
+ mock_frame_2 = mock.Mock()
+ mock_code_2 = mock.Mock()
+ mock_code_2.co_filename = 'myfile.py'
+ mock_code_2.co_name = 'Method2'
+ mock_code_2.co_firstlineno = 124
+ mock_frame_2.f_code = mock_code_2
+ mock_frame_2.f_lineno = 125
+
+ # Set up mock frame 3
+ mock_frame_3 = mock.Mock()
+ mock_code_3 = mock.Mock()
+ mock_code_3.co_filename = 'myfile.py'
+ mock_code_3.co_name = 'Method3'
+ mock_code_3.co_firstlineno = 124
+ mock_frame_3.f_code = mock_code_3
+ mock_frame_3.f_lineno = 125
+
+ # Set up mock frame 4 that has the same function name as frame 2.
+ mock_frame_4 = mock.Mock()
+ mock_code_4 = mock.Mock()
+ mock_code_4.co_filename = 'myfile.py'
+ mock_code_4.co_name = 'Method2'
+ mock_code_4.co_firstlineno = 248
+ mock_frame_4.f_code = mock_code_4
+ mock_frame_4.f_lineno = 249
+
+ # Tie them together.
+ mock_frame_4.f_back = None
+ mock_frame_3.f_back = mock_frame_4
+ mock_frame_2.f_back = mock_frame_3
+ mock_frame_1.f_back = mock_frame_2
+ mock_frame_0.f_back = mock_frame_1
+
+ mock.patch.object(sys, '_getframe').start()
+ sys._getframe.return_value = mock_frame_0
+
+ def setUp(self):
+ super().setUp()
+ self.message = 'Hello Nurse'
+ self.logger = logging.ABSLLogger('')
+
+ def tearDown(self):
+ mock.patch.stopall()
+ self.logger._frames_to_skip.clear()
+ super().tearDown()
+
+ def test_constructor_without_level(self):
+ self.logger = logging.ABSLLogger('')
+ self.assertEqual(std_logging.NOTSET, self.logger.getEffectiveLevel())
+
+ def test_constructor_with_level(self):
+ self.logger = logging.ABSLLogger('', std_logging.DEBUG)
+ self.assertEqual(std_logging.DEBUG, self.logger.getEffectiveLevel())
+
+ def test_find_caller_normal(self):
+ self.set_up_mock_frames()
+ expected_name = 'Method1'
+ self.assertEqual(expected_name, self.logger.findCaller()[2])
+
+ def test_find_caller_skip_method1(self):
+ self.set_up_mock_frames()
+ self.logger.register_frame_to_skip('myfile.py', 'Method1')
+ expected_name = 'Method2'
+ self.assertEqual(expected_name, self.logger.findCaller()[2])
+
+ def test_find_caller_skip_method1_and_method2(self):
+ self.set_up_mock_frames()
+ self.logger.register_frame_to_skip('myfile.py', 'Method1')
+ self.logger.register_frame_to_skip('myfile.py', 'Method2')
+ expected_name = 'Method3'
+ self.assertEqual(expected_name, self.logger.findCaller()[2])
+
+ def test_find_caller_skip_method1_and_method3(self):
+ self.set_up_mock_frames()
+ self.logger.register_frame_to_skip('myfile.py', 'Method1')
+ # Skipping Method3 should change nothing since Method2 should be hit.
+ self.logger.register_frame_to_skip('myfile.py', 'Method3')
+ expected_name = 'Method2'
+ self.assertEqual(expected_name, self.logger.findCaller()[2])
+
+ def test_find_caller_skip_method1_and_method4(self):
+ self.set_up_mock_frames()
+ self.logger.register_frame_to_skip('myfile.py', 'Method1')
+ # Skipping frame 4's Method2 should change nothing for frame 2's Method2.
+ self.logger.register_frame_to_skip('myfile.py', 'Method2', 248)
+ expected_name = 'Method2'
+ expected_frame_lineno = 125
+ self.assertEqual(expected_name, self.logger.findCaller()[2])
+ self.assertEqual(expected_frame_lineno, self.logger.findCaller()[1])
+
+ def test_find_caller_skip_method1_method2_and_method3(self):
+ self.set_up_mock_frames()
+ self.logger.register_frame_to_skip('myfile.py', 'Method1')
+ self.logger.register_frame_to_skip('myfile.py', 'Method2', 124)
+ self.logger.register_frame_to_skip('myfile.py', 'Method3')
+ expected_name = 'Method2'
+ expected_frame_lineno = 249
+ self.assertEqual(expected_name, self.logger.findCaller()[2])
+ self.assertEqual(expected_frame_lineno, self.logger.findCaller()[1])
+
+ def test_find_caller_stack_info(self):
+ self.set_up_mock_frames()
+ self.logger.register_frame_to_skip('myfile.py', 'Method1')
+ with mock.patch.object(traceback, 'print_stack') as print_stack:
+ self.assertEqual(
+ ('myfile.py', 125, 'Method2', 'Stack (most recent call last):'),
+ self.logger.findCaller(stack_info=True))
+ print_stack.assert_called_once()
+
+ def test_critical(self):
+ with mock.patch.object(self.logger, 'log'):
+ self.logger.critical(self.message)
+ self.logger.log.assert_called_once_with(
+ std_logging.CRITICAL, self.message)
+
+ def test_fatal(self):
+ with mock.patch.object(self.logger, 'log'):
+ self.logger.fatal(self.message)
+ self.logger.log.assert_called_once_with(std_logging.FATAL, self.message)
+
+ def test_error(self):
+ with mock.patch.object(self.logger, 'log'):
+ self.logger.error(self.message)
+ self.logger.log.assert_called_once_with(std_logging.ERROR, self.message)
+
+ def test_warn(self):
+ with mock.patch.object(self.logger, 'log'):
+ self.logger.warn(self.message)
+ self.logger.log.assert_called_once_with(std_logging.WARN, self.message)
+
+ def test_warning(self):
+ with mock.patch.object(self.logger, 'log'):
+ self.logger.warning(self.message)
+ self.logger.log.assert_called_once_with(std_logging.WARNING, self.message)
+
+ def test_info(self):
+ with mock.patch.object(self.logger, 'log'):
+ self.logger.info(self.message)
+ self.logger.log.assert_called_once_with(std_logging.INFO, self.message)
+
+ def test_debug(self):
+ with mock.patch.object(self.logger, 'log'):
+ self.logger.debug(self.message)
+ self.logger.log.assert_called_once_with(std_logging.DEBUG, self.message)
+
+ def test_log_debug_with_python(self):
+ with mock.patch.object(self.logger, 'log'):
+ FLAGS.verbosity = 1
+ self.logger.debug(self.message)
+ self.logger.log.assert_called_once_with(std_logging.DEBUG, self.message)
+
+ def test_log_fatal_with_python(self):
+ with mock.patch.object(self.logger, 'log'):
+ self.logger.fatal(self.message)
+ self.logger.log.assert_called_once_with(std_logging.FATAL, self.message)
+
+ def test_register_frame_to_skip(self):
+ # This is basically just making sure that if I put something in a
+ # list, it actually appears in that list.
+ frame_tuple = ('file', 'method')
+ self.logger.register_frame_to_skip(*frame_tuple)
+ self.assertIn(frame_tuple, self.logger._frames_to_skip)
+
+ def test_register_frame_to_skip_with_lineno(self):
+ frame_tuple = ('file', 'method', 123)
+ self.logger.register_frame_to_skip(*frame_tuple)
+ self.assertIn(frame_tuple, self.logger._frames_to_skip)
+
+ def test_logger_cannot_be_disabled(self):
+ self.logger.disabled = True
+ record = self.logger.makeRecord(
+ 'name', std_logging.INFO, 'fn', 20, 'msg', [], False)
+ with mock.patch.object(self.logger, 'callHandlers') as mock_call_handlers:
+ self.logger.handle(record)
+ mock_call_handlers.assert_called_once()
+
+
+class ABSLLogPrefixTest(parameterized.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.record = std_logging.LogRecord(
+ 'name', std_logging.INFO, 'path/to/source.py', 13, 'log message',
+ None, None)
+
+ @parameterized.named_parameters(
+ ('debug', std_logging.DEBUG, 'I'),
+ ('info', std_logging.INFO, 'I'),
+ ('warning', std_logging.WARNING, 'W'),
+ ('error', std_logging.ERROR, 'E'),
+ )
+ def test_default_prefixes(self, levelno, level_prefix):
+ self.record.levelno = levelno
+ self.record.created = 1494293880.378885
+ thread_id = '{: >5}'.format(logging._get_thread_id())
+ # Use UTC so the test passes regardless of the local time zone.
+ with mock.patch.object(time, 'localtime', side_effect=time.gmtime):
+ self.assertEqual(
+ '{}0509 01:38:00.378885 {} source.py:13] '.format(
+ level_prefix, thread_id),
+ logging.get_absl_log_prefix(self.record))
+ time.localtime.assert_called_once_with(self.record.created)
+
+ def test_absl_prefix_regex(self):
+ self.record.created = 1226888258.0521369
+ # Use UTC so the test passes regardless of the local time zone.
+ with mock.patch.object(time, 'localtime', side_effect=time.gmtime):
+ prefix = logging.get_absl_log_prefix(self.record)
+
+ match = re.search(logging.ABSL_LOGGING_PREFIX_REGEX, prefix)
+ self.assertTrue(match)
+
+ expect = {'severity': 'I',
+ 'month': '11',
+ 'day': '17',
+ 'hour': '02',
+ 'minute': '17',
+ 'second': '38',
+ 'microsecond': '052136',
+ 'thread_id': str(logging._get_thread_id()),
+ 'filename': 'source.py',
+ 'line': '13',
+ }
+ actual = {name: match.group(name) for name in expect}
+ self.assertEqual(expect, actual)
+
+ def test_critical_absl(self):
+ self.record.levelno = std_logging.CRITICAL
+ self.record.created = 1494293880.378885
+ self.record._absl_log_fatal = True
+ thread_id = '{: >5}'.format(logging._get_thread_id())
+ # Use UTC so the test passes regardless of the local time zone.
+ with mock.patch.object(time, 'localtime', side_effect=time.gmtime):
+ self.assertEqual(
+ 'F0509 01:38:00.378885 {} source.py:13] '.format(thread_id),
+ logging.get_absl_log_prefix(self.record))
+ time.localtime.assert_called_once_with(self.record.created)
+
+ def test_critical_non_absl(self):
+ self.record.levelno = std_logging.CRITICAL
+ self.record.created = 1494293880.378885
+ thread_id = '{: >5}'.format(logging._get_thread_id())
+ # Use UTC so the test passes regardless of the local time zone.
+ with mock.patch.object(time, 'localtime', side_effect=time.gmtime):
+ self.assertEqual(
+ 'E0509 01:38:00.378885 {} source.py:13] CRITICAL - '.format(
+ thread_id),
+ logging.get_absl_log_prefix(self.record))
+ time.localtime.assert_called_once_with(self.record.created)
+
+
+class LogCountTest(absltest.TestCase):
+
+ def test_counter_threadsafe(self):
+ threads_start = threading.Event()
+ counts = set()
+ k = object()
+
+ def t():
+ threads_start.wait()
+ counts.add(logging._get_next_log_count_per_token(k))
+
+ threads = [threading.Thread(target=t) for _ in range(100)]
+ for thread in threads:
+ thread.start()
+ threads_start.set()
+ for thread in threads:
+ thread.join()
+ self.assertEqual(counts, {i for i in range(100)})
+
+
+class LoggingTest(absltest.TestCase):
+
+ def test_fatal(self):
+ with mock.patch.object(os, 'abort') as mock_abort:
+ logging.fatal('Die!')
+ mock_abort.assert_called_once()
+
+ def test_find_log_dir_with_arg(self):
+ with mock.patch.object(os, 'access'), \
+ mock.patch.object(os.path, 'isdir'):
+ os.path.isdir.return_value = True
+ os.access.return_value = True
+ log_dir = logging.find_log_dir(log_dir='./')
+ self.assertEqual('./', log_dir)
+
+ @flagsaver.flagsaver(log_dir='./')
+ def test_find_log_dir_with_flag(self):
+ with mock.patch.object(os, 'access'), \
+ mock.patch.object(os.path, 'isdir'):
+ os.path.isdir.return_value = True
+ os.access.return_value = True
+ log_dir = logging.find_log_dir()
+ self.assertEqual('./', log_dir)
+
+ @flagsaver.flagsaver(log_dir='')
+ def test_find_log_dir_with_hda_tmp(self):
+ with mock.patch.object(os, 'access'), \
+ mock.patch.object(os.path, 'exists'), \
+ mock.patch.object(os.path, 'isdir'):
+ os.path.exists.return_value = True
+ os.path.isdir.return_value = True
+ os.access.return_value = True
+ log_dir = logging.find_log_dir()
+ self.assertEqual('/tmp/', log_dir)
+
+ @flagsaver.flagsaver(log_dir='')
+ def test_find_log_dir_with_tmp(self):
+ with mock.patch.object(os, 'access'), \
+ mock.patch.object(os.path, 'exists'), \
+ mock.patch.object(os.path, 'isdir'):
+ os.path.exists.return_value = False
+ os.path.isdir.side_effect = lambda path: path == '/tmp/'
+ os.access.return_value = True
+ log_dir = logging.find_log_dir()
+ self.assertEqual('/tmp/', log_dir)
+
+ def test_find_log_dir_with_nothing(self):
+ with mock.patch.object(os.path, 'exists'), \
+ mock.patch.object(os.path, 'isdir'):
+ os.path.exists.return_value = False
+ os.path.isdir.return_value = False
+ with self.assertRaises(FileNotFoundError):
+ logging.find_log_dir()
+
+ def test_find_log_dir_and_names_with_args(self):
+ user = 'test_user'
+ host = 'test_host'
+ log_dir = 'here'
+ program_name = 'prog1'
+ with mock.patch.object(getpass, 'getuser'), \
+ mock.patch.object(logging, 'find_log_dir') as mock_find_log_dir, \
+ mock.patch.object(socket, 'gethostname') as mock_gethostname:
+ getpass.getuser.return_value = user
+ mock_gethostname.return_value = host
+ mock_find_log_dir.return_value = log_dir
+
+ prefix = '%s.%s.%s.log' % (program_name, host, user)
+ self.assertEqual((log_dir, prefix, program_name),
+ logging.find_log_dir_and_names(
+ program_name=program_name, log_dir=log_dir))
+
+ def test_find_log_dir_and_names_without_args(self):
+ user = 'test_user'
+ host = 'test_host'
+ log_dir = 'here'
+ py_program_name = 'py_prog1'
+ sys.argv[0] = 'path/to/prog1'
+ with mock.patch.object(getpass, 'getuser'), \
+ mock.patch.object(logging, 'find_log_dir') as mock_find_log_dir, \
+ mock.patch.object(socket, 'gethostname'):
+ getpass.getuser.return_value = user
+ socket.gethostname.return_value = host
+ mock_find_log_dir.return_value = log_dir
+ prefix = '%s.%s.%s.log' % (py_program_name, host, user)
+ self.assertEqual((log_dir, prefix, py_program_name),
+ logging.find_log_dir_and_names())
+
+ def test_find_log_dir_and_names_wo_username(self):
+ # Windows doesn't have os.getuid at all
+ if hasattr(os, 'getuid'):
+ mock_getuid = mock.patch.object(os, 'getuid')
+ uid = 100
+ logged_uid = '100'
+ else:
+ # The function doesn't exist, but our test code still tries to mock
+ # it, so just use a fake thing.
+ mock_getuid = _mock_windows_os_getuid()
+ uid = -1
+ logged_uid = 'unknown'
+
+ host = 'test_host'
+ log_dir = 'here'
+ program_name = 'prog1'
+ with mock.patch.object(getpass, 'getuser'), \
+ mock_getuid as getuid, \
+ mock.patch.object(logging, 'find_log_dir') as mock_find_log_dir, \
+ mock.patch.object(socket, 'gethostname') as mock_gethostname:
+ getpass.getuser.side_effect = KeyError()
+ getuid.return_value = uid
+ mock_gethostname.return_value = host
+ mock_find_log_dir.return_value = log_dir
+
+ prefix = '%s.%s.%s.log' % (program_name, host, logged_uid)
+ self.assertEqual((log_dir, prefix, program_name),
+ logging.find_log_dir_and_names(
+ program_name=program_name, log_dir=log_dir))
+
+ def test_errors_in_logging(self):
+ with mock.patch.object(sys, 'stderr', new=io.StringIO()) as stderr:
+ logging.info('not enough args: %s %s', 'foo') # pylint: disable=logging-too-few-args
+ self.assertIn('Traceback (most recent call last):', stderr.getvalue())
+ self.assertIn('TypeError', stderr.getvalue())
+
+ def test_dict_arg(self):
+ # Tests that passing a dictionary as a single argument does not crash.
+ logging.info('%(test)s', {'test': 'Hello world!'})
+
+ def test_exception_dict_format(self):
+ # Just verify that this doesn't raise a TypeError.
+ logging.exception('%(test)s', {'test': 'Hello world!'})
+
+ def test_logging_levels(self):
+ old_level = logging.get_verbosity()
+
+ logging.set_verbosity(logging.DEBUG)
+ self.assertEqual(logging.get_verbosity(), logging.DEBUG)
+ self.assertTrue(logging.level_debug())
+ self.assertTrue(logging.level_info())
+ self.assertTrue(logging.level_warning())
+ self.assertTrue(logging.level_error())
+
+ logging.set_verbosity(logging.INFO)
+ self.assertEqual(logging.get_verbosity(), logging.INFO)
+ self.assertFalse(logging.level_debug())
+ self.assertTrue(logging.level_info())
+ self.assertTrue(logging.level_warning())
+ self.assertTrue(logging.level_error())
+
+ logging.set_verbosity(logging.WARNING)
+ self.assertEqual(logging.get_verbosity(), logging.WARNING)
+ self.assertFalse(logging.level_debug())
+ self.assertFalse(logging.level_info())
+ self.assertTrue(logging.level_warning())
+ self.assertTrue(logging.level_error())
+
+ logging.set_verbosity(logging.ERROR)
+ self.assertEqual(logging.get_verbosity(), logging.ERROR)
+ self.assertFalse(logging.level_debug())
+ self.assertFalse(logging.level_info())
+ self.assertTrue(logging.level_error())
+
+ logging.set_verbosity(old_level)
+
+ def test_set_verbosity_strings(self):
+ old_level = logging.get_verbosity()
+
+ # Lowercase names.
+ logging.set_verbosity('debug')
+ self.assertEqual(logging.get_verbosity(), logging.DEBUG)
+ logging.set_verbosity('info')
+ self.assertEqual(logging.get_verbosity(), logging.INFO)
+ logging.set_verbosity('warning')
+ self.assertEqual(logging.get_verbosity(), logging.WARNING)
+ logging.set_verbosity('warn')
+ self.assertEqual(logging.get_verbosity(), logging.WARNING)
+ logging.set_verbosity('error')
+ self.assertEqual(logging.get_verbosity(), logging.ERROR)
+ logging.set_verbosity('fatal')
+
+ # Uppercase names.
+ self.assertEqual(logging.get_verbosity(), logging.FATAL)
+ logging.set_verbosity('DEBUG')
+ self.assertEqual(logging.get_verbosity(), logging.DEBUG)
+ logging.set_verbosity('INFO')
+ self.assertEqual(logging.get_verbosity(), logging.INFO)
+ logging.set_verbosity('WARNING')
+ self.assertEqual(logging.get_verbosity(), logging.WARNING)
+ logging.set_verbosity('WARN')
+ self.assertEqual(logging.get_verbosity(), logging.WARNING)
+ logging.set_verbosity('ERROR')
+ self.assertEqual(logging.get_verbosity(), logging.ERROR)
+ logging.set_verbosity('FATAL')
+ self.assertEqual(logging.get_verbosity(), logging.FATAL)
+
+ # Integers as strings.
+ logging.set_verbosity(str(logging.DEBUG))
+ self.assertEqual(logging.get_verbosity(), logging.DEBUG)
+ logging.set_verbosity(str(logging.INFO))
+ self.assertEqual(logging.get_verbosity(), logging.INFO)
+ logging.set_verbosity(str(logging.WARNING))
+ self.assertEqual(logging.get_verbosity(), logging.WARNING)
+ logging.set_verbosity(str(logging.ERROR))
+ self.assertEqual(logging.get_verbosity(), logging.ERROR)
+ logging.set_verbosity(str(logging.FATAL))
+ self.assertEqual(logging.get_verbosity(), logging.FATAL)
+
+ logging.set_verbosity(old_level)
+
+ def test_key_flags(self):
+ key_flags = FLAGS.get_key_flags_for_module(logging)
+ key_flag_names = [flag.name for flag in key_flags]
+ self.assertIn('stderrthreshold', key_flag_names)
+ self.assertIn('verbosity', key_flag_names)
+
+ def test_get_absl_logger(self):
+ self.assertIsInstance(
+ logging.get_absl_logger(), logging.ABSLLogger)
+
+ def test_get_absl_handler(self):
+ self.assertIsInstance(
+ logging.get_absl_handler(), logging.ABSLHandler)
+
+
+@mock.patch.object(logging.ABSLLogger, 'register_frame_to_skip')
+class LogSkipPrefixTest(absltest.TestCase):
+ """Tests for logging.skip_log_prefix."""
+
+ def _log_some_info(self):
+ """Logging helper function for LogSkipPrefixTest."""
+ logging.info('info')
+
+ def _log_nested_outer(self):
+ """Nested logging helper functions for LogSkipPrefixTest."""
+ def _log_nested_inner():
+ logging.info('info nested')
+ return _log_nested_inner
+
+ def test_skip_log_prefix_with_name(self, mock_skip_register):
+ retval = logging.skip_log_prefix('_log_some_info')
+ mock_skip_register.assert_called_once_with(__file__, '_log_some_info', None)
+ self.assertEqual(retval, '_log_some_info')
+
+ def test_skip_log_prefix_with_func(self, mock_skip_register):
+ retval = logging.skip_log_prefix(self._log_some_info)
+ mock_skip_register.assert_called_once_with(
+ __file__, '_log_some_info', mock.ANY)
+ self.assertEqual(retval, self._log_some_info)
+
+ def test_skip_log_prefix_with_functools_partial(self, mock_skip_register):
+ partial_input = functools.partial(self._log_some_info)
+ with self.assertRaises(ValueError):
+ _ = logging.skip_log_prefix(partial_input)
+ mock_skip_register.assert_not_called()
+
+ def test_skip_log_prefix_with_lambda(self, mock_skip_register):
+ lambda_input = lambda _: self._log_some_info()
+ retval = logging.skip_log_prefix(lambda_input)
+ mock_skip_register.assert_called_once_with(__file__, '<lambda>', mock.ANY)
+ self.assertEqual(retval, lambda_input)
+
+ def test_skip_log_prefix_with_bad_input(self, mock_skip_register):
+ dict_input = {1: 2, 2: 3}
+ with self.assertRaises(TypeError):
+ _ = logging.skip_log_prefix(dict_input)
+ mock_skip_register.assert_not_called()
+
+ def test_skip_log_prefix_with_nested_func(self, mock_skip_register):
+ nested_input = self._log_nested_outer()
+ retval = logging.skip_log_prefix(nested_input)
+ mock_skip_register.assert_called_once_with(
+ __file__, '_log_nested_inner', mock.ANY)
+ self.assertEqual(retval, nested_input)
+
+ def test_skip_log_prefix_decorator(self, mock_skip_register):
+
+ @logging.skip_log_prefix
+ def _log_decorated():
+ logging.info('decorated')
+
+ del _log_decorated
+ mock_skip_register.assert_called_once_with(
+ __file__, '_log_decorated', mock.ANY)
+
+
+@contextlib.contextmanager
+def override_python_handler_stream(stream):
+ handler = logging.get_absl_handler().python_handler
+ old_stream = handler.stream
+ handler.stream = stream
+ try:
+ yield
+ finally:
+ handler.stream = old_stream
+
+
+class GetLogFileNameTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ('err', sys.stderr),
+ ('out', sys.stdout),
+ )
+ def test_get_log_file_name_py_std(self, stream):
+ with override_python_handler_stream(stream):
+ self.assertEqual('', logging.get_log_file_name())
+
+ def test_get_log_file_name_py_no_name(self):
+
+ class FakeFile(object):
+ pass
+
+ with override_python_handler_stream(FakeFile()):
+ self.assertEqual('', logging.get_log_file_name())
+
+ def test_get_log_file_name_py_file(self):
+ _, filename = tempfile.mkstemp(dir=absltest.TEST_TMPDIR.value)
+ with open(filename, 'a') as stream:
+ with override_python_handler_stream(stream):
+ self.assertEqual(filename, logging.get_log_file_name())
+
+
+@contextlib.contextmanager
+def _mock_windows_os_getuid():
+ yield mock.MagicMock()
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/logging/tests/verbosity_flag_test.py b/absl/logging/tests/verbosity_flag_test.py
new file mode 100644
index 0000000..4609e64
--- /dev/null
+++ b/absl/logging/tests/verbosity_flag_test.py
@@ -0,0 +1,56 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests -v/--verbosity flag and logging.root level's sync behavior."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+
+assert logging.root.getEffectiveLevel() == logging.WARN, (
+ 'default logging.root level should be WARN, but found {}'.format(
+ logging.root.getEffectiveLevel()))
+
+# This is here to test importing logging won't change the level.
+logging.root.setLevel(logging.ERROR)
+
+assert logging.root.getEffectiveLevel() == logging.ERROR, (
+ 'logging.root level should be changed to ERROR, but found {}'.format(
+ logging.root.getEffectiveLevel()))
+
+from absl import flags
+from absl import logging as _ # pylint: disable=unused-import
+from absl.testing import absltest
+
+FLAGS = flags.FLAGS
+
+assert FLAGS['verbosity'].value == -1, (
+ '-v/--verbosity should be -1 before flags are parsed.')
+
+assert logging.root.getEffectiveLevel() == logging.ERROR, (
+ 'logging.root level should be kept to ERROR, but found {}'.format(
+ logging.root.getEffectiveLevel()))
+
+
+class VerbosityFlagTest(absltest.TestCase):
+
+ def test_default_value_after_init(self):
+ self.assertEqual(0, FLAGS.verbosity)
+ self.assertEqual(logging.INFO, logging.root.getEffectiveLevel())
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/testing/BUILD b/absl/testing/BUILD
new file mode 100644
index 0000000..b608c8c
--- /dev/null
+++ b/absl/testing/BUILD
@@ -0,0 +1,254 @@
+licenses(["notice"])
+
+py_library(
+ name = "absltest",
+ srcs = ["absltest.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":_pretty_print_reporter",
+ ":xml_reporter",
+ "//absl:app",
+ "//absl/flags",
+ "//absl/logging",
+ ],
+)
+
+py_library(
+ name = "flagsaver",
+ srcs = ["flagsaver.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//absl/flags",
+ ],
+)
+
+py_library(
+ name = "parameterized",
+ srcs = [
+ "parameterized.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":absltest",
+ ],
+)
+
+py_library(
+ name = "xml_reporter",
+ srcs = ["xml_reporter.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":_pretty_print_reporter",
+ ],
+)
+
+py_library(
+ name = "_bazelize_command",
+ testonly = 1,
+ srcs = ["_bazelize_command.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//:__subpackages__"],
+ deps = [
+ "//absl/flags",
+ ],
+)
+
+py_library(
+ name = "_pretty_print_reporter",
+ srcs = ["_pretty_print_reporter.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_library(
+ name = "tests/absltest_env",
+ testonly = True,
+ srcs = ["tests/absltest_env.py"],
+)
+
+py_test(
+ name = "tests/absltest_filtering_test",
+ size = "medium",
+ srcs = ["tests/absltest_filtering_test.py"],
+ data = [":tests/absltest_filtering_test_helper"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":_bazelize_command",
+ ":absltest",
+ ":parameterized",
+ ":tests/absltest_env",
+ "//absl/logging",
+ ],
+)
+
+py_binary(
+ name = "tests/absltest_filtering_test_helper",
+ testonly = 1,
+ srcs = ["tests/absltest_filtering_test_helper.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":absltest",
+ ":parameterized",
+ "//absl:app",
+ ],
+)
+
+py_test(
+ name = "tests/absltest_fail_fast_test",
+ size = "small",
+ srcs = ["tests/absltest_fail_fast_test.py"],
+ data = [":tests/absltest_fail_fast_test_helper"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":_bazelize_command",
+ ":absltest",
+ ":parameterized",
+ ":tests/absltest_env",
+ "//absl/logging",
+ ],
+)
+
+py_binary(
+ name = "tests/absltest_fail_fast_test_helper",
+ testonly = 1,
+ srcs = ["tests/absltest_fail_fast_test_helper.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":absltest",
+ "//absl:app",
+ ],
+)
+
+py_test(
+ name = "tests/absltest_randomization_test",
+ size = "medium",
+ srcs = ["tests/absltest_randomization_test.py"],
+ data = [":tests/absltest_randomization_testcase"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":_bazelize_command",
+ ":absltest",
+ ":parameterized",
+ ":tests/absltest_env",
+ "//absl/flags",
+ ],
+)
+
+py_binary(
+ name = "tests/absltest_randomization_testcase",
+ testonly = 1,
+ srcs = ["tests/absltest_randomization_testcase.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":absltest",
+ ],
+)
+
+py_test(
+ name = "tests/absltest_sharding_test",
+ size = "small",
+ srcs = ["tests/absltest_sharding_test.py"],
+ data = [":tests/absltest_sharding_test_helper"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":_bazelize_command",
+ ":absltest",
+ ":tests/absltest_env",
+ ],
+)
+
+py_binary(
+ name = "tests/absltest_sharding_test_helper",
+ testonly = 1,
+ srcs = ["tests/absltest_sharding_test_helper.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [":absltest"],
+)
+
+py_test(
+ name = "tests/absltest_test",
+ size = "small",
+ srcs = ["tests/absltest_test.py"],
+ data = [":tests/absltest_test_helper"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":_bazelize_command",
+ ":absltest",
+ ":parameterized",
+ ":tests/absltest_env",
+ ],
+)
+
+py_binary(
+ name = "tests/absltest_test_helper",
+ testonly = 1,
+ srcs = ["tests/absltest_test_helper.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":absltest",
+ "//absl/flags",
+ ],
+)
+
+py_test(
+ name = "tests/flagsaver_test",
+ srcs = ["tests/flagsaver_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":absltest",
+ ":flagsaver",
+ "//absl/flags",
+ ],
+)
+
+py_test(
+ name = "tests/parameterized_test",
+ srcs = ["tests/parameterized_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":absltest",
+ ":parameterized",
+ ],
+)
+
+py_test(
+ name = "tests/xml_reporter_test",
+ srcs = ["tests/xml_reporter_test.py"],
+ data = [":tests/xml_reporter_helper_test"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":_bazelize_command",
+ ":absltest",
+ ":parameterized",
+ ":xml_reporter",
+ "//absl/logging",
+ ],
+)
+
+py_binary(
+ name = "tests/xml_reporter_helper_test",
+ testonly = 1,
+ srcs = ["tests/xml_reporter_helper_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":absltest",
+ "//absl/flags",
+ ],
+)
diff --git a/absl/testing/__init__.py b/absl/testing/__init__.py
new file mode 100644
index 0000000..a3bd1cd
--- /dev/null
+++ b/absl/testing/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/absl/testing/_bazelize_command.py b/absl/testing/_bazelize_command.py
new file mode 100644
index 0000000..fdf6eb6
--- /dev/null
+++ b/absl/testing/_bazelize_command.py
@@ -0,0 +1,72 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Internal helper for running tests on Windows Bazel."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import flags
+
+FLAGS = flags.FLAGS
+
+
+def get_executable_path(py_binary_name):
+ """Returns the executable path of a py_binary.
+
+ This returns the executable path of a py_binary that is in another Bazel
+ target's data dependencies.
+
+ On Linux/macOS, the path and __file__ has the same root directory.
+ On Windows, bazel builds an .exe file and we need to use the MANIFEST file
+ the location the actual binary.
+
+ Args:
+ py_binary_name: string, the name of a py_binary that is in another Bazel
+ target's data dependencies.
+
+ Raises:
+ RuntimeError: Raised when it cannot locate the executable path.
+ """
+
+ if os.name == 'nt':
+ py_binary_name += '.exe'
+ manifest_file = os.path.join(FLAGS.test_srcdir, 'MANIFEST')
+ workspace_name = os.environ['TEST_WORKSPACE']
+ manifest_entry = '{}/{}'.format(workspace_name, py_binary_name)
+ with open(manifest_file, 'r') as manifest_fd:
+ for line in manifest_fd:
+ tokens = line.strip().split(' ')
+ if len(tokens) != 2:
+ continue
+ if manifest_entry == tokens[0]:
+ return tokens[1]
+ raise RuntimeError(
+ 'Cannot locate executable path for {}, MANIFEST file: {}.'.format(
+ py_binary_name, manifest_file))
+ else:
+ # NOTE: __file__ may be .py or .pyc, depending on how the module was
+ # loaded and executed.
+ path = __file__
+
+ # Use the package name to find the root directory: every dot is
+ # a directory, plus one for ourselves.
+ for _ in range(__name__.count('.') + 1):
+ path = os.path.dirname(path)
+
+ root_directory = path
+ return os.path.join(root_directory, py_binary_name)
diff --git a/absl/testing/_pretty_print_reporter.py b/absl/testing/_pretty_print_reporter.py
new file mode 100644
index 0000000..ef03934
--- /dev/null
+++ b/absl/testing/_pretty_print_reporter.py
@@ -0,0 +1,95 @@
+# Copyright 2018 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""TestResult implementing default output for test execution status."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import unittest
+
+
+class TextTestResult(unittest.TextTestResult):
+ """TestResult class that provides the default text result formatting."""
+
+ def __init__(self, stream, descriptions, verbosity):
+ # Disable the verbose per-test output from the superclass, since it would
+ # conflict with our customized output.
+ super(TextTestResult, self).__init__(stream, descriptions, 0)
+ self._per_test_output = verbosity > 0
+
+ def _print_status(self, tag, test):
+ if self._per_test_output:
+ test_id = test.id()
+ if test_id.startswith('__main__.'):
+ test_id = test_id[len('__main__.'):]
+ print('[%s] %s' % (tag, test_id), file=self.stream)
+ self.stream.flush()
+
+ def startTest(self, test):
+ super(TextTestResult, self).startTest(test)
+ self._print_status(' RUN ', test)
+
+ def addSuccess(self, test):
+ super(TextTestResult, self).addSuccess(test)
+ self._print_status(' OK ', test)
+
+ def addError(self, test, err):
+ super(TextTestResult, self).addError(test, err)
+ self._print_status(' FAILED ', test)
+
+ def addFailure(self, test, err):
+ super(TextTestResult, self).addFailure(test, err)
+ self._print_status(' FAILED ', test)
+
+ def addSkip(self, test, reason):
+ super(TextTestResult, self).addSkip(test, reason)
+ self._print_status(' SKIPPED ', test)
+
+ def addExpectedFailure(self, test, err):
+ super(TextTestResult, self).addExpectedFailure(test, err)
+ self._print_status(' OK ', test)
+
+ def addUnexpectedSuccess(self, test):
+ super(TextTestResult, self).addUnexpectedSuccess(test)
+ self._print_status(' FAILED ', test)
+
+
+class TextTestRunner(unittest.TextTestRunner):
+ """A test runner that produces formatted text results."""
+
+ _TEST_RESULT_CLASS = TextTestResult
+
+ # Set this to true at the class or instance level to run tests using a
+ # debug-friendly method (e.g, one that doesn't catch exceptions and interacts
+ # better with debuggers).
+ # Usually this is set using --pdb_post_mortem.
+ run_for_debugging = False
+
+ def run(self, test):
+ # type: (TestCase) -> TestResult
+ if self.run_for_debugging:
+ return self._run_debug(test)
+ else:
+ return super(TextTestRunner, self).run(test)
+
+ def _run_debug(self, test):
+ # type: (TestCase) -> TestResult
+ test.debug()
+ # Return an empty result to indicate success.
+ return self._makeResult()
+
+ def _makeResult(self):
+ return TextTestResult(self.stream, self.descriptions, self.verbosity)
diff --git a/absl/testing/absltest.py b/absl/testing/absltest.py
new file mode 100644
index 0000000..cebb1ca
--- /dev/null
+++ b/absl/testing/absltest.py
@@ -0,0 +1,2554 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Base functionality for Abseil Python tests.
+
+This module contains base classes and high-level functions for Abseil-style
+tests.
+"""
+
+from collections import abc
+import contextlib
+import difflib
+import enum
+import errno
+import getpass
+import inspect
+import io
+import itertools
+import json
+import os
+import random
+import re
+import shlex
+import shutil
+import signal
+import stat
+import subprocess
+import sys
+import tempfile
+import textwrap
+import unittest
+from unittest import mock # pylint: disable=unused-import Allow absltest.mock.
+from urllib import parse
+
+try:
+ # The faulthandler module isn't always available, and pytype doesn't
+ # understand that we're catching ImportError, so suppress the error.
+ # pytype: disable=import-error
+ import faulthandler
+ # pytype: enable=import-error
+except ImportError:
+ # We use faulthandler if it is available.
+ faulthandler = None
+
+from absl import app
+from absl import flags
+from absl import logging
+from absl.testing import _pretty_print_reporter
+from absl.testing import xml_reporter
+
+# Make typing an optional import to avoid it being a required dependency
+# in Python 2. Type checkers will still understand the imports.
+try:
+ # pylint: disable=unused-import
+ import typing
+ from typing import Any, AnyStr, BinaryIO, Callable, ContextManager, IO, Iterator, List, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Text, TextIO, Tuple, Type, Union
+ # pylint: enable=unused-import
+except ImportError:
+ pass
+else:
+ # Use an if-type-checking block to prevent leakage of type-checking only
+ # symbols. We don't want people relying on these at runtime.
+ if typing.TYPE_CHECKING:
+ # Unbounded TypeVar for general usage
+ _T = typing.TypeVar('_T')
+
+ import unittest.case
+ _OutcomeType = unittest.case._Outcome # pytype: disable=module-attr
+
+
+
+# Re-export a bunch of unittest functions we support so that people don't
+# have to import unittest to get them
+# pylint: disable=invalid-name
+skip = unittest.skip
+skipIf = unittest.skipIf
+skipUnless = unittest.skipUnless
+SkipTest = unittest.SkipTest
+expectedFailure = unittest.expectedFailure
+# pylint: enable=invalid-name
+
+# End unittest re-exports
+
+FLAGS = flags.FLAGS
+
+_TEXT_OR_BINARY_TYPES = (str, bytes)
+
+# Suppress surplus entries in AssertionError stack traces.
+__unittest = True # pylint: disable=invalid-name
+
+
+def expectedFailureIf(condition, reason): # pylint: disable=invalid-name
+ """Expects the test to fail if the run condition is True.
+
+ Example usage:
+ @expectedFailureIf(sys.version.major == 2, "Not yet working in py2")
+ def test_foo(self):
+ ...
+
+ Args:
+ condition: bool, whether to expect failure or not.
+ reason: Text, the reason to expect failure.
+ Returns:
+ Decorator function
+ """
+ del reason # Unused
+ if condition:
+ return unittest.expectedFailure
+ else:
+ return lambda f: f
+
+
+class TempFileCleanup(enum.Enum):
+ # Always cleanup temp files when the test completes.
+ ALWAYS = 'always'
+ # Only cleanup temp file if the test passes. This allows easier inspection
+ # of tempfile contents on test failure. absltest.TEST_TMPDIR.value determines
+ # where tempfiles are created.
+ SUCCESS = 'success'
+ # Never cleanup temp files.
+ OFF = 'never'
+
+
+# Many of the methods in this module have names like assertSameElements.
+# This kind of name does not comply with PEP8 style,
+# but it is consistent with the naming of methods in unittest.py.
+# pylint: disable=invalid-name
+
+
+def _get_default_test_random_seed():
+ # type: () -> int
+ random_seed = 301
+ value = os.environ.get('TEST_RANDOM_SEED', '')
+ try:
+ random_seed = int(value)
+ except ValueError:
+ pass
+ return random_seed
+
+
+def get_default_test_srcdir():
+ # type: () -> Text
+ """Returns default test source dir."""
+ return os.environ.get('TEST_SRCDIR', '')
+
+
+def get_default_test_tmpdir():
+ # type: () -> Text
+ """Returns default test temp dir."""
+ tmpdir = os.environ.get('TEST_TMPDIR', '')
+ if not tmpdir:
+ tmpdir = os.path.join(tempfile.gettempdir(), 'absl_testing')
+
+ return tmpdir
+
+
+def _get_default_randomize_ordering_seed():
+ # type: () -> int
+ """Returns default seed to use for randomizing test order.
+
+ This function first checks the --test_randomize_ordering_seed flag, and then
+ the TEST_RANDOMIZE_ORDERING_SEED environment variable. If the first value
+ we find is:
+ * (not set): disable test randomization
+ * 0: disable test randomization
+ * 'random': choose a random seed in [1, 4294967295] for test order
+ randomization
+ * positive integer: use this seed for test order randomization
+
+ (The values used are patterned after
+ https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED).
+
+ In principle, it would be simpler to return None if no override is provided;
+ however, the python random module has no `get_seed()`, only `getstate()`,
+ which returns far more data than we want to pass via an environment variable
+ or flag.
+
+ Returns:
+ A default value for test case randomization (int). 0 means do not randomize.
+
+ Raises:
+ ValueError: Raised when the flag or env value is not one of the options
+ above.
+ """
+ if FLAGS['test_randomize_ordering_seed'].present:
+ randomize = FLAGS.test_randomize_ordering_seed
+ elif 'TEST_RANDOMIZE_ORDERING_SEED' in os.environ:
+ randomize = os.environ['TEST_RANDOMIZE_ORDERING_SEED']
+ else:
+ randomize = ''
+ if not randomize:
+ return 0
+ if randomize == 'random':
+ return random.Random().randint(1, 4294967295)
+ if randomize == '0':
+ return 0
+ try:
+ seed = int(randomize)
+ if seed > 0:
+ return seed
+ except ValueError:
+ pass
+ raise ValueError(
+ 'Unknown test randomization seed value: {}'.format(randomize))
+
+
+TEST_SRCDIR = flags.DEFINE_string(
+ 'test_srcdir',
+ get_default_test_srcdir(),
+ 'Root of directory tree where source files live',
+ allow_override_cpp=True)
+TEST_TMPDIR = flags.DEFINE_string(
+ 'test_tmpdir',
+ get_default_test_tmpdir(),
+ 'Directory for temporary testing files',
+ allow_override_cpp=True)
+
+flags.DEFINE_integer(
+ 'test_random_seed',
+ _get_default_test_random_seed(),
+ 'Random seed for testing. Some test frameworks may '
+ 'change the default value of this flag between runs, so '
+ 'it is not appropriate for seeding probabilistic tests.',
+ allow_override_cpp=True)
+flags.DEFINE_string(
+ 'test_randomize_ordering_seed',
+ '',
+ 'If positive, use this as a seed to randomize the '
+ 'execution order for test cases. If "random", pick a '
+ 'random seed to use. If 0 or not set, do not randomize '
+ 'test case execution order. This flag also overrides '
+ 'the TEST_RANDOMIZE_ORDERING_SEED environment variable.',
+ allow_override_cpp=True)
+flags.DEFINE_string('xml_output_file', '', 'File to store XML test results')
+
+
+# We might need to monkey-patch TestResult so that it stops considering an
+# unexpected pass as a as a "successful result". For details, see
+# http://bugs.python.org/issue20165
+def _monkey_patch_test_result_for_unexpected_passes():
+ # type: () -> None
+ """Workaround for <http://bugs.python.org/issue20165>."""
+
+ def wasSuccessful(self):
+ # type: () -> bool
+ """Tells whether or not this result was a success.
+
+ Any unexpected pass is to be counted as a non-success.
+
+ Args:
+ self: The TestResult instance.
+
+ Returns:
+ Whether or not this result was a success.
+ """
+ return (len(self.failures) == len(self.errors) ==
+ len(self.unexpectedSuccesses) == 0)
+
+ test_result = unittest.TestResult()
+ test_result.addUnexpectedSuccess(unittest.FunctionTestCase(lambda: None))
+ if test_result.wasSuccessful(): # The bug is present.
+ unittest.TestResult.wasSuccessful = wasSuccessful
+ if test_result.wasSuccessful(): # Warn the user if our hot-fix failed.
+ sys.stderr.write('unittest.result.TestResult monkey patch to report'
+ ' unexpected passes as failures did not work.\n')
+
+
+_monkey_patch_test_result_for_unexpected_passes()
+
+
+def _open(filepath, mode, _open_func=open):
+ # type: (Text, Text, Callable[..., IO]) -> IO
+ """Opens a file.
+
+ Like open(), but ensure that we can open real files even if tests stub out
+ open().
+
+ Args:
+ filepath: A filepath.
+ mode: A mode.
+ _open_func: A built-in open() function.
+
+ Returns:
+ The opened file object.
+ """
+ return _open_func(filepath, mode, encoding='utf-8')
+
+
+class _TempDir(object):
+ """Represents a temporary directory for tests.
+
+ Creation of this class is internal. Using its public methods is OK.
+
+ This class implements the `os.PathLike` interface (specifically,
+ `os.PathLike[str]`). This means, in Python 3, it can be directly passed
+ to e.g. `os.path.join()`.
+ """
+
+ def __init__(self, path):
+ # type: (Text) -> None
+ """Module-private: do not instantiate outside module."""
+ self._path = path
+
+ @property
+ def full_path(self):
+ # type: () -> Text
+ """Returns the path, as a string, for the directory.
+
+ TIP: Instead of e.g. `os.path.join(temp_dir.full_path)`, you can simply
+ do `os.path.join(temp_dir)` because `__fspath__()` is implemented.
+ """
+ return self._path
+
+ def __fspath__(self):
+ # type: () -> Text
+ """See os.PathLike."""
+ return self.full_path
+
+ def create_file(self, file_path=None, content=None, mode='w', encoding='utf8',
+ errors='strict'):
+ # type: (Optional[Text], Optional[AnyStr], Text, Text, Text) -> _TempFile
+ """Create a file in the directory.
+
+ NOTE: If the file already exists, it will be made writable and overwritten.
+
+ Args:
+ file_path: Optional file path for the temp file. If not given, a unique
+ file name will be generated and used. Slashes are allowed in the name;
+ any missing intermediate directories will be created. NOTE: This path
+ is the path that will be cleaned up, including any directories in the
+ path, e.g., 'foo/bar/baz.txt' will `rm -r foo`
+ content: Optional string or bytes to initially write to the file. If not
+ specified, then an empty file is created.
+ mode: Mode string to use when writing content. Only used if `content` is
+ non-empty.
+ encoding: Encoding to use when writing string content. Only used if
+ `content` is text.
+ errors: How to handle text to bytes encoding errors. Only used if
+ `content` is text.
+
+ Returns:
+ A _TempFile representing the created file.
+ """
+ tf, _ = _TempFile._create(self._path, file_path, content, mode, encoding,
+ errors)
+ return tf
+
+ def mkdir(self, dir_path=None):
+ # type: (Optional[Text]) -> _TempDir
+ """Create a directory in the directory.
+
+ Args:
+ dir_path: Optional path to the directory to create. If not given,
+ a unique name will be generated and used.
+
+ Returns:
+ A _TempDir representing the created directory.
+ """
+ if dir_path:
+ path = os.path.join(self._path, dir_path)
+ else:
+ path = tempfile.mkdtemp(dir=self._path)
+
+ # Note: there's no need to clear the directory since the containing
+ # dir was cleared by the tempdir() function.
+ os.makedirs(path, exist_ok=True)
+ return _TempDir(path)
+
+
+class _TempFile(object):
+ """Represents a tempfile for tests.
+
+ Creation of this class is internal. Using its public methods is OK.
+
+ This class implements the `os.PathLike` interface (specifically,
+ `os.PathLike[str]`). This means, in Python 3, it can be directly passed
+ to e.g. `os.path.join()`.
+ """
+
+ def __init__(self, path):
+ # type: (Text) -> None
+ """Private: use _create instead."""
+ self._path = path
+
+ # pylint: disable=line-too-long
+ @classmethod
+ def _create(cls, base_path, file_path, content, mode, encoding, errors):
+ # type: (Text, Optional[Text], AnyStr, Text, Text, Text) -> Tuple[_TempFile, Text]
+ # pylint: enable=line-too-long
+ """Module-private: create a tempfile instance."""
+ if file_path:
+ cleanup_path = os.path.join(base_path, _get_first_part(file_path))
+ path = os.path.join(base_path, file_path)
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ # The file may already exist, in which case, ensure it's writable so that
+ # it can be truncated.
+ if os.path.exists(path) and not os.access(path, os.W_OK):
+ stat_info = os.stat(path)
+ os.chmod(path, stat_info.st_mode | stat.S_IWUSR)
+ else:
+ os.makedirs(base_path, exist_ok=True)
+ fd, path = tempfile.mkstemp(dir=str(base_path))
+ os.close(fd)
+ cleanup_path = path
+
+ tf = cls(path)
+
+ if content:
+ if isinstance(content, str):
+ tf.write_text(content, mode=mode, encoding=encoding, errors=errors)
+ else:
+ tf.write_bytes(content, mode)
+
+ else:
+ tf.write_bytes(b'')
+
+ return tf, cleanup_path
+
+ @property
+ def full_path(self):
+ # type: () -> Text
+ """Returns the path, as a string, for the file.
+
+ TIP: Instead of e.g. `os.path.join(temp_file.full_path)`, you can simply
+ do `os.path.join(temp_file)` because `__fspath__()` is implemented.
+ """
+ return self._path
+
+ def __fspath__(self):
+ # type: () -> Text
+ """See os.PathLike."""
+ return self.full_path
+
+ def read_text(self, encoding='utf8', errors='strict'):
+ # type: (Text, Text) -> Text
+ """Return the contents of the file as text."""
+ with self.open_text(encoding=encoding, errors=errors) as fp:
+ return fp.read()
+
+ def read_bytes(self):
+ # type: () -> bytes
+ """Return the content of the file as bytes."""
+ with self.open_bytes() as fp:
+ return fp.read()
+
+ def write_text(self, text, mode='w', encoding='utf8', errors='strict'):
+ # type: (Text, Text, Text, Text) -> None
+ """Write text to the file.
+
+ Args:
+ text: Text to write. In Python 2, it can be bytes, which will be
+ decoded using the `encoding` arg (this is as an aid for code that
+ is 2 and 3 compatible).
+ mode: The mode to open the file for writing.
+ encoding: The encoding to use when writing the text to the file.
+ errors: The error handling strategy to use when converting text to bytes.
+ """
+ with self.open_text(mode, encoding=encoding, errors=errors) as fp:
+ fp.write(text)
+
+ def write_bytes(self, data, mode='wb'):
+ # type: (bytes, Text) -> None
+ """Write bytes to the file.
+
+ Args:
+ data: bytes to write.
+ mode: Mode to open the file for writing. The "b" flag is implicit if
+ not already present. It must not have the "t" flag.
+ """
+ with self.open_bytes(mode) as fp:
+ fp.write(data)
+
+ def open_text(self, mode='rt', encoding='utf8', errors='strict'):
+ # type: (Text, Text, Text) -> ContextManager[TextIO]
+ """Return a context manager for opening the file in text mode.
+
+ Args:
+ mode: The mode to open the file in. The "t" flag is implicit if not
+ already present. It must not have the "b" flag.
+ encoding: The encoding to use when opening the file.
+ errors: How to handle decoding errors.
+
+ Returns:
+ Context manager that yields an open file.
+
+ Raises:
+ ValueError: if invalid inputs are provided.
+ """
+ if 'b' in mode:
+ raise ValueError('Invalid mode {!r}: "b" flag not allowed when opening '
+ 'file in text mode'.format(mode))
+ if 't' not in mode:
+ mode += 't'
+ cm = self._open(mode, encoding, errors)
+ return cm
+
+ def open_bytes(self, mode='rb'):
+ # type: (Text) -> ContextManager[BinaryIO]
+ """Return a context manager for opening the file in binary mode.
+
+ Args:
+ mode: The mode to open the file in. The "b" mode is implicit if not
+ already present. It must not have the "t" flag.
+
+ Returns:
+ Context manager that yields an open file.
+
+ Raises:
+ ValueError: if invalid inputs are provided.
+ """
+ if 't' in mode:
+ raise ValueError('Invalid mode {!r}: "t" flag not allowed when opening '
+ 'file in binary mode'.format(mode))
+ if 'b' not in mode:
+ mode += 'b'
+ cm = self._open(mode, encoding=None, errors=None)
+ return cm
+
+ # TODO(b/123775699): Once pytype supports typing.Literal, use overload and
+ # Literal to express more precise return types. The contained type is
+ # currently `Any` to avoid [bad-return-type] errors in the open_* methods.
+ @contextlib.contextmanager
+ def _open(self, mode, encoding='utf8', errors='strict'):
+ # type: (Text, Text, Text) -> Iterator[Any]
+ with io.open(
+ self.full_path, mode=mode, encoding=encoding, errors=errors) as fp:
+ yield fp
+
+
+class _method(object):
+ """A decorator that supports both instance and classmethod invocations.
+
+ Using similar semantics to the @property builtin, this decorator can augment
+ an instance method to support conditional logic when invoked on a class
+ object. This breaks support for invoking an instance method via the class
+ (e.g. Cls.method(self, ...)) but is still situationally useful.
+ """
+
+ def __init__(self, finstancemethod):
+ # type: (Callable[..., Any]) -> None
+ self._finstancemethod = finstancemethod
+ self._fclassmethod = None
+
+ def classmethod(self, fclassmethod):
+ # type: (Callable[..., Any]) -> _method
+ self._fclassmethod = classmethod(fclassmethod)
+ return self
+
+ def __doc__(self):
+ # type: () -> str
+ if getattr(self._finstancemethod, '__doc__'):
+ return self._finstancemethod.__doc__
+ elif getattr(self._fclassmethod, '__doc__'):
+ return self._fclassmethod.__doc__
+ return ''
+
+ def __get__(self, obj, type_):
+ # type: (Optional[Any], Optional[Type[Any]]) -> Callable[..., Any]
+ func = self._fclassmethod if obj is None else self._finstancemethod
+ return func.__get__(obj, type_) # pytype: disable=attribute-error
+
+
+class TestCase(unittest.TestCase):
+ """Extension of unittest.TestCase providing more power."""
+
+ # When to cleanup files/directories created by our `create_tempfile()` and
+ # `create_tempdir()` methods after each test case completes. This does *not*
+ # affect e.g., files created outside of those methods, e.g., using the stdlib
+ # tempfile module. This can be overridden at the class level, instance level,
+ # or with the `cleanup` arg of `create_tempfile()` and `create_tempdir()`. See
+ # `TempFileCleanup` for details on the different values.
+ # TODO(b/70517332): Remove the type comment and the disable once pytype has
+ # better support for enums.
+ tempfile_cleanup = TempFileCleanup.ALWAYS # type: TempFileCleanup # pytype: disable=annotation-type-mismatch
+
+ maxDiff = 80 * 20
+ longMessage = True
+
+ # Exit stacks for per-test and per-class scopes.
+ _exit_stack = None
+ _cls_exit_stack = None
+
+ def __init__(self, *args, **kwargs):
+ super(TestCase, self).__init__(*args, **kwargs)
+ # This is to work around missing type stubs in unittest.pyi
+ self._outcome = getattr(self, '_outcome') # type: Optional[_OutcomeType]
+
+ def setUp(self):
+ super(TestCase, self).setUp()
+ # NOTE: Only Python 3 contextlib has ExitStack
+ if hasattr(contextlib, 'ExitStack'):
+ self._exit_stack = contextlib.ExitStack()
+ self.addCleanup(self._exit_stack.close)
+
+ @classmethod
+ def setUpClass(cls):
+ super(TestCase, cls).setUpClass()
+ # NOTE: Only Python 3 contextlib has ExitStack and only Python 3.8+ has
+ # addClassCleanup.
+ if hasattr(contextlib, 'ExitStack') and hasattr(cls, 'addClassCleanup'):
+ cls._cls_exit_stack = contextlib.ExitStack()
+ cls.addClassCleanup(cls._cls_exit_stack.close)
+
+ def create_tempdir(self, name=None, cleanup=None):
+ # type: (Optional[Text], Optional[TempFileCleanup]) -> _TempDir
+ """Create a temporary directory specific to the test.
+
+ NOTE: The directory and its contents will be recursively cleared before
+ creation. This ensures that there is no pre-existing state.
+
+ This creates a named directory on disk that is isolated to this test, and
+ will be properly cleaned up by the test. This avoids several pitfalls of
+ creating temporary directories for test purposes, as well as makes it easier
+ to setup directories and verify their contents. For example:
+
+ def test_foo(self):
+ out_dir = self.create_tempdir()
+ out_log = out_dir.create_file('output.log')
+ expected_outputs = [
+ os.path.join(out_dir, 'data-0.txt'),
+ os.path.join(out_dir, 'data-1.txt'),
+ ]
+ code_under_test(out_dir)
+ self.assertTrue(os.path.exists(expected_paths[0]))
+ self.assertTrue(os.path.exists(expected_paths[1]))
+ self.assertEqual('foo', out_log.read_text())
+
+ See also: `create_tempfile()` for creating temporary files.
+
+ Args:
+ name: Optional name of the directory. If not given, a unique
+ name will be generated and used.
+ cleanup: Optional cleanup policy on when/if to remove the directory (and
+ all its contents) at the end of the test. If None, then uses
+ `self.tempfile_cleanup`.
+
+ Returns:
+ A _TempDir representing the created directory; see _TempDir class docs
+ for usage.
+ """
+ test_path = self._get_tempdir_path_test()
+
+ if name:
+ path = os.path.join(test_path, name)
+ cleanup_path = os.path.join(test_path, _get_first_part(name))
+ else:
+ os.makedirs(test_path, exist_ok=True)
+ path = tempfile.mkdtemp(dir=test_path)
+ cleanup_path = path
+
+ _rmtree_ignore_errors(cleanup_path)
+ os.makedirs(path, exist_ok=True)
+
+ self._maybe_add_temp_path_cleanup(cleanup_path, cleanup)
+
+ return _TempDir(path)
+
+ # pylint: disable=line-too-long
+ def create_tempfile(self, file_path=None, content=None, mode='w',
+ encoding='utf8', errors='strict', cleanup=None):
+ # type: (Optional[Text], Optional[AnyStr], Text, Text, Text, Optional[TempFileCleanup]) -> _TempFile
+ # pylint: enable=line-too-long
+ """Create a temporary file specific to the test.
+
+ This creates a named file on disk that is isolated to this test, and will
+ be properly cleaned up by the test. This avoids several pitfalls of
+ creating temporary files for test purposes, as well as makes it easier
+ to setup files, their data, read them back, and inspect them when
+ a test fails. For example:
+
+ def test_foo(self):
+ output = self.create_tempfile()
+ code_under_test(output)
+ self.assertGreater(os.path.getsize(output), 0)
+ self.assertEqual('foo', output.read_text())
+
+ NOTE: This will zero-out the file. This ensures there is no pre-existing
+ state.
+ NOTE: If the file already exists, it will be made writable and overwritten.
+
+ See also: `create_tempdir()` for creating temporary directories, and
+ `_TempDir.create_file` for creating files within a temporary directory.
+
+ Args:
+ file_path: Optional file path for the temp file. If not given, a unique
+ file name will be generated and used. Slashes are allowed in the name;
+ any missing intermediate directories will be created. NOTE: This path is
+ the path that will be cleaned up, including any directories in the path,
+ e.g., 'foo/bar/baz.txt' will `rm -r foo`.
+ content: Optional string or
+ bytes to initially write to the file. If not
+ specified, then an empty file is created.
+ mode: Mode string to use when writing content. Only used if `content` is
+ non-empty.
+ encoding: Encoding to use when writing string content. Only used if
+ `content` is text.
+ errors: How to handle text to bytes encoding errors. Only used if
+ `content` is text.
+ cleanup: Optional cleanup policy on when/if to remove the directory (and
+ all its contents) at the end of the test. If None, then uses
+ `self.tempfile_cleanup`.
+
+ Returns:
+ A _TempFile representing the created file; see _TempFile class docs for
+ usage.
+ """
+ test_path = self._get_tempdir_path_test()
+ tf, cleanup_path = _TempFile._create(test_path, file_path, content=content,
+ mode=mode, encoding=encoding,
+ errors=errors)
+ self._maybe_add_temp_path_cleanup(cleanup_path, cleanup)
+ return tf
+
+ @_method
+ def enter_context(self, manager):
+ # type: (ContextManager[_T]) -> _T
+ """Returns the CM's value after registering it with the exit stack.
+
+ Entering a context pushes it onto a stack of contexts. When `enter_context`
+ is called on the test instance (e.g. `self.enter_context`), the context is
+ exited after the test case's tearDown call. When called on the test class
+ (e.g. `TestCase.enter_context`), the context is exited after the test
+ class's tearDownClass call.
+
+ Contexts are are exited in the reverse order of entering. They will always
+ be exited, regardless of test failure/success.
+
+ This is useful to eliminate per-test boilerplate when context managers
+ are used. For example, instead of decorating every test with `@mock.patch`,
+ simply do `self.foo = self.enter_context(mock.patch(...))' in `setUp()`.
+
+ NOTE: The context managers will always be exited without any error
+ information. This is an unfortunate implementation detail due to some
+ internals of how unittest runs tests.
+
+ Args:
+ manager: The context manager to enter.
+ """
+ if not self._exit_stack:
+ raise AssertionError(
+ 'self._exit_stack is not set: enter_context is Py3-only; also make '
+ 'sure that AbslTest.setUp() is called.')
+ return self._exit_stack.enter_context(manager)
+
+ @enter_context.classmethod
+ def enter_context(cls, manager): # pylint: disable=no-self-argument
+ # type: (ContextManager[_T]) -> _T
+ if not cls._cls_exit_stack:
+ raise AssertionError(
+ 'cls._cls_exit_stack is not set: cls.enter_context requires '
+ 'Python 3.8+; also make sure that AbslTest.setUpClass() is called.')
+ return cls._cls_exit_stack.enter_context(manager)
+
+ @classmethod
+ def _get_tempdir_path_cls(cls):
+ # type: () -> Text
+ return os.path.join(TEST_TMPDIR.value,
+ cls.__qualname__.replace('__main__.', ''))
+
+ def _get_tempdir_path_test(self):
+ # type: () -> Text
+ return os.path.join(self._get_tempdir_path_cls(), self._testMethodName)
+
+ def _get_tempfile_cleanup(self, override):
+ # type: (Optional[TempFileCleanup]) -> TempFileCleanup
+ if override is not None:
+ return override
+ return self.tempfile_cleanup
+
+ def _maybe_add_temp_path_cleanup(self, path, cleanup):
+ # type: (Text, Optional[TempFileCleanup]) -> None
+ cleanup = self._get_tempfile_cleanup(cleanup)
+ if cleanup == TempFileCleanup.OFF:
+ return
+ elif cleanup == TempFileCleanup.ALWAYS:
+ self.addCleanup(_rmtree_ignore_errors, path)
+ elif cleanup == TempFileCleanup.SUCCESS:
+ self._internal_cleanup_on_success(_rmtree_ignore_errors, path)
+ else:
+ raise AssertionError('Unexpected cleanup value: {}'.format(cleanup))
+
+ def _internal_cleanup_on_success(self, function, *args, **kwargs):
+ # type: (Callable[..., object], Any, Any) -> None
+ def _call_cleaner_on_success(*args, **kwargs):
+ if not self._ran_and_passed():
+ return
+ function(*args, **kwargs)
+ self.addCleanup(_call_cleaner_on_success, *args, **kwargs)
+
+ def _ran_and_passed(self):
+ # type: () -> bool
+ outcome = self._outcome
+ result = self.defaultTestResult()
+ self._feedErrorsToResult(result, outcome.errors) # pytype: disable=attribute-error
+ return result.wasSuccessful()
+
+ def shortDescription(self):
+ # type: () -> Text
+ """Formats both the test method name and the first line of its docstring.
+
+ If no docstring is given, only returns the method name.
+
+ This method overrides unittest.TestCase.shortDescription(), which
+ only returns the first line of the docstring, obscuring the name
+ of the test upon failure.
+
+ Returns:
+ desc: A short description of a test method.
+ """
+ desc = self.id()
+
+ # Omit the main name so that test name can be directly copy/pasted to
+ # the command line.
+ if desc.startswith('__main__.'):
+ desc = desc[len('__main__.'):]
+
+ # NOTE: super() is used here instead of directly invoking
+ # unittest.TestCase.shortDescription(self), because of the
+ # following line that occurs later on:
+ # unittest.TestCase = TestCase
+ # Because of this, direct invocation of what we think is the
+ # superclass will actually cause infinite recursion.
+ doc_first_line = super(TestCase, self).shortDescription()
+ if doc_first_line is not None:
+ desc = '\n'.join((desc, doc_first_line))
+ return desc
+
+ def assertStartsWith(self, actual, expected_start, msg=None):
+ """Asserts that actual.startswith(expected_start) is True.
+
+ Args:
+ actual: str
+ expected_start: str
+ msg: Optional message to report on failure.
+ """
+ if not actual.startswith(expected_start):
+ self.fail('%r does not start with %r' % (actual, expected_start), msg)
+
+ def assertNotStartsWith(self, actual, unexpected_start, msg=None):
+ """Asserts that actual.startswith(unexpected_start) is False.
+
+ Args:
+ actual: str
+ unexpected_start: str
+ msg: Optional message to report on failure.
+ """
+ if actual.startswith(unexpected_start):
+ self.fail('%r does start with %r' % (actual, unexpected_start), msg)
+
+ def assertEndsWith(self, actual, expected_end, msg=None):
+ """Asserts that actual.endswith(expected_end) is True.
+
+ Args:
+ actual: str
+ expected_end: str
+ msg: Optional message to report on failure.
+ """
+ if not actual.endswith(expected_end):
+ self.fail('%r does not end with %r' % (actual, expected_end), msg)
+
+ def assertNotEndsWith(self, actual, unexpected_end, msg=None):
+ """Asserts that actual.endswith(unexpected_end) is False.
+
+ Args:
+ actual: str
+ unexpected_end: str
+ msg: Optional message to report on failure.
+ """
+ if actual.endswith(unexpected_end):
+ self.fail('%r does end with %r' % (actual, unexpected_end), msg)
+
+ def assertSequenceStartsWith(self, prefix, whole, msg=None):
+ """An equality assertion for the beginning of ordered sequences.
+
+ If prefix is an empty sequence, it will raise an error unless whole is also
+ an empty sequence.
+
+ If prefix is not a sequence, it will raise an error if the first element of
+ whole does not match.
+
+ Args:
+ prefix: A sequence expected at the beginning of the whole parameter.
+ whole: The sequence in which to look for prefix.
+ msg: Optional message to report on failure.
+ """
+ try:
+ prefix_len = len(prefix)
+ except (TypeError, NotImplementedError):
+ prefix = [prefix]
+ prefix_len = 1
+
+ try:
+ whole_len = len(whole)
+ except (TypeError, NotImplementedError):
+ self.fail('For whole: len(%s) is not supported, it appears to be type: '
+ '%s' % (whole, type(whole)), msg)
+
+ assert prefix_len <= whole_len, self._formatMessage(
+ msg,
+ 'Prefix length (%d) is longer than whole length (%d).' %
+ (prefix_len, whole_len)
+ )
+
+ if not prefix_len and whole_len:
+ self.fail('Prefix length is 0 but whole length is %d: %s' %
+ (len(whole), whole), msg)
+
+ try:
+ self.assertSequenceEqual(prefix, whole[:prefix_len], msg)
+ except AssertionError:
+ self.fail('prefix: %s not found at start of whole: %s.' %
+ (prefix, whole), msg)
+
+ def assertEmpty(self, container, msg=None):
+ """Asserts that an object has zero length.
+
+ Args:
+ container: Anything that implements the collections.abc.Sized interface.
+ msg: Optional message to report on failure.
+ """
+ if not isinstance(container, abc.Sized):
+ self.fail('Expected a Sized object, got: '
+ '{!r}'.format(type(container).__name__), msg)
+
+ # explicitly check the length since some Sized objects (e.g. numpy.ndarray)
+ # have strange __nonzero__/__bool__ behavior.
+ if len(container): # pylint: disable=g-explicit-length-test
+ self.fail('{!r} has length of {}.'.format(container, len(container)), msg)
+
+ def assertNotEmpty(self, container, msg=None):
+ """Asserts that an object has non-zero length.
+
+ Args:
+ container: Anything that implements the collections.abc.Sized interface.
+ msg: Optional message to report on failure.
+ """
+ if not isinstance(container, abc.Sized):
+ self.fail('Expected a Sized object, got: '
+ '{!r}'.format(type(container).__name__), msg)
+
+ # explicitly check the length since some Sized objects (e.g. numpy.ndarray)
+ # have strange __nonzero__/__bool__ behavior.
+ if not len(container): # pylint: disable=g-explicit-length-test
+ self.fail('{!r} has length of 0.'.format(container), msg)
+
+ def assertLen(self, container, expected_len, msg=None):
+ """Asserts that an object has the expected length.
+
+ Args:
+ container: Anything that implements the collections.abc.Sized interface.
+ expected_len: The expected length of the container.
+ msg: Optional message to report on failure.
+ """
+ if not isinstance(container, abc.Sized):
+ self.fail('Expected a Sized object, got: '
+ '{!r}'.format(type(container).__name__), msg)
+ if len(container) != expected_len:
+ container_repr = unittest.util.safe_repr(container) # pytype: disable=module-attr
+ self.fail('{} has length of {}, expected {}.'.format(
+ container_repr, len(container), expected_len), msg)
+
+ def assertSequenceAlmostEqual(self, expected_seq, actual_seq, places=None,
+ msg=None, delta=None):
+ """An approximate equality assertion for ordered sequences.
+
+ Fail if the two sequences are unequal as determined by their value
+ differences rounded to the given number of decimal places (default 7) and
+ comparing to zero, or by comparing that the difference between each value
+ in the two sequences is more than the given delta.
+
+ Note that decimal places (from zero) are usually not the same as significant
+ digits (measured from the most significant digit).
+
+ If the two sequences compare equal then they will automatically compare
+ almost equal.
+
+ Args:
+ expected_seq: A sequence containing elements we are expecting.
+ actual_seq: The sequence that we are testing.
+ places: The number of decimal places to compare.
+ msg: The message to be printed if the test fails.
+ delta: The OK difference between compared values.
+ """
+ if len(expected_seq) != len(actual_seq):
+ self.fail('Sequence size mismatch: {} vs {}'.format(
+ len(expected_seq), len(actual_seq)), msg)
+
+ err_list = []
+ for idx, (exp_elem, act_elem) in enumerate(zip(expected_seq, actual_seq)):
+ try:
+ # assertAlmostEqual should be called with at most one of `places` and
+ # `delta`. However, it's okay for assertSequenceAlmostEqual to pass
+ # both because we want the latter to fail if the former does.
+ # pytype: disable=wrong-keyword-args
+ self.assertAlmostEqual(exp_elem, act_elem, places=places, msg=msg,
+ delta=delta)
+ # pytype: enable=wrong-keyword-args
+ except self.failureException as err:
+ err_list.append('At index {}: {}'.format(idx, err))
+
+ if err_list:
+ if len(err_list) > 30:
+ err_list = err_list[:30] + ['...']
+ msg = self._formatMessage(msg, '\n'.join(err_list))
+ self.fail(msg)
+
+ def assertContainsSubset(self, expected_subset, actual_set, msg=None):
+ """Checks whether actual iterable is a superset of expected iterable."""
+ missing = set(expected_subset) - set(actual_set)
+ if not missing:
+ return
+
+ self.fail('Missing elements %s\nExpected: %s\nActual: %s' % (
+ missing, expected_subset, actual_set), msg)
+
+ def assertNoCommonElements(self, expected_seq, actual_seq, msg=None):
+ """Checks whether actual iterable and expected iterable are disjoint."""
+ common = set(expected_seq) & set(actual_seq)
+ if not common:
+ return
+
+ self.fail('Common elements %s\nExpected: %s\nActual: %s' % (
+ common, expected_seq, actual_seq), msg)
+
+ def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
+ """Deprecated, please use assertCountEqual instead.
+
+ This is equivalent to assertCountEqual.
+
+ Args:
+ expected_seq: A sequence containing elements we are expecting.
+ actual_seq: The sequence that we are testing.
+ msg: The message to be printed if the test fails.
+ """
+ super().assertCountEqual(expected_seq, actual_seq, msg)
+
+ def assertSameElements(self, expected_seq, actual_seq, msg=None):
+ """Asserts that two sequences have the same elements (in any order).
+
+ This method, unlike assertCountEqual, doesn't care about any
+ duplicates in the expected and actual sequences.
+
+ >> assertSameElements([1, 1, 1, 0, 0, 0], [0, 1])
+ # Doesn't raise an AssertionError
+
+ If possible, you should use assertCountEqual instead of
+ assertSameElements.
+
+ Args:
+ expected_seq: A sequence containing elements we are expecting.
+ actual_seq: The sequence that we are testing.
+ msg: The message to be printed if the test fails.
+ """
+ # `unittest2.TestCase` used to have assertSameElements, but it was
+ # removed in favor of assertItemsEqual. As there's a unit test
+ # that explicitly checks this behavior, I am leaving this method
+ # alone.
+ # Fail on strings: empirically, passing strings to this test method
+ # is almost always a bug. If comparing the character sets of two strings
+ # is desired, cast the inputs to sets or lists explicitly.
+ if (isinstance(expected_seq, _TEXT_OR_BINARY_TYPES) or
+ isinstance(actual_seq, _TEXT_OR_BINARY_TYPES)):
+ self.fail('Passing string/bytes to assertSameElements is usually a bug. '
+ 'Did you mean to use assertEqual?\n'
+ 'Expected: %s\nActual: %s' % (expected_seq, actual_seq))
+ try:
+ expected = dict([(element, None) for element in expected_seq])
+ actual = dict([(element, None) for element in actual_seq])
+ missing = [element for element in expected if element not in actual]
+ unexpected = [element for element in actual if element not in expected]
+ missing.sort()
+ unexpected.sort()
+ except TypeError:
+ # Fall back to slower list-compare if any of the objects are
+ # not hashable.
+ expected = list(expected_seq)
+ actual = list(actual_seq)
+ expected.sort()
+ actual.sort()
+ missing, unexpected = _sorted_list_difference(expected, actual)
+ errors = []
+ if msg:
+ errors.extend((msg, ':\n'))
+ if missing:
+ errors.append('Expected, but missing:\n %r\n' % missing)
+ if unexpected:
+ errors.append('Unexpected, but present:\n %r\n' % unexpected)
+ if missing or unexpected:
+ self.fail(''.join(errors))
+
+ # unittest.TestCase.assertMultiLineEqual works very similarly, but it
+ # has a different error format. However, I find this slightly more readable.
+ def assertMultiLineEqual(self, first, second, msg=None, **kwargs):
+ """Asserts that two multi-line strings are equal."""
+ assert isinstance(first,
+ str), ('First argument is not a string: %r' % (first,))
+ assert isinstance(second,
+ str), ('Second argument is not a string: %r' % (second,))
+ line_limit = kwargs.pop('line_limit', 0)
+ if kwargs:
+ raise TypeError('Unexpected keyword args {}'.format(tuple(kwargs)))
+
+ if first == second:
+ return
+ if msg:
+ failure_message = [msg + ':\n']
+ else:
+ failure_message = ['\n']
+ if line_limit:
+ line_limit += len(failure_message)
+ for line in difflib.ndiff(first.splitlines(True), second.splitlines(True)):
+ failure_message.append(line)
+ if not line.endswith('\n'):
+ failure_message.append('\n')
+ if line_limit and len(failure_message) > line_limit:
+ n_omitted = len(failure_message) - line_limit
+ failure_message = failure_message[:line_limit]
+ failure_message.append(
+ '(... and {} more delta lines omitted for brevity.)\n'.format(
+ n_omitted))
+
+ raise self.failureException(''.join(failure_message))
+
+ def assertBetween(self, value, minv, maxv, msg=None):
+ """Asserts that value is between minv and maxv (inclusive)."""
+ msg = self._formatMessage(msg,
+ '"%r" unexpectedly not between "%r" and "%r"' %
+ (value, minv, maxv))
+ self.assertTrue(minv <= value, msg)
+ self.assertTrue(maxv >= value, msg)
+
+ def assertRegexMatch(self, actual_str, regexes, message=None):
+ r"""Asserts that at least one regex in regexes matches str.
+
+ If possible you should use `assertRegex`, which is a simpler
+ version of this method. `assertRegex` takes a single regular
+ expression (a string or re compiled object) instead of a list.
+
+ Notes:
+ 1. This function uses substring matching, i.e. the matching
+ succeeds if *any* substring of the error message matches *any*
+ regex in the list. This is more convenient for the user than
+ full-string matching.
+
+ 2. If regexes is the empty list, the matching will always fail.
+
+ 3. Use regexes=[''] for a regex that will always pass.
+
+ 4. '.' matches any single character *except* the newline. To
+ match any character, use '(.|\n)'.
+
+ 5. '^' matches the beginning of each line, not just the beginning
+ of the string. Similarly, '$' matches the end of each line.
+
+ 6. An exception will be thrown if regexes contains an invalid
+ regex.
+
+ Args:
+ actual_str: The string we try to match with the items in regexes.
+ regexes: The regular expressions we want to match against str.
+ See "Notes" above for detailed notes on how this is interpreted.
+ message: The message to be printed if the test fails.
+ """
+ if isinstance(regexes, _TEXT_OR_BINARY_TYPES):
+ self.fail('regexes is string or bytes; use assertRegex instead.',
+ message)
+ if not regexes:
+ self.fail('No regexes specified.', message)
+
+ regex_type = type(regexes[0])
+ for regex in regexes[1:]:
+ if type(regex) is not regex_type: # pylint: disable=unidiomatic-typecheck
+ self.fail('regexes list must all be the same type.', message)
+
+ if regex_type is bytes and isinstance(actual_str, str):
+ regexes = [regex.decode('utf-8') for regex in regexes]
+ regex_type = str
+ elif regex_type is str and isinstance(actual_str, bytes):
+ regexes = [regex.encode('utf-8') for regex in regexes]
+ regex_type = bytes
+
+ if regex_type is str:
+ regex = u'(?:%s)' % u')|(?:'.join(regexes)
+ elif regex_type is bytes:
+ regex = b'(?:' + (b')|(?:'.join(regexes)) + b')'
+ else:
+ self.fail('Only know how to deal with unicode str or bytes regexes.',
+ message)
+
+ if not re.search(regex, actual_str, re.MULTILINE):
+ self.fail('"%s" does not contain any of these regexes: %s.' %
+ (actual_str, regexes), message)
+
+ def assertCommandSucceeds(self, command, regexes=(b'',), env=None,
+ close_fds=True, msg=None):
+ """Asserts that a shell command succeeds (i.e. exits with code 0).
+
+ Args:
+ command: List or string representing the command to run.
+ regexes: List of regular expression byte strings that match success.
+ env: Dictionary of environment variable settings. If None, no environment
+ variables will be set for the child process. This is to make tests
+ more hermetic. NOTE: this behavior is different than the standard
+ subprocess module.
+ close_fds: Whether or not to close all open fd's in the child after
+ forking.
+ msg: Optional message to report on failure.
+ """
+ (ret_code, err) = get_command_stderr(command, env, close_fds)
+
+ # We need bytes regexes here because `err` is bytes.
+ # Accommodate code which listed their output regexes w/o the b'' prefix by
+ # converting them to bytes for the user.
+ if isinstance(regexes[0], str):
+ regexes = [regex.encode('utf-8') for regex in regexes]
+
+ command_string = get_command_string(command)
+ self.assertEqual(
+ ret_code, 0,
+ self._formatMessage(msg,
+ 'Running command\n'
+ '%s failed with error code %s and message\n'
+ '%s' % (_quote_long_string(command_string),
+ ret_code,
+ _quote_long_string(err)))
+ )
+ self.assertRegexMatch(
+ err,
+ regexes,
+ message=self._formatMessage(
+ msg,
+ 'Running command\n'
+ '%s failed with error code %s and message\n'
+ '%s which matches no regex in %s' % (
+ _quote_long_string(command_string),
+ ret_code,
+ _quote_long_string(err),
+ regexes)))
+
+ def assertCommandFails(self, command, regexes, env=None, close_fds=True,
+ msg=None):
+ """Asserts a shell command fails and the error matches a regex in a list.
+
+ Args:
+ command: List or string representing the command to run.
+ regexes: the list of regular expression strings.
+ env: Dictionary of environment variable settings. If None, no environment
+ variables will be set for the child process. This is to make tests
+ more hermetic. NOTE: this behavior is different than the standard
+ subprocess module.
+ close_fds: Whether or not to close all open fd's in the child after
+ forking.
+ msg: Optional message to report on failure.
+ """
+ (ret_code, err) = get_command_stderr(command, env, close_fds)
+
+ # We need bytes regexes here because `err` is bytes.
+ # Accommodate code which listed their output regexes w/o the b'' prefix by
+ # converting them to bytes for the user.
+ if isinstance(regexes[0], str):
+ regexes = [regex.encode('utf-8') for regex in regexes]
+
+ command_string = get_command_string(command)
+ self.assertNotEqual(
+ ret_code, 0,
+ self._formatMessage(msg, 'The following command succeeded '
+ 'while expected to fail:\n%s' %
+ _quote_long_string(command_string)))
+ self.assertRegexMatch(
+ err,
+ regexes,
+ message=self._formatMessage(
+ msg,
+ 'Running command\n'
+ '%s failed with error code %s and message\n'
+ '%s which matches no regex in %s' % (
+ _quote_long_string(command_string),
+ ret_code,
+ _quote_long_string(err),
+ regexes)))
+
+ class _AssertRaisesContext(object):
+
+ def __init__(self, expected_exception, test_case, test_func, msg=None):
+ self.expected_exception = expected_exception
+ self.test_case = test_case
+ self.test_func = test_func
+ self.msg = msg
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, tb):
+ if exc_type is None:
+ self.test_case.fail(self.expected_exception.__name__ + ' not raised',
+ self.msg)
+ if not issubclass(exc_type, self.expected_exception):
+ return False
+ self.test_func(exc_value)
+ return True
+
+ @typing.overload
+ def assertRaisesWithPredicateMatch(
+ self, expected_exception, predicate) -> _AssertRaisesContext:
+ # The purpose of this return statement is to work around
+ # https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored.
+ return self._AssertRaisesContext(None, None, None)
+
+ @typing.overload
+ def assertRaisesWithPredicateMatch(
+ self, expected_exception, predicate, callable_obj: Callable[..., Any],
+ *args, **kwargs) -> None:
+ # The purpose of this return statement is to work around
+ # https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored.
+ return self._AssertRaisesContext(None, None, None)
+
+ def assertRaisesWithPredicateMatch(self, expected_exception, predicate,
+ callable_obj=None, *args, **kwargs):
+ """Asserts that exception is thrown and predicate(exception) is true.
+
+ Args:
+ expected_exception: Exception class expected to be raised.
+ predicate: Function of one argument that inspects the passed-in exception
+ and returns True (success) or False (please fail the test).
+ callable_obj: Function to be called.
+ *args: Extra args.
+ **kwargs: Extra keyword args.
+
+ Returns:
+ A context manager if callable_obj is None. Otherwise, None.
+
+ Raises:
+ self.failureException if callable_obj does not raise a matching exception.
+ """
+ def Check(err):
+ self.assertTrue(predicate(err),
+ '%r does not match predicate %r' % (err, predicate))
+
+ context = self._AssertRaisesContext(expected_exception, self, Check)
+ if callable_obj is None:
+ return context
+ with context:
+ callable_obj(*args, **kwargs)
+
+ @typing.overload
+ def assertRaisesWithLiteralMatch(
+ self, expected_exception, expected_exception_message
+ ) -> _AssertRaisesContext:
+ # The purpose of this return statement is to work around
+ # https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored.
+ return self._AssertRaisesContext(None, None, None)
+
+ @typing.overload
+ def assertRaisesWithLiteralMatch(
+ self, expected_exception, expected_exception_message,
+ callable_obj: Callable[..., Any], *args, **kwargs) -> None:
+ # The purpose of this return statement is to work around
+ # https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored.
+ return self._AssertRaisesContext(None, None, None)
+
+ def assertRaisesWithLiteralMatch(self, expected_exception,
+ expected_exception_message,
+ callable_obj=None, *args, **kwargs):
+ """Asserts that the message in a raised exception equals the given string.
+
+ Unlike assertRaisesRegex, this method takes a literal string, not
+ a regular expression.
+
+ with self.assertRaisesWithLiteralMatch(ExType, 'message'):
+ DoSomething()
+
+ Args:
+ expected_exception: Exception class expected to be raised.
+ expected_exception_message: String message expected in the raised
+ exception. For a raise exception e, expected_exception_message must
+ equal str(e).
+ callable_obj: Function to be called, or None to return a context.
+ *args: Extra args.
+ **kwargs: Extra kwargs.
+
+ Returns:
+ A context manager if callable_obj is None. Otherwise, None.
+
+ Raises:
+ self.failureException if callable_obj does not raise a matching exception.
+ """
+ def Check(err):
+ actual_exception_message = str(err)
+ self.assertTrue(expected_exception_message == actual_exception_message,
+ 'Exception message does not match.\n'
+ 'Expected: %r\n'
+ 'Actual: %r' % (expected_exception_message,
+ actual_exception_message))
+
+ context = self._AssertRaisesContext(expected_exception, self, Check)
+ if callable_obj is None:
+ return context
+ with context:
+ callable_obj(*args, **kwargs)
+
+ def assertContainsInOrder(self, strings, target, msg=None):
+ """Asserts that the strings provided are found in the target in order.
+
+ This may be useful for checking HTML output.
+
+ Args:
+ strings: A list of strings, such as [ 'fox', 'dog' ]
+ target: A target string in which to look for the strings, such as
+ 'The quick brown fox jumped over the lazy dog'.
+ msg: Optional message to report on failure.
+ """
+ if isinstance(strings, (bytes, unicode if str is bytes else str)):
+ strings = (strings,)
+
+ current_index = 0
+ last_string = None
+ for string in strings:
+ index = target.find(str(string), current_index)
+ if index == -1 and current_index == 0:
+ self.fail("Did not find '%s' in '%s'" %
+ (string, target), msg)
+ elif index == -1:
+ self.fail("Did not find '%s' after '%s' in '%s'" %
+ (string, last_string, target), msg)
+ last_string = string
+ current_index = index
+
+ def assertContainsSubsequence(self, container, subsequence, msg=None):
+ """Asserts that "container" contains "subsequence" as a subsequence.
+
+ Asserts that "container" contains all the elements of "subsequence", in
+ order, but possibly with other elements interspersed. For example, [1, 2, 3]
+ is a subsequence of [0, 0, 1, 2, 0, 3, 0] but not of [0, 0, 1, 3, 0, 2, 0].
+
+ Args:
+ container: the list we're testing for subsequence inclusion.
+ subsequence: the list we hope will be a subsequence of container.
+ msg: Optional message to report on failure.
+ """
+ first_nonmatching = None
+ reversed_container = list(reversed(container))
+ subsequence = list(subsequence)
+
+ for e in subsequence:
+ if e not in reversed_container:
+ first_nonmatching = e
+ break
+ while e != reversed_container.pop():
+ pass
+
+ if first_nonmatching is not None:
+ self.fail('%s not a subsequence of %s. First non-matching element: %s' %
+ (subsequence, container, first_nonmatching), msg)
+
+ def assertContainsExactSubsequence(self, container, subsequence, msg=None):
+ """Asserts that "container" contains "subsequence" as an exact subsequence.
+
+ Asserts that "container" contains all the elements of "subsequence", in
+ order, and without other elements interspersed. For example, [1, 2, 3] is an
+ exact subsequence of [0, 0, 1, 2, 3, 0] but not of [0, 0, 1, 2, 0, 3, 0].
+
+ Args:
+ container: the list we're testing for subsequence inclusion.
+ subsequence: the list we hope will be an exact subsequence of container.
+ msg: Optional message to report on failure.
+ """
+ container = list(container)
+ subsequence = list(subsequence)
+ longest_match = 0
+
+ for start in range(1 + len(container) - len(subsequence)):
+ if longest_match == len(subsequence):
+ break
+ index = 0
+ while (index < len(subsequence) and
+ subsequence[index] == container[start + index]):
+ index += 1
+ longest_match = max(longest_match, index)
+
+ if longest_match < len(subsequence):
+ self.fail('%s not an exact subsequence of %s. '
+ 'Longest matching prefix: %s' %
+ (subsequence, container, subsequence[:longest_match]), msg)
+
+ def assertTotallyOrdered(self, *groups, **kwargs):
+ """Asserts that total ordering has been implemented correctly.
+
+ For example, say you have a class A that compares only on its attribute x.
+ Comparators other than __lt__ are omitted for brevity.
+
+ class A(object):
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __hash__(self):
+ return hash(self.x)
+
+ def __lt__(self, other):
+ try:
+ return self.x < other.x
+ except AttributeError:
+ return NotImplemented
+
+ assertTotallyOrdered will check that instances can be ordered correctly.
+ For example,
+
+ self.assertTotallyOrdered(
+ [None], # None should come before everything else.
+ [1], # Integers sort earlier.
+ [A(1, 'a')],
+ [A(2, 'b')], # 2 is after 1.
+ [A(3, 'c'), A(3, 'd')], # The second argument is irrelevant.
+ [A(4, 'z')],
+ ['foo']) # Strings sort last.
+
+ Args:
+ *groups: A list of groups of elements. Each group of elements is a list
+ of objects that are equal. The elements in each group must be less
+ than the elements in the group after it. For example, these groups are
+ totally ordered: [None], [1], [2, 2], [3].
+ **kwargs: optional msg keyword argument can be passed.
+ """
+
+ def CheckOrder(small, big):
+ """Ensures small is ordered before big."""
+ self.assertFalse(small == big,
+ self._formatMessage(msg, '%r unexpectedly equals %r' %
+ (small, big)))
+ self.assertTrue(small != big,
+ self._formatMessage(msg, '%r unexpectedly equals %r' %
+ (small, big)))
+ self.assertLess(small, big, msg)
+ self.assertFalse(big < small,
+ self._formatMessage(msg,
+ '%r unexpectedly less than %r' %
+ (big, small)))
+ self.assertLessEqual(small, big, msg)
+ self.assertFalse(big <= small, self._formatMessage(
+ '%r unexpectedly less than or equal to %r' % (big, small), msg
+ ))
+ self.assertGreater(big, small, msg)
+ self.assertFalse(small > big,
+ self._formatMessage(msg,
+ '%r unexpectedly greater than %r' %
+ (small, big)))
+ self.assertGreaterEqual(big, small)
+ self.assertFalse(small >= big, self._formatMessage(
+ msg,
+ '%r unexpectedly greater than or equal to %r' % (small, big)))
+
+ def CheckEqual(a, b):
+ """Ensures that a and b are equal."""
+ self.assertEqual(a, b, msg)
+ self.assertFalse(a != b,
+ self._formatMessage(msg, '%r unexpectedly unequals %r' %
+ (a, b)))
+
+ # Objects that compare equal must hash to the same value, but this only
+ # applies if both objects are hashable.
+ if (isinstance(a, abc.Hashable) and
+ isinstance(b, abc.Hashable)):
+ self.assertEqual(
+ hash(a), hash(b),
+ self._formatMessage(
+ msg, 'hash %d of %r unexpectedly not equal to hash %d of %r' %
+ (hash(a), a, hash(b), b)))
+
+ self.assertFalse(a < b,
+ self._formatMessage(msg,
+ '%r unexpectedly less than %r' %
+ (a, b)))
+ self.assertFalse(b < a,
+ self._formatMessage(msg,
+ '%r unexpectedly less than %r' %
+ (b, a)))
+ self.assertLessEqual(a, b, msg)
+ self.assertLessEqual(b, a, msg) # pylint: disable=arguments-out-of-order
+ self.assertFalse(a > b,
+ self._formatMessage(msg,
+ '%r unexpectedly greater than %r' %
+ (a, b)))
+ self.assertFalse(b > a,
+ self._formatMessage(msg,
+ '%r unexpectedly greater than %r' %
+ (b, a)))
+ self.assertGreaterEqual(a, b, msg)
+ self.assertGreaterEqual(b, a, msg) # pylint: disable=arguments-out-of-order
+
+ msg = kwargs.get('msg')
+
+ # For every combination of elements, check the order of every pair of
+ # elements.
+ for elements in itertools.product(*groups):
+ elements = list(elements)
+ for index, small in enumerate(elements[:-1]):
+ for big in elements[index + 1:]:
+ CheckOrder(small, big)
+
+ # Check that every element in each group is equal.
+ for group in groups:
+ for a in group:
+ CheckEqual(a, a)
+ for a, b in itertools.product(group, group):
+ CheckEqual(a, b)
+
+ def assertDictEqual(self, a, b, msg=None):
+ """Raises AssertionError if a and b are not equal dictionaries.
+
+ Args:
+ a: A dict, the expected value.
+ b: A dict, the actual value.
+ msg: An optional str, the associated message.
+
+ Raises:
+ AssertionError: if the dictionaries are not equal.
+ """
+ self.assertIsInstance(a, dict, self._formatMessage(
+ msg,
+ 'First argument is not a dictionary'
+ ))
+ self.assertIsInstance(b, dict, self._formatMessage(
+ msg,
+ 'Second argument is not a dictionary'
+ ))
+
+ def Sorted(list_of_items):
+ try:
+ return sorted(list_of_items) # In 3.3, unordered are possible.
+ except TypeError:
+ return list_of_items
+
+ if a == b:
+ return
+ a_items = Sorted(list(a.items()))
+ b_items = Sorted(list(b.items()))
+
+ unexpected = []
+ missing = []
+ different = []
+
+ safe_repr = unittest.util.safe_repr # pytype: disable=module-attr
+
+ def Repr(dikt):
+ """Deterministic repr for dict."""
+ # Sort the entries based on their repr, not based on their sort order,
+ # which will be non-deterministic across executions, for many types.
+ entries = sorted((safe_repr(k), safe_repr(v)) for k, v in dikt.items())
+ return '{%s}' % (', '.join('%s: %s' % pair for pair in entries))
+
+ message = ['%s != %s%s' % (Repr(a), Repr(b), ' (%s)' % msg if msg else '')]
+
+ # The standard library default output confounds lexical difference with
+ # value difference; treat them separately.
+ for a_key, a_value in a_items:
+ if a_key not in b:
+ missing.append((a_key, a_value))
+ elif a_value != b[a_key]:
+ different.append((a_key, a_value, b[a_key]))
+
+ for b_key, b_value in b_items:
+ if b_key not in a:
+ unexpected.append((b_key, b_value))
+
+ if unexpected:
+ message.append(
+ 'Unexpected, but present entries:\n%s' % ''.join(
+ '%s: %s\n' % (safe_repr(k), safe_repr(v)) for k, v in unexpected))
+
+ if different:
+ message.append(
+ 'repr() of differing entries:\n%s' % ''.join(
+ '%s: %s != %s\n' % (safe_repr(k), safe_repr(a_value),
+ safe_repr(b_value))
+ for k, a_value, b_value in different))
+
+ if missing:
+ message.append(
+ 'Missing entries:\n%s' % ''.join(
+ ('%s: %s\n' % (safe_repr(k), safe_repr(v)) for k, v in missing)))
+
+ raise self.failureException('\n'.join(message))
+
+ def assertUrlEqual(self, a, b, msg=None):
+ """Asserts that urls are equal, ignoring ordering of query params."""
+ parsed_a = parse.urlparse(a)
+ parsed_b = parse.urlparse(b)
+ self.assertEqual(parsed_a.scheme, parsed_b.scheme, msg)
+ self.assertEqual(parsed_a.netloc, parsed_b.netloc, msg)
+ self.assertEqual(parsed_a.path, parsed_b.path, msg)
+ self.assertEqual(parsed_a.fragment, parsed_b.fragment, msg)
+ self.assertEqual(sorted(parsed_a.params.split(';')),
+ sorted(parsed_b.params.split(';')), msg)
+ self.assertDictEqual(
+ parse.parse_qs(parsed_a.query, keep_blank_values=True),
+ parse.parse_qs(parsed_b.query, keep_blank_values=True), msg)
+
+ def assertSameStructure(self, a, b, aname='a', bname='b', msg=None):
+ """Asserts that two values contain the same structural content.
+
+ The two arguments should be data trees consisting of trees of dicts and
+ lists. They will be deeply compared by walking into the contents of dicts
+ and lists; other items will be compared using the == operator.
+ If the two structures differ in content, the failure message will indicate
+ the location within the structures where the first difference is found.
+ This may be helpful when comparing large structures.
+
+ Mixed Sequence and Set types are supported. Mixed Mapping types are
+ supported, but the order of the keys will not be considered in the
+ comparison.
+
+ Args:
+ a: The first structure to compare.
+ b: The second structure to compare.
+ aname: Variable name to use for the first structure in assertion messages.
+ bname: Variable name to use for the second structure.
+ msg: Additional text to include in the failure message.
+ """
+
+ # Accumulate all the problems found so we can report all of them at once
+ # rather than just stopping at the first
+ problems = []
+
+ _walk_structure_for_problems(a, b, aname, bname, problems)
+
+ # Avoid spamming the user toooo much
+ if self.maxDiff is not None:
+ max_problems_to_show = self.maxDiff // 80
+ if len(problems) > max_problems_to_show:
+ problems = problems[0:max_problems_to_show-1] + ['...']
+
+ if problems:
+ self.fail('; '.join(problems), msg)
+
+ def assertJsonEqual(self, first, second, msg=None):
+ """Asserts that the JSON objects defined in two strings are equal.
+
+ A summary of the differences will be included in the failure message
+ using assertSameStructure.
+
+ Args:
+ first: A string containing JSON to decode and compare to second.
+ second: A string containing JSON to decode and compare to first.
+ msg: Additional text to include in the failure message.
+ """
+ try:
+ first_structured = json.loads(first)
+ except ValueError as e:
+ raise ValueError(self._formatMessage(
+ msg,
+ 'could not decode first JSON value %s: %s' % (first, e)))
+
+ try:
+ second_structured = json.loads(second)
+ except ValueError as e:
+ raise ValueError(self._formatMessage(
+ msg,
+ 'could not decode second JSON value %s: %s' % (second, e)))
+
+ self.assertSameStructure(first_structured, second_structured,
+ aname='first', bname='second', msg=msg)
+
+ def _getAssertEqualityFunc(self, first, second):
+ # type: (Any, Any) -> Callable[..., None]
+ try:
+ return super(TestCase, self)._getAssertEqualityFunc(first, second)
+ except AttributeError:
+ # This is a workaround if unittest.TestCase.__init__ was never run.
+ # It usually means that somebody created a subclass just for the
+ # assertions and has overridden __init__. "assertTrue" is a safe
+ # value that will not make __init__ raise a ValueError.
+ test_method = getattr(self, '_testMethodName', 'assertTrue')
+ super(TestCase, self).__init__(test_method)
+
+ return super(TestCase, self)._getAssertEqualityFunc(first, second)
+
+ def fail(self, msg=None, prefix=None):
+ """Fail immediately with the given message, optionally prefixed."""
+ return super(TestCase, self).fail(self._formatMessage(prefix, msg))
+
+
+def _sorted_list_difference(expected, actual):
+ # type: (List[_T], List[_T]) -> Tuple[List[_T], List[_T]]
+ """Finds elements in only one or the other of two, sorted input lists.
+
+ Returns a two-element tuple of lists. The first list contains those
+ elements in the "expected" list but not in the "actual" list, and the
+ second contains those elements in the "actual" list but not in the
+ "expected" list. Duplicate elements in either input list are ignored.
+
+ Args:
+ expected: The list we expected.
+ actual: The list we actually got.
+ Returns:
+ (missing, unexpected)
+ missing: items in expected that are not in actual.
+ unexpected: items in actual that are not in expected.
+ """
+ i = j = 0
+ missing = []
+ unexpected = []
+ while True:
+ try:
+ e = expected[i]
+ a = actual[j]
+ if e < a:
+ missing.append(e)
+ i += 1
+ while expected[i] == e:
+ i += 1
+ elif e > a:
+ unexpected.append(a)
+ j += 1
+ while actual[j] == a:
+ j += 1
+ else:
+ i += 1
+ try:
+ while expected[i] == e:
+ i += 1
+ finally:
+ j += 1
+ while actual[j] == a:
+ j += 1
+ except IndexError:
+ missing.extend(expected[i:])
+ unexpected.extend(actual[j:])
+ break
+ return missing, unexpected
+
+
+def _are_both_of_integer_type(a, b):
+ # type: (object, object) -> bool
+ return isinstance(a, int) and isinstance(b, int)
+
+
+def _are_both_of_sequence_type(a, b):
+ # type: (object, object) -> bool
+ return isinstance(a, abc.Sequence) and isinstance(
+ b, abc.Sequence) and not isinstance(
+ a, _TEXT_OR_BINARY_TYPES) and not isinstance(b, _TEXT_OR_BINARY_TYPES)
+
+
+def _are_both_of_set_type(a, b):
+ # type: (object, object) -> bool
+ return isinstance(a, abc.Set) and isinstance(b, abc.Set)
+
+
+def _are_both_of_mapping_type(a, b):
+ # type: (object, object) -> bool
+ return isinstance(a, abc.Mapping) and isinstance(
+ b, abc.Mapping)
+
+
+def _walk_structure_for_problems(a, b, aname, bname, problem_list):
+ """The recursive comparison behind assertSameStructure."""
+ if type(a) != type(b) and not ( # pylint: disable=unidiomatic-typecheck
+ _are_both_of_integer_type(a, b) or _are_both_of_sequence_type(a, b) or
+ _are_both_of_set_type(a, b) or _are_both_of_mapping_type(a, b)):
+ # We do not distinguish between int and long types as 99.99% of Python 2
+ # code should never care. They collapse into a single type in Python 3.
+ problem_list.append('%s is a %r but %s is a %r' %
+ (aname, type(a), bname, type(b)))
+ # If they have different types there's no point continuing
+ return
+
+ if isinstance(a, abc.Set):
+ for k in a:
+ if k not in b:
+ problem_list.append(
+ '%s has %r but %s does not' % (aname, k, bname))
+ for k in b:
+ if k not in a:
+ problem_list.append('%s lacks %r but %s has it' % (aname, k, bname))
+
+ # NOTE: a or b could be a defaultdict, so we must take care that the traversal
+ # doesn't modify the data.
+ elif isinstance(a, abc.Mapping):
+ for k in a:
+ if k in b:
+ _walk_structure_for_problems(
+ a[k], b[k], '%s[%r]' % (aname, k), '%s[%r]' % (bname, k),
+ problem_list)
+ else:
+ problem_list.append(
+ "%s has [%r] with value %r but it's missing in %s" %
+ (aname, k, a[k], bname))
+ for k in b:
+ if k not in a:
+ problem_list.append(
+ '%s lacks [%r] but %s has it with value %r' %
+ (aname, k, bname, b[k]))
+
+ # Strings/bytes are Sequences but we'll just do those with regular !=
+ elif (isinstance(a, abc.Sequence) and
+ not isinstance(a, _TEXT_OR_BINARY_TYPES)):
+ minlen = min(len(a), len(b))
+ for i in range(minlen):
+ _walk_structure_for_problems(
+ a[i], b[i], '%s[%d]' % (aname, i), '%s[%d]' % (bname, i),
+ problem_list)
+ for i in range(minlen, len(a)):
+ problem_list.append('%s has [%i] with value %r but %s does not' %
+ (aname, i, a[i], bname))
+ for i in range(minlen, len(b)):
+ problem_list.append('%s lacks [%i] but %s has it with value %r' %
+ (aname, i, bname, b[i]))
+
+ else:
+ if a != b:
+ problem_list.append('%s is %r but %s is %r' % (aname, a, bname, b))
+
+
+def get_command_string(command):
+ """Returns an escaped string that can be used as a shell command.
+
+ Args:
+ command: List or string representing the command to run.
+ Returns:
+ A string suitable for use as a shell command.
+ """
+ if isinstance(command, str):
+ return command
+ else:
+ if os.name == 'nt':
+ return ' '.join(command)
+ else:
+ # The following is identical to Python 3's shlex.quote function.
+ command_string = ''
+ for word in command:
+ # Single quote word, and replace each ' in word with '"'"'
+ command_string += "'" + word.replace("'", "'\"'\"'") + "' "
+ return command_string[:-1]
+
+
+def get_command_stderr(command, env=None, close_fds=True):
+ """Runs the given shell command and returns a tuple.
+
+ Args:
+ command: List or string representing the command to run.
+ env: Dictionary of environment variable settings. If None, no environment
+ variables will be set for the child process. This is to make tests
+ more hermetic. NOTE: this behavior is different than the standard
+ subprocess module.
+ close_fds: Whether or not to close all open fd's in the child after forking.
+ On Windows, this is ignored and close_fds is always False.
+
+ Returns:
+ Tuple of (exit status, text printed to stdout and stderr by the command).
+ """
+ if env is None: env = {}
+ if os.name == 'nt':
+ # Windows does not support setting close_fds to True while also redirecting
+ # standard handles.
+ close_fds = False
+
+ use_shell = isinstance(command, str)
+ process = subprocess.Popen(
+ command,
+ close_fds=close_fds,
+ env=env,
+ shell=use_shell,
+ stderr=subprocess.STDOUT,
+ stdout=subprocess.PIPE)
+ output = process.communicate()[0]
+ exit_status = process.wait()
+ return (exit_status, output)
+
+
+def _quote_long_string(s):
+ # type: (Union[Text, bytes, bytearray]) -> Text
+ """Quotes a potentially multi-line string to make the start and end obvious.
+
+ Args:
+ s: A string.
+
+ Returns:
+ The quoted string.
+ """
+ if isinstance(s, (bytes, bytearray)):
+ try:
+ s = s.decode('utf-8')
+ except UnicodeDecodeError:
+ s = str(s)
+ return ('8<-----------\n' +
+ s + '\n' +
+ '----------->8\n')
+
+
+def print_python_version():
+ # type: () -> None
+ # Having this in the test output logs by default helps debugging when all
+ # you've got is the log and no other idea of which Python was used.
+ sys.stderr.write('Running tests under Python {0[0]}.{0[1]}.{0[2]}: '
+ '{1}\n'.format(
+ sys.version_info,
+ sys.executable if sys.executable else 'embedded.'))
+
+
+def main(*args, **kwargs):
+ # type: (Text, Any) -> None
+ """Executes a set of Python unit tests.
+
+ Usually this function is called without arguments, so the
+ unittest.TestProgram instance will get created with the default settings,
+ so it will run all test methods of all TestCase classes in the __main__
+ module.
+
+ Args:
+ *args: Positional arguments passed through to unittest.TestProgram.__init__.
+ **kwargs: Keyword arguments passed through to unittest.TestProgram.__init__.
+ """
+ print_python_version()
+ _run_in_app(run_tests, args, kwargs)
+
+
+def _is_in_app_main():
+ # type: () -> bool
+ """Returns True iff app.run is active."""
+ f = sys._getframe().f_back # pylint: disable=protected-access
+ while f:
+ if f.f_code == app.run.__code__:
+ return True
+ f = f.f_back
+ return False
+
+
+class _SavedFlag(object):
+ """Helper class for saving and restoring a flag value."""
+
+ def __init__(self, flag):
+ self.flag = flag
+ self.value = flag.value
+ self.present = flag.present
+
+ def restore_flag(self):
+ self.flag.value = self.value
+ self.flag.present = self.present
+
+
+def _register_sigterm_with_faulthandler():
+ # type: () -> None
+ """Have faulthandler dump stacks on SIGTERM. Useful to diagnose timeouts."""
+ if faulthandler and getattr(faulthandler, 'register', None):
+ # faulthandler.register is not available on Windows.
+ # faulthandler.enable() is already called by app.run.
+ try:
+ faulthandler.register(signal.SIGTERM, chain=True) # pytype: disable=module-attr
+ except Exception as e: # pylint: disable=broad-except
+ sys.stderr.write('faulthandler.register(SIGTERM) failed '
+ '%r; ignoring.\n' % e)
+
+
+def _run_in_app(function, args, kwargs):
+ # type: (Callable[..., None], Sequence[Text], Mapping[Text, Any]) -> None
+ """Executes a set of Python unit tests, ensuring app.run.
+
+ This is a private function, users should call absltest.main().
+
+ _run_in_app calculates argv to be the command-line arguments of this program
+ (without the flags), sets the default of FLAGS.alsologtostderr to True,
+ then it calls function(argv, args, kwargs), making sure that `function'
+ will get called within app.run(). _run_in_app does this by checking whether
+ it is called by app.run(), or by calling app.run() explicitly.
+
+ The reason why app.run has to be ensured is to make sure that
+ flags are parsed and stripped properly, and other initializations done by
+ the app module are also carried out, no matter if absltest.run() is called
+ from within or outside app.run().
+
+ If _run_in_app is called from within app.run(), then it will reparse
+ sys.argv and pass the result without command-line flags into the argv
+ argument of `function'. The reason why this parsing is needed is that
+ __main__.main() calls absltest.main() without passing its argv. So the
+ only way _run_in_app could get to know the argv without the flags is that
+ it reparses sys.argv.
+
+ _run_in_app changes the default of FLAGS.alsologtostderr to True so that the
+ test program's stderr will contain all the log messages unless otherwise
+ specified on the command-line. This overrides any explicit assignment to
+ FLAGS.alsologtostderr by the test program prior to the call to _run_in_app()
+ (e.g. in __main__.main).
+
+ Please note that _run_in_app (and the function it calls) is allowed to make
+ changes to kwargs.
+
+ Args:
+ function: absltest.run_tests or a similar function. It will be called as
+ function(argv, args, kwargs) where argv is a list containing the
+ elements of sys.argv without the command-line flags.
+ args: Positional arguments passed through to unittest.TestProgram.__init__.
+ kwargs: Keyword arguments passed through to unittest.TestProgram.__init__.
+ """
+ if _is_in_app_main():
+ _register_sigterm_with_faulthandler()
+
+ # Save command-line flags so the side effects of FLAGS(sys.argv) can be
+ # undone.
+ flag_objects = (FLAGS[name] for name in FLAGS)
+ saved_flags = dict((f.name, _SavedFlag(f)) for f in flag_objects)
+
+ # Change the default of alsologtostderr from False to True, so the test
+ # programs's stderr will contain all the log messages.
+ # If --alsologtostderr=false is specified in the command-line, or user
+ # has called FLAGS.alsologtostderr = False before, then the value is kept
+ # False.
+ FLAGS.set_default('alsologtostderr', True)
+ # Remove it from saved flags so it doesn't get restored later.
+ del saved_flags['alsologtostderr']
+
+ # The call FLAGS(sys.argv) parses sys.argv, returns the arguments
+ # without the flags, and -- as a side effect -- modifies flag values in
+ # FLAGS. We don't want the side effect, because we don't want to
+ # override flag changes the program did (e.g. in __main__.main)
+ # after the command-line has been parsed. So we have the for loop below
+ # to change back flags to their old values.
+ argv = FLAGS(sys.argv)
+ for saved_flag in saved_flags.values():
+ saved_flag.restore_flag()
+
+ function(argv, args, kwargs)
+ else:
+ # Send logging to stderr. Use --alsologtostderr instead of --logtostderr
+ # in case tests are reading their own logs.
+ FLAGS.set_default('alsologtostderr', True)
+
+ def main_function(argv):
+ _register_sigterm_with_faulthandler()
+ function(argv, args, kwargs)
+
+ app.run(main=main_function)
+
+
+def _is_suspicious_attribute(testCaseClass, name):
+ # type: (Type, Text) -> bool
+ """Returns True if an attribute is a method named like a test method."""
+ if name.startswith('Test') and len(name) > 4 and name[4].isupper():
+ attr = getattr(testCaseClass, name)
+ if inspect.isfunction(attr) or inspect.ismethod(attr):
+ args = inspect.getfullargspec(attr)
+ return (len(args.args) == 1 and args.args[0] == 'self' and
+ args.varargs is None and args.varkw is None and
+ not args.kwonlyargs)
+ return False
+
+
+def skipThisClass(reason):
+ # type: (Text) -> Callable[[_T], _T]
+ """Skip tests in the decorated TestCase, but not any of its subclasses.
+
+ This decorator indicates that this class should skip all its tests, but not
+ any of its subclasses. Useful for if you want to share testMethod or setUp
+ implementations between a number of concrete testcase classes.
+
+ Example usage, showing how you can share some common test methods between
+ subclasses. In this example, only 'BaseTest' will be marked as skipped, and
+ not RealTest or SecondRealTest:
+
+ @absltest.skipThisClass("Shared functionality")
+ class BaseTest(absltest.TestCase):
+ def test_simple_functionality(self):
+ self.assertEqual(self.system_under_test.method(), 1)
+
+ class RealTest(BaseTest):
+ def setUp(self):
+ super().setUp()
+ self.system_under_test = MakeSystem(argument)
+
+ def test_specific_behavior(self):
+ ...
+
+ class SecondRealTest(BaseTest):
+ def setUp(self):
+ super().setUp()
+ self.system_under_test = MakeSystem(other_arguments)
+
+ def test_other_behavior(self):
+ ...
+
+ Args:
+ reason: The reason we have a skip in place. For instance: 'shared test
+ methods' or 'shared assertion methods'.
+
+ Returns:
+ Decorator function that will cause a class to be skipped.
+ """
+ if isinstance(reason, type):
+ raise TypeError('Got {!r}, expected reason as string'.format(reason))
+
+ def _skip_class(test_case_class):
+ if not issubclass(test_case_class, unittest.TestCase):
+ raise TypeError(
+ 'Decorating {!r}, expected TestCase subclass'.format(test_case_class))
+
+ # Only shadow the setUpClass method if it is directly defined. If it is
+ # in the parent class we invoke it via a super() call instead of holding
+ # a reference to it.
+ shadowed_setupclass = test_case_class.__dict__.get('setUpClass', None)
+
+ @classmethod
+ def replacement_setupclass(cls, *args, **kwargs):
+ # Skip this class if it is the one that was decorated with @skipThisClass
+ if cls is test_case_class:
+ raise SkipTest(reason)
+ if shadowed_setupclass:
+ # Pass along `cls` so the MRO chain doesn't break.
+ # The original method is a `classmethod` descriptor, which can't
+ # be directly called, but `__func__` has the underlying function.
+ return shadowed_setupclass.__func__(cls, *args, **kwargs)
+ else:
+ # Because there's no setUpClass() defined directly on test_case_class,
+ # we call super() ourselves to continue execution of the inheritance
+ # chain.
+ return super(test_case_class, cls).setUpClass(*args, **kwargs)
+
+ test_case_class.setUpClass = replacement_setupclass
+ return test_case_class
+
+ return _skip_class
+
+
+class TestLoader(unittest.TestLoader):
+ """A test loader which supports common test features.
+
+ Supported features include:
+ * Banning untested methods with test-like names: methods attached to this
+ testCase with names starting with `Test` are ignored by the test runner,
+ and often represent mistakenly-omitted test cases. This loader will raise
+ a TypeError when attempting to load a TestCase with such methods.
+ * Randomization of test case execution order (optional).
+ """
+
+ _ERROR_MSG = textwrap.dedent("""Method '%s' is named like a test case but
+ is not one. This is often a bug. If you want it to be a test method,
+ name it with 'test' in lowercase. If not, rename the method to not begin
+ with 'Test'.""")
+
+ def __init__(self, *args, **kwds):
+ super(TestLoader, self).__init__(*args, **kwds)
+ seed = _get_default_randomize_ordering_seed()
+ if seed:
+ self._randomize_ordering_seed = seed
+ self._random = random.Random(self._randomize_ordering_seed)
+ else:
+ self._randomize_ordering_seed = None
+ self._random = None
+
+ def getTestCaseNames(self, testCaseClass): # pylint:disable=invalid-name
+ """Validates and returns a (possibly randomized) list of test case names."""
+ for name in dir(testCaseClass):
+ if _is_suspicious_attribute(testCaseClass, name):
+ raise TypeError(TestLoader._ERROR_MSG % name)
+ names = super(TestLoader, self).getTestCaseNames(testCaseClass)
+ if self._randomize_ordering_seed is not None:
+ logging.info(
+ 'Randomizing test order with seed: %d', self._randomize_ordering_seed)
+ logging.info(
+ 'To reproduce this order, re-run with '
+ '--test_randomize_ordering_seed=%d', self._randomize_ordering_seed)
+ self._random.shuffle(names)
+ return names
+
+
+def get_default_xml_output_filename():
+ # type: () -> Optional[Text]
+ if os.environ.get('XML_OUTPUT_FILE'):
+ return os.environ['XML_OUTPUT_FILE']
+ elif os.environ.get('RUNNING_UNDER_TEST_DAEMON'):
+ return os.path.join(os.path.dirname(TEST_TMPDIR.value), 'test_detail.xml')
+ elif os.environ.get('TEST_XMLOUTPUTDIR'):
+ return os.path.join(
+ os.environ['TEST_XMLOUTPUTDIR'],
+ os.path.splitext(os.path.basename(sys.argv[0]))[0] + '.xml')
+
+
+def _setup_filtering(argv):
+ # type: (MutableSequence[Text]) -> None
+ """Implements the bazel test filtering protocol.
+
+ The following environment variable is used in this method:
+
+ TESTBRIDGE_TEST_ONLY: string, if set, is forwarded to the unittest
+ framework to use as a test filter. Its value is split with shlex, then:
+ 1. On Python 3.6 and before, split values are passed as positional
+ arguments on argv.
+ 2. On Python 3.7+, split values are passed to unittest's `-k` flag. Tests
+ are matched by glob patterns or substring. See
+ https://docs.python.org/3/library/unittest.html#cmdoption-unittest-k
+
+ Args:
+ argv: the argv to mutate in-place.
+ """
+ test_filter = os.environ.get('TESTBRIDGE_TEST_ONLY')
+ if argv is None or not test_filter:
+ return
+
+ filters = shlex.split(test_filter)
+ if sys.version_info[:2] >= (3, 7):
+ filters = ['-k=' + test_filter for test_filter in filters]
+
+ argv[1:1] = filters
+
+
+def _setup_test_runner_fail_fast(argv):
+ # type: (MutableSequence[Text]) -> None
+ """Implements the bazel test fail fast protocol.
+
+ The following environment variable is used in this method:
+
+ TESTBRIDGE_TEST_RUNNER_FAIL_FAST=<1|0>
+
+ If set to 1, --failfast is passed to the unittest framework to return upon
+ first failure.
+
+ Args:
+ argv: the argv to mutate in-place.
+ """
+
+ if argv is None:
+ return
+
+ if os.environ.get('TESTBRIDGE_TEST_RUNNER_FAIL_FAST') != '1':
+ return
+
+ argv[1:1] = ['--failfast']
+
+
+def _setup_sharding(custom_loader=None):
+ # type: (Optional[unittest.TestLoader]) -> unittest.TestLoader
+ """Implements the bazel sharding protocol.
+
+ The following environment variables are used in this method:
+
+ TEST_SHARD_STATUS_FILE: string, if set, points to a file. We write a blank
+ file to tell the test runner that this test implements the test sharding
+ protocol.
+
+ TEST_TOTAL_SHARDS: int, if set, sharding is requested.
+
+ TEST_SHARD_INDEX: int, must be set if TEST_TOTAL_SHARDS is set. Specifies
+ the shard index for this instance of the test process. Must satisfy:
+ 0 <= TEST_SHARD_INDEX < TEST_TOTAL_SHARDS.
+
+ Args:
+ custom_loader: A TestLoader to be made sharded.
+
+ Returns:
+ The test loader for shard-filtering or the standard test loader, depending
+ on the sharding environment variables.
+ """
+
+ # It may be useful to write the shard file even if the other sharding
+ # environment variables are not set. Test runners may use this functionality
+ # to query whether a test binary implements the test sharding protocol.
+ if 'TEST_SHARD_STATUS_FILE' in os.environ:
+ try:
+ with open(os.environ['TEST_SHARD_STATUS_FILE'], 'w') as f:
+ f.write('')
+ except IOError:
+ sys.stderr.write('Error opening TEST_SHARD_STATUS_FILE (%s). Exiting.'
+ % os.environ['TEST_SHARD_STATUS_FILE'])
+ sys.exit(1)
+
+ base_loader = custom_loader or TestLoader()
+ if 'TEST_TOTAL_SHARDS' not in os.environ:
+ # Not using sharding, use the expected test loader.
+ return base_loader
+
+ total_shards = int(os.environ['TEST_TOTAL_SHARDS'])
+ shard_index = int(os.environ['TEST_SHARD_INDEX'])
+
+ if shard_index < 0 or shard_index >= total_shards:
+ sys.stderr.write('ERROR: Bad sharding values. index=%d, total=%d\n' %
+ (shard_index, total_shards))
+ sys.exit(1)
+
+ # Replace the original getTestCaseNames with one that returns
+ # the test case names for this shard.
+ delegate_get_names = base_loader.getTestCaseNames
+
+ bucket_iterator = itertools.cycle(range(total_shards))
+
+ def getShardedTestCaseNames(testCaseClass):
+ filtered_names = []
+ # We need to sort the list of tests in order to determine which tests this
+ # shard is responsible for; however, it's important to preserve the order
+ # returned by the base loader, e.g. in the case of randomized test ordering.
+ ordered_names = delegate_get_names(testCaseClass)
+ for testcase in sorted(ordered_names):
+ bucket = next(bucket_iterator)
+ if bucket == shard_index:
+ filtered_names.append(testcase)
+ return [x for x in ordered_names if x in filtered_names]
+
+ base_loader.getTestCaseNames = getShardedTestCaseNames
+ return base_loader
+
+
+# pylint: disable=line-too-long
+def _run_and_get_tests_result(argv, args, kwargs, xml_test_runner_class):
+ # type: (MutableSequence[Text], Sequence[Any], MutableMapping[Text, Any], Type) -> unittest.TestResult
+ # pylint: enable=line-too-long
+ """Same as run_tests, except it returns the result instead of exiting."""
+
+ # The entry from kwargs overrides argv.
+ argv = kwargs.pop('argv', argv)
+
+ # Set up test filtering if requested in environment.
+ _setup_filtering(argv)
+ # Set up --failfast as requested in environment
+ _setup_test_runner_fail_fast(argv)
+
+ # Shard the (default or custom) loader if sharding is turned on.
+ kwargs['testLoader'] = _setup_sharding(kwargs.get('testLoader', None))
+
+ # XML file name is based upon (sorted by priority):
+ # --xml_output_file flag, XML_OUTPUT_FILE variable,
+ # TEST_XMLOUTPUTDIR variable or RUNNING_UNDER_TEST_DAEMON variable.
+ if not FLAGS.xml_output_file:
+ FLAGS.xml_output_file = get_default_xml_output_filename()
+ xml_output_file = FLAGS.xml_output_file
+
+ xml_buffer = None
+ if xml_output_file:
+ xml_output_dir = os.path.dirname(xml_output_file)
+ if xml_output_dir and not os.path.isdir(xml_output_dir):
+ try:
+ os.makedirs(xml_output_dir)
+ except OSError as e:
+ # File exists error can occur with concurrent tests
+ if e.errno != errno.EEXIST:
+ raise
+ # Fail early if we can't write to the XML output file. This is so that we
+ # don't waste people's time running tests that will just fail anyways.
+ with _open(xml_output_file, 'w'):
+ pass
+
+ # We can reuse testRunner if it supports XML output (e. g. by inheriting
+ # from xml_reporter.TextAndXMLTestRunner). Otherwise we need to use
+ # xml_reporter.TextAndXMLTestRunner.
+ if (kwargs.get('testRunner') is not None
+ and not hasattr(kwargs['testRunner'], 'set_default_xml_stream')):
+ sys.stderr.write('WARNING: XML_OUTPUT_FILE or --xml_output_file setting '
+ 'overrides testRunner=%r setting (possibly from --pdb)'
+ % (kwargs['testRunner']))
+ # Passing a class object here allows TestProgram to initialize
+ # instances based on its kwargs and/or parsed command-line args.
+ kwargs['testRunner'] = xml_test_runner_class
+ if kwargs.get('testRunner') is None:
+ kwargs['testRunner'] = xml_test_runner_class
+ # Use an in-memory buffer (not backed by the actual file) to store the XML
+ # report, because some tools modify the file (e.g., create a placeholder
+ # with partial information, in case the test process crashes).
+ xml_buffer = io.StringIO()
+ kwargs['testRunner'].set_default_xml_stream(xml_buffer) # pytype: disable=attribute-error
+
+ # If we've used a seed to randomize test case ordering, we want to record it
+ # as a top-level attribute in the `testsuites` section of the XML output.
+ randomize_ordering_seed = getattr(
+ kwargs['testLoader'], '_randomize_ordering_seed', None)
+ setter = getattr(kwargs['testRunner'], 'set_testsuites_property', None)
+ if randomize_ordering_seed and setter:
+ setter('test_randomize_ordering_seed', randomize_ordering_seed)
+ elif kwargs.get('testRunner') is None:
+ kwargs['testRunner'] = _pretty_print_reporter.TextTestRunner
+
+ if FLAGS.pdb_post_mortem:
+ runner = kwargs['testRunner']
+ # testRunner can be a class or an instance, which must be tested for
+ # differently.
+ # Overriding testRunner isn't uncommon, so only enable the debugging
+ # integration if the runner claims it does; we don't want to accidentally
+ # clobber something on the runner.
+ if ((isinstance(runner, type) and
+ issubclass(runner, _pretty_print_reporter.TextTestRunner)) or
+ isinstance(runner, _pretty_print_reporter.TextTestRunner)):
+ runner.run_for_debugging = True
+
+ # Make sure tmpdir exists.
+ if not os.path.isdir(TEST_TMPDIR.value):
+ try:
+ os.makedirs(TEST_TMPDIR.value)
+ except OSError as e:
+ # Concurrent test might have created the directory.
+ if e.errno != errno.EEXIST:
+ raise
+
+ # Let unittest.TestProgram.__init__ do its own argv parsing, e.g. for '-v',
+ # on argv, which is sys.argv without the command-line flags.
+ kwargs['argv'] = argv
+
+ try:
+ test_program = unittest.TestProgram(*args, **kwargs)
+ return test_program.result
+ finally:
+ if xml_buffer:
+ try:
+ with _open(xml_output_file, 'w') as f:
+ f.write(xml_buffer.getvalue())
+ finally:
+ xml_buffer.close()
+
+
+def run_tests(argv, args, kwargs): # pylint: disable=line-too-long
+ # type: (MutableSequence[Text], Sequence[Any], MutableMapping[Text, Any]) -> None
+ # pylint: enable=line-too-long
+ """Executes a set of Python unit tests.
+
+ Most users should call absltest.main() instead of run_tests.
+
+ Please note that run_tests should be called from app.run.
+ Calling absltest.main() would ensure that.
+
+ Please note that run_tests is allowed to make changes to kwargs.
+
+ Args:
+ argv: sys.argv with the command-line flags removed from the front, i.e. the
+ argv with which app.run() has called __main__.main. It is passed to
+ unittest.TestProgram.__init__(argv=), which does its own flag parsing. It
+ is ignored if kwargs contains an argv entry.
+ args: Positional arguments passed through to unittest.TestProgram.__init__.
+ kwargs: Keyword arguments passed through to unittest.TestProgram.__init__.
+ """
+ result = _run_and_get_tests_result(
+ argv, args, kwargs, xml_reporter.TextAndXMLTestRunner)
+ sys.exit(not result.wasSuccessful())
+
+
+def _rmtree_ignore_errors(path):
+ # type: (Text) -> None
+ if os.path.isfile(path):
+ try:
+ os.unlink(path)
+ except OSError:
+ pass
+ else:
+ shutil.rmtree(path, ignore_errors=True)
+
+
+def _get_first_part(path):
+ # type: (Text) -> Text
+ parts = path.split(os.sep, 1)
+ return parts[0]
diff --git a/absl/testing/flagsaver.py b/absl/testing/flagsaver.py
new file mode 100644
index 0000000..7fe95fe
--- /dev/null
+++ b/absl/testing/flagsaver.py
@@ -0,0 +1,198 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Decorator and context manager for saving and restoring flag values.
+
+There are many ways to save and restore. Always use the most convenient method
+for a given use case.
+
+Here are examples of each method. They all call do_stuff() while FLAGS.someflag
+is temporarily set to 'foo'.
+
+ from absl.testing import flagsaver
+
+ # Use a decorator which can optionally override flags via arguments.
+ @flagsaver.flagsaver(someflag='foo')
+ def some_func():
+ do_stuff()
+
+ # Use a decorator which can optionally override flags with flagholders.
+ @flagsaver.flagsaver((module.FOO_FLAG, 'foo'), (other_mod.BAR_FLAG, 23))
+ def some_func():
+ do_stuff()
+
+ # Use a decorator which does not override flags itself.
+ @flagsaver.flagsaver
+ def some_func():
+ FLAGS.someflag = 'foo'
+ do_stuff()
+
+ # Use a context manager which can optionally override flags via arguments.
+ with flagsaver.flagsaver(someflag='foo'):
+ do_stuff()
+
+ # Save and restore the flag values yourself.
+ saved_flag_values = flagsaver.save_flag_values()
+ try:
+ FLAGS.someflag = 'foo'
+ do_stuff()
+ finally:
+ flagsaver.restore_flag_values(saved_flag_values)
+
+We save and restore a shallow copy of each Flag object's __dict__ attribute.
+This preserves all attributes of the flag, such as whether or not it was
+overridden from its default value.
+
+WARNING: Currently a flag that is saved and then deleted cannot be restored. An
+exception will be raised. However if you *add* a flag after saving flag values,
+and then restore flag values, the added flag will be deleted with no errors.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import inspect
+
+from absl import flags
+
+FLAGS = flags.FLAGS
+
+
+def flagsaver(*args, **kwargs):
+ """The main flagsaver interface. See module doc for usage."""
+ if not args:
+ return _FlagOverrider(**kwargs)
+ # args can be [func] if used as `@flagsaver` instead of `@flagsaver(...)`
+ if len(args) == 1 and callable(args[0]):
+ if kwargs:
+ raise ValueError(
+ "It's invalid to specify both positional and keyword parameters.")
+ func = args[0]
+ if inspect.isclass(func):
+ raise TypeError('@flagsaver.flagsaver cannot be applied to a class.')
+ return _wrap(func, {})
+ # args can be a list of (FlagHolder, value) pairs.
+ # In which case they augment any specified kwargs.
+ for arg in args:
+ if not isinstance(arg, tuple) or len(arg) != 2:
+ raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,))
+ holder, value = arg
+ if not isinstance(holder, flags.FlagHolder):
+ raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,))
+ if holder.name in kwargs:
+ raise ValueError('Cannot set --%s multiple times' % holder.name)
+ kwargs[holder.name] = value
+ return _FlagOverrider(**kwargs)
+
+
+def save_flag_values(flag_values=FLAGS):
+ """Returns copy of flag values as a dict.
+
+ Args:
+ flag_values: FlagValues, the FlagValues instance with which the flag will
+ be saved. This should almost never need to be overridden.
+ Returns:
+ Dictionary mapping keys to values. Keys are flag names, values are
+ corresponding __dict__ members. E.g. {'key': value_dict, ...}.
+ """
+ return {name: _copy_flag_dict(flag_values[name]) for name in flag_values}
+
+
+def restore_flag_values(saved_flag_values, flag_values=FLAGS):
+ """Restores flag values based on the dictionary of flag values.
+
+ Args:
+ saved_flag_values: {'flag_name': value_dict, ...}
+ flag_values: FlagValues, the FlagValues instance from which the flag will
+ be restored. This should almost never need to be overridden.
+ """
+ new_flag_names = list(flag_values)
+ for name in new_flag_names:
+ saved = saved_flag_values.get(name)
+ if saved is None:
+ # If __dict__ was not saved delete "new" flag.
+ delattr(flag_values, name)
+ else:
+ if flag_values[name].value != saved['_value']:
+ flag_values[name].value = saved['_value'] # Ensure C++ value is set.
+ flag_values[name].__dict__ = saved
+
+
+def _wrap(func, overrides):
+ """Creates a wrapper function that saves/restores flag values.
+
+ Args:
+ func: function object - This will be called between saving flags and
+ restoring flags.
+ overrides: {str: object} - Flag names mapped to their values. These flags
+ will be set after saving the original flag state.
+
+ Returns:
+ return value from func()
+ """
+ @functools.wraps(func)
+ def _flagsaver_wrapper(*args, **kwargs):
+ """Wrapper function that saves and restores flags."""
+ with _FlagOverrider(**overrides):
+ return func(*args, **kwargs)
+ return _flagsaver_wrapper
+
+
+class _FlagOverrider(object):
+ """Overrides flags for the duration of the decorated function call.
+
+ It also restores all original values of flags after decorated method
+ completes.
+ """
+
+ def __init__(self, **overrides):
+ self._overrides = overrides
+ self._saved_flag_values = None
+
+ def __call__(self, func):
+ if inspect.isclass(func):
+ raise TypeError('flagsaver cannot be applied to a class.')
+ return _wrap(func, self._overrides)
+
+ def __enter__(self):
+ self._saved_flag_values = save_flag_values(FLAGS)
+ try:
+ FLAGS._set_attributes(**self._overrides)
+ except:
+ # It may fail because of flag validators.
+ restore_flag_values(self._saved_flag_values, FLAGS)
+ raise
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ restore_flag_values(self._saved_flag_values, FLAGS)
+
+
+def _copy_flag_dict(flag):
+ """Returns a copy of the flag object's __dict__.
+
+ It's mostly a shallow copy of the __dict__, except it also does a shallow
+ copy of the validator list.
+
+ Args:
+ flag: flags.Flag, the flag to copy.
+
+ Returns:
+ A copy of the flag object's __dict__.
+ """
+ copy = flag.__dict__.copy()
+ copy['_value'] = flag.value # Ensure correct restore for C++ flags.
+ copy['validators'] = list(flag.validators)
+ return copy
diff --git a/absl/testing/parameterized.py b/absl/testing/parameterized.py
new file mode 100644
index 0000000..ec6a529
--- /dev/null
+++ b/absl/testing/parameterized.py
@@ -0,0 +1,700 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Adds support for parameterized tests to Python's unittest TestCase class.
+
+A parameterized test is a method in a test case that is invoked with different
+argument tuples.
+
+A simple example:
+
+ class AdditionExample(parameterized.TestCase):
+ @parameterized.parameters(
+ (1, 2, 3),
+ (4, 5, 9),
+ (1, 1, 3))
+ def testAddition(self, op1, op2, result):
+ self.assertEqual(result, op1 + op2)
+
+
+Each invocation is a separate test case and properly isolated just
+like a normal test method, with its own setUp/tearDown cycle. In the
+example above, there are three separate testcases, one of which will
+fail due to an assertion error (1 + 1 != 3).
+
+Parameters for individual test cases can be tuples (with positional parameters)
+or dictionaries (with named parameters):
+
+ class AdditionExample(parameterized.TestCase):
+ @parameterized.parameters(
+ {'op1': 1, 'op2': 2, 'result': 3},
+ {'op1': 4, 'op2': 5, 'result': 9},
+ )
+ def testAddition(self, op1, op2, result):
+ self.assertEqual(result, op1 + op2)
+
+If a parameterized test fails, the error message will show the
+original test name and the parameters for that test.
+
+The id method of the test, used internally by the unittest framework, is also
+modified to show the arguments (but note that the name reported by `id()`
+doesn't match the actual test name, see below). To make sure that test names
+stay the same across several invocations, object representations like
+
+ >>> class Foo(object):
+ ... pass
+ >>> repr(Foo())
+ '<__main__.Foo object at 0x23d8610>'
+
+are turned into '<__main__.Foo>'. When selecting a subset of test cases to run
+on the command-line, the test cases contain an index suffix for each argument
+in the order they were passed to `parameters()` (eg. testAddition0,
+testAddition1, etc.) This naming scheme is subject to change; for more reliable
+and stable names, especially in test logs, use `named_parameters()` instead.
+
+Tests using `named_parameters()` are similar to `parameters()`, except only
+tuples or dicts of args are supported. For tuples, the first parameter arg
+has to be a string (or an object that returns an apt name when converted via
+str()). For dicts, a value for the key 'testcase_name' must be present and must
+be a string (or an object that returns an apt name when converted via str()):
+
+ class NamedExample(parameterized.TestCase):
+ @parameterized.named_parameters(
+ ('Normal', 'aa', 'aaa', True),
+ ('EmptyPrefix', '', 'abc', True),
+ ('BothEmpty', '', '', True))
+ def testStartsWith(self, prefix, string, result):
+ self.assertEqual(result, string.startswith(prefix))
+
+ class NamedExample(parameterized.TestCase):
+ @parameterized.named_parameters(
+ {'testcase_name': 'Normal',
+ 'result': True, 'string': 'aaa', 'prefix': 'aa'},
+ {'testcase_name': 'EmptyPrefix',
+ 'result': True, 'string': 'abc', 'prefix': ''},
+ {'testcase_name': 'BothEmpty',
+ 'result': True, 'string': '', 'prefix': ''})
+ def testStartsWith(self, prefix, string, result):
+ self.assertEqual(result, string.startswith(prefix))
+
+Named tests also have the benefit that they can be run individually
+from the command line:
+
+ $ testmodule.py NamedExample.testStartsWithNormal
+ .
+ --------------------------------------------------------------------
+ Ran 1 test in 0.000s
+
+ OK
+
+Parameterized Classes
+=====================
+If invocation arguments are shared across test methods in a single
+TestCase class, instead of decorating all test methods
+individually, the class itself can be decorated:
+
+ @parameterized.parameters(
+ (1, 2, 3),
+ (4, 5, 9))
+ class ArithmeticTest(parameterized.TestCase):
+ def testAdd(self, arg1, arg2, result):
+ self.assertEqual(arg1 + arg2, result)
+
+ def testSubtract(self, arg1, arg2, result):
+ self.assertEqual(result - arg1, arg2)
+
+Inputs from Iterables
+=====================
+If parameters should be shared across several test cases, or are dynamically
+created from other sources, a single non-tuple iterable can be passed into
+the decorator. This iterable will be used to obtain the test cases:
+
+ class AdditionExample(parameterized.TestCase):
+ @parameterized.parameters(
+ c.op1, c.op2, c.result for c in testcases
+ )
+ def testAddition(self, op1, op2, result):
+ self.assertEqual(result, op1 + op2)
+
+
+Single-Argument Test Methods
+============================
+If a test method takes only one argument, the single arguments must not be
+wrapped into a tuple:
+
+ class NegativeNumberExample(parameterized.TestCase):
+ @parameterized.parameters(
+ -1, -3, -4, -5
+ )
+ def testIsNegative(self, arg):
+ self.assertTrue(IsNegative(arg))
+
+
+List/tuple as a Single Argument
+===============================
+If a test method takes a single argument of a list/tuple, it must be wrapped
+inside a tuple:
+
+ class ZeroSumExample(parameterized.TestCase):
+ @parameterized.parameters(
+ ([-1, 0, 1], ),
+ ([-2, 0, 2], ),
+ )
+ def testSumIsZero(self, arg):
+ self.assertEqual(0, sum(arg))
+
+
+Cartesian product of Parameter Values as Parametrized Test Cases
+======================================================
+If required to test method over a cartesian product of parameters,
+`parameterized.product` may be used to facilitate generation of parameters
+test combinations:
+
+ class TestModuloExample(parameterized.TestCase):
+ @parameterized.product(
+ num=[0, 20, 80],
+ modulo=[2, 4],
+ expected=[0]
+ )
+ def testModuloResult(self, num, modulo, expected):
+ self.assertEqual(expected, num % modulo)
+
+This results in 6 test cases being created - one for each combination of the
+parameters. It is also possible to supply sequences of keyword argument dicts
+as elements of the cartesian product:
+
+ @parameterized.product(
+ (dict(num=5, modulo=3, expected=2),
+ dict(num=7, modulo=4, expected=3)),
+ dtype=(int, float)
+ )
+ def testModuloResult(self, num, modulo, expected, dtype):
+ self.assertEqual(expected, dtype(num) % modulo)
+
+This results in 4 test cases being created - for each of the two sets of test
+data (supplied as kwarg dicts) and for each of the two data types (supplied as
+a named parameter). Multiple keyword argument dicts may be supplied if required.
+
+Async Support
+===============================
+If a test needs to call async functions, it can inherit from both
+parameterized.TestCase and another TestCase that supports async calls, such
+as [asynctest](https://github.com/Martiusweb/asynctest):
+
+ import asynctest
+
+ class AsyncExample(parameterized.TestCase, asynctest.TestCase):
+ @parameterized.parameters(
+ ('a', 1),
+ ('b', 2),
+ )
+ async def testSomeAsyncFunction(self, arg, expected):
+ actual = await someAsyncFunction(arg)
+ self.assertEqual(actual, expected)
+"""
+
+from collections import abc
+import functools
+import inspect
+import itertools
+import re
+import types
+import unittest
+
+from absl.testing import absltest
+
+
+_ADDR_RE = re.compile(r'\<([a-zA-Z0-9_\-\.]+) object at 0x[a-fA-F0-9]+\>')
+_NAMED = object()
+_ARGUMENT_REPR = object()
+_NAMED_DICT_KEY = 'testcase_name'
+
+
+class NoTestsError(Exception):
+ """Raised when parameterized decorators do not generate any tests."""
+
+
+class DuplicateTestNameError(Exception):
+ """Raised when a parameterized test has the same test name multiple times."""
+
+ def __init__(self, test_class_name, new_test_name, original_test_name):
+ super(DuplicateTestNameError, self).__init__(
+ 'Duplicate parameterized test name in {}: generated test name {!r} '
+ '(generated from {!r}) already exists. Consider using '
+ 'named_parameters() to give your tests unique names and/or renaming '
+ 'the conflicting test method.'.format(
+ test_class_name, new_test_name, original_test_name))
+
+
+def _clean_repr(obj):
+ return _ADDR_RE.sub(r'<\1>', repr(obj))
+
+
+def _non_string_or_bytes_iterable(obj):
+ return (isinstance(obj, abc.Iterable) and not isinstance(obj, str) and
+ not isinstance(obj, bytes))
+
+
+def _format_parameter_list(testcase_params):
+ if isinstance(testcase_params, abc.Mapping):
+ return ', '.join('%s=%s' % (argname, _clean_repr(value))
+ for argname, value in testcase_params.items())
+ elif _non_string_or_bytes_iterable(testcase_params):
+ return ', '.join(map(_clean_repr, testcase_params))
+ else:
+ return _format_parameter_list((testcase_params,))
+
+
+def _async_wrapped(func):
+ @functools.wraps(func)
+ async def wrapper(*args, **kwargs):
+ return await func(*args, **kwargs)
+ return wrapper
+
+
+class _ParameterizedTestIter(object):
+ """Callable and iterable class for producing new test cases."""
+
+ def __init__(self, test_method, testcases, naming_type, original_name=None):
+ """Returns concrete test functions for a test and a list of parameters.
+
+ The naming_type is used to determine the name of the concrete
+ functions as reported by the unittest framework. If naming_type is
+ _FIRST_ARG, the testcases must be tuples, and the first element must
+ have a string representation that is a valid Python identifier.
+
+ Args:
+ test_method: The decorated test method.
+ testcases: (list of tuple/dict) A list of parameter tuples/dicts for
+ individual test invocations.
+ naming_type: The test naming type, either _NAMED or _ARGUMENT_REPR.
+ original_name: The original test method name. When decorated on a test
+ method, None is passed to __init__ and test_method.__name__ is used.
+ Note test_method.__name__ might be different than the original defined
+ test method because of the use of other decorators. A more accurate
+ value is set by TestGeneratorMetaclass.__new__ later.
+ """
+ self._test_method = test_method
+ self.testcases = testcases
+ self._naming_type = naming_type
+ if original_name is None:
+ original_name = test_method.__name__
+ self._original_name = original_name
+ self.__name__ = _ParameterizedTestIter.__name__
+
+ def __call__(self, *args, **kwargs):
+ raise RuntimeError('You appear to be running a parameterized test case '
+ 'without having inherited from parameterized.'
+ 'TestCase. This is bad because none of '
+ 'your test cases are actually being run. You may also '
+ 'be using another decorator before the parameterized '
+ 'one, in which case you should reverse the order.')
+
+ def __iter__(self):
+ test_method = self._test_method
+ naming_type = self._naming_type
+
+ def make_bound_param_test(testcase_params):
+ @functools.wraps(test_method)
+ def bound_param_test(self):
+ if isinstance(testcase_params, abc.Mapping):
+ return test_method(self, **testcase_params)
+ elif _non_string_or_bytes_iterable(testcase_params):
+ return test_method(self, *testcase_params)
+ else:
+ return test_method(self, testcase_params)
+
+ if naming_type is _NAMED:
+ # Signal the metaclass that the name of the test function is unique
+ # and descriptive.
+ bound_param_test.__x_use_name__ = True
+
+ testcase_name = None
+ if isinstance(testcase_params, abc.Mapping):
+ if _NAMED_DICT_KEY not in testcase_params:
+ raise RuntimeError(
+ 'Dict for named tests must contain key "%s"' % _NAMED_DICT_KEY)
+ # Create a new dict to avoid modifying the supplied testcase_params.
+ testcase_name = testcase_params[_NAMED_DICT_KEY]
+ testcase_params = {
+ k: v for k, v in testcase_params.items() if k != _NAMED_DICT_KEY
+ }
+ elif _non_string_or_bytes_iterable(testcase_params):
+ if not isinstance(testcase_params[0], str):
+ raise RuntimeError(
+ 'The first element of named test parameters is the test name '
+ 'suffix and must be a string')
+ testcase_name = testcase_params[0]
+ testcase_params = testcase_params[1:]
+ else:
+ raise RuntimeError(
+ 'Named tests must be passed a dict or non-string iterable.')
+
+ test_method_name = self._original_name
+ # Support PEP-8 underscore style for test naming if used.
+ if (test_method_name.startswith('test_')
+ and testcase_name
+ and not testcase_name.startswith('_')):
+ test_method_name += '_'
+
+ bound_param_test.__name__ = test_method_name + str(testcase_name)
+ elif naming_type is _ARGUMENT_REPR:
+ # If it's a generator, convert it to a tuple and treat them as
+ # parameters.
+ if isinstance(testcase_params, types.GeneratorType):
+ testcase_params = tuple(testcase_params)
+ # The metaclass creates a unique, but non-descriptive method name for
+ # _ARGUMENT_REPR tests using an indexed suffix.
+ # To keep test names descriptive, only the original method name is used.
+ # To make sure test names are unique, we add a unique descriptive suffix
+ # __x_params_repr__ for every test.
+ params_repr = '(%s)' % (_format_parameter_list(testcase_params),)
+ bound_param_test.__x_params_repr__ = params_repr
+ else:
+ raise RuntimeError('%s is not a valid naming type.' % (naming_type,))
+
+ bound_param_test.__doc__ = '%s(%s)' % (
+ bound_param_test.__name__, _format_parameter_list(testcase_params))
+ if test_method.__doc__:
+ bound_param_test.__doc__ += '\n%s' % (test_method.__doc__,)
+ if inspect.iscoroutinefunction(test_method):
+ return _async_wrapped(bound_param_test)
+ return bound_param_test
+
+ return (make_bound_param_test(c) for c in self.testcases)
+
+
+def _modify_class(class_object, testcases, naming_type):
+ assert not getattr(class_object, '_test_params_reprs', None), (
+ 'Cannot add parameters to %s. Either it already has parameterized '
+ 'methods, or its super class is also a parameterized class.' % (
+ class_object,))
+ # NOTE: _test_params_repr is private to parameterized.TestCase and it's
+ # metaclass; do not use it outside of those classes.
+ class_object._test_params_reprs = test_params_reprs = {}
+ for name, obj in class_object.__dict__.copy().items():
+ if (name.startswith(unittest.TestLoader.testMethodPrefix)
+ and isinstance(obj, types.FunctionType)):
+ delattr(class_object, name)
+ methods = {}
+ _update_class_dict_for_param_test_case(
+ class_object.__name__, methods, test_params_reprs, name,
+ _ParameterizedTestIter(obj, testcases, naming_type, name))
+ for meth_name, meth in methods.items():
+ setattr(class_object, meth_name, meth)
+
+
+def _parameter_decorator(naming_type, testcases):
+ """Implementation of the parameterization decorators.
+
+ Args:
+ naming_type: The naming type.
+ testcases: Testcase parameters.
+
+ Raises:
+ NoTestsError: Raised when the decorator generates no tests.
+
+ Returns:
+ A function for modifying the decorated object.
+ """
+ def _apply(obj):
+ if isinstance(obj, type):
+ _modify_class(obj, testcases, naming_type)
+ return obj
+ else:
+ return _ParameterizedTestIter(obj, testcases, naming_type)
+
+ if (len(testcases) == 1 and
+ not isinstance(testcases[0], tuple) and
+ not isinstance(testcases[0], abc.Mapping)):
+ # Support using a single non-tuple parameter as a list of test cases.
+ # Note that the single non-tuple parameter can't be Mapping either, which
+ # means a single dict parameter case.
+ assert _non_string_or_bytes_iterable(testcases[0]), (
+ 'Single parameter argument must be a non-string non-Mapping iterable')
+ testcases = testcases[0]
+
+ if not isinstance(testcases, abc.Sequence):
+ testcases = list(testcases)
+ if not testcases:
+ raise NoTestsError(
+ 'parameterized test decorators did not generate any tests. '
+ 'Make sure you specify non-empty parameters, '
+ 'and do not reuse generators more than once.')
+
+ return _apply
+
+
+def parameters(*testcases):
+ """A decorator for creating parameterized tests.
+
+ See the module docstring for a usage example.
+
+ Args:
+ *testcases: Parameters for the decorated method, either a single
+ iterable, or a list of tuples/dicts/objects (for tests with only one
+ argument).
+
+ Raises:
+ NoTestsError: Raised when the decorator generates no tests.
+
+ Returns:
+ A test generator to be handled by TestGeneratorMetaclass.
+ """
+ return _parameter_decorator(_ARGUMENT_REPR, testcases)
+
+
+def named_parameters(*testcases):
+ """A decorator for creating parameterized tests.
+
+ See the module docstring for a usage example. For every parameter tuple
+ passed, the first element of the tuple should be a string and will be appended
+ to the name of the test method. Each parameter dict passed must have a value
+ for the key "testcase_name", the string representation of that value will be
+ appended to the name of the test method.
+
+ Args:
+ *testcases: Parameters for the decorated method, either a single iterable,
+ or a list of tuples or dicts.
+
+ Raises:
+ NoTestsError: Raised when the decorator generates no tests.
+
+ Returns:
+ A test generator to be handled by TestGeneratorMetaclass.
+ """
+ return _parameter_decorator(_NAMED, testcases)
+
+
+def product(*kwargs_seqs, **testgrid):
+ """A decorator for running tests over cartesian product of parameters values.
+
+ See the module docstring for a usage example. The test will be run for every
+ possible combination of the parameters.
+
+ Args:
+ *kwargs_seqs: Each positional parameter is a sequence of keyword arg dicts;
+ every test case generated will include exactly one kwargs dict from each
+ positional parameter; these will then be merged to form an overall list
+ of arguments for the test case.
+ **testgrid: A mapping of parameter names and their possible values. Possible
+ values should given as either a list or a tuple.
+
+ Raises:
+ NoTestsError: Raised when the decorator generates no tests.
+
+ Returns:
+ A test generator to be handled by TestGeneratorMetaclass.
+ """
+
+ for name, values in testgrid.items():
+ assert isinstance(values, (list, tuple)), (
+ 'Values of {} must be given as list or tuple, found {}'.format(
+ name, type(values)))
+
+ prior_arg_names = set()
+ for kwargs_seq in kwargs_seqs:
+ assert ((isinstance(kwargs_seq, (list, tuple))) and
+ all(isinstance(kwargs, dict) for kwargs in kwargs_seq)), (
+ 'Positional parameters must be a sequence of keyword arg'
+ 'dicts, found {}'
+ .format(kwargs_seq))
+ if kwargs_seq:
+ arg_names = set(kwargs_seq[0])
+ assert all(set(kwargs) == arg_names for kwargs in kwargs_seq), (
+ 'Keyword argument dicts within a single parameter must all have the '
+ 'same keys, found {}'.format(kwargs_seq))
+ assert not (arg_names & prior_arg_names), (
+ 'Keyword argument dict sequences must all have distinct argument '
+ 'names, found duplicate(s) {}'
+ .format(sorted(arg_names & prior_arg_names)))
+ prior_arg_names |= arg_names
+
+ assert not (prior_arg_names & set(testgrid)), (
+ 'Arguments supplied in kwargs dicts in positional parameters must not '
+ 'overlap with arguments supplied as named parameters; found duplicate '
+ 'argument(s) {}'.format(sorted(prior_arg_names & set(testgrid))))
+
+ # Convert testgrid into a sequence of sequences of kwargs dicts and combine
+ # with the positional parameters.
+ # So foo=[1,2], bar=[3,4] --> [[{foo: 1}, {foo: 2}], [{bar: 3, bar: 4}]]
+ testgrid = (tuple({k: v} for v in vs) for k, vs in testgrid.items())
+ testgrid = tuple(kwargs_seqs) + tuple(testgrid)
+
+ # Create all possible combinations of parameters as a cartesian product
+ # of parameter values.
+ testcases = [
+ dict(itertools.chain.from_iterable(case.items()
+ for case in cases))
+ for cases in itertools.product(*testgrid)
+ ]
+ return _parameter_decorator(_ARGUMENT_REPR, testcases)
+
+
+class TestGeneratorMetaclass(type):
+ """Metaclass for adding tests generated by parameterized decorators."""
+
+ def __new__(cls, class_name, bases, dct):
+ # NOTE: _test_params_repr is private to parameterized.TestCase and it's
+ # metaclass; do not use it outside of those classes.
+ test_params_reprs = dct.setdefault('_test_params_reprs', {})
+ for name, obj in dct.copy().items():
+ if (name.startswith(unittest.TestLoader.testMethodPrefix) and
+ _non_string_or_bytes_iterable(obj)):
+ # NOTE: `obj` might not be a _ParameterizedTestIter in two cases:
+ # 1. a class-level iterable named test* that isn't a test, such as
+ # a list of something. Such attributes get deleted from the class.
+ #
+ # 2. If a decorator is applied to the parameterized test, e.g.
+ # @morestuff
+ # @parameterized.parameters(...)
+ # def test_foo(...): ...
+ #
+ # This is OK so long as the underlying parameterized function state
+ # is forwarded (e.g. using functool.wraps() and **without**
+ # accessing explicitly accessing the internal attributes.
+ if isinstance(obj, _ParameterizedTestIter):
+ # Update the original test method name so it's more accurate.
+ # The mismatch might happen when another decorator is used inside
+ # the parameterized decrators, and the inner decorator doesn't
+ # preserve its __name__.
+ obj._original_name = name
+ iterator = iter(obj)
+ dct.pop(name)
+ _update_class_dict_for_param_test_case(
+ class_name, dct, test_params_reprs, name, iterator)
+ # If the base class is a subclass of parameterized.TestCase, inherit its
+ # _test_params_reprs too.
+ for base in bases:
+ # Check if the base has _test_params_reprs first, then check if it's a
+ # subclass of parameterized.TestCase. Otherwise when this is called for
+ # the parameterized.TestCase definition itself, this raises because
+ # itself is not defined yet. This works as long as absltest.TestCase does
+ # not define _test_params_reprs.
+ base_test_params_reprs = getattr(base, '_test_params_reprs', None)
+ if base_test_params_reprs and issubclass(base, TestCase):
+ for test_method, test_method_id in base_test_params_reprs.items():
+ # test_method may both exists in base and this class.
+ # This class's method overrides base class's.
+ # That's why it should only inherit it if it does not exist.
+ test_params_reprs.setdefault(test_method, test_method_id)
+
+ return type.__new__(cls, class_name, bases, dct)
+
+
+def _update_class_dict_for_param_test_case(
+ test_class_name, dct, test_params_reprs, name, iterator):
+ """Adds individual test cases to a dictionary.
+
+ Args:
+ test_class_name: The name of the class tests are added to.
+ dct: The target dictionary.
+ test_params_reprs: The dictionary for mapping names to test IDs.
+ name: The original name of the test case.
+ iterator: The iterator generating the individual test cases.
+
+ Raises:
+ DuplicateTestNameError: Raised when a test name occurs multiple times.
+ RuntimeError: If non-parameterized functions are generated.
+ """
+ for idx, func in enumerate(iterator):
+ assert callable(func), 'Test generators must yield callables, got %r' % (
+ func,)
+ if not (getattr(func, '__x_use_name__', None) or
+ getattr(func, '__x_params_repr__', None)):
+ raise RuntimeError(
+ '{}.{} generated a test function without using the parameterized '
+ 'decorators. Only tests generated using the decorators are '
+ 'supported.'.format(test_class_name, name))
+
+ if getattr(func, '__x_use_name__', False):
+ original_name = func.__name__
+ new_name = original_name
+ else:
+ original_name = name
+ new_name = '%s%d' % (original_name, idx)
+
+ if new_name in dct:
+ raise DuplicateTestNameError(test_class_name, new_name, original_name)
+
+ dct[new_name] = func
+ test_params_reprs[new_name] = getattr(func, '__x_params_repr__', '')
+
+
+class TestCase(absltest.TestCase, metaclass=TestGeneratorMetaclass):
+ """Base class for test cases using the parameters decorator."""
+
+ # visibility: private; do not call outside this class.
+ def _get_params_repr(self):
+ return self._test_params_reprs.get(self._testMethodName, '')
+
+ def __str__(self):
+ params_repr = self._get_params_repr()
+ if params_repr:
+ params_repr = ' ' + params_repr
+ return '{}{} ({})'.format(
+ self._testMethodName, params_repr,
+ unittest.util.strclass(self.__class__))
+
+ def id(self):
+ """Returns the descriptive ID of the test.
+
+ This is used internally by the unittesting framework to get a name
+ for the test to be used in reports.
+
+ Returns:
+ The test id.
+ """
+ base = super(TestCase, self).id()
+ params_repr = self._get_params_repr()
+ if params_repr:
+ # We include the params in the id so that, when reported in the
+ # test.xml file, the value is more informative than just "test_foo0".
+ # Use a space to separate them so that it's copy/paste friendly and
+ # easy to identify the actual test id.
+ return '{} {}'.format(base, params_repr)
+ else:
+ return base
+
+
+# This function is kept CamelCase because it's used as a class's base class.
+def CoopTestCase(other_base_class): # pylint: disable=invalid-name
+ """Returns a new base class with a cooperative metaclass base.
+
+ This enables the TestCase to be used in combination
+ with other base classes that have custom metaclasses, such as
+ mox.MoxTestBase.
+
+ Only works with metaclasses that do not override type.__new__.
+
+ Example:
+
+ from absl.testing import parameterized
+
+ class ExampleTest(parameterized.CoopTestCase(OtherTestCase)):
+ ...
+
+ Args:
+ other_base_class: (class) A test case base class.
+
+ Returns:
+ A new class object.
+ """
+ metaclass = type(
+ 'CoopMetaclass',
+ (other_base_class.__metaclass__,
+ TestGeneratorMetaclass), {})
+ return metaclass(
+ 'CoopTestCase',
+ (other_base_class, TestCase), {})
diff --git a/absl/testing/tests/__init__.py b/absl/testing/tests/__init__.py
new file mode 100644
index 0000000..a3bd1cd
--- /dev/null
+++ b/absl/testing/tests/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/absl/testing/tests/absltest_env.py b/absl/testing/tests/absltest_env.py
new file mode 100644
index 0000000..6c06d62
--- /dev/null
+++ b/absl/testing/tests/absltest_env.py
@@ -0,0 +1,30 @@
+"""Helper library to get environment variables for absltest helper binaries."""
+
+import os
+
+
+_INHERITED_ENV_KEYS = frozenset({
+ # This is needed to correctly use the Python interpreter determined by
+ # bazel.
+ 'PATH',
+ # This is used by the random module on Windows to locate crypto
+ # libraries.
+ 'SYSTEMROOT',
+})
+
+
+def inherited_env():
+ """Returns the environment variables that should be inherited from parent.
+
+ Reason why using an explicit list of environment variables instead of
+ inheriting all from parent: the absltest module itself interprets a list of
+ environment variables set by bazel, e.g. XML_OUTPUT_FILE,
+ TESTBRIDGE_TEST_ONLY. While testing absltest's own behavior, we should
+ remove them when invoking the helper subprocess. Using an explicit list is
+ safer.
+ """
+ env = {}
+ for key in _INHERITED_ENV_KEYS:
+ if key in os.environ:
+ env[key] = os.environ[key]
+ return env
diff --git a/absl/testing/tests/absltest_fail_fast_test.py b/absl/testing/tests/absltest_fail_fast_test.py
new file mode 100644
index 0000000..dc967f9
--- /dev/null
+++ b/absl/testing/tests/absltest_fail_fast_test.py
@@ -0,0 +1,109 @@
+# Copyright 2020 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for test fail fast protocol."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import subprocess
+
+from absl import logging
+from absl.testing import _bazelize_command
+from absl.testing import absltest
+from absl.testing import parameterized
+from absl.testing.tests import absltest_env
+
+
+@parameterized.named_parameters(
+ ('use_app_run', True),
+ ('no_argv', False),
+)
+class TestFailFastTest(parameterized.TestCase):
+ """Integration tests: Runs a test binary with fail fast.
+
+ This is done by setting the fail fast environment variable
+ """
+
+ def setUp(self):
+ super().setUp()
+ self._test_name = 'absl/testing/tests/absltest_fail_fast_test_helper'
+
+ def _run_fail_fast(self, fail_fast, use_app_run):
+ """Runs the py_test binary in a subprocess.
+
+ Args:
+ fail_fast: string, the fail fast value.
+ use_app_run: bool, whether the test helper should call
+ `absltest.main(argv=)` inside `app.run`.
+
+ Returns:
+ (stdout, exit_code) tuple of (string, int).
+ """
+ env = absltest_env.inherited_env()
+ if fail_fast is not None:
+ env['TESTBRIDGE_TEST_RUNNER_FAIL_FAST'] = fail_fast
+ env['USE_APP_RUN'] = '1' if use_app_run else '0'
+
+ proc = subprocess.Popen(
+ args=[_bazelize_command.get_executable_path(self._test_name)],
+ env=env,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ universal_newlines=True)
+ stdout = proc.communicate()[0]
+
+ logging.info('output: %s', stdout)
+ return stdout, proc.wait()
+
+ def test_no_fail_fast(self, use_app_run):
+ out, exit_code = self._run_fail_fast(None, use_app_run)
+ self.assertEqual(1, exit_code)
+ self.assertIn('class A test A', out)
+ self.assertIn('class A test B', out)
+ self.assertIn('class A test C', out)
+ self.assertIn('class A test D', out)
+ self.assertIn('class A test E', out)
+
+ def test_empty_fail_fast(self, use_app_run):
+ out, exit_code = self._run_fail_fast('', use_app_run)
+ self.assertEqual(1, exit_code)
+ self.assertIn('class A test A', out)
+ self.assertIn('class A test B', out)
+ self.assertIn('class A test C', out)
+ self.assertIn('class A test D', out)
+ self.assertIn('class A test E', out)
+
+ def test_fail_fast_1(self, use_app_run):
+ out, exit_code = self._run_fail_fast('1', use_app_run)
+ self.assertEqual(1, exit_code)
+ self.assertIn('class A test A', out)
+ self.assertIn('class A test B', out)
+ self.assertIn('class A test C', out)
+ self.assertNotIn('class A test D', out)
+ self.assertNotIn('class A test E', out)
+
+ def test_fail_fast_0(self, use_app_run):
+ out, exit_code = self._run_fail_fast('0', use_app_run)
+ self.assertEqual(1, exit_code)
+ self.assertIn('class A test A', out)
+ self.assertIn('class A test B', out)
+ self.assertIn('class A test C', out)
+ self.assertIn('class A test D', out)
+ self.assertIn('class A test E', out)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/testing/tests/absltest_fail_fast_test_helper.py b/absl/testing/tests/absltest_fail_fast_test_helper.py
new file mode 100644
index 0000000..339a569
--- /dev/null
+++ b/absl/testing/tests/absltest_fail_fast_test_helper.py
@@ -0,0 +1,56 @@
+# Copyright 2020 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A helper test program for absltest_fail_fast_test."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+
+from absl import app
+from absl.testing import absltest
+
+
+class ClassA(absltest.TestCase):
+ """Helper test case A for absltest_fail_fast_test."""
+
+ def testA(self):
+ sys.stderr.write('\nclass A test A\n')
+
+ def testB(self):
+ sys.stderr.write('\nclass A test B\n')
+
+ def testC(self):
+ sys.stderr.write('\nclass A test C\n')
+ self.fail('Force failure')
+
+ def testD(self):
+ sys.stderr.write('\nclass A test D\n')
+
+ def testE(self):
+ sys.stderr.write('\nclass A test E\n')
+
+
+def main(argv):
+ absltest.main(argv=argv)
+
+
+if __name__ == '__main__':
+ if os.environ['USE_APP_RUN'] == '1':
+ app.run(main)
+ else:
+ absltest.main()
diff --git a/absl/testing/tests/absltest_filtering_test.py b/absl/testing/tests/absltest_filtering_test.py
new file mode 100644
index 0000000..30a81f6
--- /dev/null
+++ b/absl/testing/tests/absltest_filtering_test.py
@@ -0,0 +1,192 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for test filtering protocol."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import subprocess
+import sys
+
+from absl import logging
+from absl.testing import _bazelize_command
+from absl.testing import absltest
+from absl.testing import parameterized
+from absl.testing.tests import absltest_env
+
+
+@parameterized.named_parameters(
+ ('as_env_variable_use_app_run', True, True),
+ ('as_env_variable_no_argv', True, False),
+ ('as_commandline_args_use_app_run', False, True),
+ ('as_commandline_args_no_argv', False, False),
+)
+class TestFilteringTest(absltest.TestCase):
+ """Integration tests: Runs a test binary with filtering.
+
+ This is done by either setting the filtering environment variable, or passing
+ the filters as command line arguments.
+ """
+
+ def setUp(self):
+ super().setUp()
+ self._test_name = 'absl/testing/tests/absltest_filtering_test_helper'
+
+ def _run_filtered(self, test_filter, use_env_variable, use_app_run):
+ """Runs the py_test binary in a subprocess.
+
+ Args:
+ test_filter: string, the filter argument to use.
+ use_env_variable: bool, pass the test filter as environment variable if
+ True, otherwise pass as command line arguments.
+ use_app_run: bool, whether the test helper should call
+ `absltest.main(argv=)` inside `app.run`.
+
+ Returns:
+ (stdout, exit_code) tuple of (string, int).
+ """
+ env = absltest_env.inherited_env()
+ env['USE_APP_RUN'] = '1' if use_app_run else '0'
+ additional_args = []
+ if test_filter is not None:
+ if use_env_variable:
+ env['TESTBRIDGE_TEST_ONLY'] = test_filter
+ elif test_filter:
+ if sys.version_info[:2] >= (3, 7):
+ # The -k flags are passed as positional arguments to absl.flags.
+ additional_args.append('--')
+ additional_args.extend(['-k=' + f for f in test_filter.split(' ')])
+ else:
+ additional_args.extend(test_filter.split(' '))
+
+ proc = subprocess.Popen(
+ args=([_bazelize_command.get_executable_path(self._test_name)] +
+ additional_args),
+ env=env,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ universal_newlines=True)
+ stdout = proc.communicate()[0]
+
+ logging.info('output: %s', stdout)
+ return stdout, proc.wait()
+
+ def test_no_filter(self, use_env_variable, use_app_run):
+ out, exit_code = self._run_filtered(None, use_env_variable, use_app_run)
+ self.assertEqual(1, exit_code)
+ self.assertIn('class B test E', out)
+
+ def test_empty_filter(self, use_env_variable, use_app_run):
+ out, exit_code = self._run_filtered('', use_env_variable, use_app_run)
+ self.assertEqual(1, exit_code)
+ self.assertIn('class B test E', out)
+
+ def test_class_filter(self, use_env_variable, use_app_run):
+ out, exit_code = self._run_filtered('ClassA', use_env_variable, use_app_run)
+ self.assertEqual(0, exit_code)
+ self.assertNotIn('class B', out)
+
+ def test_method_filter(self, use_env_variable, use_app_run):
+ out, exit_code = self._run_filtered('ClassB.testA', use_env_variable,
+ use_app_run)
+ self.assertEqual(0, exit_code)
+ self.assertNotIn('class A', out)
+ self.assertNotIn('class B test B', out)
+
+ out, exit_code = self._run_filtered('ClassB.testE', use_env_variable,
+ use_app_run)
+ self.assertEqual(1, exit_code)
+ self.assertNotIn('class A', out)
+
+ def test_multiple_class_and_method_filter(self, use_env_variable,
+ use_app_run):
+ out, exit_code = self._run_filtered(
+ 'ClassA.testA ClassA.testB ClassB.testC', use_env_variable, use_app_run)
+ self.assertEqual(0, exit_code)
+ self.assertIn('class A test A', out)
+ self.assertIn('class A test B', out)
+ self.assertNotIn('class A test C', out)
+ self.assertIn('class B test C', out)
+ self.assertNotIn('class B test A', out)
+
+ @absltest.skipIf(
+ sys.version_info[:2] < (3, 7),
+ 'Only Python 3.7+ does glob and substring matching.')
+ def test_substring(self, use_env_variable, use_app_run):
+ out, exit_code = self._run_filtered(
+ 'testA', use_env_variable, use_app_run)
+ self.assertEqual(0, exit_code)
+ self.assertIn('Ran 2 tests', out)
+ self.assertIn('ClassA.testA', out)
+ self.assertIn('ClassB.testA', out)
+
+ @absltest.skipIf(
+ sys.version_info[:2] < (3, 7),
+ 'Only Python 3.7+ does glob and substring matching.')
+ def test_glob_pattern(self, use_env_variable, use_app_run):
+ out, exit_code = self._run_filtered(
+ '__main__.Class*.testA', use_env_variable, use_app_run)
+ self.assertEqual(0, exit_code)
+ self.assertIn('Ran 2 tests', out)
+ self.assertIn('ClassA.testA', out)
+ self.assertIn('ClassB.testA', out)
+
+ @absltest.skipIf(
+ sys.version_info[:2] >= (3, 7),
+ "Python 3.7+ uses unittest's -k flag and doesn't fail if no tests match.")
+ def test_not_found_filters_py36(self, use_env_variable, use_app_run):
+ out, exit_code = self._run_filtered('NotExistedClass.not_existed_method',
+ use_env_variable, use_app_run)
+ self.assertEqual(1, exit_code)
+ self.assertIn("has no attribute 'NotExistedClass'", out)
+
+ @absltest.skipIf(
+ sys.version_info[:2] < (3, 7),
+ 'Python 3.6 passes the filter as positional arguments and fails if no '
+ 'tests match.'
+ )
+ def test_not_found_filters_py37(self, use_env_variable, use_app_run):
+ out, exit_code = self._run_filtered('NotExistedClass.not_existed_method',
+ use_env_variable, use_app_run)
+ self.assertEqual(0, exit_code)
+ self.assertIn('Ran 0 tests', out)
+
+ @absltest.skipIf(
+ sys.version_info[:2] < (3, 7),
+ 'Python 3.6 passes the filter as positional arguments and matches by name'
+ )
+ def test_parameterized_unnamed(self, use_env_variable, use_app_run):
+ out, exit_code = self._run_filtered('ParameterizedTest.test_unnamed',
+ use_env_variable, use_app_run)
+ self.assertEqual(0, exit_code)
+ self.assertIn('Ran 2 tests', out)
+ self.assertIn('parameterized unnamed 1', out)
+ self.assertIn('parameterized unnamed 2', out)
+
+ @absltest.skipIf(
+ sys.version_info[:2] < (3, 7),
+ 'Python 3.6 passes the filter as positional arguments and matches by name'
+ )
+ def test_parameterized_named(self, use_env_variable, use_app_run):
+ out, exit_code = self._run_filtered('ParameterizedTest.test_named',
+ use_env_variable, use_app_run)
+ self.assertEqual(0, exit_code)
+ self.assertIn('Ran 2 tests', out)
+ self.assertIn('parameterized named 1', out)
+ self.assertIn('parameterized named 2', out)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/testing/tests/absltest_filtering_test_helper.py b/absl/testing/tests/absltest_filtering_test_helper.py
new file mode 100644
index 0000000..2b741ed
--- /dev/null
+++ b/absl/testing/tests/absltest_filtering_test_helper.py
@@ -0,0 +1,85 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A helper test program for absltest_filtering_test."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+
+from absl import app
+from absl.testing import absltest
+from absl.testing import parameterized
+
+
+class ClassA(absltest.TestCase):
+ """Helper test case A for absltest_filtering_test."""
+
+ def testA(self):
+ sys.stderr.write('\nclass A test A\n')
+
+ def testB(self):
+ sys.stderr.write('\nclass A test B\n')
+
+ def testC(self):
+ sys.stderr.write('\nclass A test C\n')
+
+
+class ClassB(absltest.TestCase):
+ """Helper test case B for absltest_filtering_test."""
+
+ def testA(self):
+ sys.stderr.write('\nclass B test A\n')
+
+ def testB(self):
+ sys.stderr.write('\nclass B test B\n')
+
+ def testC(self):
+ sys.stderr.write('\nclass B test C\n')
+
+ def testD(self):
+ sys.stderr.write('\nclass B test D\n')
+
+ def testE(self):
+ sys.stderr.write('\nclass B test E\n')
+ self.fail('Force failure')
+
+
+class ParameterizedTest(parameterized.TestCase):
+ """Helper parameterized test case for absltest_filtering_test."""
+
+ @parameterized.parameters([1, 2])
+ def test_unnamed(self, value):
+ sys.stderr.write('\nparameterized unnamed %s' % value)
+
+ @parameterized.named_parameters(
+ ('test1', 1),
+ ('test2', 2),
+ )
+ def test_named(self, value):
+ sys.stderr.write('\nparameterized named %s' % value)
+
+
+def main(argv):
+ absltest.main(argv=argv)
+
+
+if __name__ == '__main__':
+ if os.environ['USE_APP_RUN'] == '1':
+ app.run(main)
+ else:
+ absltest.main()
diff --git a/absl/testing/tests/absltest_py3_test.py b/absl/testing/tests/absltest_py3_test.py
new file mode 100644
index 0000000..7c5f500
--- /dev/null
+++ b/absl/testing/tests/absltest_py3_test.py
@@ -0,0 +1,44 @@
+# Copyright 2020 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Python3-only Tests for absltest."""
+
+from absl.testing import absltest
+
+
+class GetTestCaseNamesPEP3102Test(absltest.TestCase):
+ """This test verifies absltest.TestLoader.GetTestCasesNames PEP3102 support.
+
+ The test is Python3 only, as keyword only arguments are considered
+ syntax error in Python2.
+
+ The rest of getTestCaseNames functionality is covered
+ by absltest_test.TestLoaderTest.
+ """
+
+ class Valid(absltest.TestCase):
+
+ def testKeywordOnly(self, *, arg):
+ pass
+
+ def setUp(self):
+ self.loader = absltest.TestLoader()
+ super(GetTestCaseNamesPEP3102Test, self).setUp()
+
+ def test_PEP3102_get_test_case_names(self):
+ self.assertCountEqual(
+ self.loader.getTestCaseNames(GetTestCaseNamesPEP3102Test.Valid),
+ ["testKeywordOnly"])
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/absl/testing/tests/absltest_randomization_test.py b/absl/testing/tests/absltest_randomization_test.py
new file mode 100644
index 0000000..75a3868
--- /dev/null
+++ b/absl/testing/tests/absltest_randomization_test.py
@@ -0,0 +1,154 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for test randomization."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+import subprocess
+
+from absl import flags
+from absl.testing import _bazelize_command
+from absl.testing import absltest
+from absl.testing import parameterized
+from absl.testing.tests import absltest_env
+
+FLAGS = flags.FLAGS
+
+
+class TestOrderRandomizationTest(parameterized.TestCase):
+ """Integration tests: Runs a py_test binary with randomization.
+
+ This is done by setting flags and environment variables.
+ """
+
+ def setUp(self):
+ super(TestOrderRandomizationTest, self).setUp()
+ self._test_name = 'absl/testing/tests/absltest_randomization_testcase'
+
+ def _run_test(self, extra_argv, extra_env):
+ """Runs the py_test binary in a subprocess, with the given args or env.
+
+ Args:
+ extra_argv: extra args to pass to the test
+ extra_env: extra env vars to set when running the test
+
+ Returns:
+ (stdout, test_cases, exit_code) tuple of (str, list of strs, int).
+ """
+ env = absltest_env.inherited_env()
+ # If *this* test is being run with this flag, we don't want to
+ # automatically set it for all tests we run.
+ env.pop('TEST_RANDOMIZE_ORDERING_SEED', '')
+ if extra_env is not None:
+ env.update(extra_env)
+
+ command = (
+ [_bazelize_command.get_executable_path(self._test_name)] + extra_argv)
+ proc = subprocess.Popen(
+ args=command,
+ env=env,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ universal_newlines=True)
+
+ stdout, _ = proc.communicate()
+
+ test_lines = [l for l in stdout.splitlines() if l.startswith('class ')]
+ return stdout, test_lines, proc.wait()
+
+ def test_no_args(self):
+ output, tests, exit_code = self._run_test([], None)
+ self.assertEqual(0, exit_code, msg='command output: ' + output)
+ self.assertNotIn('Randomizing test order with seed:', output)
+ cases = ['class A test ' + t for t in ('A', 'B', 'C')]
+ self.assertEqual(cases, tests)
+
+ @parameterized.parameters(
+ {
+ 'argv': ['--test_randomize_ordering_seed=random'],
+ 'env': None,
+ },
+ {
+ 'argv': [],
+ 'env': {
+ 'TEST_RANDOMIZE_ORDERING_SEED': 'random',
+ },
+ },)
+ def test_simple_randomization(self, argv, env):
+ output, tests, exit_code = self._run_test(argv, env)
+ self.assertEqual(0, exit_code, msg='command output: ' + output)
+ self.assertIn('Randomizing test order with seed: ', output)
+ cases = ['class A test ' + t for t in ('A', 'B', 'C')]
+ # This may come back in any order; we just know it'll be the same
+ # set of elements.
+ self.assertSameElements(cases, tests)
+
+ @parameterized.parameters(
+ {
+ 'argv': ['--test_randomize_ordering_seed=1'],
+ 'env': None,
+ },
+ {
+ 'argv': [],
+ 'env': {
+ 'TEST_RANDOMIZE_ORDERING_SEED': '1'
+ },
+ },
+ {
+ 'argv': [],
+ 'env': {
+ 'LATE_SET_TEST_RANDOMIZE_ORDERING_SEED': '1'
+ },
+ },
+ )
+ def test_fixed_seed(self, argv, env):
+ output, tests, exit_code = self._run_test(argv, env)
+ self.assertEqual(0, exit_code, msg='command output: ' + output)
+ self.assertIn('Randomizing test order with seed: 1', output)
+ # Even though we know the seed, we need to shuffle the tests here, since
+ # this behaves differently in Python2 vs Python3.
+ shuffled_cases = ['A', 'B', 'C']
+ random.Random(1).shuffle(shuffled_cases)
+ cases = ['class A test ' + t for t in shuffled_cases]
+ # We know what order this will come back for the random seed we've
+ # specified.
+ self.assertEqual(cases, tests)
+
+ @parameterized.parameters(
+ {
+ 'argv': ['--test_randomize_ordering_seed=0'],
+ 'env': {
+ 'TEST_RANDOMIZE_ORDERING_SEED': 'random'
+ },
+ },
+ {
+ 'argv': [],
+ 'env': {
+ 'TEST_RANDOMIZE_ORDERING_SEED': '0'
+ },
+ },)
+ def test_disabling_randomization(self, argv, env):
+ output, tests, exit_code = self._run_test(argv, env)
+ self.assertEqual(0, exit_code, msg='command output: ' + output)
+ self.assertNotIn('Randomizing test order with seed:', output)
+ cases = ['class A test ' + t for t in ('A', 'B', 'C')]
+ self.assertEqual(cases, tests)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/testing/tests/absltest_randomization_testcase.py b/absl/testing/tests/absltest_randomization_testcase.py
new file mode 100644
index 0000000..18b20ff
--- /dev/null
+++ b/absl/testing/tests/absltest_randomization_testcase.py
@@ -0,0 +1,47 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Stub tests, only for use in absltest_randomization_test.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+
+from absl.testing import absltest
+
+
+# This stanza exercises setting $TEST_RANDOMIZE_ORDERING_SEED *after* importing
+# the absltest library.
+if os.environ.get('LATE_SET_TEST_RANDOMIZE_ORDERING_SEED', ''):
+ os.environ['TEST_RANDOMIZE_ORDERING_SEED'] = os.environ[
+ 'LATE_SET_TEST_RANDOMIZE_ORDERING_SEED']
+
+
+class ClassA(absltest.TestCase):
+
+ def test_a(self):
+ sys.stderr.write('\nclass A test A\n')
+
+ def test_b(self):
+ sys.stderr.write('\nclass A test B\n')
+
+ def test_c(self):
+ sys.stderr.write('\nclass A test C\n')
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/testing/tests/absltest_sharding_test.py b/absl/testing/tests/absltest_sharding_test.py
new file mode 100644
index 0000000..6411971
--- /dev/null
+++ b/absl/testing/tests/absltest_sharding_test.py
@@ -0,0 +1,165 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for test sharding protocol."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import subprocess
+
+from absl.testing import _bazelize_command
+from absl.testing import absltest
+from absl.testing.tests import absltest_env
+
+
+NUM_TEST_METHODS = 8 # Hard-coded, based on absltest_sharding_test_helper.py
+
+
+class TestShardingTest(absltest.TestCase):
+ """Integration tests: Runs a test binary with sharding.
+
+ This is done by setting the sharding environment variables.
+ """
+
+ def setUp(self):
+ super().setUp()
+ self._test_name = 'absl/testing/tests/absltest_sharding_test_helper'
+ self._shard_file = None
+
+ def tearDown(self):
+ super().tearDown()
+ if self._shard_file is not None and os.path.exists(self._shard_file):
+ os.unlink(self._shard_file)
+
+ def _run_sharded(self,
+ total_shards,
+ shard_index,
+ shard_file=None,
+ additional_env=None):
+ """Runs the py_test binary in a subprocess.
+
+ Args:
+ total_shards: int, the total number of shards.
+ shard_index: int, the shard index.
+ shard_file: string, if not 'None', the path to the shard file.
+ This method asserts it is properly created.
+ additional_env: Additional environment variables to be set for the py_test
+ binary.
+
+ Returns:
+ (stdout, exit_code) tuple of (string, int).
+ """
+ env = absltest_env.inherited_env()
+ if additional_env:
+ env.update(additional_env)
+ env.update({
+ 'TEST_TOTAL_SHARDS': str(total_shards),
+ 'TEST_SHARD_INDEX': str(shard_index)
+ })
+ if shard_file:
+ self._shard_file = shard_file
+ env['TEST_SHARD_STATUS_FILE'] = shard_file
+ if os.path.exists(shard_file):
+ os.unlink(shard_file)
+
+ proc = subprocess.Popen(
+ args=[_bazelize_command.get_executable_path(self._test_name)],
+ env=env,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ universal_newlines=True)
+ stdout = proc.communicate()[0]
+
+ if shard_file:
+ self.assertTrue(os.path.exists(shard_file))
+
+ return (stdout, proc.wait())
+
+ def _assert_sharding_correctness(self, total_shards):
+ """Assert the primary correctness and performance of sharding.
+
+ 1. Completeness (all methods are run)
+ 2. Partition (each method run at most once)
+ 3. Balance (for performance)
+
+ Args:
+ total_shards: int, total number of shards.
+ """
+
+ outerr_by_shard = [] # A list of lists of strings
+ combined_outerr = [] # A list of strings
+ exit_code_by_shard = [] # A list of ints
+
+ for i in range(total_shards):
+ (out, exit_code) = self._run_sharded(total_shards, i)
+ method_list = [x for x in out.split('\n') if x.startswith('class')]
+ outerr_by_shard.append(method_list)
+ combined_outerr.extend(method_list)
+ exit_code_by_shard.append(exit_code)
+
+ self.assertLen([x for x in exit_code_by_shard if x != 0], 1,
+ 'Expected exactly one failure')
+
+ # Test completeness and partition properties.
+ self.assertLen(combined_outerr, NUM_TEST_METHODS,
+ 'Partition requirement not met')
+ self.assertLen(set(combined_outerr), NUM_TEST_METHODS,
+ 'Completeness requirement not met')
+
+ # Test balance:
+ for i in range(len(outerr_by_shard)):
+ self.assertGreaterEqual(len(outerr_by_shard[i]),
+ (NUM_TEST_METHODS / total_shards) - 1,
+ 'Shard %d of %d out of balance' %
+ (i, len(outerr_by_shard)))
+
+ def test_shard_file(self):
+ self._run_sharded(3, 1, os.path.join(
+ absltest.TEST_TMPDIR.value, 'shard_file'))
+
+ def test_zero_shards(self):
+ out, exit_code = self._run_sharded(0, 0)
+ self.assertEqual(1, exit_code)
+ self.assertGreaterEqual(out.find('Bad sharding values. index=0, total=0'),
+ 0, 'Bad output: %s' % (out))
+
+ def test_with_four_shards(self):
+ self._assert_sharding_correctness(4)
+
+ def test_with_one_shard(self):
+ self._assert_sharding_correctness(1)
+
+ def test_with_ten_shards(self):
+ self._assert_sharding_correctness(10)
+
+ def test_sharding_with_randomization(self):
+ # If we're both sharding *and* randomizing, we need to confirm that we
+ # randomize within the shard; we use two seeds to confirm we're seeing the
+ # same tests (sharding is consistent) in a different order.
+ tests_seen = []
+ for seed in ('7', '17'):
+ out, exit_code = self._run_sharded(
+ 2, 0, additional_env={'TEST_RANDOMIZE_ORDERING_SEED': seed})
+ self.assertEqual(0, exit_code)
+ tests_seen.append([x for x in out.splitlines() if x.startswith('class')])
+ first_tests, second_tests = tests_seen # pylint: disable=unbalanced-tuple-unpacking
+ self.assertEqual(set(first_tests), set(second_tests))
+ self.assertNotEqual(first_tests, second_tests)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/testing/tests/absltest_sharding_test_helper.py b/absl/testing/tests/absltest_sharding_test_helper.py
new file mode 100644
index 0000000..7b2f20e
--- /dev/null
+++ b/absl/testing/tests/absltest_sharding_test_helper.py
@@ -0,0 +1,60 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A helper test program for absltest_sharding_test."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+from absl.testing import absltest
+
+
+class ClassA(absltest.TestCase):
+ """Helper test case A for absltest_sharding_test."""
+
+ def testA(self):
+ sys.stderr.write('\nclass A test A\n')
+
+ def testB(self):
+ sys.stderr.write('\nclass A test B\n')
+
+ def testC(self):
+ sys.stderr.write('\nclass A test C\n')
+
+
+class ClassB(absltest.TestCase):
+ """Helper test case B for absltest_sharding_test."""
+
+ def testA(self):
+ sys.stderr.write('\nclass B test A\n')
+
+ def testB(self):
+ sys.stderr.write('\nclass B test B\n')
+
+ def testC(self):
+ sys.stderr.write('\nclass B test C\n')
+
+ def testD(self):
+ sys.stderr.write('\nclass B test D\n')
+
+ def testE(self):
+ sys.stderr.write('\nclass B test E\n')
+ self.fail('Force failure')
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/testing/tests/absltest_test.py b/absl/testing/tests/absltest_test.py
new file mode 100644
index 0000000..48eeca8
--- /dev/null
+++ b/absl/testing/tests/absltest_test.py
@@ -0,0 +1,2374 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for absltest."""
+
+import collections
+import contextlib
+import io
+import os
+import pathlib
+import re
+import stat
+import string
+import subprocess
+import tempfile
+import unittest
+
+from absl.testing import _bazelize_command
+from absl.testing import absltest
+from absl.testing import parameterized
+from absl.testing.tests import absltest_env
+
+
+class HelperMixin(object):
+
+ def _get_helper_exec_path(self):
+ helper = 'absl/testing/tests/absltest_test_helper'
+ return _bazelize_command.get_executable_path(helper)
+
+ def run_helper(self, test_id, args, env_overrides, expect_success):
+ env = absltest_env.inherited_env()
+ for key, value in env_overrides.items():
+ if value is None:
+ if key in env:
+ del env[key]
+ else:
+ env[key] = value
+
+ command = [self._get_helper_exec_path(),
+ '--test_id={}'.format(test_id)] + args
+ process = subprocess.Popen(
+ command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env,
+ universal_newlines=True)
+ stdout, stderr = process.communicate()
+ if expect_success:
+ self.assertEqual(
+ 0, process.returncode,
+ 'Expected success, but failed with '
+ 'stdout:\n{}\nstderr:\n{}\n'.format(stdout, stderr))
+ else:
+ self.assertEqual(
+ 1, process.returncode,
+ 'Expected failure, but succeeded with '
+ 'stdout:\n{}\nstderr:\n{}\n'.format(stdout, stderr))
+ return stdout, stderr
+
+
+class TestCaseTest(absltest.TestCase, HelperMixin):
+ longMessage = True
+
+ def run_helper(self, test_id, args, env_overrides, expect_success):
+ return super(TestCaseTest, self).run_helper(test_id, args + ['HelperTest'],
+ env_overrides, expect_success)
+
+ def test_flags_no_env_var_no_flags(self):
+ self.run_helper(
+ 1,
+ [],
+ {'TEST_RANDOM_SEED': None,
+ 'TEST_SRCDIR': None,
+ 'TEST_TMPDIR': None,
+ },
+ expect_success=True)
+
+ def test_flags_env_var_no_flags(self):
+ tmpdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
+ srcdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
+ self.run_helper(
+ 2,
+ [],
+ {'TEST_RANDOM_SEED': '321',
+ 'TEST_SRCDIR': srcdir,
+ 'TEST_TMPDIR': tmpdir,
+ 'ABSLTEST_TEST_HELPER_EXPECTED_TEST_SRCDIR': srcdir,
+ 'ABSLTEST_TEST_HELPER_EXPECTED_TEST_TMPDIR': tmpdir,
+ },
+ expect_success=True)
+
+ def test_flags_no_env_var_flags(self):
+ tmpdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
+ srcdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
+ self.run_helper(
+ 3,
+ ['--test_random_seed=123', '--test_srcdir={}'.format(srcdir),
+ '--test_tmpdir={}'.format(tmpdir)],
+ {'TEST_RANDOM_SEED': None,
+ 'TEST_SRCDIR': None,
+ 'TEST_TMPDIR': None,
+ 'ABSLTEST_TEST_HELPER_EXPECTED_TEST_SRCDIR': srcdir,
+ 'ABSLTEST_TEST_HELPER_EXPECTED_TEST_TMPDIR': tmpdir,
+ },
+ expect_success=True)
+
+ def test_flags_env_var_flags(self):
+ tmpdir_from_flag = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
+ srcdir_from_flag = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
+ tmpdir_from_env_var = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
+ srcdir_from_env_var = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
+ self.run_helper(
+ 4,
+ ['--test_random_seed=221', '--test_srcdir={}'.format(srcdir_from_flag),
+ '--test_tmpdir={}'.format(tmpdir_from_flag)],
+ {'TEST_RANDOM_SEED': '123',
+ 'TEST_SRCDIR': srcdir_from_env_var,
+ 'TEST_TMPDIR': tmpdir_from_env_var,
+ 'ABSLTEST_TEST_HELPER_EXPECTED_TEST_SRCDIR': srcdir_from_flag,
+ 'ABSLTEST_TEST_HELPER_EXPECTED_TEST_TMPDIR': tmpdir_from_flag,
+ },
+ expect_success=True)
+
+ def test_xml_output_file_from_xml_output_file_env(self):
+ xml_dir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
+ xml_output_file_env = os.path.join(xml_dir, 'xml_output_file.xml')
+ random_dir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
+ self.run_helper(
+ 6,
+ [],
+ {'XML_OUTPUT_FILE': xml_output_file_env,
+ 'RUNNING_UNDER_TEST_DAEMON': '1',
+ 'TEST_XMLOUTPUTDIR': random_dir,
+ 'ABSLTEST_TEST_HELPER_EXPECTED_XML_OUTPUT_FILE': xml_output_file_env,
+ },
+ expect_success=True)
+
+ def test_xml_output_file_from_daemon(self):
+ tmpdir = os.path.join(tempfile.mkdtemp(
+ dir=absltest.TEST_TMPDIR.value), 'sub_dir')
+ random_dir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
+ self.run_helper(
+ 6,
+ ['--test_tmpdir', tmpdir],
+ {'XML_OUTPUT_FILE': None,
+ 'RUNNING_UNDER_TEST_DAEMON': '1',
+ 'TEST_XMLOUTPUTDIR': random_dir,
+ 'ABSLTEST_TEST_HELPER_EXPECTED_XML_OUTPUT_FILE': os.path.join(
+ os.path.dirname(tmpdir), 'test_detail.xml'),
+ },
+ expect_success=True)
+
+ def test_xml_output_file_from_test_xmloutputdir_env(self):
+ xml_output_dir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
+ expected_xml_file = 'absltest_test_helper.xml'
+ self.run_helper(
+ 6,
+ [],
+ {'XML_OUTPUT_FILE': None,
+ 'RUNNING_UNDER_TEST_DAEMON': None,
+ 'TEST_XMLOUTPUTDIR': xml_output_dir,
+ 'ABSLTEST_TEST_HELPER_EXPECTED_XML_OUTPUT_FILE': os.path.join(
+ xml_output_dir, expected_xml_file),
+ },
+ expect_success=True)
+
+ def test_xml_output_file_from_flag(self):
+ random_dir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
+ flag_file = os.path.join(
+ tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value), 'output.xml')
+ self.run_helper(
+ 6,
+ ['--xml_output_file', flag_file],
+ {'XML_OUTPUT_FILE': os.path.join(random_dir, 'output.xml'),
+ 'RUNNING_UNDER_TEST_DAEMON': '1',
+ 'TEST_XMLOUTPUTDIR': random_dir,
+ 'ABSLTEST_TEST_HELPER_EXPECTED_XML_OUTPUT_FILE': flag_file,
+ },
+ expect_success=True)
+
+ def test_assert_in(self):
+ animals = {'monkey': 'banana', 'cow': 'grass', 'seal': 'fish'}
+
+ self.assertIn('a', 'abc')
+ self.assertIn(2, [1, 2, 3])
+ self.assertIn('monkey', animals)
+
+ self.assertNotIn('d', 'abc')
+ self.assertNotIn(0, [1, 2, 3])
+ self.assertNotIn('otter', animals)
+
+ self.assertRaises(AssertionError, self.assertIn, 'x', 'abc')
+ self.assertRaises(AssertionError, self.assertIn, 4, [1, 2, 3])
+ self.assertRaises(AssertionError, self.assertIn, 'elephant', animals)
+
+ self.assertRaises(AssertionError, self.assertNotIn, 'c', 'abc')
+ self.assertRaises(AssertionError, self.assertNotIn, 1, [1, 2, 3])
+ self.assertRaises(AssertionError, self.assertNotIn, 'cow', animals)
+
+ @absltest.expectedFailure
+ def test_expected_failure(self):
+ self.assertEqual(1, 2) # the expected failure
+
+ @absltest.expectedFailureIf(True, 'always true')
+ def test_expected_failure_if(self):
+ self.assertEqual(1, 2) # the expected failure
+
+ def test_expected_failure_success(self):
+ _, stderr = self.run_helper(5, ['--', '-v'], {}, expect_success=False)
+ self.assertRegex(stderr, r'FAILED \(.*unexpected successes=1\)')
+
+ def test_assert_equal(self):
+ self.assertListEqual([], [])
+ self.assertTupleEqual((), ())
+ self.assertSequenceEqual([], ())
+
+ a = [0, 'a', []]
+ b = []
+ self.assertRaises(absltest.TestCase.failureException,
+ self.assertListEqual, a, b)
+ self.assertRaises(absltest.TestCase.failureException,
+ self.assertListEqual, tuple(a), tuple(b))
+ self.assertRaises(absltest.TestCase.failureException,
+ self.assertSequenceEqual, a, tuple(b))
+
+ b.extend(a)
+ self.assertListEqual(a, b)
+ self.assertTupleEqual(tuple(a), tuple(b))
+ self.assertSequenceEqual(a, tuple(b))
+ self.assertSequenceEqual(tuple(a), b)
+
+ self.assertRaises(AssertionError, self.assertListEqual, a, tuple(b))
+ self.assertRaises(AssertionError, self.assertTupleEqual, tuple(a), b)
+ self.assertRaises(AssertionError, self.assertListEqual, None, b)
+ self.assertRaises(AssertionError, self.assertTupleEqual, None, tuple(b))
+ self.assertRaises(AssertionError, self.assertSequenceEqual, None, tuple(b))
+ self.assertRaises(AssertionError, self.assertListEqual, 1, 1)
+ self.assertRaises(AssertionError, self.assertTupleEqual, 1, 1)
+ self.assertRaises(AssertionError, self.assertSequenceEqual, 1, 1)
+
+ self.assertSameElements([1, 2, 3], [3, 2, 1])
+ self.assertSameElements([1, 2] + [3] * 100, [1] * 100 + [2, 3])
+ self.assertSameElements(['foo', 'bar', 'baz'], ['bar', 'baz', 'foo'])
+ self.assertRaises(AssertionError, self.assertSameElements, [10], [10, 11])
+ self.assertRaises(AssertionError, self.assertSameElements, [10, 11], [10])
+
+ # Test that sequences of unhashable objects can be tested for sameness:
+ self.assertSameElements([[1, 2], [3, 4]], [[3, 4], [1, 2]])
+ self.assertRaises(AssertionError, self.assertSameElements, [[1]], [[2]])
+
+ def test_assert_items_equal_hotfix(self):
+ """Confirm that http://bugs.python.org/issue14832 - b/10038517 is gone."""
+ for assert_items_method in (self.assertItemsEqual, self.assertCountEqual):
+ with self.assertRaises(self.failureException) as error_context:
+ assert_items_method([4], [2])
+ error_message = str(error_context.exception)
+ # Confirm that the bug is either no longer present in Python or that our
+ # assertItemsEqual patching version of the method in absltest.TestCase
+ # doesn't get used.
+ self.assertIn('First has 1, Second has 0: 4', error_message)
+ self.assertIn('First has 0, Second has 1: 2', error_message)
+
+ def test_assert_dict_equal(self):
+ self.assertDictEqual({}, {})
+
+ c = {'x': 1}
+ d = {}
+ self.assertRaises(absltest.TestCase.failureException,
+ self.assertDictEqual, c, d)
+
+ d.update(c)
+ self.assertDictEqual(c, d)
+
+ d['x'] = 0
+ self.assertRaises(absltest.TestCase.failureException,
+ self.assertDictEqual, c, d, 'These are unequal')
+
+ self.assertRaises(AssertionError, self.assertDictEqual, None, d)
+ self.assertRaises(AssertionError, self.assertDictEqual, [], d)
+ self.assertRaises(AssertionError, self.assertDictEqual, 1, 1)
+
+ try:
+ # Ensure we use equality as the sole measure of elements, not type, since
+ # that is consistent with dict equality.
+ self.assertDictEqual({1: 1.0, 2: 2}, {1: 1, 2: 3})
+ except AssertionError as e:
+ self.assertMultiLineEqual('{1: 1.0, 2: 2} != {1: 1, 2: 3}\n'
+ 'repr() of differing entries:\n2: 2 != 3\n',
+ str(e))
+
+ try:
+ self.assertDictEqual({}, {'x': 1})
+ except AssertionError as e:
+ self.assertMultiLineEqual("{} != {'x': 1}\n"
+ "Unexpected, but present entries:\n'x': 1\n",
+ str(e))
+ else:
+ self.fail('Expecting AssertionError')
+
+ try:
+ self.assertDictEqual({}, {'x': 1}, 'a message')
+ except AssertionError as e:
+ self.assertIn('a message', str(e))
+ else:
+ self.fail('Expecting AssertionError')
+
+ expected = {'a': 1, 'b': 2, 'c': 3}
+ seen = {'a': 2, 'c': 3, 'd': 4}
+ try:
+ self.assertDictEqual(expected, seen)
+ except AssertionError as e:
+ self.assertMultiLineEqual("""\
+{'a': 1, 'b': 2, 'c': 3} != {'a': 2, 'c': 3, 'd': 4}
+Unexpected, but present entries:
+'d': 4
+
+repr() of differing entries:
+'a': 1 != 2
+
+Missing entries:
+'b': 2
+""", str(e))
+ else:
+ self.fail('Expecting AssertionError')
+
+ self.assertRaises(AssertionError, self.assertDictEqual, (1, 2), {})
+ self.assertRaises(AssertionError, self.assertDictEqual, {}, (1, 2))
+
+ # Ensure deterministic output of keys in dictionaries whose sort order
+ # doesn't match the lexical ordering of repr -- this is most Python objects,
+ # which are keyed by memory address.
+ class Obj(object):
+
+ def __init__(self, name):
+ self.name = name
+
+ def __repr__(self):
+ return self.name
+
+ try:
+ self.assertDictEqual(
+ {'a': Obj('A'), Obj('b'): Obj('B'), Obj('c'): Obj('C')},
+ {'a': Obj('A'), Obj('d'): Obj('D'), Obj('e'): Obj('E')})
+ except AssertionError as e:
+ # Do as best we can not to be misleading when objects have the same repr
+ # but aren't equal.
+ err_str = str(e)
+ self.assertStartsWith(err_str,
+ "{'a': A, b: B, c: C} != {'a': A, d: D, e: E}\n")
+ self.assertRegex(
+ err_str, r'(?ms).*^Unexpected, but present entries:\s+'
+ r'^(d: D$\s+^e: E|e: E$\s+^d: D)$')
+ self.assertRegex(
+ err_str, r'(?ms).*^repr\(\) of differing entries:\s+'
+ r'^.a.: A != A$', err_str)
+ self.assertRegex(
+ err_str, r'(?ms).*^Missing entries:\s+'
+ r'^(b: B$\s+^c: C|c: C$\s+^b: B)$')
+ else:
+ self.fail('Expecting AssertionError')
+
+ # Confirm that safe_repr, not repr, is being used.
+ class RaisesOnRepr(object):
+
+ def __repr__(self):
+ return 1/0 # Intentionally broken __repr__ implementation.
+
+ try:
+ self.assertDictEqual(
+ {RaisesOnRepr(): RaisesOnRepr()},
+ {RaisesOnRepr(): RaisesOnRepr()}
+ )
+ self.fail('Expected dicts not to match')
+ except AssertionError as e:
+ # Depending on the testing environment, the object may get a __main__
+ # prefix or a absltest_test prefix, so strip that for comparison.
+ error_msg = re.sub(
+ r'( at 0x[^>]+)|__main__\.|absltest_test\.', '', str(e))
+ self.assertRegex(error_msg, """(?m)\
+{<.*RaisesOnRepr object.*>: <.*RaisesOnRepr object.*>} != \
+{<.*RaisesOnRepr object.*>: <.*RaisesOnRepr object.*>}
+Unexpected, but present entries:
+<.*RaisesOnRepr object.*>: <.*RaisesOnRepr object.*>
+
+Missing entries:
+<.*RaisesOnRepr object.*>: <.*RaisesOnRepr object.*>
+""")
+
+ # Confirm that safe_repr, not repr, is being used.
+ class RaisesOnLt(object):
+
+ def __lt__(self, unused_other):
+ raise TypeError('Object is unordered.')
+
+ def __repr__(self):
+ return '<RaisesOnLt object>'
+
+ try:
+ self.assertDictEqual(
+ {RaisesOnLt(): RaisesOnLt()},
+ {RaisesOnLt(): RaisesOnLt()})
+ except AssertionError as e:
+ self.assertIn('Unexpected, but present entries:\n<RaisesOnLt', str(e))
+ self.assertIn('Missing entries:\n<RaisesOnLt', str(e))
+
+ def test_assert_set_equal(self):
+ set1 = set()
+ set2 = set()
+ self.assertSetEqual(set1, set2)
+
+ self.assertRaises(AssertionError, self.assertSetEqual, None, set2)
+ self.assertRaises(AssertionError, self.assertSetEqual, [], set2)
+ self.assertRaises(AssertionError, self.assertSetEqual, set1, None)
+ self.assertRaises(AssertionError, self.assertSetEqual, set1, [])
+
+ set1 = set(['a'])
+ set2 = set()
+ self.assertRaises(AssertionError, self.assertSetEqual, set1, set2)
+
+ set1 = set(['a'])
+ set2 = set(['a'])
+ self.assertSetEqual(set1, set2)
+
+ set1 = set(['a'])
+ set2 = set(['a', 'b'])
+ self.assertRaises(AssertionError, self.assertSetEqual, set1, set2)
+
+ set1 = set(['a'])
+ set2 = frozenset(['a', 'b'])
+ self.assertRaises(AssertionError, self.assertSetEqual, set1, set2)
+
+ set1 = set(['a', 'b'])
+ set2 = frozenset(['a', 'b'])
+ self.assertSetEqual(set1, set2)
+
+ set1 = set()
+ set2 = 'foo'
+ self.assertRaises(AssertionError, self.assertSetEqual, set1, set2)
+ self.assertRaises(AssertionError, self.assertSetEqual, set2, set1)
+
+ # make sure any string formatting is tuple-safe
+ set1 = set([(0, 1), (2, 3)])
+ set2 = set([(4, 5)])
+ self.assertRaises(AssertionError, self.assertSetEqual, set1, set2)
+
+ def test_assert_dict_contains_subset(self):
+ self.assertDictContainsSubset({}, {})
+
+ self.assertDictContainsSubset({}, {'a': 1})
+
+ self.assertDictContainsSubset({'a': 1}, {'a': 1})
+
+ self.assertDictContainsSubset({'a': 1}, {'a': 1, 'b': 2})
+
+ self.assertDictContainsSubset({'a': 1, 'b': 2}, {'a': 1, 'b': 2})
+
+ self.assertRaises(absltest.TestCase.failureException,
+ self.assertDictContainsSubset, {'a': 2}, {'a': 1},
+ '.*Mismatched values:.*')
+
+ self.assertRaises(absltest.TestCase.failureException,
+ self.assertDictContainsSubset, {'c': 1}, {'a': 1},
+ '.*Missing:.*')
+
+ self.assertRaises(absltest.TestCase.failureException,
+ self.assertDictContainsSubset, {'a': 1, 'c': 1}, {'a': 1},
+ '.*Missing:.*')
+
+ self.assertRaises(absltest.TestCase.failureException,
+ self.assertDictContainsSubset, {'a': 1, 'c': 1}, {'a': 1},
+ '.*Missing:.*Mismatched values:.*')
+
+ def test_assert_sequence_almost_equal(self):
+ actual = (1.1, 1.2, 1.4)
+
+ # Test across sequence types.
+ self.assertSequenceAlmostEqual((1.1, 1.2, 1.4), actual)
+ self.assertSequenceAlmostEqual([1.1, 1.2, 1.4], actual)
+
+ # Test sequence size mismatch.
+ with self.assertRaises(AssertionError):
+ self.assertSequenceAlmostEqual([1.1, 1.2], actual)
+ with self.assertRaises(AssertionError):
+ self.assertSequenceAlmostEqual([1.1, 1.2, 1.4, 1.5], actual)
+
+ # Test delta.
+ with self.assertRaises(AssertionError):
+ self.assertSequenceAlmostEqual((1.15, 1.15, 1.4), actual)
+ self.assertSequenceAlmostEqual((1.15, 1.15, 1.4), actual, delta=0.1)
+
+ # Test places.
+ with self.assertRaises(AssertionError):
+ self.assertSequenceAlmostEqual((1.1001, 1.2001, 1.3999), actual)
+ self.assertSequenceAlmostEqual((1.1001, 1.2001, 1.3999), actual, places=3)
+
+ def test_assert_contains_subset(self):
+ # sets, lists, tuples, dicts all ok. Types of set and subset do not have to
+ # match.
+ actual = ('a', 'b', 'c')
+ self.assertContainsSubset({'a', 'b'}, actual)
+ self.assertContainsSubset(('b', 'c'), actual)
+ self.assertContainsSubset({'b': 1, 'c': 2}, list(actual))
+ self.assertContainsSubset(['c', 'a'], set(actual))
+ self.assertContainsSubset([], set())
+ self.assertContainsSubset([], {'a': 1})
+
+ self.assertRaises(AssertionError, self.assertContainsSubset, ('d',), actual)
+ self.assertRaises(AssertionError, self.assertContainsSubset, ['d'],
+ set(actual))
+ self.assertRaises(AssertionError, self.assertContainsSubset, {'a': 1}, [])
+
+ with self.assertRaisesRegex(AssertionError, 'Missing elements'):
+ self.assertContainsSubset({1, 2, 3}, {1, 2})
+
+ with self.assertRaisesRegex(
+ AssertionError,
+ re.compile('Missing elements .* Custom message', re.DOTALL)):
+ self.assertContainsSubset({1, 2}, {1}, 'Custom message')
+
+ def test_assert_no_common_elements(self):
+ actual = ('a', 'b', 'c')
+ self.assertNoCommonElements((), actual)
+ self.assertNoCommonElements(('d', 'e'), actual)
+ self.assertNoCommonElements({'d', 'e'}, actual)
+
+ with self.assertRaisesRegex(
+ AssertionError,
+ re.compile('Common elements .* Custom message', re.DOTALL)):
+ self.assertNoCommonElements({1, 2}, {1}, 'Custom message')
+
+ with self.assertRaises(AssertionError):
+ self.assertNoCommonElements(['a'], actual)
+
+ with self.assertRaises(AssertionError):
+ self.assertNoCommonElements({'a', 'b', 'c'}, actual)
+
+ with self.assertRaises(AssertionError):
+ self.assertNoCommonElements({'b', 'c'}, set(actual))
+
+ def test_assert_almost_equal(self):
+ self.assertAlmostEqual(1.00000001, 1.0)
+ self.assertNotAlmostEqual(1.0000001, 1.0)
+
+ def test_assert_almost_equals_with_delta(self):
+ self.assertAlmostEqual(3.14, 3, delta=0.2)
+ self.assertAlmostEqual(2.81, 3.14, delta=1)
+ self.assertAlmostEqual(-1, 1, delta=3)
+ self.assertRaises(AssertionError, self.assertAlmostEqual,
+ 3.14, 2.81, delta=0.1)
+ self.assertRaises(AssertionError, self.assertAlmostEqual,
+ 1, 2, delta=0.5)
+ self.assertNotAlmostEqual(3.14, 2.81, delta=0.1)
+
+ def test_assert_starts_with(self):
+ self.assertStartsWith('foobar', 'foo')
+ self.assertStartsWith('foobar', 'foobar')
+ msg = 'This is a useful message'
+ whole_msg = "'foobar' does not start with 'bar' : This is a useful message"
+ self.assertRaisesWithLiteralMatch(AssertionError, whole_msg,
+ self.assertStartsWith,
+ 'foobar', 'bar', msg)
+ self.assertRaises(AssertionError, self.assertStartsWith, 'foobar', 'blah')
+
+ def test_assert_not_starts_with(self):
+ self.assertNotStartsWith('foobar', 'bar')
+ self.assertNotStartsWith('foobar', 'blah')
+ msg = 'This is a useful message'
+ whole_msg = "'foobar' does start with 'foo' : This is a useful message"
+ self.assertRaisesWithLiteralMatch(AssertionError, whole_msg,
+ self.assertNotStartsWith,
+ 'foobar', 'foo', msg)
+ self.assertRaises(AssertionError, self.assertNotStartsWith, 'foobar',
+ 'foobar')
+
+ def test_assert_ends_with(self):
+ self.assertEndsWith('foobar', 'bar')
+ self.assertEndsWith('foobar', 'foobar')
+ msg = 'This is a useful message'
+ whole_msg = "'foobar' does not end with 'foo' : This is a useful message"
+ self.assertRaisesWithLiteralMatch(AssertionError, whole_msg,
+ self.assertEndsWith,
+ 'foobar', 'foo', msg)
+ self.assertRaises(AssertionError, self.assertEndsWith, 'foobar', 'blah')
+
+ def test_assert_not_ends_with(self):
+ self.assertNotEndsWith('foobar', 'foo')
+ self.assertNotEndsWith('foobar', 'blah')
+ msg = 'This is a useful message'
+ whole_msg = "'foobar' does end with 'bar' : This is a useful message"
+ self.assertRaisesWithLiteralMatch(AssertionError, whole_msg,
+ self.assertNotEndsWith,
+ 'foobar', 'bar', msg)
+ self.assertRaises(AssertionError, self.assertNotEndsWith, 'foobar',
+ 'foobar')
+
+ def test_assert_regex_backports(self):
+ self.assertRegex('regex', 'regex')
+ self.assertNotRegex('not-regex', 'no-match')
+ with self.assertRaisesRegex(ValueError, 'pattern'):
+ raise ValueError('pattern')
+
+ def test_assert_regex_match_matches(self):
+ self.assertRegexMatch('str', ['str'])
+
+ def test_assert_regex_match_matches_substring(self):
+ self.assertRegexMatch('pre-str-post', ['str'])
+
+ def test_assert_regex_match_multiple_regex_matches(self):
+ self.assertRegexMatch('str', ['rts', 'str'])
+
+ def test_assert_regex_match_empty_list_fails(self):
+ expected_re = re.compile(r'No regexes specified\.', re.MULTILINE)
+
+ with self.assertRaisesRegex(AssertionError, expected_re):
+ self.assertRegexMatch('str', regexes=[])
+
+ def test_assert_regex_match_bad_arguments(self):
+ with self.assertRaisesRegex(AssertionError,
+ 'regexes is string or bytes;.*'):
+ self.assertRegexMatch('1.*2', '1 2')
+
+ def test_assert_regex_match_unicode_vs_bytes(self):
+ """Ensure proper utf-8 encoding or decoding happens automatically."""
+ self.assertRegexMatch(u'str', [b'str'])
+ self.assertRegexMatch(b'str', [u'str'])
+
+ def test_assert_regex_match_unicode(self):
+ self.assertRegexMatch(u'foo str', [u'str'])
+
+ def test_assert_regex_match_bytes(self):
+ self.assertRegexMatch(b'foo str', [b'str'])
+
+ def test_assert_regex_match_all_the_same_type(self):
+ with self.assertRaisesRegex(AssertionError, 'regexes .* same type'):
+ self.assertRegexMatch('foo str', [b'str', u'foo'])
+
+ def test_assert_command_fails_stderr(self):
+ tmpdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
+ self.assertCommandFails(
+ ['cat', os.path.join(tmpdir, 'file.txt')],
+ ['No such file or directory'],
+ env=_env_for_command_tests())
+
+ def test_assert_command_fails_with_list_of_string(self):
+ self.assertCommandFails(
+ ['false'], [''], env=_env_for_command_tests())
+
+ def test_assert_command_fails_with_list_of_unicode_string(self):
+ self.assertCommandFails(
+ [u'false'], [''], env=_env_for_command_tests())
+
+ def test_assert_command_fails_with_unicode_string(self):
+ self.assertCommandFails(
+ u'false', [u''], env=_env_for_command_tests())
+
+ def test_assert_command_fails_with_unicode_string_bytes_regex(self):
+ self.assertCommandFails(
+ u'false', [b''], env=_env_for_command_tests())
+
+ def test_assert_command_fails_with_message(self):
+ msg = 'This is a useful message'
+ expected_re = re.compile('The following command succeeded while expected to'
+ ' fail:.* This is a useful message', re.DOTALL)
+
+ with self.assertRaisesRegex(AssertionError, expected_re):
+ self.assertCommandFails(
+ [u'true'], [''], msg=msg, env=_env_for_command_tests())
+
+ def test_assert_command_succeeds_stderr(self):
+ expected_re = re.compile('No such file or directory')
+ tmpdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
+
+ with self.assertRaisesRegex(AssertionError, expected_re):
+ self.assertCommandSucceeds(
+ ['cat', os.path.join(tmpdir, 'file.txt')],
+ env=_env_for_command_tests())
+
+ def test_assert_command_succeeds_with_matching_unicode_regexes(self):
+ self.assertCommandSucceeds(
+ ['echo', 'SUCCESS'], regexes=[u'SUCCESS'],
+ env=_env_for_command_tests())
+
+ def test_assert_command_succeeds_with_matching_bytes_regexes(self):
+ self.assertCommandSucceeds(
+ ['echo', 'SUCCESS'], regexes=[b'SUCCESS'],
+ env=_env_for_command_tests())
+
+ def test_assert_command_succeeds_with_non_matching_regexes(self):
+ expected_re = re.compile('Running command.* This is a useful message',
+ re.DOTALL)
+ msg = 'This is a useful message'
+
+ with self.assertRaisesRegex(AssertionError, expected_re):
+ self.assertCommandSucceeds(
+ ['echo', 'FAIL'], regexes=['SUCCESS'], msg=msg,
+ env=_env_for_command_tests())
+
+ def test_assert_command_succeeds_with_list_of_string(self):
+ self.assertCommandSucceeds(
+ ['true'], env=_env_for_command_tests())
+
+ def test_assert_command_succeeds_with_list_of_unicode_string(self):
+ self.assertCommandSucceeds(
+ [u'true'], env=_env_for_command_tests())
+
+ def test_assert_command_succeeds_with_unicode_string(self):
+ self.assertCommandSucceeds(
+ u'true', env=_env_for_command_tests())
+
+ def test_inequality(self):
+ # Try ints
+ self.assertGreater(2, 1)
+ self.assertGreaterEqual(2, 1)
+ self.assertGreaterEqual(1, 1)
+ self.assertLess(1, 2)
+ self.assertLessEqual(1, 2)
+ self.assertLessEqual(1, 1)
+ self.assertRaises(AssertionError, self.assertGreater, 1, 2)
+ self.assertRaises(AssertionError, self.assertGreater, 1, 1)
+ self.assertRaises(AssertionError, self.assertGreaterEqual, 1, 2)
+ self.assertRaises(AssertionError, self.assertLess, 2, 1)
+ self.assertRaises(AssertionError, self.assertLess, 1, 1)
+ self.assertRaises(AssertionError, self.assertLessEqual, 2, 1)
+
+ # Try Floats
+ self.assertGreater(1.1, 1.0)
+ self.assertGreaterEqual(1.1, 1.0)
+ self.assertGreaterEqual(1.0, 1.0)
+ self.assertLess(1.0, 1.1)
+ self.assertLessEqual(1.0, 1.1)
+ self.assertLessEqual(1.0, 1.0)
+ self.assertRaises(AssertionError, self.assertGreater, 1.0, 1.1)
+ self.assertRaises(AssertionError, self.assertGreater, 1.0, 1.0)
+ self.assertRaises(AssertionError, self.assertGreaterEqual, 1.0, 1.1)
+ self.assertRaises(AssertionError, self.assertLess, 1.1, 1.0)
+ self.assertRaises(AssertionError, self.assertLess, 1.0, 1.0)
+ self.assertRaises(AssertionError, self.assertLessEqual, 1.1, 1.0)
+
+ # Try Strings
+ self.assertGreater('bug', 'ant')
+ self.assertGreaterEqual('bug', 'ant')
+ self.assertGreaterEqual('ant', 'ant')
+ self.assertLess('ant', 'bug')
+ self.assertLessEqual('ant', 'bug')
+ self.assertLessEqual('ant', 'ant')
+ self.assertRaises(AssertionError, self.assertGreater, 'ant', 'bug')
+ self.assertRaises(AssertionError, self.assertGreater, 'ant', 'ant')
+ self.assertRaises(AssertionError, self.assertGreaterEqual, 'ant', 'bug')
+ self.assertRaises(AssertionError, self.assertLess, 'bug', 'ant')
+ self.assertRaises(AssertionError, self.assertLess, 'ant', 'ant')
+ self.assertRaises(AssertionError, self.assertLessEqual, 'bug', 'ant')
+
+ # Try Unicode
+ self.assertGreater(u'bug', u'ant')
+ self.assertGreaterEqual(u'bug', u'ant')
+ self.assertGreaterEqual(u'ant', u'ant')
+ self.assertLess(u'ant', u'bug')
+ self.assertLessEqual(u'ant', u'bug')
+ self.assertLessEqual(u'ant', u'ant')
+ self.assertRaises(AssertionError, self.assertGreater, u'ant', u'bug')
+ self.assertRaises(AssertionError, self.assertGreater, u'ant', u'ant')
+ self.assertRaises(AssertionError, self.assertGreaterEqual, u'ant', u'bug')
+ self.assertRaises(AssertionError, self.assertLess, u'bug', u'ant')
+ self.assertRaises(AssertionError, self.assertLess, u'ant', u'ant')
+ self.assertRaises(AssertionError, self.assertLessEqual, u'bug', u'ant')
+
+ # Try Mixed String/Unicode
+ self.assertGreater('bug', u'ant')
+ self.assertGreater(u'bug', 'ant')
+ self.assertGreaterEqual('bug', u'ant')
+ self.assertGreaterEqual(u'bug', 'ant')
+ self.assertGreaterEqual('ant', u'ant')
+ self.assertGreaterEqual(u'ant', 'ant')
+ self.assertLess('ant', u'bug')
+ self.assertLess(u'ant', 'bug')
+ self.assertLessEqual('ant', u'bug')
+ self.assertLessEqual(u'ant', 'bug')
+ self.assertLessEqual('ant', u'ant')
+ self.assertLessEqual(u'ant', 'ant')
+ self.assertRaises(AssertionError, self.assertGreater, 'ant', u'bug')
+ self.assertRaises(AssertionError, self.assertGreater, u'ant', 'bug')
+ self.assertRaises(AssertionError, self.assertGreater, 'ant', u'ant')
+ self.assertRaises(AssertionError, self.assertGreater, u'ant', 'ant')
+ self.assertRaises(AssertionError, self.assertGreaterEqual, 'ant', u'bug')
+ self.assertRaises(AssertionError, self.assertGreaterEqual, u'ant', 'bug')
+ self.assertRaises(AssertionError, self.assertLess, 'bug', u'ant')
+ self.assertRaises(AssertionError, self.assertLess, u'bug', 'ant')
+ self.assertRaises(AssertionError, self.assertLess, 'ant', u'ant')
+ self.assertRaises(AssertionError, self.assertLess, u'ant', 'ant')
+ self.assertRaises(AssertionError, self.assertLessEqual, 'bug', u'ant')
+ self.assertRaises(AssertionError, self.assertLessEqual, u'bug', 'ant')
+
+ def test_assert_multi_line_equal(self):
+ sample_text = """\
+http://www.python.org/doc/2.3/lib/module-unittest.html
+test case
+ A test case is the smallest unit of testing. [...]
+"""
+ revised_sample_text = """\
+http://www.python.org/doc/2.4.1/lib/module-unittest.html
+test case
+ A test case is the smallest unit of testing. [...] You may provide your
+ own implementation that does not subclass from TestCase, of course.
+"""
+ sample_text_error = """
+- http://www.python.org/doc/2.3/lib/module-unittest.html
+? ^
++ http://www.python.org/doc/2.4.1/lib/module-unittest.html
+? ^^^
+ test case
+- A test case is the smallest unit of testing. [...]
++ A test case is the smallest unit of testing. [...] You may provide your
+? +++++++++++++++++++++
++ own implementation that does not subclass from TestCase, of course.
+"""
+ self.assertRaisesWithLiteralMatch(AssertionError, sample_text_error,
+ self.assertMultiLineEqual,
+ sample_text,
+ revised_sample_text)
+
+ self.assertRaises(AssertionError, self.assertMultiLineEqual, (1, 2), 'str')
+ self.assertRaises(AssertionError, self.assertMultiLineEqual, 'str', (1, 2))
+
+ def test_assert_multi_line_equal_adds_newlines_if_needed(self):
+ self.assertRaisesWithLiteralMatch(
+ AssertionError,
+ '\n'
+ ' line1\n'
+ '- line2\n'
+ '? ^\n'
+ '+ line3\n'
+ '? ^\n',
+ self.assertMultiLineEqual,
+ 'line1\n'
+ 'line2',
+ 'line1\n'
+ 'line3')
+
+ def test_assert_multi_line_equal_shows_missing_newlines(self):
+ self.assertRaisesWithLiteralMatch(
+ AssertionError,
+ '\n'
+ ' line1\n'
+ '- line2\n'
+ '? -\n'
+ '+ line2\n',
+ self.assertMultiLineEqual,
+ 'line1\n'
+ 'line2\n',
+ 'line1\n'
+ 'line2')
+
+ def test_assert_multi_line_equal_shows_extra_newlines(self):
+ self.assertRaisesWithLiteralMatch(
+ AssertionError,
+ '\n'
+ ' line1\n'
+ '- line2\n'
+ '+ line2\n'
+ '? +\n',
+ self.assertMultiLineEqual,
+ 'line1\n'
+ 'line2',
+ 'line1\n'
+ 'line2\n')
+
+ def test_assert_multi_line_equal_line_limit_limits(self):
+ self.assertRaisesWithLiteralMatch(
+ AssertionError,
+ '\n'
+ ' line1\n'
+ '(... and 4 more delta lines omitted for brevity.)\n',
+ self.assertMultiLineEqual,
+ 'line1\n'
+ 'line2\n',
+ 'line1\n'
+ 'line3\n',
+ line_limit=1)
+
+ def test_assert_multi_line_equal_line_limit_limits_with_message(self):
+ self.assertRaisesWithLiteralMatch(
+ AssertionError,
+ 'Prefix:\n'
+ ' line1\n'
+ '(... and 4 more delta lines omitted for brevity.)\n',
+ self.assertMultiLineEqual,
+ 'line1\n'
+ 'line2\n',
+ 'line1\n'
+ 'line3\n',
+ 'Prefix',
+ line_limit=1)
+
+ def test_assert_is_none(self):
+ self.assertIsNone(None)
+ self.assertRaises(AssertionError, self.assertIsNone, False)
+ self.assertIsNotNone('Google')
+ self.assertRaises(AssertionError, self.assertIsNotNone, None)
+ self.assertRaises(AssertionError, self.assertIsNone, (1, 2))
+
+ def test_assert_is(self):
+ self.assertIs(object, object)
+ self.assertRaises(AssertionError, self.assertIsNot, object, object)
+ self.assertIsNot(True, False)
+ self.assertRaises(AssertionError, self.assertIs, True, False)
+
+ def test_assert_between(self):
+ self.assertBetween(3.14, 3.1, 3.141)
+ self.assertBetween(4, 4, 1e10000)
+ self.assertBetween(9.5, 9.4, 9.5)
+ self.assertBetween(-1e10, -1e10000, 0)
+ self.assertRaises(AssertionError, self.assertBetween, 9.4, 9.3, 9.3999)
+ self.assertRaises(AssertionError, self.assertBetween, -1e10000, -1e10, 0)
+
+ def test_assert_raises_with_predicate_match_no_raise(self):
+ with self.assertRaisesRegex(AssertionError, '^Exception not raised$'):
+ self.assertRaisesWithPredicateMatch(Exception,
+ lambda e: True,
+ lambda: 1) # don't raise
+
+ with self.assertRaisesRegex(AssertionError, '^Exception not raised$'):
+ with self.assertRaisesWithPredicateMatch(Exception, lambda e: True):
+ pass # don't raise
+
+ def test_assert_raises_with_predicate_match_raises_wrong_exception(self):
+ def _raise_value_error():
+ raise ValueError
+
+ with self.assertRaises(ValueError):
+ self.assertRaisesWithPredicateMatch(IOError,
+ lambda e: True,
+ _raise_value_error)
+
+ with self.assertRaises(ValueError):
+ with self.assertRaisesWithPredicateMatch(IOError, lambda e: True):
+ raise ValueError
+
+ def test_assert_raises_with_predicate_match_predicate_fails(self):
+ def _raise_value_error():
+ raise ValueError
+ with self.assertRaisesRegex(AssertionError, ' does not match predicate '):
+ self.assertRaisesWithPredicateMatch(ValueError,
+ lambda e: False,
+ _raise_value_error)
+
+ with self.assertRaisesRegex(AssertionError, ' does not match predicate '):
+ with self.assertRaisesWithPredicateMatch(ValueError, lambda e: False):
+ raise ValueError
+
+ def test_assert_raises_with_predicate_match_predicate_passes(self):
+ def _raise_value_error():
+ raise ValueError
+
+ self.assertRaisesWithPredicateMatch(ValueError,
+ lambda e: True,
+ _raise_value_error)
+
+ with self.assertRaisesWithPredicateMatch(ValueError, lambda e: True):
+ raise ValueError
+
+ def test_assert_contains_in_order(self):
+ # Valids
+ self.assertContainsInOrder(
+ ['fox', 'dog'], 'The quick brown fox jumped over the lazy dog.')
+ self.assertContainsInOrder(
+ ['quick', 'fox', 'dog'],
+ 'The quick brown fox jumped over the lazy dog.')
+ self.assertContainsInOrder(
+ ['The', 'fox', 'dog.'], 'The quick brown fox jumped over the lazy dog.')
+ self.assertContainsInOrder(
+ ['fox'], 'The quick brown fox jumped over the lazy dog.')
+ self.assertContainsInOrder(
+ 'fox', 'The quick brown fox jumped over the lazy dog.')
+ self.assertContainsInOrder(
+ ['fox', 'dog'], 'fox dog fox')
+ self.assertContainsInOrder(
+ [], 'The quick brown fox jumped over the lazy dog.')
+ self.assertContainsInOrder(
+ [], '')
+
+ # Invalids
+ msg = 'This is a useful message'
+ whole_msg = ("Did not find 'fox' after 'dog' in 'The quick brown fox"
+ " jumped over the lazy dog' : This is a useful message")
+ self.assertRaisesWithLiteralMatch(
+ AssertionError, whole_msg, self.assertContainsInOrder,
+ ['dog', 'fox'], 'The quick brown fox jumped over the lazy dog', msg=msg)
+ self.assertRaises(
+ AssertionError, self.assertContainsInOrder,
+ ['The', 'dog', 'fox'], 'The quick brown fox jumped over the lazy dog')
+ self.assertRaises(
+ AssertionError, self.assertContainsInOrder, ['dog'], '')
+
+ def test_assert_contains_subsequence_for_numbers(self):
+ self.assertContainsSubsequence([1, 2, 3], [1])
+ self.assertContainsSubsequence([1, 2, 3], [1, 2])
+ self.assertContainsSubsequence([1, 2, 3], [1, 3])
+
+ with self.assertRaises(AssertionError):
+ self.assertContainsSubsequence([1, 2, 3], [4])
+ msg = 'This is a useful message'
+ whole_msg = ('[3, 1] not a subsequence of [1, 2, 3]. '
+ 'First non-matching element: 1 : This is a useful message')
+ self.assertRaisesWithLiteralMatch(AssertionError, whole_msg,
+ self.assertContainsSubsequence,
+ [1, 2, 3], [3, 1], msg=msg)
+
+ def test_assert_contains_subsequence_for_strings(self):
+ self.assertContainsSubsequence(['foo', 'bar', 'blorp'], ['foo', 'blorp'])
+ with self.assertRaises(AssertionError):
+ self.assertContainsSubsequence(
+ ['foo', 'bar', 'blorp'], ['blorp', 'foo'])
+
+ def test_assert_contains_subsequence_with_empty_subsequence(self):
+ self.assertContainsSubsequence([1, 2, 3], [])
+ self.assertContainsSubsequence(['foo', 'bar', 'blorp'], [])
+ self.assertContainsSubsequence([], [])
+
+ def test_assert_contains_subsequence_with_empty_container(self):
+ with self.assertRaises(AssertionError):
+ self.assertContainsSubsequence([], [1])
+ with self.assertRaises(AssertionError):
+ self.assertContainsSubsequence([], ['foo'])
+
+ def test_assert_contains_exact_subsequence_for_numbers(self):
+ self.assertContainsExactSubsequence([1, 2, 3], [1])
+ self.assertContainsExactSubsequence([1, 2, 3], [1, 2])
+ self.assertContainsExactSubsequence([1, 2, 3], [2, 3])
+
+ with self.assertRaises(AssertionError):
+ self.assertContainsExactSubsequence([1, 2, 3], [4])
+ msg = 'This is a useful message'
+ whole_msg = ('[1, 2, 4] not an exact subsequence of [1, 2, 3, 4]. '
+ 'Longest matching prefix: [1, 2] : This is a useful message')
+ self.assertRaisesWithLiteralMatch(AssertionError, whole_msg,
+ self.assertContainsExactSubsequence,
+ [1, 2, 3, 4], [1, 2, 4], msg=msg)
+
+ def test_assert_contains_exact_subsequence_for_strings(self):
+ self.assertContainsExactSubsequence(
+ ['foo', 'bar', 'blorp'], ['foo', 'bar'])
+ with self.assertRaises(AssertionError):
+ self.assertContainsExactSubsequence(
+ ['foo', 'bar', 'blorp'], ['blorp', 'foo'])
+
+ def test_assert_contains_exact_subsequence_with_empty_subsequence(self):
+ self.assertContainsExactSubsequence([1, 2, 3], [])
+ self.assertContainsExactSubsequence(['foo', 'bar', 'blorp'], [])
+ self.assertContainsExactSubsequence([], [])
+
+ def test_assert_contains_exact_subsequence_with_empty_container(self):
+ with self.assertRaises(AssertionError):
+ self.assertContainsExactSubsequence([], [3])
+ with self.assertRaises(AssertionError):
+ self.assertContainsExactSubsequence([], ['foo', 'bar'])
+ self.assertContainsExactSubsequence([], [])
+
+ def test_assert_totally_ordered(self):
+ # Valid.
+ self.assertTotallyOrdered()
+ self.assertTotallyOrdered([1])
+ self.assertTotallyOrdered([1], [2])
+ self.assertTotallyOrdered([1, 1, 1])
+ self.assertTotallyOrdered([(1, 1)], [(1, 2)], [(2, 1)])
+
+ # From the docstring.
+ class A(object):
+
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __hash__(self):
+ return hash(self.x)
+
+ def __repr__(self):
+ return 'A(%r, %r)' % (self.x, self.y)
+
+ def __eq__(self, other):
+ try:
+ return self.x == other.x
+ except AttributeError:
+ return NotImplemented
+
+ def __ne__(self, other):
+ try:
+ return self.x != other.x
+ except AttributeError:
+ return NotImplemented
+
+ def __lt__(self, other):
+ try:
+ return self.x < other.x
+ except AttributeError:
+ return NotImplemented
+
+ def __le__(self, other):
+ try:
+ return self.x <= other.x
+ except AttributeError:
+ return NotImplemented
+
+ def __gt__(self, other):
+ try:
+ return self.x > other.x
+ except AttributeError:
+ return NotImplemented
+
+ def __ge__(self, other):
+ try:
+ return self.x >= other.x
+ except AttributeError:
+ return NotImplemented
+
+ class B(A):
+ """Like A, but not hashable."""
+ __hash__ = None
+
+ self.assertTotallyOrdered(
+ [A(1, 'a')],
+ [A(2, 'b')], # 2 is after 1.
+ [
+ A(3, 'c'),
+ B(3, 'd'),
+ B(3, 'e') # The second argument is irrelevant.
+ ],
+ [A(4, 'z')])
+
+ # Invalid.
+ msg = 'This is a useful message'
+ whole_msg = '2 not less than 1 : This is a useful message'
+ self.assertRaisesWithLiteralMatch(AssertionError, whole_msg,
+ self.assertTotallyOrdered, [2], [1],
+ msg=msg)
+ self.assertRaises(AssertionError, self.assertTotallyOrdered, [2], [1])
+ self.assertRaises(AssertionError, self.assertTotallyOrdered, [2], [1], [3])
+ self.assertRaises(AssertionError, self.assertTotallyOrdered, [1, 2])
+
+ def test_short_description_without_docstring(self):
+ self.assertEquals(
+ self.shortDescription(),
+ 'TestCaseTest.test_short_description_without_docstring')
+
+ def test_short_description_with_one_line_docstring(self):
+ """Tests shortDescription() for a method with a docstring."""
+ self.assertEquals(
+ self.shortDescription(),
+ 'TestCaseTest.test_short_description_with_one_line_docstring\n'
+ 'Tests shortDescription() for a method with a docstring.')
+
+ def test_short_description_with_multi_line_docstring(self):
+ """Tests shortDescription() for a method with a longer docstring.
+
+ This method ensures that only the first line of a docstring is
+ returned used in the short description, no matter how long the
+ whole thing is.
+ """
+ self.assertEquals(
+ self.shortDescription(),
+ 'TestCaseTest.test_short_description_with_multi_line_docstring\n'
+ 'Tests shortDescription() for a method with a longer docstring.')
+
+ def test_assert_url_equal_same(self):
+ self.assertUrlEqual('http://a', 'http://a')
+ self.assertUrlEqual('http://a/path/test', 'http://a/path/test')
+ self.assertUrlEqual('#fragment', '#fragment')
+ self.assertUrlEqual('http://a/?q=1', 'http://a/?q=1')
+ self.assertUrlEqual('http://a/?q=1&v=5', 'http://a/?v=5&q=1')
+ self.assertUrlEqual('/logs?v=1&a=2&t=labels&f=path%3A%22foo%22',
+ '/logs?a=2&f=path%3A%22foo%22&v=1&t=labels')
+ self.assertUrlEqual('http://a/path;p1', 'http://a/path;p1')
+ self.assertUrlEqual('http://a/path;p2;p3;p1', 'http://a/path;p1;p2;p3')
+ self.assertUrlEqual('sip:alice@atlanta.com;maddr=239.255.255.1;ttl=15',
+ 'sip:alice@atlanta.com;ttl=15;maddr=239.255.255.1')
+ self.assertUrlEqual('http://nyan/cat?p=1&b=', 'http://nyan/cat?b=&p=1')
+
+ def test_assert_url_equal_different(self):
+ msg = 'This is a useful message'
+ whole_msg = 'This is a useful message:\n- a\n+ b\n'
+ self.assertRaisesWithLiteralMatch(AssertionError, whole_msg,
+ self.assertUrlEqual,
+ 'http://a', 'http://b', msg=msg)
+ self.assertRaises(AssertionError, self.assertUrlEqual,
+ 'http://a/x', 'http://a:8080/x')
+ self.assertRaises(AssertionError, self.assertUrlEqual,
+ 'http://a/x', 'http://a/y')
+ self.assertRaises(AssertionError, self.assertUrlEqual,
+ 'http://a/?q=2', 'http://a/?q=1')
+ self.assertRaises(AssertionError, self.assertUrlEqual,
+ 'http://a/?q=1&v=5', 'http://a/?v=2&q=1')
+ self.assertRaises(AssertionError, self.assertUrlEqual,
+ 'http://a', 'sip://b')
+ self.assertRaises(AssertionError, self.assertUrlEqual,
+ 'http://a#g', 'sip://a#f')
+ self.assertRaises(AssertionError, self.assertUrlEqual,
+ 'http://a/path;p1;p3;p1', 'http://a/path;p1;p2;p3')
+ self.assertRaises(AssertionError, self.assertUrlEqual,
+ 'http://nyan/cat?p=1&b=', 'http://nyan/cat?p=1')
+
+ def test_same_structure_same(self):
+ self.assertSameStructure(0, 0)
+ self.assertSameStructure(1, 1)
+ self.assertSameStructure('', '')
+ self.assertSameStructure('hello', 'hello', msg='This Should not fail')
+ self.assertSameStructure(set(), set())
+ self.assertSameStructure(set([1, 2]), set([1, 2]))
+ self.assertSameStructure(set(), frozenset())
+ self.assertSameStructure(set([1, 2]), frozenset([1, 2]))
+ self.assertSameStructure([], [])
+ self.assertSameStructure(['a'], ['a'])
+ self.assertSameStructure([], ())
+ self.assertSameStructure(['a'], ('a',))
+ self.assertSameStructure({}, {})
+ self.assertSameStructure({'one': 1}, {'one': 1})
+ self.assertSameStructure(collections.defaultdict(None, {'one': 1}),
+ {'one': 1})
+ self.assertSameStructure(collections.OrderedDict({'one': 1}),
+ collections.defaultdict(None, {'one': 1}))
+
+ def test_same_structure_different(self):
+ # Different type
+ with self.assertRaisesRegex(
+ AssertionError,
+ r"a is a <(type|class) 'int'> but b is a <(type|class) 'str'>"):
+ self.assertSameStructure(0, 'hello')
+ with self.assertRaisesRegex(
+ AssertionError,
+ r"a is a <(type|class) 'int'> but b is a <(type|class) 'list'>"):
+ self.assertSameStructure(0, [])
+ with self.assertRaisesRegex(
+ AssertionError,
+ r"a is a <(type|class) 'int'> but b is a <(type|class) 'float'>"):
+ self.assertSameStructure(2, 2.0)
+
+ with self.assertRaisesRegex(
+ AssertionError,
+ r"a is a <(type|class) 'list'> but b is a <(type|class) 'dict'>"):
+ self.assertSameStructure([], {})
+
+ with self.assertRaisesRegex(
+ AssertionError,
+ r"a is a <(type|class) 'list'> but b is a <(type|class) 'set'>"):
+ self.assertSameStructure([], set())
+
+ with self.assertRaisesRegex(
+ AssertionError,
+ r"a is a <(type|class) 'dict'> but b is a <(type|class) 'set'>"):
+ self.assertSameStructure({}, set())
+
+ # Different scalar values
+ self.assertRaisesWithLiteralMatch(
+ AssertionError, 'a is 0 but b is 1',
+ self.assertSameStructure, 0, 1)
+ self.assertRaisesWithLiteralMatch(
+ AssertionError, "a is 'hello' but b is 'goodbye' : This was expected",
+ self.assertSameStructure, 'hello', 'goodbye', msg='This was expected')
+
+ # Different sets
+ self.assertRaisesWithLiteralMatch(
+ AssertionError,
+ r'AA has 2 but BB does not',
+ self.assertSameStructure,
+ set([1, 2]),
+ set([1]),
+ aname='AA',
+ bname='BB')
+ self.assertRaisesWithLiteralMatch(
+ AssertionError,
+ r'AA lacks 2 but BB has it',
+ self.assertSameStructure,
+ set([1]),
+ set([1, 2]),
+ aname='AA',
+ bname='BB')
+
+ # Different lists
+ self.assertRaisesWithLiteralMatch(
+ AssertionError, "a has [2] with value 'z' but b does not",
+ self.assertSameStructure, ['x', 'y', 'z'], ['x', 'y'])
+ self.assertRaisesWithLiteralMatch(
+ AssertionError, "a lacks [2] but b has it with value 'z'",
+ self.assertSameStructure, ['x', 'y'], ['x', 'y', 'z'])
+ self.assertRaisesWithLiteralMatch(
+ AssertionError, "a[2] is 'z' but b[2] is 'Z'",
+ self.assertSameStructure, ['x', 'y', 'z'], ['x', 'y', 'Z'])
+
+ # Different dicts
+ self.assertRaisesWithLiteralMatch(
+ AssertionError, "a has ['two'] with value 2 but it's missing in b",
+ self.assertSameStructure, {'one': 1, 'two': 2}, {'one': 1})
+ self.assertRaisesWithLiteralMatch(
+ AssertionError, "a lacks ['two'] but b has it with value 2",
+ self.assertSameStructure, {'one': 1}, {'one': 1, 'two': 2})
+ self.assertRaisesWithLiteralMatch(
+ AssertionError, "a['two'] is 2 but b['two'] is 3",
+ self.assertSameStructure, {'one': 1, 'two': 2}, {'one': 1, 'two': 3})
+
+ # String and byte types should not be considered equivalent to other
+ # sequences
+ self.assertRaisesRegex(
+ AssertionError,
+ r"a is a <(type|class) 'list'> but b is a <(type|class) 'str'>",
+ self.assertSameStructure, [], '')
+ self.assertRaisesRegex(
+ AssertionError,
+ r"a is a <(type|class) 'str'> but b is a <(type|class) 'tuple'>",
+ self.assertSameStructure, '', ())
+ self.assertRaisesRegex(
+ AssertionError,
+ r"a is a <(type|class) 'list'> but b is a <(type|class) 'str'>",
+ self.assertSameStructure, ['a', 'b', 'c'], 'abc')
+ self.assertRaisesRegex(
+ AssertionError,
+ r"a is a <(type|class) 'str'> but b is a <(type|class) 'tuple'>",
+ self.assertSameStructure, 'abc', ('a', 'b', 'c'))
+
+ # Deep key generation
+ self.assertRaisesWithLiteralMatch(
+ AssertionError,
+ "a[0][0]['x']['y']['z'][0] is 1 but b[0][0]['x']['y']['z'][0] is 2",
+ self.assertSameStructure,
+ [[{'x': {'y': {'z': [1]}}}]], [[{'x': {'y': {'z': [2]}}}]])
+
+ # Multiple problems
+ self.assertRaisesWithLiteralMatch(
+ AssertionError,
+ 'a[0] is 1 but b[0] is 3; a[1] is 2 but b[1] is 4',
+ self.assertSameStructure, [1, 2], [3, 4])
+ with self.assertRaisesRegex(
+ AssertionError,
+ re.compile(r"^a\[0] is 'a' but b\[0] is 'A'; .*"
+ r"a\[18] is 's' but b\[18] is 'S'; \.\.\.$")):
+ self.assertSameStructure(
+ list(string.ascii_lowercase), list(string.ascii_uppercase))
+
+ # Verify same behavior with self.maxDiff = None
+ self.maxDiff = None
+ self.assertRaisesWithLiteralMatch(
+ AssertionError,
+ 'a[0] is 1 but b[0] is 3; a[1] is 2 but b[1] is 4',
+ self.assertSameStructure, [1, 2], [3, 4])
+
+ def test_same_structure_mapping_unchanged(self):
+ default_a = collections.defaultdict(lambda: 'BAD MODIFICATION', {})
+ dict_b = {'one': 'z'}
+ self.assertRaisesWithLiteralMatch(
+ AssertionError,
+ r"a lacks ['one'] but b has it with value 'z'",
+ self.assertSameStructure, default_a, dict_b)
+ self.assertEmpty(default_a)
+
+ dict_a = {'one': 'z'}
+ default_b = collections.defaultdict(lambda: 'BAD MODIFICATION', {})
+ self.assertRaisesWithLiteralMatch(
+ AssertionError,
+ r"a has ['one'] with value 'z' but it's missing in b",
+ self.assertSameStructure, dict_a, default_b)
+ self.assertEmpty(default_b)
+
+ def test_assert_json_equal_same(self):
+ self.assertJsonEqual('{"success": true}', '{"success": true}')
+ self.assertJsonEqual('{"success": true}', '{"success":true}')
+ self.assertJsonEqual('true', 'true')
+ self.assertJsonEqual('null', 'null')
+ self.assertJsonEqual('false', 'false')
+ self.assertJsonEqual('34', '34')
+ self.assertJsonEqual('[1, 2, 3]', '[1,2,3]', msg='please PASS')
+ self.assertJsonEqual('{"sequence": [1, 2, 3], "float": 23.42}',
+ '{"float": 23.42, "sequence": [1,2,3]}')
+ self.assertJsonEqual('{"nest": {"spam": "eggs"}, "float": 23.42}',
+ '{"float": 23.42, "nest": {"spam":"eggs"}}')
+
+ def test_assert_json_equal_different(self):
+ with self.assertRaises(AssertionError):
+ self.assertJsonEqual('{"success": true}', '{"success": false}')
+ with self.assertRaises(AssertionError):
+ self.assertJsonEqual('{"success": false}', '{"Success": false}')
+ with self.assertRaises(AssertionError):
+ self.assertJsonEqual('false', 'true')
+ with self.assertRaises(AssertionError) as error_context:
+ self.assertJsonEqual('null', '0', msg='I demand FAILURE')
+ self.assertIn('I demand FAILURE', error_context.exception.args[0])
+ self.assertIn('None', error_context.exception.args[0])
+ with self.assertRaises(AssertionError):
+ self.assertJsonEqual('[1, 0, 3]', '[1,2,3]')
+ with self.assertRaises(AssertionError):
+ self.assertJsonEqual('{"sequence": [1, 2, 3], "float": 23.42}',
+ '{"float": 23.42, "sequence": [1,0,3]}')
+ with self.assertRaises(AssertionError):
+ self.assertJsonEqual('{"nest": {"spam": "eggs"}, "float": 23.42}',
+ '{"float": 23.42, "nest": {"Spam":"beans"}}')
+
+ def test_assert_json_equal_bad_json(self):
+ with self.assertRaises(ValueError) as error_context:
+ self.assertJsonEqual("alhg'2;#", '{"a": true}')
+ self.assertIn('first', error_context.exception.args[0])
+ self.assertIn('alhg', error_context.exception.args[0])
+
+ with self.assertRaises(ValueError) as error_context:
+ self.assertJsonEqual('{"a": true}', "alhg'2;#")
+ self.assertIn('second', error_context.exception.args[0])
+ self.assertIn('alhg', error_context.exception.args[0])
+
+ with self.assertRaises(ValueError) as error_context:
+ self.assertJsonEqual('', '')
+
+
+class GetCommandStderrTestCase(absltest.TestCase):
+
+ def setUp(self):
+ super(GetCommandStderrTestCase, self).setUp()
+ self.original_environ = os.environ.copy()
+
+ def tearDown(self):
+ super(GetCommandStderrTestCase, self).tearDown()
+ os.environ = self.original_environ
+
+ def test_return_status(self):
+ tmpdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
+ returncode = (
+ absltest.get_command_stderr(
+ ['cat', os.path.join(tmpdir, 'file.txt')],
+ env=_env_for_command_tests())[0])
+ self.assertEqual(1, returncode)
+
+ def test_stderr(self):
+ tmpdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
+ stderr = (
+ absltest.get_command_stderr(
+ ['cat', os.path.join(tmpdir, 'file.txt')],
+ env=_env_for_command_tests())[1])
+ stderr = stderr.decode('utf-8')
+ self.assertRegex(stderr, 'No such file or directory')
+
+
+@contextlib.contextmanager
+def cm_for_test(obj):
+ try:
+ obj.cm_state = 'yielded'
+ yield 'value'
+ finally:
+ obj.cm_state = 'exited'
+
+
+class EnterContextTest(absltest.TestCase):
+
+ def setUp(self):
+ self.cm_state = 'unset'
+ self.cm_value = 'unset'
+
+ def assert_cm_exited():
+ self.assertEqual(self.cm_state, 'exited')
+
+ # Because cleanup functions are run in reverse order, we have to add
+ # our assert-cleanup before the exit stack registers its own cleanup.
+ # This ensures we see state after the stack cleanup runs.
+ self.addCleanup(assert_cm_exited)
+
+ super(EnterContextTest, self).setUp()
+ self.cm_value = self.enter_context(cm_for_test(self))
+
+ def test_enter_context(self):
+ self.assertEqual(self.cm_value, 'value')
+ self.assertEqual(self.cm_state, 'yielded')
+
+
+@absltest.skipIf(not hasattr(absltest.TestCase, 'addClassCleanup'),
+ 'Python 3.8 required for class-level enter_context')
+class EnterContextClassmethodTest(absltest.TestCase):
+
+ cm_state = 'unset'
+ cm_value = 'unset'
+
+ @classmethod
+ def setUpClass(cls):
+
+ def assert_cm_exited():
+ assert cls.cm_state == 'exited'
+
+ # Because cleanup functions are run in reverse order, we have to add
+ # our assert-cleanup before the exit stack registers its own cleanup.
+ # This ensures we see state after the stack cleanup runs.
+ cls.addClassCleanup(assert_cm_exited)
+
+ super(EnterContextClassmethodTest, cls).setUpClass()
+ cls.cm_value = cls.enter_context(cm_for_test(cls))
+
+ def test_enter_context(self):
+ self.assertEqual(self.cm_value, 'value')
+ self.assertEqual(self.cm_state, 'yielded')
+
+
+class EqualityAssertionTest(absltest.TestCase):
+ """This test verifies that absltest.failIfEqual actually tests __ne__.
+
+ If a user class implements __eq__, unittest.failUnlessEqual will call it
+ via first == second. However, failIfEqual also calls
+ first == second. This means that while the caller may believe
+ their __ne__ method is being tested, it is not.
+ """
+
+ class NeverEqual(object):
+ """Objects of this class behave like NaNs."""
+
+ def __eq__(self, unused_other):
+ return False
+
+ def __ne__(self, unused_other):
+ return False
+
+ class AllSame(object):
+ """All objects of this class compare as equal."""
+
+ def __eq__(self, unused_other):
+ return True
+
+ def __ne__(self, unused_other):
+ return False
+
+ class EqualityTestsWithEq(object):
+ """Performs all equality and inequality tests with __eq__."""
+
+ def __init__(self, value):
+ self._value = value
+
+ def __eq__(self, other):
+ return self._value == other._value
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ class EqualityTestsWithNe(object):
+ """Performs all equality and inequality tests with __ne__."""
+
+ def __init__(self, value):
+ self._value = value
+
+ def __eq__(self, other):
+ return not self.__ne__(other)
+
+ def __ne__(self, other):
+ return self._value != other._value
+
+ class EqualityTestsWithCmp(object):
+
+ def __init__(self, value):
+ self._value = value
+
+ def __cmp__(self, other):
+ return cmp(self._value, other._value)
+
+ class EqualityTestsWithLtEq(object):
+
+ def __init__(self, value):
+ self._value = value
+
+ def __eq__(self, other):
+ return self._value == other._value
+
+ def __lt__(self, other):
+ return self._value < other._value
+
+ def test_all_comparisons_fail(self):
+ i1 = self.NeverEqual()
+ i2 = self.NeverEqual()
+ self.assertFalse(i1 == i2)
+ self.assertFalse(i1 != i2)
+
+ # Compare two distinct objects
+ self.assertFalse(i1 is i2)
+ self.assertRaises(AssertionError, self.assertEqual, i1, i2)
+ self.assertRaises(AssertionError, self.assertEquals, i1, i2)
+ self.assertRaises(AssertionError, self.failUnlessEqual, i1, i2)
+ self.assertRaises(AssertionError, self.assertNotEqual, i1, i2)
+ self.assertRaises(AssertionError, self.assertNotEquals, i1, i2)
+ self.assertRaises(AssertionError, self.failIfEqual, i1, i2)
+ # A NeverEqual object should not compare equal to itself either.
+ i2 = i1
+ self.assertTrue(i1 is i2)
+ self.assertFalse(i1 == i2)
+ self.assertFalse(i1 != i2)
+ self.assertRaises(AssertionError, self.assertEqual, i1, i2)
+ self.assertRaises(AssertionError, self.assertEquals, i1, i2)
+ self.assertRaises(AssertionError, self.failUnlessEqual, i1, i2)
+ self.assertRaises(AssertionError, self.assertNotEqual, i1, i2)
+ self.assertRaises(AssertionError, self.assertNotEquals, i1, i2)
+ self.assertRaises(AssertionError, self.failIfEqual, i1, i2)
+
+ def test_all_comparisons_succeed(self):
+ a = self.AllSame()
+ b = self.AllSame()
+ self.assertFalse(a is b)
+ self.assertTrue(a == b)
+ self.assertFalse(a != b)
+ self.assertEqual(a, b)
+ self.assertEquals(a, b)
+ self.failUnlessEqual(a, b)
+ self.assertRaises(AssertionError, self.assertNotEqual, a, b)
+ self.assertRaises(AssertionError, self.assertNotEquals, a, b)
+ self.assertRaises(AssertionError, self.failIfEqual, a, b)
+
+ def _perform_apple_apple_orange_checks(self, same_a, same_b, different):
+ """Perform consistency checks with two apples and an orange.
+
+ The two apples should always compare as being the same (and inequality
+ checks should fail). The orange should always compare as being different
+ to each of the apples.
+
+ Args:
+ same_a: the first apple
+ same_b: the second apple
+ different: the orange
+ """
+ self.assertTrue(same_a == same_b)
+ self.assertFalse(same_a != same_b)
+ self.assertEqual(same_a, same_b)
+ self.assertEquals(same_a, same_b)
+ self.failUnlessEqual(same_a, same_b)
+
+ self.assertFalse(same_a == different)
+ self.assertTrue(same_a != different)
+ self.assertNotEqual(same_a, different)
+ self.assertNotEquals(same_a, different)
+ self.failIfEqual(same_a, different)
+
+ self.assertFalse(same_b == different)
+ self.assertTrue(same_b != different)
+ self.assertNotEqual(same_b, different)
+ self.assertNotEquals(same_b, different)
+ self.failIfEqual(same_b, different)
+
+ def test_comparison_with_eq(self):
+ same_a = self.EqualityTestsWithEq(42)
+ same_b = self.EqualityTestsWithEq(42)
+ different = self.EqualityTestsWithEq(1769)
+ self._perform_apple_apple_orange_checks(same_a, same_b, different)
+
+ def test_comparison_with_ne(self):
+ same_a = self.EqualityTestsWithNe(42)
+ same_b = self.EqualityTestsWithNe(42)
+ different = self.EqualityTestsWithNe(1769)
+ self._perform_apple_apple_orange_checks(same_a, same_b, different)
+
+ def test_comparison_with_cmp_or_lt_eq(self):
+ same_a = self.EqualityTestsWithLtEq(42)
+ same_b = self.EqualityTestsWithLtEq(42)
+ different = self.EqualityTestsWithLtEq(1769)
+ self._perform_apple_apple_orange_checks(same_a, same_b, different)
+
+
+class AssertSequenceStartsWithTest(absltest.TestCase):
+
+ def setUp(self):
+ self.a = [5, 'foo', {'c': 'd'}, None]
+
+ def test_empty_sequence_starts_with_empty_prefix(self):
+ self.assertSequenceStartsWith([], ())
+
+ def test_sequence_prefix_is_an_empty_list(self):
+ self.assertSequenceStartsWith([[]], ([], 'foo'))
+
+ def test_raise_if_empty_prefix_with_non_empty_whole(self):
+ with self.assertRaisesRegex(
+ AssertionError, 'Prefix length is 0 but whole length is %d: %s' % (len(
+ self.a), r"\[5, 'foo', \{'c': 'd'\}, None\]")):
+ self.assertSequenceStartsWith([], self.a)
+
+ def test_single_element_prefix(self):
+ self.assertSequenceStartsWith([5], self.a)
+
+ def test_two_element_prefix(self):
+ self.assertSequenceStartsWith((5, 'foo'), self.a)
+
+ def test_prefix_is_full_sequence(self):
+ self.assertSequenceStartsWith([5, 'foo', {'c': 'd'}, None], self.a)
+
+ def test_string_prefix(self):
+ self.assertSequenceStartsWith('abc', 'abc123')
+
+ def test_convert_non_sequence_prefix_to_sequence_and_try_again(self):
+ self.assertSequenceStartsWith(5, self.a)
+
+ def test_whole_not_asequence(self):
+ msg = (r'For whole: len\(5\) is not supported, it appears to be type: '
+ '<(type|class) \'int\'>')
+ with self.assertRaisesRegex(AssertionError, msg):
+ self.assertSequenceStartsWith(self.a, 5)
+
+ def test_raise_if_sequence_does_not_start_with_prefix(self):
+ msg = (r"prefix: \['foo', \{'c': 'd'\}\] not found at start of whole: "
+ r"\[5, 'foo', \{'c': 'd'\}, None\].")
+ with self.assertRaisesRegex(AssertionError, msg):
+ self.assertSequenceStartsWith(['foo', {'c': 'd'}], self.a)
+
+ def test_raise_if_types_ar_not_supported(self):
+ with self.assertRaisesRegex(TypeError, 'unhashable type'):
+ self.assertSequenceStartsWith({'a': 1, 2: 'b'},
+ {'a': 1, 2: 'b', 'c': '3'})
+
+
+class TestAssertEmpty(absltest.TestCase):
+ longMessage = True
+
+ def test_raises_if_not_asized_object(self):
+ msg = "Expected a Sized object, got: 'int'"
+ with self.assertRaisesRegex(AssertionError, msg):
+ self.assertEmpty(1)
+
+ def test_calls_len_not_bool(self):
+
+ class BadList(list):
+
+ def __bool__(self):
+ return False
+
+ __nonzero__ = __bool__
+
+ bad_list = BadList()
+ self.assertEmpty(bad_list)
+ self.assertFalse(bad_list)
+
+ def test_passes_when_empty(self):
+ empty_containers = [
+ list(),
+ tuple(),
+ dict(),
+ set(),
+ frozenset(),
+ b'',
+ u'',
+ bytearray(),
+ ]
+ for container in empty_containers:
+ self.assertEmpty(container)
+
+ def test_raises_with_not_empty_containers(self):
+ not_empty_containers = [
+ [1],
+ (1,),
+ {'foo': 'bar'},
+ {1},
+ frozenset([1]),
+ b'a',
+ u'a',
+ bytearray(b'a'),
+ ]
+ regexp = r'.* has length of 1\.$'
+ for container in not_empty_containers:
+ with self.assertRaisesRegex(AssertionError, regexp):
+ self.assertEmpty(container)
+
+ def test_user_message_added_to_default(self):
+ msg = 'This is a useful message'
+ whole_msg = re.escape('[1] has length of 1. : This is a useful message')
+ with self.assertRaisesRegex(AssertionError, whole_msg):
+ self.assertEmpty([1], msg=msg)
+
+
+class TestAssertNotEmpty(absltest.TestCase):
+ longMessage = True
+
+ def test_raises_if_not_asized_object(self):
+ msg = "Expected a Sized object, got: 'int'"
+ with self.assertRaisesRegex(AssertionError, msg):
+ self.assertNotEmpty(1)
+
+ def test_calls_len_not_bool(self):
+
+ class BadList(list):
+
+ def __bool__(self):
+ return False
+
+ __nonzero__ = __bool__
+
+ bad_list = BadList([1])
+ self.assertNotEmpty(bad_list)
+ self.assertFalse(bad_list)
+
+ def test_passes_when_not_empty(self):
+ not_empty_containers = [
+ [1],
+ (1,),
+ {'foo': 'bar'},
+ {1},
+ frozenset([1]),
+ b'a',
+ u'a',
+ bytearray(b'a'),
+ ]
+ for container in not_empty_containers:
+ self.assertNotEmpty(container)
+
+ def test_raises_with_empty_containers(self):
+ empty_containers = [
+ list(),
+ tuple(),
+ dict(),
+ set(),
+ frozenset(),
+ b'',
+ u'',
+ bytearray(),
+ ]
+ regexp = r'.* has length of 0\.$'
+ for container in empty_containers:
+ with self.assertRaisesRegex(AssertionError, regexp):
+ self.assertNotEmpty(container)
+
+ def test_user_message_added_to_default(self):
+ msg = 'This is a useful message'
+ whole_msg = re.escape('[] has length of 0. : This is a useful message')
+ with self.assertRaisesRegex(AssertionError, whole_msg):
+ self.assertNotEmpty([], msg=msg)
+
+
+class TestAssertLen(absltest.TestCase):
+ longMessage = True
+
+ def test_raises_if_not_asized_object(self):
+ msg = "Expected a Sized object, got: 'int'"
+ with self.assertRaisesRegex(AssertionError, msg):
+ self.assertLen(1, 1)
+
+ def test_passes_when_expected_len(self):
+ containers = [
+ [[1], 1],
+ [(1, 2), 2],
+ [{'a': 1, 'b': 2, 'c': 3}, 3],
+ [{1, 2, 3, 4}, 4],
+ [frozenset([1]), 1],
+ [b'abc', 3],
+ [u'def', 3],
+ [bytearray(b'ghij'), 4],
+ ]
+ for container, expected_len in containers:
+ self.assertLen(container, expected_len)
+
+ def test_raises_when_unexpected_len(self):
+ containers = [
+ [1],
+ (1, 2),
+ {'a': 1, 'b': 2, 'c': 3},
+ {1, 2, 3, 4},
+ frozenset([1]),
+ b'abc',
+ u'def',
+ bytearray(b'ghij'),
+ ]
+ for container in containers:
+ regexp = r'.* has length of %d, expected 100\.$' % len(container)
+ with self.assertRaisesRegex(AssertionError, regexp):
+ self.assertLen(container, 100)
+
+ def test_user_message_added_to_default(self):
+ msg = 'This is a useful message'
+ whole_msg = (
+ r'\[1\] has length of 1, expected 100. : This is a useful message')
+ with self.assertRaisesRegex(AssertionError, whole_msg):
+ self.assertLen([1], 100, msg)
+
+
+class TestLoaderTest(absltest.TestCase):
+ """Tests that the TestLoader bans methods named TestFoo."""
+
+ # pylint: disable=invalid-name
+ class Valid(absltest.TestCase):
+ """Test case containing a variety of valid names."""
+
+ test_property = 1
+ TestProperty = 2
+
+ @staticmethod
+ def TestStaticMethod():
+ pass
+
+ @staticmethod
+ def TestStaticMethodWithArg(foo):
+ pass
+
+ @classmethod
+ def TestClassMethod(cls):
+ pass
+
+ def Test(self):
+ pass
+
+ def TestingHelper(self):
+ pass
+
+ def testMethod(self):
+ pass
+
+ def TestHelperWithParams(self, a, b):
+ pass
+
+ def TestHelperWithVarargs(self, *args, **kwargs):
+ pass
+
+ def TestHelperWithDefaults(self, a=5):
+ pass
+
+ class Invalid(absltest.TestCase):
+ """Test case containing a suspicious method."""
+
+ def testMethod(self):
+ pass
+
+ def TestSuspiciousMethod(self):
+ pass
+ # pylint: enable=invalid-name
+
+ def setUp(self):
+ self.loader = absltest.TestLoader()
+
+ def test_valid(self):
+ suite = self.loader.loadTestsFromTestCase(TestLoaderTest.Valid)
+ self.assertEquals(1, suite.countTestCases())
+
+ def testInvalid(self):
+ with self.assertRaisesRegex(TypeError, 'TestSuspiciousMethod'):
+ self.loader.loadTestsFromTestCase(TestLoaderTest.Invalid)
+
+
+class InitNotNecessaryForAssertsTest(absltest.TestCase):
+ """TestCase assertions should work even if __init__ wasn't correctly called.
+
+ This is a workaround, see comment in
+ absltest.TestCase._getAssertEqualityFunc. We know that not calling
+ __init__ of a superclass is a bad thing, but people keep doing them,
+ and this (even if a little bit dirty) saves them from shooting
+ themselves in the foot.
+ """
+
+ def test_subclass(self):
+
+ class Subclass(absltest.TestCase):
+
+ def __init__(self): # pylint: disable=super-init-not-called
+ pass
+
+ Subclass().assertEquals({}, {})
+
+ def test_multiple_inheritance(self):
+
+ class Foo(object):
+
+ def __init__(self, *args, **kwargs):
+ pass
+
+ class Subclass(Foo, absltest.TestCase):
+ pass
+
+ Subclass().assertEquals({}, {})
+
+
+class GetCommandStringTest(parameterized.TestCase):
+
+ @parameterized.parameters(
+ ([], '', ''),
+ ([''], "''", ''),
+ (['command', 'arg-0'], "'command' 'arg-0'", 'command arg-0'),
+ ([u'command', u'arg-0'], "'command' 'arg-0'", u'command arg-0'),
+ (["foo'bar"], "'foo'\"'\"'bar'", "foo'bar"),
+ (['foo"bar'], "'foo\"bar'", 'foo"bar'),
+ ('command arg-0', 'command arg-0', 'command arg-0'),
+ (u'command arg-0', 'command arg-0', 'command arg-0'))
+ def test_get_command_string(
+ self, command, expected_non_windows, expected_windows):
+ expected = expected_windows if os.name == 'nt' else expected_non_windows
+ self.assertEqual(expected, absltest.get_command_string(command))
+
+
+class TempFileTest(absltest.TestCase, HelperMixin):
+
+ def assert_dir_exists(self, temp_dir):
+ path = temp_dir.full_path
+ self.assertTrue(os.path.exists(path), 'Dir {} does not exist'.format(path))
+ self.assertTrue(os.path.isdir(path),
+ 'Path {} exists, but is not a directory'.format(path))
+
+ def assert_file_exists(self, temp_file, expected_content=b''):
+ path = temp_file.full_path
+ self.assertTrue(os.path.exists(path), 'File {} does not exist'.format(path))
+ self.assertTrue(os.path.isfile(path),
+ 'Path {} exists, but is not a file'.format(path))
+
+ mode = 'rb' if isinstance(expected_content, bytes) else 'rt'
+ with io.open(path, mode) as fp:
+ actual = fp.read()
+ self.assertEqual(expected_content, actual)
+
+ def run_tempfile_helper(self, cleanup, expected_paths):
+ tmpdir = self.create_tempdir('helper-test-temp-dir')
+ env = {
+ 'ABSLTEST_TEST_HELPER_TEMPFILE_CLEANUP': cleanup,
+ 'TEST_TMPDIR': tmpdir.full_path,
+ }
+ stdout, stderr = self.run_helper(0, ['TempFileHelperTest'], env,
+ expect_success=False)
+ output = ('\n=== Helper output ===\n'
+ '----- stdout -----\n{}\n'
+ '----- end stdout -----\n'
+ '----- stderr -----\n{}\n'
+ '----- end stderr -----\n'
+ '===== end helper output =====').format(stdout, stderr)
+ self.assertIn('test_failure', stderr, output)
+
+ # Adjust paths to match on Windows
+ expected_paths = {path.replace('/', os.sep) for path in expected_paths}
+
+ actual = {
+ os.path.relpath(f, tmpdir.full_path)
+ for f in _listdir_recursive(tmpdir.full_path)
+ if f != tmpdir.full_path
+ }
+ self.assertEqual(expected_paths, actual, output)
+
+ def test_create_file_pre_existing_readonly(self):
+ first = self.create_tempfile('foo', content='first')
+ os.chmod(first.full_path, 0o444)
+ second = self.create_tempfile('foo', content='second')
+ self.assertEqual('second', first.read_text())
+ self.assertEqual('second', second.read_text())
+
+ def test_create_file_fails_cleanup(self):
+ path = self.create_tempfile().full_path
+ # Removing the write bit from the file makes it undeletable on Windows.
+ os.chmod(path, 0)
+ # Removing the write bit from the whole directory makes all contained files
+ # undeletable on unix. We also need it to be exec so that os.path.isfile
+ # returns true, and we reach the buggy branch.
+ os.chmod(os.path.dirname(path), stat.S_IEXEC)
+ # The test should pass, even though that file cannot be deleted in teardown.
+
+ def test_temp_file_path_like(self):
+ tempdir = self.create_tempdir('foo')
+ self.assertIsInstance(tempdir, os.PathLike)
+
+ tempfile_ = tempdir.create_file('bar')
+ self.assertIsInstance(tempfile_, os.PathLike)
+
+ self.assertEqual(tempfile_.read_text(), pathlib.Path(tempfile_).read_text())
+
+ def test_unnamed(self):
+ td = self.create_tempdir()
+ self.assert_dir_exists(td)
+
+ tdf = td.create_file()
+ self.assert_file_exists(tdf)
+
+ tdd = td.mkdir()
+ self.assert_dir_exists(tdd)
+
+ tf = self.create_tempfile()
+ self.assert_file_exists(tf)
+
+ def test_named(self):
+ td = self.create_tempdir('d')
+ self.assert_dir_exists(td)
+
+ tdf = td.create_file('df')
+ self.assert_file_exists(tdf)
+
+ tdd = td.mkdir('dd')
+ self.assert_dir_exists(tdd)
+
+ tf = self.create_tempfile('f')
+ self.assert_file_exists(tf)
+
+ def test_nested_paths(self):
+ td = self.create_tempdir('d1/d2')
+ self.assert_dir_exists(td)
+
+ tdf = td.create_file('df1/df2')
+ self.assert_file_exists(tdf)
+
+ tdd = td.mkdir('dd1/dd2')
+ self.assert_dir_exists(tdd)
+
+ tf = self.create_tempfile('f1/f2')
+ self.assert_file_exists(tf)
+
+ def test_tempdir_create_file(self):
+ td = self.create_tempdir()
+ td.create_file(content='text')
+
+ def test_tempfile_text(self):
+ tf = self.create_tempfile(content='text')
+ self.assert_file_exists(tf, 'text')
+ self.assertEqual('text', tf.read_text())
+
+ with tf.open_text() as fp:
+ self.assertEqual('text', fp.read())
+
+ with tf.open_text('w') as fp:
+ fp.write(u'text-from-open-write')
+ self.assertEqual('text-from-open-write', tf.read_text())
+
+ tf.write_text('text-from-write-text')
+ self.assertEqual('text-from-write-text', tf.read_text())
+
+ def test_tempfile_bytes(self):
+ tf = self.create_tempfile(content=b'\x00\x01\x02')
+ self.assert_file_exists(tf, b'\x00\x01\x02')
+ self.assertEqual(b'\x00\x01\x02', tf.read_bytes())
+
+ with tf.open_bytes() as fp:
+ self.assertEqual(b'\x00\x01\x02', fp.read())
+
+ with tf.open_bytes('wb') as fp:
+ fp.write(b'\x03')
+ self.assertEqual(b'\x03', tf.read_bytes())
+
+ tf.write_bytes(b'\x04')
+ self.assertEqual(b'\x04', tf.read_bytes())
+
+ def test_tempdir_same_name(self):
+ """Make sure the same directory name can be used."""
+ td1 = self.create_tempdir('foo')
+ td2 = self.create_tempdir('foo')
+ self.assert_dir_exists(td1)
+ self.assert_dir_exists(td2)
+
+ def test_tempfile_cleanup_success(self):
+ expected = {
+ 'TempFileHelperTest',
+ 'TempFileHelperTest/test_failure',
+ 'TempFileHelperTest/test_failure/failure',
+ 'TempFileHelperTest/test_success',
+ }
+ self.run_tempfile_helper('SUCCESS', expected)
+
+ def test_tempfile_cleanup_always(self):
+ expected = {
+ 'TempFileHelperTest',
+ 'TempFileHelperTest/test_failure',
+ 'TempFileHelperTest/test_success',
+ }
+ self.run_tempfile_helper('ALWAYS', expected)
+
+ def test_tempfile_cleanup_off(self):
+ expected = {
+ 'TempFileHelperTest',
+ 'TempFileHelperTest/test_failure',
+ 'TempFileHelperTest/test_failure/failure',
+ 'TempFileHelperTest/test_success',
+ 'TempFileHelperTest/test_success/success',
+ }
+ self.run_tempfile_helper('OFF', expected)
+
+
+class SkipClassTest(absltest.TestCase):
+
+ def test_incorrect_decorator_call(self):
+ with self.assertRaises(TypeError):
+
+ @absltest.skipThisClass # pylint: disable=unused-variable
+ class Test(absltest.TestCase):
+ pass
+
+ def test_incorrect_decorator_subclass(self):
+ with self.assertRaises(TypeError):
+
+ @absltest.skipThisClass('reason')
+ def test_method(): # pylint: disable=unused-variable
+ pass
+
+ def test_correct_decorator_class(self):
+
+ @absltest.skipThisClass('reason')
+ class Test(absltest.TestCase):
+ pass
+
+ with self.assertRaises(absltest.SkipTest):
+ Test.setUpClass()
+
+ def test_correct_decorator_subclass(self):
+
+ @absltest.skipThisClass('reason')
+ class Test(absltest.TestCase):
+ pass
+
+ class Subclass(Test):
+ pass
+
+ with self.subTest('Base class should be skipped'):
+ with self.assertRaises(absltest.SkipTest):
+ Test.setUpClass()
+
+ with self.subTest('Subclass should not be skipped'):
+ Subclass.setUpClass() # should not raise.
+
+ def test_setup(self):
+
+ @absltest.skipThisClass('reason')
+ class Test(absltest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ super(Test, cls).setUpClass()
+ cls.foo = 1
+
+ class Subclass(Test):
+ pass
+
+ Subclass.setUpClass()
+ self.assertEqual(Subclass.foo, 1)
+
+ def test_setup_chain(self):
+
+ @absltest.skipThisClass('reason')
+ class BaseTest(absltest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ super(BaseTest, cls).setUpClass()
+ cls.foo = 1
+
+ @absltest.skipThisClass('reason')
+ class SecondBaseTest(BaseTest):
+
+ @classmethod
+ def setUpClass(cls):
+ super(SecondBaseTest, cls).setUpClass()
+ cls.bar = 2
+
+ class Subclass(SecondBaseTest):
+ pass
+
+ Subclass.setUpClass()
+ self.assertEqual(Subclass.foo, 1)
+ self.assertEqual(Subclass.bar, 2)
+
+ def test_setup_args(self):
+
+ @absltest.skipThisClass('reason')
+ class Test(absltest.TestCase):
+
+ @classmethod
+ def setUpClass(cls, foo, bar=None):
+ super(Test, cls).setUpClass()
+ cls.foo = foo
+ cls.bar = bar
+
+ class Subclass(Test):
+
+ @classmethod
+ def setUpClass(cls):
+ super(Subclass, cls).setUpClass('foo', bar='baz')
+
+ Subclass.setUpClass()
+ self.assertEqual(Subclass.foo, 'foo')
+ self.assertEqual(Subclass.bar, 'baz')
+
+ def test_setup_multiple_inheritance(self):
+
+ # Test that skipping this class doesn't break the MRO chain and stop
+ # RequiredBase.setUpClass from running.
+ @absltest.skipThisClass('reason')
+ class Left(absltest.TestCase):
+ pass
+
+ class RequiredBase(absltest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ super(RequiredBase, cls).setUpClass()
+ cls.foo = 'foo'
+
+ class Right(RequiredBase):
+
+ @classmethod
+ def setUpClass(cls):
+ super(Right, cls).setUpClass()
+
+ # Test will fail unless Left.setUpClass() follows mro properly
+ # Right.setUpClass()
+ class Subclass(Left, Right):
+
+ @classmethod
+ def setUpClass(cls):
+ super(Subclass, cls).setUpClass()
+
+ class Test(Subclass):
+ pass
+
+ Test.setUpClass()
+ self.assertEqual(Test.foo, 'foo')
+
+ def test_skip_class(self):
+
+ @absltest.skipThisClass('reason')
+ class BaseTest(absltest.TestCase):
+
+ def test_foo(self):
+ _ = 1 / 0
+
+ class Test(BaseTest):
+
+ def test_foo(self):
+ self.assertEqual(1, 1)
+
+ with self.subTest('base class'):
+ ts = unittest.makeSuite(BaseTest)
+ self.assertEqual(1, ts.countTestCases())
+
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertTrue(res.wasSuccessful())
+ self.assertLen(res.skipped, 1)
+ self.assertEqual(0, res.testsRun)
+ self.assertEmpty(res.failures)
+ self.assertEmpty(res.errors)
+
+ with self.subTest('real test'):
+ ts = unittest.makeSuite(Test)
+ self.assertEqual(1, ts.countTestCases())
+
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertTrue(res.wasSuccessful())
+ self.assertEqual(1, res.testsRun)
+ self.assertEmpty(res.skipped)
+ self.assertEmpty(res.failures)
+ self.assertEmpty(res.errors)
+
+ def test_skip_class_unittest(self):
+
+ @absltest.skipThisClass('reason')
+ class Test(unittest.TestCase): # note: unittest not absltest
+
+ def test_foo(self):
+ _ = 1 / 0
+
+ ts = unittest.makeSuite(Test)
+ self.assertEqual(1, ts.countTestCases())
+
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertTrue(res.wasSuccessful())
+ self.assertLen(res.skipped, 1)
+ self.assertEqual(0, res.testsRun)
+ self.assertEmpty(res.failures)
+ self.assertEmpty(res.errors)
+
+
+def _listdir_recursive(path):
+ for dirname, _, filenames in os.walk(path):
+ yield dirname
+ for filename in filenames:
+ yield os.path.join(dirname, filename)
+
+
+def _env_for_command_tests():
+ if os.name == 'nt' and 'PATH' in os.environ:
+ # get_command_stderr and assertCommandXXX don't inherit environment
+ # variables by default. This makes sure msys commands can be found on
+ # Windows.
+ return {'PATH': os.environ['PATH']}
+ else:
+ return None
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/testing/tests/absltest_test_helper.py b/absl/testing/tests/absltest_test_helper.py
new file mode 100644
index 0000000..c6b2465
--- /dev/null
+++ b/absl/testing/tests/absltest_test_helper.py
@@ -0,0 +1,106 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Helper binary for absltest_test.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tempfile
+import unittest
+
+from absl import flags
+from absl.testing import absltest
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_integer('test_id', 0, 'Which test to run.')
+
+
+class HelperTest(absltest.TestCase):
+
+ def test_flags(self):
+ if FLAGS.test_id == 1:
+ self.assertEqual(FLAGS.test_random_seed, 301)
+ if os.name == 'nt':
+ # On Windows, it's always in the temp dir, which doesn't start with '/'.
+ expected_prefix = tempfile.gettempdir()
+ else:
+ expected_prefix = '/'
+ self.assertTrue(
+ absltest.TEST_TMPDIR.value.startswith(expected_prefix),
+ '--test_tmpdir={} does not start with {}'.format(
+ absltest.TEST_TMPDIR.value, expected_prefix))
+ self.assertTrue(os.access(absltest.TEST_TMPDIR.value, os.W_OK))
+ elif FLAGS.test_id == 2:
+ self.assertEqual(FLAGS.test_random_seed, 321)
+ self.assertEqual(
+ absltest.TEST_SRCDIR.value,
+ os.environ['ABSLTEST_TEST_HELPER_EXPECTED_TEST_SRCDIR'])
+ self.assertEqual(
+ absltest.TEST_TMPDIR.value,
+ os.environ['ABSLTEST_TEST_HELPER_EXPECTED_TEST_TMPDIR'])
+ elif FLAGS.test_id == 3:
+ self.assertEqual(FLAGS.test_random_seed, 123)
+ self.assertEqual(
+ absltest.TEST_SRCDIR.value,
+ os.environ['ABSLTEST_TEST_HELPER_EXPECTED_TEST_SRCDIR'])
+ self.assertEqual(
+ absltest.TEST_TMPDIR.value,
+ os.environ['ABSLTEST_TEST_HELPER_EXPECTED_TEST_TMPDIR'])
+ elif FLAGS.test_id == 4:
+ self.assertEqual(FLAGS.test_random_seed, 221)
+ self.assertEqual(
+ absltest.TEST_SRCDIR.value,
+ os.environ['ABSLTEST_TEST_HELPER_EXPECTED_TEST_SRCDIR'])
+ self.assertEqual(
+ absltest.TEST_TMPDIR.value,
+ os.environ['ABSLTEST_TEST_HELPER_EXPECTED_TEST_TMPDIR'])
+ else:
+ raise unittest.SkipTest(
+ 'Not asked to run: --test_id={}'.format(FLAGS.test_id))
+
+ @unittest.expectedFailure
+ def test_expected_failure(self):
+ if FLAGS.test_id == 5:
+ self.assertEqual(1, 1) # Expected failure, got success.
+ else:
+ self.assertEqual(1, 2) # The expected failure.
+
+ def test_xml_env_vars(self):
+ if FLAGS.test_id == 6:
+ self.assertEqual(
+ FLAGS.xml_output_file,
+ os.environ['ABSLTEST_TEST_HELPER_EXPECTED_XML_OUTPUT_FILE'])
+ else:
+ raise unittest.SkipTest(
+ 'Not asked to run: --test_id={}'.format(FLAGS.test_id))
+
+
+class TempFileHelperTest(absltest.TestCase):
+ tempfile_cleanup = absltest.TempFileCleanup[os.environ.get(
+ 'ABSLTEST_TEST_HELPER_TEMPFILE_CLEANUP', 'SUCCESS')]
+
+ def test_failure(self):
+ self.create_tempfile('failure')
+ self.fail('expected failure')
+
+ def test_success(self):
+ self.create_tempfile('success')
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/testing/tests/flagsaver_test.py b/absl/testing/tests/flagsaver_test.py
new file mode 100644
index 0000000..3439a32
--- /dev/null
+++ b/absl/testing/tests/flagsaver_test.py
@@ -0,0 +1,467 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for flagsaver."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import flags
+from absl.testing import absltest
+from absl.testing import flagsaver
+
+flags.DEFINE_string('flagsaver_test_flag0', 'unchanged0', 'flag to test with')
+flags.DEFINE_string('flagsaver_test_flag1', 'unchanged1', 'flag to test with')
+
+flags.DEFINE_string('flagsaver_test_validated_flag', None, 'flag to test with')
+flags.register_validator('flagsaver_test_validated_flag', lambda x: not x)
+
+flags.DEFINE_string('flagsaver_test_validated_flag1', None, 'flag to test with')
+flags.DEFINE_string('flagsaver_test_validated_flag2', None, 'flag to test with')
+
+INT_FLAG = flags.DEFINE_integer(
+ 'flagsaver_test_int_flag', default=1, help='help')
+STR_FLAG = flags.DEFINE_string(
+ 'flagsaver_test_str_flag', default='str default', help='help')
+
+
+@flags.multi_flags_validator(
+ ('flagsaver_test_validated_flag1', 'flagsaver_test_validated_flag2'))
+def validate_test_flags(flag_dict):
+ return (flag_dict['flagsaver_test_validated_flag1'] ==
+ flag_dict['flagsaver_test_validated_flag2'])
+
+
+FLAGS = flags.FLAGS
+
+
+@flags.validator('flagsaver_test_flag0')
+def check_no_upper_case(value):
+ return value == value.lower()
+
+
+class _TestError(Exception):
+ """Exception class for use in these tests."""
+
+
+class FlagSaverTest(absltest.TestCase):
+
+ def test_context_manager_without_parameters(self):
+ with flagsaver.flagsaver():
+ FLAGS.flagsaver_test_flag0 = 'new value'
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+
+ def test_context_manager_with_overrides(self):
+ with flagsaver.flagsaver(flagsaver_test_flag0='new value'):
+ self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
+ FLAGS.flagsaver_test_flag1 = 'another value'
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+ self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
+
+ def test_context_manager_with_flagholders(self):
+ with flagsaver.flagsaver((INT_FLAG, 3), (STR_FLAG, 'new value')):
+ self.assertEqual('new value', STR_FLAG.value)
+ self.assertEqual(3, INT_FLAG.value)
+ FLAGS.flagsaver_test_flag1 = 'another value'
+ self.assertEqual(INT_FLAG.value, INT_FLAG.default)
+ self.assertEqual(STR_FLAG.value, STR_FLAG.default)
+ self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
+
+ def test_context_manager_with_overrides_and_flagholders(self):
+ with flagsaver.flagsaver((INT_FLAG, 3), flagsaver_test_flag0='new value'):
+ self.assertEqual(STR_FLAG.default, STR_FLAG.value)
+ self.assertEqual(3, INT_FLAG.value)
+ FLAGS.flagsaver_test_flag0 = 'new value'
+ self.assertEqual(INT_FLAG.value, INT_FLAG.default)
+ self.assertEqual(STR_FLAG.value, STR_FLAG.default)
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+
+ def test_context_manager_with_cross_validated_overrides_set_together(self):
+ # When the flags are set in the same flagsaver call their validators will
+ # be triggered only once the setting is done.
+ with flagsaver.flagsaver(
+ flagsaver_test_validated_flag1='new_value',
+ flagsaver_test_validated_flag2='new_value'):
+ self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag1)
+ self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag2)
+
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ def test_context_manager_with_cross_validated_overrides_set_badly(self):
+
+ # Different values should violate the validator.
+ with self.assertRaisesRegex(flags.IllegalFlagValueError,
+ 'Flag validation failed'):
+ with flagsaver.flagsaver(
+ flagsaver_test_validated_flag1='new_value',
+ flagsaver_test_validated_flag2='other_value'):
+ pass
+
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ def test_context_manager_with_cross_validated_overrides_set_separately(self):
+
+ # Setting just one flag will trip the validator as well.
+ with self.assertRaisesRegex(flags.IllegalFlagValueError,
+ 'Flag validation failed'):
+ with flagsaver.flagsaver(flagsaver_test_validated_flag1='new_value'):
+ pass
+
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ def test_context_manager_with_exception(self):
+ with self.assertRaises(_TestError):
+ with flagsaver.flagsaver(flagsaver_test_flag0='new value'):
+ self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
+ FLAGS.flagsaver_test_flag1 = 'another value'
+ raise _TestError('oops')
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+ self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
+
+ def test_context_manager_with_validation_exception(self):
+ with self.assertRaises(flags.IllegalFlagValueError):
+ with flagsaver.flagsaver(
+ flagsaver_test_flag0='new value',
+ flagsaver_test_validated_flag='new value'):
+ pass
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+ self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag)
+
+ def test_decorator_without_call(self):
+
+ @flagsaver.flagsaver
+ def mutate_flags(value):
+ """Test function that mutates a flag."""
+ # The undecorated method mutates --flagsaver_test_flag0 to the given value
+ # and then returns the value of that flag. If the @flagsaver.flagsaver
+ # decorator works as designed, then this mutation will be reverted after
+ # this method returns.
+ FLAGS.flagsaver_test_flag0 = value
+ return FLAGS.flagsaver_test_flag0
+
+ # mutate_flags returns the flag value before it gets restored by
+ # the flagsaver decorator. So we check that flag value was
+ # actually changed in the method's scope.
+ self.assertEqual('new value', mutate_flags('new value'))
+ # But... notice that the flag is now unchanged0.
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+
+ def test_decorator_without_parameters(self):
+
+ @flagsaver.flagsaver()
+ def mutate_flags(value):
+ FLAGS.flagsaver_test_flag0 = value
+ return FLAGS.flagsaver_test_flag0
+
+ self.assertEqual('new value', mutate_flags('new value'))
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+
+ def test_decorator_with_overrides(self):
+
+ @flagsaver.flagsaver(flagsaver_test_flag0='new value')
+ def mutate_flags():
+ """Test function expecting new value."""
+ # If the @flagsaver.decorator decorator works as designed,
+ # then the value of the flag should be changed in the scope of
+ # the method but the change will be reverted after this method
+ # returns.
+ return FLAGS.flagsaver_test_flag0
+
+ # mutate_flags returns the flag value before it gets restored by
+ # the flagsaver decorator. So we check that flag value was
+ # actually changed in the method's scope.
+ self.assertEqual('new value', mutate_flags())
+ # But... notice that the flag is now unchanged0.
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+
+ def test_decorator_with_cross_validated_overrides_set_together(self):
+
+ # When the flags are set in the same flagsaver call their validators will
+ # be triggered only once the setting is done.
+ @flagsaver.flagsaver(
+ flagsaver_test_validated_flag1='new_value',
+ flagsaver_test_validated_flag2='new_value')
+ def mutate_flags_together():
+ return (FLAGS.flagsaver_test_validated_flag1,
+ FLAGS.flagsaver_test_validated_flag2)
+
+ self.assertEqual(('new_value', 'new_value'), mutate_flags_together())
+
+ # The flags have not changed outside the context of the function.
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ def test_decorator_with_cross_validated_overrides_set_badly(self):
+
+ # Different values should violate the validator.
+ @flagsaver.flagsaver(
+ flagsaver_test_validated_flag1='new_value',
+ flagsaver_test_validated_flag2='other_value')
+ def mutate_flags_together_badly():
+ return (FLAGS.flagsaver_test_validated_flag1,
+ FLAGS.flagsaver_test_validated_flag2)
+
+ with self.assertRaisesRegex(flags.IllegalFlagValueError,
+ 'Flag validation failed'):
+ mutate_flags_together_badly()
+
+ # The flags have not changed outside the context of the exception.
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ def test_decorator_with_cross_validated_overrides_set_separately(self):
+
+ # Setting the flags sequentially and not together will trip the validator,
+ # because it will be called at the end of each flagsaver call.
+ @flagsaver.flagsaver(flagsaver_test_validated_flag1='new_value')
+ @flagsaver.flagsaver(flagsaver_test_validated_flag2='new_value')
+ def mutate_flags_separately():
+ return (FLAGS.flagsaver_test_validated_flag1,
+ FLAGS.flagsaver_test_validated_flag2)
+
+ with self.assertRaisesRegex(flags.IllegalFlagValueError,
+ 'Flag validation failed'):
+ mutate_flags_separately()
+
+ # The flags have not changed outside the context of the exception.
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
+ self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
+
+ def test_save_flag_value(self):
+ # First save the flag values.
+ saved_flag_values = flagsaver.save_flag_values()
+
+ # Now mutate the flag's value field and check that it changed.
+ FLAGS.flagsaver_test_flag0 = 'new value'
+ self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
+
+ # Now restore the flag to its original value.
+ flagsaver.restore_flag_values(saved_flag_values)
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+
+ def test_save_flag_default(self):
+ # First save the flag.
+ saved_flag_values = flagsaver.save_flag_values()
+
+ # Now mutate the flag's default field and check that it changed.
+ FLAGS.set_default('flagsaver_test_flag0', 'new_default')
+ self.assertEqual('new_default', FLAGS['flagsaver_test_flag0'].default)
+
+ # Now restore the flag's default field.
+ flagsaver.restore_flag_values(saved_flag_values)
+ self.assertEqual('unchanged0', FLAGS['flagsaver_test_flag0'].default)
+
+ def test_restore_after_parse(self):
+ # First save the flag.
+ saved_flag_values = flagsaver.save_flag_values()
+
+ # Sanity check (would fail if called with --flagsaver_test_flag0).
+ self.assertEqual(0, FLAGS['flagsaver_test_flag0'].present)
+ # Now populate the flag and check that it changed.
+ FLAGS['flagsaver_test_flag0'].parse('new value')
+ self.assertEqual('new value', FLAGS['flagsaver_test_flag0'].value)
+ self.assertEqual(1, FLAGS['flagsaver_test_flag0'].present)
+
+ # Now restore the flag to its original value.
+ flagsaver.restore_flag_values(saved_flag_values)
+ self.assertEqual('unchanged0', FLAGS['flagsaver_test_flag0'].value)
+ self.assertEqual(0, FLAGS['flagsaver_test_flag0'].present)
+
+ def test_decorator_with_exception(self):
+
+ @flagsaver.flagsaver
+ def raise_exception():
+ FLAGS.flagsaver_test_flag0 = 'new value'
+ # Simulate a failed test.
+ raise _TestError('something happened')
+
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+ self.assertRaises(_TestError, raise_exception)
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+
+ def test_validator_list_is_restored(self):
+
+ self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 1)
+ original_validators = list(FLAGS['flagsaver_test_flag0'].validators)
+
+ @flagsaver.flagsaver
+ def modify_validators():
+
+ def no_space(value):
+ return ' ' not in value
+
+ flags.register_validator('flagsaver_test_flag0', no_space)
+ self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 2)
+
+ modify_validators()
+ self.assertEqual(original_validators,
+ FLAGS['flagsaver_test_flag0'].validators)
+
+
+class FlagSaverDecoratorUsageTest(absltest.TestCase):
+
+ @flagsaver.flagsaver
+ def test_mutate1(self):
+ # Even though other test cases change the flag, it should be
+ # restored to 'unchanged0' if the flagsaver is working.
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+ FLAGS.flagsaver_test_flag0 = 'changed0'
+
+ @flagsaver.flagsaver
+ def test_mutate2(self):
+ # Even though other test cases change the flag, it should be
+ # restored to 'unchanged0' if the flagsaver is working.
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+ FLAGS.flagsaver_test_flag0 = 'changed0'
+
+ @flagsaver.flagsaver
+ def test_mutate3(self):
+ # Even though other test cases change the flag, it should be
+ # restored to 'unchanged0' if the flagsaver is working.
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+ FLAGS.flagsaver_test_flag0 = 'changed0'
+
+ @flagsaver.flagsaver
+ def test_mutate4(self):
+ # Even though other test cases change the flag, it should be
+ # restored to 'unchanged0' if the flagsaver is working.
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+ FLAGS.flagsaver_test_flag0 = 'changed0'
+
+
+class FlagSaverSetUpTearDownUsageTest(absltest.TestCase):
+
+ def setUp(self):
+ self.saved_flag_values = flagsaver.save_flag_values()
+
+ def tearDown(self):
+ flagsaver.restore_flag_values(self.saved_flag_values)
+
+ def test_mutate1(self):
+ # Even though other test cases change the flag, it should be
+ # restored to 'unchanged0' if the flagsaver is working.
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+ FLAGS.flagsaver_test_flag0 = 'changed0'
+
+ def test_mutate2(self):
+ # Even though other test cases change the flag, it should be
+ # restored to 'unchanged0' if the flagsaver is working.
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+ FLAGS.flagsaver_test_flag0 = 'changed0'
+
+ def test_mutate3(self):
+ # Even though other test cases change the flag, it should be
+ # restored to 'unchanged0' if the flagsaver is working.
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+ FLAGS.flagsaver_test_flag0 = 'changed0'
+
+ def test_mutate4(self):
+ # Even though other test cases change the flag, it should be
+ # restored to 'unchanged0' if the flagsaver is working.
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+ FLAGS.flagsaver_test_flag0 = 'changed0'
+
+
+class FlagSaverBadUsageTest(absltest.TestCase):
+ """Tests that certain kinds of improper usages raise errors."""
+
+ def test_flag_saver_on_class(self):
+ with self.assertRaises(TypeError):
+
+ # WRONG. Don't do this.
+ # Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
+ @flagsaver.flagsaver
+ class FooTest(absltest.TestCase):
+
+ def test_tautology(self):
+ pass
+
+ del FooTest
+
+ def test_flag_saver_call_on_class(self):
+ with self.assertRaises(TypeError):
+
+ # WRONG. Don't do this.
+ # Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
+ @flagsaver.flagsaver()
+ class FooTest(absltest.TestCase):
+
+ def test_tautology(self):
+ pass
+
+ del FooTest
+
+ def test_flag_saver_with_overrides_on_class(self):
+ with self.assertRaises(TypeError):
+
+ # WRONG. Don't do this.
+ # Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
+ @flagsaver.flagsaver(foo='bar')
+ class FooTest(absltest.TestCase):
+
+ def test_tautology(self):
+ pass
+
+ del FooTest
+
+ def test_multiple_positional_parameters(self):
+ with self.assertRaises(ValueError):
+ func_a = lambda: None
+ func_b = lambda: None
+ flagsaver.flagsaver(func_a, func_b)
+
+ def test_both_positional_and_keyword_parameters(self):
+ with self.assertRaises(ValueError):
+ func_a = lambda: None
+ flagsaver.flagsaver(func_a, flagsaver_test_flag0='new value')
+
+ def test_duplicate_holder_parameters(self):
+ with self.assertRaises(ValueError):
+ flagsaver.flagsaver((INT_FLAG, 45), (INT_FLAG, 45))
+
+ def test_duplicate_holder_and_kw_parameter(self):
+ with self.assertRaises(ValueError):
+ flagsaver.flagsaver((INT_FLAG, 45), **{INT_FLAG.name: 45})
+
+ def test_both_positional_and_holder_parameters(self):
+ with self.assertRaises(ValueError):
+ func_a = lambda: None
+ flagsaver.flagsaver(func_a, (INT_FLAG, 45))
+
+ def test_holder_parameters_wrong_shape(self):
+ with self.assertRaises(ValueError):
+ flagsaver.flagsaver(INT_FLAG)
+
+ def test_holder_parameters_tuple_too_long(self):
+ with self.assertRaises(ValueError):
+ # Even if it is a bool flag, it should be a tuple
+ flagsaver.flagsaver((INT_FLAG, 4, 5))
+
+ def test_holder_parameters_tuple_wrong_type(self):
+ with self.assertRaises(ValueError):
+ # Even if it is a bool flag, it should be a tuple
+ flagsaver.flagsaver((4, INT_FLAG))
+
+ def test_both_wrong_positional_parameters(self):
+ with self.assertRaises(ValueError):
+ func_a = lambda: None
+ flagsaver.flagsaver(func_a, STR_FLAG, '45')
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/testing/tests/parameterized_test.py b/absl/testing/tests/parameterized_test.py
new file mode 100644
index 0000000..8acbd93
--- /dev/null
+++ b/absl/testing/tests/parameterized_test.py
@@ -0,0 +1,1077 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for absl.testing.parameterized."""
+
+from collections import abc
+import sys
+import unittest
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+
+class MyOwnClass(object):
+ pass
+
+
+def dummy_decorator(method):
+
+ def decorated(*args, **kwargs):
+ return method(*args, **kwargs)
+
+ return decorated
+
+
+def dict_decorator(key, value):
+ """Sample implementation of a chained decorator.
+
+ Sets a single field in a dict on a test with a dict parameter.
+ Uses the exposed '_ParameterizedTestIter.testcases' field to
+ modify arguments from previous decorators to allow decorator chains.
+
+ Args:
+ key: key to map to
+ value: value to set
+
+ Returns:
+ The test decorator
+ """
+ def decorator(test_method):
+ # If decorating result of another dict_decorator
+ if isinstance(test_method, abc.Iterable):
+ actual_tests = []
+ for old_test in test_method.testcases:
+ # each test is a ('test_suffix', dict) tuple
+ new_dict = old_test[1].copy()
+ new_dict[key] = value
+ test_suffix = '%s_%s_%s' % (old_test[0], key, value)
+ actual_tests.append((test_suffix, new_dict))
+
+ test_method.testcases = actual_tests
+ return test_method
+ else:
+ test_suffix = ('_%s_%s') % (key, value)
+ tests_to_make = ((test_suffix, {key: value}),)
+ # 'test_method' here is the original test method
+ return parameterized.named_parameters(*tests_to_make)(test_method)
+ return decorator
+
+
+class ParameterizedTestsTest(absltest.TestCase):
+ # The test testcases are nested so they're not
+ # picked up by the normal test case loader code.
+
+ class GoodAdditionParams(parameterized.TestCase):
+
+ @parameterized.parameters(
+ (1, 2, 3),
+ (4, 5, 9))
+ def test_addition(self, op1, op2, result):
+ self.arguments = (op1, op2, result)
+ self.assertEqual(result, op1 + op2)
+
+ # This class does not inherit from TestCase.
+ class BadAdditionParams(absltest.TestCase):
+
+ @parameterized.parameters(
+ (1, 2, 3),
+ (4, 5, 9))
+ def test_addition(self, op1, op2, result):
+ pass # Always passes, but not called w/out TestCase.
+
+ class MixedAdditionParams(parameterized.TestCase):
+
+ @parameterized.parameters(
+ (1, 2, 1),
+ (4, 5, 9))
+ def test_addition(self, op1, op2, result):
+ self.arguments = (op1, op2, result)
+ self.assertEqual(result, op1 + op2)
+
+ class DictionaryArguments(parameterized.TestCase):
+
+ @parameterized.parameters(
+ {'op1': 1, 'op2': 2, 'result': 3},
+ {'op1': 4, 'op2': 5, 'result': 9})
+ def test_addition(self, op1, op2, result):
+ self.assertEqual(result, op1 + op2)
+
+ class NoParameterizedTests(parameterized.TestCase):
+ # iterable member with non-matching name
+ a = 'BCD'
+ # member with matching name, but not a generator
+ testInstanceMember = None # pylint: disable=invalid-name
+ test_instance_member = None
+
+ # member with a matching name and iterator, but not a generator
+ testString = 'foo' # pylint: disable=invalid-name
+ test_string = 'foo'
+
+ # generator, but no matching name
+ def someGenerator(self): # pylint: disable=invalid-name
+ yield
+ yield
+ yield
+
+ def some_generator(self):
+ yield
+ yield
+ yield
+
+ # Generator function, but not a generator instance.
+ def testGenerator(self):
+ yield
+ yield
+ yield
+
+ def test_generator(self):
+ yield
+ yield
+ yield
+
+ def testNormal(self):
+ self.assertEqual(3, 1 + 2)
+
+ def test_normal(self):
+ self.assertEqual(3, 1 + 2)
+
+ class ArgumentsWithAddresses(parameterized.TestCase):
+
+ @parameterized.parameters(
+ (object(),),
+ (MyOwnClass(),),
+ )
+ def test_something(self, case):
+ pass
+
+ class CamelCaseNamedTests(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ('Interesting', 0),
+ )
+ def testSingle(self, case):
+ pass
+
+ @parameterized.named_parameters(
+ {'testcase_name': 'Interesting', 'case': 0},
+ )
+ def testDictSingle(self, case):
+ pass
+
+ @parameterized.named_parameters(
+ ('Interesting', 0),
+ ('Boring', 1),
+ )
+ def testSomething(self, case):
+ pass
+
+ @parameterized.named_parameters(
+ {'testcase_name': 'Interesting', 'case': 0},
+ {'testcase_name': 'Boring', 'case': 1},
+ )
+ def testDictSomething(self, case):
+ pass
+
+ @parameterized.named_parameters(
+ {'testcase_name': 'Interesting', 'case': 0},
+ ('Boring', 1),
+ )
+ def testMixedSomething(self, case):
+ pass
+
+ def testWithoutParameters(self):
+ pass
+
+ class NamedTests(parameterized.TestCase):
+ """Example tests using PEP-8 style names instead of camel-case."""
+
+ @parameterized.named_parameters(
+ ('interesting', 0),
+ )
+ def test_single(self, case):
+ pass
+
+ @parameterized.named_parameters(
+ {'testcase_name': 'interesting', 'case': 0},
+ )
+ def test_dict_single(self, case):
+ pass
+
+ @parameterized.named_parameters(
+ ('interesting', 0),
+ ('boring', 1),
+ )
+ def test_something(self, case):
+ pass
+
+ @parameterized.named_parameters(
+ {'testcase_name': 'interesting', 'case': 0},
+ {'testcase_name': 'boring', 'case': 1},
+ )
+ def test_dict_something(self, case):
+ pass
+
+ @parameterized.named_parameters(
+ {'testcase_name': 'interesting', 'case': 0},
+ ('boring', 1),
+ )
+ def test_mixed_something(self, case):
+ pass
+
+ def test_without_parameters(self):
+ pass
+
+ class ChainedTests(parameterized.TestCase):
+
+ @dict_decorator('cone', 'waffle')
+ @dict_decorator('flavor', 'strawberry')
+ def test_chained(self, dictionary):
+ self.assertDictEqual(dictionary, {'cone': 'waffle',
+ 'flavor': 'strawberry'})
+
+ class SingletonListExtraction(parameterized.TestCase):
+
+ @parameterized.parameters(
+ (i, i * 2) for i in range(10))
+ def test_something(self, unused_1, unused_2):
+ pass
+
+ class SingletonArgumentExtraction(parameterized.TestCase):
+
+ @parameterized.parameters(1, 2, 3, 4, 5, 6)
+ def test_numbers(self, unused_1):
+ pass
+
+ @parameterized.parameters('foo', 'bar', 'baz')
+ def test_strings(self, unused_1):
+ pass
+
+ class SingletonDictArgument(parameterized.TestCase):
+
+ @parameterized.parameters({'op1': 1, 'op2': 2})
+ def test_something(self, op1, op2):
+ del op1, op2
+
+ @parameterized.parameters(
+ (1, 2, 3),
+ (4, 5, 9))
+ class DecoratedClass(parameterized.TestCase):
+
+ def test_add(self, arg1, arg2, arg3):
+ self.assertEqual(arg1 + arg2, arg3)
+
+ def test_subtract_fail(self, arg1, arg2, arg3):
+ self.assertEqual(arg3 + arg2, arg1)
+
+ @parameterized.parameters(
+ (a, b, a+b) for a in range(1, 5) for b in range(1, 5))
+ class GeneratorDecoratedClass(parameterized.TestCase):
+
+ def test_add(self, arg1, arg2, arg3):
+ self.assertEqual(arg1 + arg2, arg3)
+
+ def test_subtract_fail(self, arg1, arg2, arg3):
+ self.assertEqual(arg3 + arg2, arg1)
+
+ @parameterized.parameters(
+ (1, 2, 3),
+ (4, 5, 9),
+ )
+ class DecoratedBareClass(absltest.TestCase):
+
+ def test_add(self, arg1, arg2, arg3):
+ self.assertEqual(arg1 + arg2, arg3)
+
+ class OtherDecoratorUnnamed(parameterized.TestCase):
+
+ @dummy_decorator
+ @parameterized.parameters((1), (2))
+ def test_other_then_parameterized(self, arg1):
+ pass
+
+ @parameterized.parameters((1), (2))
+ @dummy_decorator
+ def test_parameterized_then_other(self, arg1):
+ pass
+
+ class OtherDecoratorNamed(parameterized.TestCase):
+
+ @dummy_decorator
+ @parameterized.named_parameters(('a', 1), ('b', 2))
+ def test_other_then_parameterized(self, arg1):
+ pass
+
+ @parameterized.named_parameters(('a', 1), ('b', 2))
+ @dummy_decorator
+ def test_parameterized_then_other(self, arg1):
+ pass
+
+ class OtherDecoratorNamedWithDict(parameterized.TestCase):
+
+ @dummy_decorator
+ @parameterized.named_parameters(
+ {'testcase_name': 'a', 'arg1': 1},
+ {'testcase_name': 'b', 'arg1': 2})
+ def test_other_then_parameterized(self, arg1):
+ pass
+
+ @parameterized.named_parameters(
+ {'testcase_name': 'a', 'arg1': 1},
+ {'testcase_name': 'b', 'arg1': 2})
+ @dummy_decorator
+ def test_parameterized_then_other(self, arg1):
+ pass
+
+ class UniqueDescriptiveNamesTest(parameterized.TestCase):
+
+ @parameterized.parameters(13, 13)
+ def test_normal(self, number):
+ del number
+
+ class MultiGeneratorsTestCase(parameterized.TestCase):
+
+ @parameterized.parameters((i for i in (1, 2, 3)), (i for i in (3, 2, 1)))
+ def test_sum(self, a, b, c):
+ self.assertEqual(6, sum([a, b, c]))
+
+ class NamedParametersReusableTestCase(parameterized.TestCase):
+ named_params_a = (
+ {'testcase_name': 'dict_a', 'unused_obj': 0},
+ ('list_a', 1),
+ )
+ named_params_b = (
+ {'testcase_name': 'dict_b', 'unused_obj': 2},
+ ('list_b', 3),
+ )
+ named_params_c = (
+ {'testcase_name': 'dict_c', 'unused_obj': 4},
+ ('list_b', 5),
+ )
+
+ @parameterized.named_parameters(*(named_params_a + named_params_b))
+ def testSomething(self, unused_obj):
+ pass
+
+ @parameterized.named_parameters(*(named_params_a + named_params_c))
+ def testSomethingElse(self, unused_obj):
+ pass
+
+ class SuperclassTestCase(parameterized.TestCase):
+
+ @parameterized.parameters('foo', 'bar')
+ def test_name(self, name):
+ del name
+
+ class SubclassTestCase(SuperclassTestCase):
+ pass
+
+ @unittest.skipIf(
+ (sys.version_info[:2] == (3, 7) and sys.version_info[2] in {0, 1, 2}),
+ 'Python 3.7.0 to 3.7.2 have a bug that breaks this test, see '
+ 'https://bugs.python.org/issue35767')
+ def test_missing_inheritance(self):
+ ts = unittest.makeSuite(self.BadAdditionParams)
+ self.assertEqual(1, ts.countTestCases())
+
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(1, res.testsRun)
+ self.assertFalse(res.wasSuccessful())
+ self.assertIn('without having inherited', str(res.errors[0]))
+
+ def test_correct_extraction_numbers(self):
+ ts = unittest.makeSuite(self.GoodAdditionParams)
+ self.assertEqual(2, ts.countTestCases())
+
+ def test_successful_execution(self):
+ ts = unittest.makeSuite(self.GoodAdditionParams)
+
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(2, res.testsRun)
+ self.assertTrue(res.wasSuccessful())
+
+ def test_correct_arguments(self):
+ ts = unittest.makeSuite(self.GoodAdditionParams)
+ res = unittest.TestResult()
+
+ params = set([
+ (1, 2, 3),
+ (4, 5, 9)])
+ for test in ts:
+ test(res)
+ self.assertIn(test.arguments, params)
+ params.remove(test.arguments)
+ self.assertEmpty(params)
+
+ def test_recorded_failures(self):
+ ts = unittest.makeSuite(self.MixedAdditionParams)
+ self.assertEqual(2, ts.countTestCases())
+
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(2, res.testsRun)
+ self.assertFalse(res.wasSuccessful())
+ self.assertLen(res.failures, 1)
+ self.assertEmpty(res.errors)
+
+ def test_short_description(self):
+ ts = unittest.makeSuite(self.GoodAdditionParams)
+ short_desc = list(ts)[0].shortDescription()
+
+ location = unittest.util.strclass(self.GoodAdditionParams).replace(
+ '__main__.', '')
+ expected = ('{}.test_addition0 (1, 2, 3)\n'.format(location) +
+ 'test_addition(1, 2, 3)')
+ self.assertEqual(expected, short_desc)
+
+ def test_short_description_addresses_removed(self):
+ ts = unittest.makeSuite(self.ArgumentsWithAddresses)
+ short_desc = list(ts)[0].shortDescription().split('\n')
+ self.assertEqual(
+ 'test_something(<object>)', short_desc[1])
+ short_desc = list(ts)[1].shortDescription().split('\n')
+ self.assertEqual(
+ 'test_something(<__main__.MyOwnClass>)', short_desc[1])
+
+ def test_id(self):
+ ts = unittest.makeSuite(self.ArgumentsWithAddresses)
+ self.assertEqual(
+ (unittest.util.strclass(self.ArgumentsWithAddresses) +
+ '.test_something0 (<object>)'),
+ list(ts)[0].id())
+ ts = unittest.makeSuite(self.GoodAdditionParams)
+ self.assertEqual(
+ (unittest.util.strclass(self.GoodAdditionParams) +
+ '.test_addition0 (1, 2, 3)'),
+ list(ts)[0].id())
+
+ def test_str(self):
+ ts = unittest.makeSuite(self.GoodAdditionParams)
+ test = list(ts)[0]
+
+ expected = 'test_addition0 (1, 2, 3) ({})'.format(
+ unittest.util.strclass(self.GoodAdditionParams))
+ self.assertEqual(expected, str(test))
+
+ def test_dict_parameters(self):
+ ts = unittest.makeSuite(self.DictionaryArguments)
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(2, res.testsRun)
+ self.assertTrue(res.wasSuccessful())
+
+ def test_no_parameterized_tests(self):
+ ts = unittest.makeSuite(self.NoParameterizedTests)
+ self.assertEqual(4, ts.countTestCases())
+ short_descs = [x.shortDescription() for x in list(ts)]
+ full_class_name = unittest.util.strclass(self.NoParameterizedTests)
+ full_class_name = full_class_name.replace('__main__.', '')
+ self.assertSameElements(
+ [
+ '{}.testGenerator'.format(full_class_name),
+ '{}.test_generator'.format(full_class_name),
+ '{}.testNormal'.format(full_class_name),
+ '{}.test_normal'.format(full_class_name),
+ ],
+ short_descs)
+
+ def test_successful_product_test_testgrid(self):
+
+ class GoodProductTestCase(parameterized.TestCase):
+
+ @parameterized.product(
+ num=(0, 20, 80),
+ modulo=(2, 4),
+ expected=(0,)
+ )
+ def testModuloResult(self, num, modulo, expected):
+ self.assertEqual(expected, num % modulo)
+
+ ts = unittest.makeSuite(GoodProductTestCase)
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(ts.countTestCases(), 6)
+ self.assertEqual(res.testsRun, 6)
+ self.assertTrue(res.wasSuccessful())
+
+ def test_successful_product_test_kwarg_seqs(self):
+
+ class GoodProductTestCase(parameterized.TestCase):
+
+ @parameterized.product((dict(num=0), dict(num=20), dict(num=0)),
+ (dict(modulo=2), dict(modulo=4)),
+ (dict(expected=0),))
+ def testModuloResult(self, num, modulo, expected):
+ self.assertEqual(expected, num % modulo)
+
+ ts = unittest.makeSuite(GoodProductTestCase)
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(ts.countTestCases(), 6)
+ self.assertEqual(res.testsRun, 6)
+ self.assertTrue(res.wasSuccessful())
+
+ def test_successful_product_test_kwarg_seq_and_testgrid(self):
+
+ class GoodProductTestCase(parameterized.TestCase):
+
+ @parameterized.product((dict(
+ num=5, modulo=3, expected=2), dict(num=7, modulo=4, expected=3)),
+ dtype=(int, float))
+ def testModuloResult(self, num, dtype, modulo, expected):
+ self.assertEqual(expected, dtype(num) % modulo)
+
+ ts = unittest.makeSuite(GoodProductTestCase)
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(ts.countTestCases(), 4)
+ self.assertEqual(res.testsRun, 4)
+ self.assertTrue(res.wasSuccessful())
+
+ def test_inconsistent_arg_names_in_kwargs_seq(self):
+ with self.assertRaisesRegex(AssertionError, 'must all have the same keys'):
+
+ class BadProductParams(parameterized.TestCase): # pylint: disable=unused-variable
+
+ @parameterized.product((dict(num=5, modulo=3), dict(num=7, modula=2)),
+ dtype=(int, float))
+ def test_something(self):
+ pass # not called because argnames are not the same
+
+ def test_duplicate_arg_names_in_kwargs_seqs(self):
+ with self.assertRaisesRegex(AssertionError, 'must all have distinct'):
+
+ class BadProductParams(parameterized.TestCase): # pylint: disable=unused-variable
+
+ @parameterized.product((dict(num=5, modulo=3), dict(num=7, modulo=4)),
+ (dict(foo='bar', num=5), dict(foo='baz', num=7)),
+ dtype=(int, float))
+ def test_something(self):
+ pass # not called because `num` is specified twice
+
+ def test_duplicate_arg_names_in_kwargs_seq_and_testgrid(self):
+ with self.assertRaisesRegex(AssertionError, 'duplicate argument'):
+
+ class BadProductParams(parameterized.TestCase): # pylint: disable=unused-variable
+
+ @parameterized.product(
+ (dict(num=5, modulo=3), dict(num=7, modulo=4)),
+ (dict(foo='bar'), dict(foo='baz')),
+ dtype=(int, float),
+ foo=('a', 'b'),
+ )
+ def test_something(self):
+ pass # not called because `foo` is specified twice
+
+ def test_product_recorded_failures(self):
+
+ class MixedProductTestCase(parameterized.TestCase):
+
+ @parameterized.product(
+ num=(0, 10, 20),
+ modulo=(2, 4),
+ expected=(0,)
+ )
+ def testModuloResult(self, num, modulo, expected):
+ self.assertEqual(expected, num % modulo)
+
+ ts = unittest.makeSuite(MixedProductTestCase)
+ self.assertEqual(6, ts.countTestCases())
+
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(res.testsRun, 6)
+ self.assertFalse(res.wasSuccessful())
+ self.assertLen(res.failures, 1)
+ self.assertEmpty(res.errors)
+
+ def test_mismatched_product_parameter(self):
+
+ class MismatchedProductParam(parameterized.TestCase):
+
+ @parameterized.product(
+ a=(1, 2),
+ mismatch=(1, 2)
+ )
+ # will fail because of mismatch in parameter names.
+ def test_something(self, a, b):
+ pass
+
+ ts = unittest.makeSuite(MismatchedProductParam)
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(res.testsRun, 4)
+ self.assertFalse(res.wasSuccessful())
+ self.assertLen(res.errors, 4)
+
+ def test_no_test_error_empty_product_parameter(self):
+ with self.assertRaises(parameterized.NoTestsError):
+
+ class EmptyProductParam(parameterized.TestCase): # pylint: disable=unused-variable
+
+ @parameterized.product(arg1=[1, 2], arg2=[])
+ def test_something(self, arg1, arg2):
+ pass # not called because arg2 has empty list of values.
+
+ def test_bad_product_parameters(self):
+ with self.assertRaisesRegex(AssertionError, 'must be given as list or'):
+
+ class BadProductParams(parameterized.TestCase): # pylint: disable=unused-variable
+
+ @parameterized.product(arg1=[1, 2], arg2='abcd')
+ def test_something(self, arg1, arg2):
+ pass # not called because arg2 is not list or tuple.
+
+ def test_generator_tests_disallowed(self):
+ with self.assertRaisesRegex(RuntimeError, 'generated.*without'):
+ class GeneratorTests(parameterized.TestCase): # pylint: disable=unused-variable
+ test_generator_method = (lambda self: None for _ in range(10))
+
+ def test_named_parameters_run(self):
+ ts = unittest.makeSuite(self.NamedTests)
+ self.assertEqual(9, ts.countTestCases())
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(9, res.testsRun)
+ self.assertTrue(res.wasSuccessful())
+
+ def test_named_parameters_id(self):
+ ts = sorted(unittest.makeSuite(self.CamelCaseNamedTests),
+ key=lambda t: t.id())
+ self.assertLen(ts, 9)
+ full_class_name = unittest.util.strclass(self.CamelCaseNamedTests)
+ self.assertEqual(
+ full_class_name + '.testDictSingleInteresting',
+ ts[0].id())
+ self.assertEqual(
+ full_class_name + '.testDictSomethingBoring',
+ ts[1].id())
+ self.assertEqual(
+ full_class_name + '.testDictSomethingInteresting',
+ ts[2].id())
+ self.assertEqual(
+ full_class_name + '.testMixedSomethingBoring',
+ ts[3].id())
+ self.assertEqual(
+ full_class_name + '.testMixedSomethingInteresting',
+ ts[4].id())
+ self.assertEqual(
+ full_class_name + '.testSingleInteresting',
+ ts[5].id())
+ self.assertEqual(
+ full_class_name + '.testSomethingBoring',
+ ts[6].id())
+ self.assertEqual(
+ full_class_name + '.testSomethingInteresting',
+ ts[7].id())
+ self.assertEqual(
+ full_class_name + '.testWithoutParameters',
+ ts[8].id())
+
+ def test_named_parameters_id_with_underscore_case(self):
+ ts = sorted(unittest.makeSuite(self.NamedTests),
+ key=lambda t: t.id())
+ self.assertLen(ts, 9)
+ full_class_name = unittest.util.strclass(self.NamedTests)
+ self.assertEqual(
+ full_class_name + '.test_dict_single_interesting',
+ ts[0].id())
+ self.assertEqual(
+ full_class_name + '.test_dict_something_boring',
+ ts[1].id())
+ self.assertEqual(
+ full_class_name + '.test_dict_something_interesting',
+ ts[2].id())
+ self.assertEqual(
+ full_class_name + '.test_mixed_something_boring',
+ ts[3].id())
+ self.assertEqual(
+ full_class_name + '.test_mixed_something_interesting',
+ ts[4].id())
+ self.assertEqual(
+ full_class_name + '.test_single_interesting',
+ ts[5].id())
+ self.assertEqual(
+ full_class_name + '.test_something_boring',
+ ts[6].id())
+ self.assertEqual(
+ full_class_name + '.test_something_interesting',
+ ts[7].id())
+ self.assertEqual(
+ full_class_name + '.test_without_parameters',
+ ts[8].id())
+
+ def test_named_parameters_short_description(self):
+ ts = sorted(unittest.makeSuite(self.NamedTests),
+ key=lambda t: t.id())
+ actual = {t._testMethodName: t.shortDescription() for t in ts}
+ expected = {
+ 'test_dict_single_interesting': 'case=0',
+ 'test_dict_something_boring': 'case=1',
+ 'test_dict_something_interesting': 'case=0',
+ 'test_mixed_something_boring': '1',
+ 'test_mixed_something_interesting': 'case=0',
+ 'test_something_boring': '1',
+ 'test_something_interesting': '0',
+ }
+ for test_name, param_repr in expected.items():
+ short_desc = actual[test_name].split('\n')
+ self.assertIn(test_name, short_desc[0])
+ self.assertEqual('{}({})'.format(test_name, param_repr), short_desc[1])
+
+ def test_load_tuple_named_test(self):
+ loader = unittest.TestLoader()
+ ts = list(loader.loadTestsFromName('NamedTests.test_something_interesting',
+ module=self))
+ self.assertLen(ts, 1)
+ self.assertEndsWith(ts[0].id(), '.test_something_interesting')
+
+ def test_load_dict_named_test(self):
+ loader = unittest.TestLoader()
+ ts = list(
+ loader.loadTestsFromName(
+ 'NamedTests.test_dict_something_interesting', module=self))
+ self.assertLen(ts, 1)
+ self.assertEndsWith(ts[0].id(), '.test_dict_something_interesting')
+
+ def test_load_mixed_named_test(self):
+ loader = unittest.TestLoader()
+ ts = list(
+ loader.loadTestsFromName(
+ 'NamedTests.test_mixed_something_interesting', module=self))
+ self.assertLen(ts, 1)
+ self.assertEndsWith(ts[0].id(), '.test_mixed_something_interesting')
+
+ def test_duplicate_named_test_fails(self):
+ with self.assertRaises(parameterized.DuplicateTestNameError):
+
+ class _(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ('Interesting', 0),
+ ('Interesting', 1),
+ )
+ def test_something(self, unused_obj):
+ pass
+
+ def test_duplicate_dict_named_test_fails(self):
+ with self.assertRaises(parameterized.DuplicateTestNameError):
+
+ class _(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ {'testcase_name': 'Interesting', 'unused_obj': 0},
+ {'testcase_name': 'Interesting', 'unused_obj': 1},
+ )
+ def test_dict_something(self, unused_obj):
+ pass
+
+ def test_duplicate_mixed_named_test_fails(self):
+ with self.assertRaises(parameterized.DuplicateTestNameError):
+
+ class _(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ {'testcase_name': 'Interesting', 'unused_obj': 0},
+ ('Interesting', 1),
+ )
+ def test_mixed_something(self, unused_obj):
+ pass
+
+ def test_named_test_with_no_name_fails(self):
+ with self.assertRaises(RuntimeError):
+
+ class _(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ (0,),
+ )
+ def test_something(self, unused_obj):
+ pass
+
+ def test_named_test_dict_with_no_name_fails(self):
+ with self.assertRaises(RuntimeError):
+
+ class _(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ {'unused_obj': 0},
+ )
+ def test_something(self, unused_obj):
+ pass
+
+ def test_parameterized_test_iter_has_testcases_property(self):
+ @parameterized.parameters(1, 2, 3, 4, 5, 6)
+ def test_something(unused_self, unused_obj): # pylint: disable=invalid-name
+ pass
+
+ expected_testcases = [1, 2, 3, 4, 5, 6]
+ self.assertTrue(hasattr(test_something, 'testcases'))
+ self.assertCountEqual(expected_testcases, test_something.testcases)
+
+ def test_chained_decorator(self):
+ ts = unittest.makeSuite(self.ChainedTests)
+ self.assertEqual(1, ts.countTestCases())
+ test = next(t for t in ts)
+ self.assertTrue(hasattr(test, 'test_chained_flavor_strawberry_cone_waffle'))
+ res = unittest.TestResult()
+
+ ts.run(res)
+ self.assertEqual(1, res.testsRun)
+ self.assertTrue(res.wasSuccessful())
+
+ def test_singleton_list_extraction(self):
+ ts = unittest.makeSuite(self.SingletonListExtraction)
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(10, res.testsRun)
+ self.assertTrue(res.wasSuccessful())
+
+ def test_singleton_argument_extraction(self):
+ ts = unittest.makeSuite(self.SingletonArgumentExtraction)
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(9, res.testsRun)
+ self.assertTrue(res.wasSuccessful())
+
+ def test_singleton_dict_argument(self):
+ ts = unittest.makeSuite(self.SingletonDictArgument)
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(1, res.testsRun)
+ self.assertTrue(res.wasSuccessful())
+
+ def test_decorated_bare_class(self):
+ ts = unittest.makeSuite(self.DecoratedBareClass)
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(2, res.testsRun)
+ self.assertTrue(res.wasSuccessful(), msg=str(res.failures))
+
+ def test_decorated_class(self):
+ ts = unittest.makeSuite(self.DecoratedClass)
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(4, res.testsRun)
+ self.assertLen(res.failures, 2)
+
+ def test_generator_decorated_class(self):
+ ts = unittest.makeSuite(self.GeneratorDecoratedClass)
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(32, res.testsRun)
+ self.assertLen(res.failures, 16)
+
+ def test_no_duplicate_decorations(self):
+ with self.assertRaises(AssertionError):
+
+ @parameterized.parameters(1, 2, 3, 4)
+ class _(parameterized.TestCase):
+
+ @parameterized.parameters(5, 6, 7, 8)
+ def test_something(self, unused_obj):
+ pass
+
+ def test_double_class_decorations_not_supported(self):
+
+ @parameterized.parameters('foo', 'bar')
+ class SuperclassWithClassDecorator(parameterized.TestCase):
+
+ def test_name(self, name):
+ del name
+
+ with self.assertRaises(AssertionError):
+
+ @parameterized.parameters('foo', 'bar')
+ class SubclassWithClassDecorator(SuperclassWithClassDecorator):
+ pass
+
+ del SubclassWithClassDecorator
+
+ def test_other_decorator_ordering_unnamed(self):
+ ts = unittest.makeSuite(self.OtherDecoratorUnnamed)
+ res = unittest.TestResult()
+ ts.run(res)
+ # Two for when the parameterized tests call the skip wrapper.
+ # One for when the skip wrapper is called first and doesn't iterate.
+ self.assertEqual(3, res.testsRun)
+ self.assertFalse(res.wasSuccessful())
+ self.assertEmpty(res.failures)
+ # One error from test_other_then_parameterized.
+ self.assertLen(res.errors, 1)
+
+ def test_other_decorator_ordering_named(self):
+ ts = unittest.makeSuite(self.OtherDecoratorNamed)
+ # Verify it generates the test method names from the original test method.
+ for test in ts: # There is only one test.
+ ts_attributes = dir(test)
+ self.assertIn('test_parameterized_then_other_a', ts_attributes)
+ self.assertIn('test_parameterized_then_other_b', ts_attributes)
+
+ res = unittest.TestResult()
+ ts.run(res)
+ # Two for when the parameterized tests call the skip wrapper.
+ # One for when the skip wrapper is called first and doesn't iterate.
+ self.assertEqual(3, res.testsRun)
+ self.assertFalse(res.wasSuccessful())
+ self.assertEmpty(res.failures)
+ # One error from test_other_then_parameterized.
+ self.assertLen(res.errors, 1)
+
+ def test_other_decorator_ordering_named_with_dict(self):
+ ts = unittest.makeSuite(self.OtherDecoratorNamedWithDict)
+ # Verify it generates the test method names from the original test method.
+ for test in ts: # There is only one test.
+ ts_attributes = dir(test)
+ self.assertIn('test_parameterized_then_other_a', ts_attributes)
+ self.assertIn('test_parameterized_then_other_b', ts_attributes)
+
+ res = unittest.TestResult()
+ ts.run(res)
+ # Two for when the parameterized tests call the skip wrapper.
+ # One for when the skip wrapper is called first and doesn't iterate.
+ self.assertEqual(3, res.testsRun)
+ self.assertFalse(res.wasSuccessful())
+ self.assertEmpty(res.failures)
+ # One error from test_other_then_parameterized.
+ self.assertLen(res.errors, 1)
+
+ def test_no_test_error_empty_parameters(self):
+ with self.assertRaises(parameterized.NoTestsError):
+
+ @parameterized.parameters()
+ def test_something():
+ pass
+
+ del test_something
+
+ def test_no_test_error_empty_generator(self):
+ with self.assertRaises(parameterized.NoTestsError):
+
+ @parameterized.parameters((i for i in []))
+ def test_something():
+ pass
+
+ del test_something
+
+ def test_unique_descriptive_names(self):
+
+ class RecordSuccessTestsResult(unittest.TestResult):
+
+ def __init__(self, *args, **kwargs):
+ super(RecordSuccessTestsResult, self).__init__(*args, **kwargs)
+ self.successful_tests = []
+
+ def addSuccess(self, test):
+ self.successful_tests.append(test)
+
+ ts = unittest.makeSuite(self.UniqueDescriptiveNamesTest)
+ res = RecordSuccessTestsResult()
+ ts.run(res)
+ self.assertTrue(res.wasSuccessful())
+ self.assertEqual(2, res.testsRun)
+ test_ids = [test.id() for test in res.successful_tests]
+ full_class_name = unittest.util.strclass(self.UniqueDescriptiveNamesTest)
+ expected_test_ids = [
+ full_class_name + '.test_normal0 (13)',
+ full_class_name + '.test_normal1 (13)',
+ ]
+ self.assertTrue(test_ids)
+ self.assertCountEqual(expected_test_ids, test_ids)
+
+ def test_multi_generators(self):
+ ts = unittest.makeSuite(self.MultiGeneratorsTestCase)
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(2, res.testsRun)
+ self.assertTrue(res.wasSuccessful(), msg=str(res.failures))
+
+ def test_named_parameters_reusable(self):
+ ts = unittest.makeSuite(self.NamedParametersReusableTestCase)
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(8, res.testsRun)
+ self.assertTrue(res.wasSuccessful(), msg=str(res.failures))
+
+ def test_subclass_inherits_superclass_test_params_reprs(self):
+ self.assertEqual(
+ {'test_name0': "('foo')", 'test_name1': "('bar')"},
+ self.SuperclassTestCase._test_params_reprs)
+ self.assertEqual(
+ {'test_name0': "('foo')", 'test_name1': "('bar')"},
+ self.SubclassTestCase._test_params_reprs)
+
+
+def _decorate_with_side_effects(func, self):
+ self.sideeffect = True
+ func(self)
+
+
+class CoopMetaclassCreationTest(absltest.TestCase):
+
+ class TestBase(absltest.TestCase):
+
+ # This test simulates a metaclass that sets some attribute ('sideeffect')
+ # on each member of the class that starts with 'test'. The test code then
+ # checks that this attribute exists when the custom metaclass and
+ # TestGeneratorMetaclass are combined with cooperative inheritance.
+
+ # The attribute has to be set in the __init__ method of the metaclass,
+ # since the TestGeneratorMetaclass already overrides __new__. Only one
+ # base metaclass can override __new__, but all can provide custom __init__
+ # methods.
+
+ class __metaclass__(type): # pylint: disable=g-bad-name
+
+ def __init__(cls, name, bases, dct):
+ type.__init__(cls, name, bases, dct)
+ for member_name, obj in dct.items():
+ if member_name.startswith('test'):
+ setattr(cls, member_name,
+ lambda self, f=obj: _decorate_with_side_effects(f, self))
+
+ class MyParams(parameterized.CoopTestCase(TestBase)):
+
+ @parameterized.parameters(
+ (1, 2, 3),
+ (4, 5, 9))
+ def test_addition(self, op1, op2, result):
+ self.assertEqual(result, op1 + op2)
+
+ class MySuite(unittest.TestSuite):
+ # Under Python 3.4 the TestCases in the suite's list of tests to run are
+ # destroyed and replaced with None after successful execution by default.
+ # This disables that behavior.
+ _cleanup = False
+
+ def test_successful_execution(self):
+ ts = unittest.makeSuite(self.MyParams)
+
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertEqual(2, res.testsRun)
+ self.assertTrue(res.wasSuccessful())
+
+ def test_metaclass_side_effects(self):
+ ts = unittest.makeSuite(self.MyParams, suiteClass=self.MySuite)
+
+ res = unittest.TestResult()
+ ts.run(res)
+ self.assertTrue(list(ts)[0].sideeffect)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/testing/tests/xml_reporter_helper_test.py b/absl/testing/tests/xml_reporter_helper_test.py
new file mode 100644
index 0000000..661bbdc
--- /dev/null
+++ b/absl/testing/tests/xml_reporter_helper_test.py
@@ -0,0 +1,97 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+
+from absl import flags
+from absl.testing import absltest
+
+
+FLAGS = flags.FLAGS
+flags.DEFINE_boolean('set_up_module_error', False,
+ 'Cause setupModule to error.')
+flags.DEFINE_boolean('tear_down_module_error', False,
+ 'Cause tearDownModule to error.')
+
+flags.DEFINE_boolean('set_up_class_error', False, 'Cause setUpClass to error.')
+flags.DEFINE_boolean('tear_down_class_error', False,
+ 'Cause tearDownClass to error.')
+
+flags.DEFINE_boolean('set_up_error', False, 'Cause setUp to error.')
+flags.DEFINE_boolean('tear_down_error', False, 'Cause tearDown to error.')
+flags.DEFINE_boolean('test_error', False, 'Cause the test to error.')
+
+flags.DEFINE_boolean('set_up_fail', False, 'Cause setUp to fail.')
+flags.DEFINE_boolean('tear_down_fail', False, 'Cause tearDown to fail.')
+flags.DEFINE_boolean('test_fail', False, 'Cause the test to fail.')
+
+flags.DEFINE_float('random_error', 0.0,
+ '0 - 1.0: fraction of a random failure at any step',
+ lower_bound=0.0, upper_bound=1.0)
+
+
+def _random_error():
+ return random.random() < FLAGS.random_error
+
+
+def setUpModule():
+ if FLAGS.set_up_module_error or _random_error():
+ raise Exception('setUpModule Errored!')
+
+
+def tearDownModule():
+ if FLAGS.tear_down_module_error or _random_error():
+ raise Exception('tearDownModule Errored!')
+
+
+class FailableTest(absltest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ if FLAGS.set_up_class_error or _random_error():
+ raise Exception('setUpClass Errored!')
+
+ @classmethod
+ def tearDownClass(cls):
+ if FLAGS.tear_down_class_error or _random_error():
+ raise Exception('tearDownClass Errored!')
+
+ def setUp(self):
+ if FLAGS.set_up_error or _random_error():
+ raise Exception('setUp Errored!')
+
+ if FLAGS.set_up_fail:
+ self.fail('setUp Failed!')
+
+ def tearDown(self):
+ if FLAGS.tear_down_error or _random_error():
+ raise Exception('tearDown Errored!')
+
+ if FLAGS.tear_down_fail:
+ self.fail('tearDown Failed!')
+
+ def test(self):
+ if FLAGS.test_error or _random_error():
+ raise Exception('test Errored!')
+
+ if FLAGS.test_fail:
+ self.fail('test Failed!')
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/testing/tests/xml_reporter_test.py b/absl/testing/tests/xml_reporter_test.py
new file mode 100644
index 0000000..0261f64
--- /dev/null
+++ b/absl/testing/tests/xml_reporter_test.py
@@ -0,0 +1,1108 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import io
+import os
+import re
+import subprocess
+import sys
+import tempfile
+import threading
+import time
+import unittest
+from unittest import mock
+from xml.etree import ElementTree
+from xml.parsers import expat
+
+from absl import logging
+from absl.testing import _bazelize_command
+from absl.testing import absltest
+from absl.testing import parameterized
+from absl.testing import xml_reporter
+
+
+class StringIOWriteLn(io.StringIO):
+
+ def writeln(self, line):
+ self.write(line + '\n')
+
+
+class MockTest(absltest.TestCase):
+ failureException = AssertionError
+
+ def __init__(self, name):
+ super(MockTest, self).__init__()
+ self.name = name
+
+ def id(self):
+ return self.name
+
+ def runTest(self):
+ return
+
+ def shortDescription(self):
+ return "This is this test's description."
+
+
+# str(exception_type) is different between Python 2 and 3.
+def xml_escaped_exception_type(exception_type):
+ return xml_reporter._escape_xml_attr(str(exception_type))
+
+
+OUTPUT_STRING = '\n'.join([
+ r'<\?xml version="1.0"\?>',
+ ('<testsuites name="" tests="%(tests)d" failures="%(failures)d"'
+ ' errors="%(errors)d" time="%(run_time).1f" timestamp="%(start_time)s">'),
+ ('<testsuite name="%(suite_name)s" tests="%(tests)d"'
+ ' failures="%(failures)d" errors="%(errors)d" time="%(run_time).1f"'
+ ' timestamp="%(start_time)s">'),
+ (' <testcase name="%(test_name)s" status="%(status)s" result="%(result)s"'
+ ' time="%(run_time).1f" classname="%(classname)s"'
+ ' timestamp="%(start_time)s">%(message)s'),
+ ' </testcase>', '</testsuite>',
+ '</testsuites>',
+])
+
+FAILURE_MESSAGE = r"""
+ <failure message="e" type="{}"><!\[CDATA\[Traceback \(most recent call last\):
+ File ".*xml_reporter_test\.py", line \d+, in get_sample_failure
+ self.fail\(\'e\'\)
+AssertionError: e
+\]\]></failure>""".format(xml_escaped_exception_type(AssertionError))
+
+ERROR_MESSAGE = r"""
+ <error message="invalid&#x20;literal&#x20;for&#x20;int\(\)&#x20;with&#x20;base&#x20;10:&#x20;(&apos;)?a(&apos;)?" type="{}"><!\[CDATA\[Traceback \(most recent call last\):
+ File ".*xml_reporter_test\.py", line \d+, in get_sample_error
+ int\('a'\)
+ValueError: invalid literal for int\(\) with base 10: '?a'?
+\]\]></error>""".format(xml_escaped_exception_type(ValueError))
+
+UNICODE_MESSAGE = r"""
+ <%s message="{0}" type="{1}"><!\[CDATA\[Traceback \(most recent call last\):
+ File ".*xml_reporter_test\.py", line \d+, in get_unicode_sample_failure
+ raise AssertionError\(u'\\xe9'\)
+AssertionError: {0}
+\]\]></%s>""".format(
+ r'\xe9',
+ xml_escaped_exception_type(AssertionError))
+
+NEWLINE_MESSAGE = r"""
+ <%s message="{0}" type="{1}"><!\[CDATA\[Traceback \(most recent call last\):
+ File ".*xml_reporter_test\.py", line \d+, in get_newline_message_sample_failure
+ raise AssertionError\(\'{2}'\)
+AssertionError: {3}
+\]\]></%s>""".format(
+ 'new&#xA;line',
+ xml_escaped_exception_type(AssertionError),
+ r'new\\nline',
+ 'new\nline')
+
+UNEXPECTED_SUCCESS_MESSAGE = '\n'.join([
+ '',
+ (r' <error message="" type=""><!\[CDATA\[Test case '
+ r'__main__.MockTest.unexpectedly_passing_test should have failed, '
+ r'but passed.\]\]></error>'),
+])
+
+UNICODE_ERROR_MESSAGE = UNICODE_MESSAGE % ('error', 'error')
+NEWLINE_ERROR_MESSAGE = NEWLINE_MESSAGE % ('error', 'error')
+
+
+class TextAndXMLTestResultTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.stream = StringIOWriteLn()
+ self.xml_stream = io.StringIO()
+
+ def _make_result(self, times):
+ timer = mock.Mock()
+ timer.side_effect = times
+ return xml_reporter._TextAndXMLTestResult(self.xml_stream, self.stream,
+ 'foo', 0, timer)
+
+ def _assert_match(self, regex, output):
+ fail_msg = 'Expected regex:\n{}\nTo match:\n{}'.format(regex, output)
+ self.assertRegex(output, regex, fail_msg)
+
+ def _assert_valid_xml(self, xml_output):
+ try:
+ expat.ParserCreate().Parse(xml_output)
+ except expat.ExpatError as e:
+ raise AssertionError('Bad XML output: {}\n{}'.format(e, xml_output))
+
+ def _simulate_error_test(self, test, result):
+ result.startTest(test)
+ result.addError(test, self.get_sample_error())
+ result.stopTest(test)
+
+ def _simulate_failing_test(self, test, result):
+ result.startTest(test)
+ result.addFailure(test, self.get_sample_failure())
+ result.stopTest(test)
+
+ def _simulate_passing_test(self, test, result):
+ result.startTest(test)
+ result.addSuccess(test)
+ result.stopTest(test)
+
+ def _iso_timestamp(self, timestamp):
+ return datetime.datetime.utcfromtimestamp(timestamp).isoformat() + '+00:00'
+
+ def test_with_passing_test(self):
+ start_time = 0
+ end_time = 2
+ result = self._make_result((start_time, start_time, end_time, end_time))
+
+ test = MockTest('__main__.MockTest.passing_test')
+ result.startTestRun()
+ result.startTest(test)
+ result.addSuccess(test)
+ result.stopTest(test)
+ result.stopTestRun()
+ result.printErrors()
+
+ run_time = end_time - start_time
+ expected_re = OUTPUT_STRING % {
+ 'suite_name': 'MockTest',
+ 'tests': 1,
+ 'failures': 0,
+ 'errors': 0,
+ 'run_time': run_time,
+ 'start_time': re.escape(self._iso_timestamp(start_time),),
+ 'test_name': 'passing_test',
+ 'classname': '__main__.MockTest',
+ 'status': 'run',
+ 'result': 'completed',
+ 'attributes': '',
+ 'message': ''
+ }
+ self._assert_match(expected_re, self.xml_stream.getvalue())
+
+ def test_with_passing_subtest(self):
+ start_time = 0
+ end_time = 2
+ result = self._make_result((start_time, start_time, end_time, end_time))
+
+ test = MockTest('__main__.MockTest.passing_test')
+ subtest = unittest.case._SubTest(test, 'msg', None)
+ result.startTestRun()
+ result.startTest(test)
+ result.addSubTest(test, subtest, None)
+ result.stopTestRun()
+ result.printErrors()
+
+ run_time = end_time - start_time
+ expected_re = OUTPUT_STRING % {
+ 'suite_name': 'MockTest',
+ 'tests': 1,
+ 'failures': 0,
+ 'errors': 0,
+ 'run_time': run_time,
+ 'start_time': re.escape(self._iso_timestamp(start_time),),
+ 'test_name': r'passing_test&#x20;\[msg\]',
+ 'classname': '__main__.MockTest',
+ 'status': 'run',
+ 'result': 'completed',
+ 'attributes': '',
+ 'message': ''
+ }
+ self._assert_match(expected_re, self.xml_stream.getvalue())
+
+ def test_with_passing_subtest_with_dots_in_parameter_name(self):
+ start_time = 0
+ end_time = 2
+ result = self._make_result((start_time, start_time, end_time, end_time))
+
+ test = MockTest('__main__.MockTest.passing_test')
+ subtest = unittest.case._SubTest(test, 'msg', {'case': 'a.b.c'})
+ result.startTestRun()
+ result.startTest(test)
+ result.addSubTest(test, subtest, None)
+ result.stopTestRun()
+ result.printErrors()
+
+ run_time = end_time - start_time
+ expected_re = OUTPUT_STRING % {
+ 'suite_name':
+ 'MockTest',
+ 'tests':
+ 1,
+ 'failures':
+ 0,
+ 'errors':
+ 0,
+ 'run_time':
+ run_time,
+ 'start_time':
+ re.escape(self._iso_timestamp(start_time),),
+ 'test_name':
+ r'passing_test&#x20;\[msg\]&#x20;\(case=&apos;a.b.c&apos;\)',
+ 'classname':
+ '__main__.MockTest',
+ 'status':
+ 'run',
+ 'result':
+ 'completed',
+ 'attributes':
+ '',
+ 'message':
+ ''
+ }
+ self._assert_match(expected_re, self.xml_stream.getvalue())
+
+ def get_sample_error(self):
+ try:
+ int('a')
+ except ValueError:
+ error_values = sys.exc_info()
+ return error_values
+
+ def get_sample_failure(self):
+ try:
+ self.fail('e')
+ except AssertionError:
+ error_values = sys.exc_info()
+ return error_values
+
+ def get_newline_message_sample_failure(self):
+ try:
+ raise AssertionError('new\nline')
+ except AssertionError:
+ error_values = sys.exc_info()
+ return error_values
+
+ def get_unicode_sample_failure(self):
+ try:
+ raise AssertionError(u'\xe9')
+ except AssertionError:
+ error_values = sys.exc_info()
+ return error_values
+
+ def get_terminal_escape_sample_failure(self):
+ try:
+ raise AssertionError('\x1b')
+ except AssertionError:
+ error_values = sys.exc_info()
+ return error_values
+
+ def test_with_failing_test(self):
+ start_time = 10
+ end_time = 20
+ result = self._make_result((start_time, start_time, end_time, end_time))
+
+ test = MockTest('__main__.MockTest.failing_test')
+ result.startTestRun()
+ result.startTest(test)
+ result.addFailure(test, self.get_sample_failure())
+ result.stopTest(test)
+ result.stopTestRun()
+ result.printErrors()
+
+ run_time = end_time - start_time
+ expected_re = OUTPUT_STRING % {
+ 'suite_name': 'MockTest',
+ 'tests': 1,
+ 'failures': 1,
+ 'errors': 0,
+ 'run_time': run_time,
+ 'start_time': re.escape(self._iso_timestamp(start_time),),
+ 'test_name': 'failing_test',
+ 'classname': '__main__.MockTest',
+ 'status': 'run',
+ 'result': 'completed',
+ 'attributes': '',
+ 'message': FAILURE_MESSAGE
+ }
+ self._assert_match(expected_re, self.xml_stream.getvalue())
+
+ def test_with_failing_subtest(self):
+ start_time = 10
+ end_time = 20
+ result = self._make_result((start_time, start_time, end_time, end_time))
+
+ test = MockTest('__main__.MockTest.failing_test')
+ subtest = unittest.case._SubTest(test, 'msg', None)
+ result.startTestRun()
+ result.startTest(test)
+ result.addSubTest(test, subtest, self.get_sample_failure())
+ result.stopTestRun()
+ result.printErrors()
+
+ run_time = end_time - start_time
+ expected_re = OUTPUT_STRING % {
+ 'suite_name': 'MockTest',
+ 'tests': 1,
+ 'failures': 1,
+ 'errors': 0,
+ 'run_time': run_time,
+ 'start_time': re.escape(self._iso_timestamp(start_time),),
+ 'test_name': r'failing_test&#x20;\[msg\]',
+ 'classname': '__main__.MockTest',
+ 'status': 'run',
+ 'result': 'completed',
+ 'attributes': '',
+ 'message': FAILURE_MESSAGE
+ }
+ self._assert_match(expected_re, self.xml_stream.getvalue())
+
+ def test_with_error_test(self):
+ start_time = 100
+ end_time = 200
+ result = self._make_result((start_time, start_time, end_time, end_time))
+
+ test = MockTest('__main__.MockTest.failing_test')
+ result.startTestRun()
+ result.startTest(test)
+ result.addError(test, self.get_sample_error())
+ result.stopTest(test)
+ result.stopTestRun()
+ result.printErrors()
+ xml = self.xml_stream.getvalue()
+
+ self._assert_valid_xml(xml)
+
+ run_time = end_time - start_time
+ expected_re = OUTPUT_STRING % {
+ 'suite_name': 'MockTest',
+ 'tests': 1,
+ 'failures': 0,
+ 'errors': 1,
+ 'run_time': run_time,
+ 'start_time': re.escape(self._iso_timestamp(start_time),),
+ 'test_name': 'failing_test',
+ 'classname': '__main__.MockTest',
+ 'status': 'run',
+ 'result': 'completed',
+ 'attributes': '',
+ 'message': ERROR_MESSAGE
+ }
+ self._assert_match(expected_re, xml)
+
+ def test_with_error_subtest(self):
+ start_time = 10
+ end_time = 20
+ result = self._make_result((start_time, start_time, end_time, end_time))
+
+ test = MockTest('__main__.MockTest.error_test')
+ subtest = unittest.case._SubTest(test, 'msg', None)
+ result.startTestRun()
+ result.startTest(test)
+ result.addSubTest(test, subtest, self.get_sample_error())
+ result.stopTestRun()
+ result.printErrors()
+
+ run_time = end_time - start_time
+ expected_re = OUTPUT_STRING % {
+ 'suite_name': 'MockTest',
+ 'tests': 1,
+ 'failures': 0,
+ 'errors': 1,
+ 'run_time': run_time,
+ 'start_time': re.escape(self._iso_timestamp(start_time),),
+ 'test_name': r'error_test&#x20;\[msg\]',
+ 'classname': '__main__.MockTest',
+ 'status': 'run',
+ 'result': 'completed',
+ 'attributes': '',
+ 'message': ERROR_MESSAGE
+ }
+ self._assert_match(expected_re, self.xml_stream.getvalue())
+
+ def test_with_fail_and_error_test(self):
+ """Tests a failure and subsequent error within a single result."""
+ start_time = 123
+ end_time = 456
+ result = self._make_result((start_time, start_time, end_time, end_time))
+
+ test = MockTest('__main__.MockTest.failing_test')
+ result.startTestRun()
+ result.startTest(test)
+ result.addFailure(test, self.get_sample_failure())
+ # This could happen in tearDown
+ result.addError(test, self.get_sample_error())
+ result.stopTest(test)
+ result.stopTestRun()
+ result.printErrors()
+ xml = self.xml_stream.getvalue()
+
+ self._assert_valid_xml(xml)
+
+ run_time = end_time - start_time
+ expected_re = OUTPUT_STRING % {
+ 'suite_name': 'MockTest',
+ 'tests': 1,
+ 'failures': 1, # Only the failure is tallied (because it was first).
+ 'errors': 0,
+ 'run_time': run_time,
+ 'start_time': re.escape(self._iso_timestamp(start_time),),
+ 'test_name': 'failing_test',
+ 'classname': '__main__.MockTest',
+ 'status': 'run',
+ 'result': 'completed',
+ 'attributes': '',
+ # Messages from failure and error should be concatenated in order.
+ 'message': FAILURE_MESSAGE + ERROR_MESSAGE
+ }
+ self._assert_match(expected_re, xml)
+
+ def test_with_error_and_fail_test(self):
+ """Tests an error and subsequent failure within a single result."""
+ start_time = 123
+ end_time = 456
+ result = self._make_result((start_time, start_time, end_time, end_time))
+
+ test = MockTest('__main__.MockTest.failing_test')
+ result.startTestRun()
+ result.startTest(test)
+ result.addError(test, self.get_sample_error())
+ result.addFailure(test, self.get_sample_failure())
+ result.stopTest(test)
+ result.stopTestRun()
+ result.printErrors()
+ xml = self.xml_stream.getvalue()
+
+ self._assert_valid_xml(xml)
+
+ run_time = end_time - start_time
+ expected_re = OUTPUT_STRING % {
+ 'suite_name': 'MockTest',
+ 'tests': 1,
+ 'failures': 0,
+ 'errors': 1, # Only the error is tallied (because it was first).
+ 'run_time': run_time,
+ 'start_time': re.escape(self._iso_timestamp(start_time),),
+ 'test_name': 'failing_test',
+ 'classname': '__main__.MockTest',
+ 'status': 'run',
+ 'result': 'completed',
+ 'attributes': '',
+ # Messages from error and failure should be concatenated in order.
+ 'message': ERROR_MESSAGE + FAILURE_MESSAGE
+ }
+ self._assert_match(expected_re, xml)
+
+ def test_with_newline_error_test(self):
+ start_time = 100
+ end_time = 200
+ result = self._make_result((start_time, start_time, end_time, end_time))
+
+ test = MockTest('__main__.MockTest.failing_test')
+ result.startTestRun()
+ result.startTest(test)
+ result.addError(test, self.get_newline_message_sample_failure())
+ result.stopTest(test)
+ result.stopTestRun()
+ result.printErrors()
+ xml = self.xml_stream.getvalue()
+
+ self._assert_valid_xml(xml)
+
+ run_time = end_time - start_time
+ expected_re = OUTPUT_STRING % {
+ 'suite_name': 'MockTest',
+ 'tests': 1,
+ 'failures': 0,
+ 'errors': 1,
+ 'run_time': run_time,
+ 'start_time': re.escape(self._iso_timestamp(start_time),),
+ 'test_name': 'failing_test',
+ 'classname': '__main__.MockTest',
+ 'status': 'run',
+ 'result': 'completed',
+ 'attributes': '',
+ 'message': NEWLINE_ERROR_MESSAGE
+ } + '\n'
+ self._assert_match(expected_re, xml)
+
+ def test_with_unicode_error_test(self):
+ start_time = 100
+ end_time = 200
+ result = self._make_result((start_time, start_time, end_time, end_time))
+
+ test = MockTest('__main__.MockTest.failing_test')
+ result.startTestRun()
+ result.startTest(test)
+ result.addError(test, self.get_unicode_sample_failure())
+ result.stopTest(test)
+ result.stopTestRun()
+ result.printErrors()
+ xml = self.xml_stream.getvalue()
+
+ self._assert_valid_xml(xml)
+
+ run_time = end_time - start_time
+ expected_re = OUTPUT_STRING % {
+ 'suite_name': 'MockTest',
+ 'tests': 1,
+ 'failures': 0,
+ 'errors': 1,
+ 'run_time': run_time,
+ 'start_time': re.escape(self._iso_timestamp(start_time),),
+ 'test_name': 'failing_test',
+ 'classname': '__main__.MockTest',
+ 'status': 'run',
+ 'result': 'completed',
+ 'attributes': '',
+ 'message': UNICODE_ERROR_MESSAGE
+ }
+ self._assert_match(expected_re, xml)
+
+ def test_with_terminal_escape_error(self):
+ start_time = 100
+ end_time = 200
+ result = self._make_result((start_time, start_time, end_time, end_time))
+
+ test = MockTest('__main__.MockTest.failing_test')
+ result.startTestRun()
+ result.startTest(test)
+ result.addError(test, self.get_terminal_escape_sample_failure())
+ result.stopTest(test)
+ result.stopTestRun()
+ result.printErrors()
+
+ self._assert_valid_xml(self.xml_stream.getvalue())
+
+ def test_with_expected_failure_test(self):
+ start_time = 100
+ end_time = 200
+ result = self._make_result((start_time, start_time, end_time, end_time))
+ error_values = ''
+
+ try:
+ raise RuntimeError('Test expectedFailure')
+ except RuntimeError:
+ error_values = sys.exc_info()
+
+ test = MockTest('__main__.MockTest.expected_failing_test')
+ result.startTestRun()
+ result.startTest(test)
+ result.addExpectedFailure(test, error_values)
+ result.stopTest(test)
+ result.stopTestRun()
+ result.printErrors()
+
+ run_time = end_time - start_time
+ expected_re = OUTPUT_STRING % {
+ 'suite_name': 'MockTest',
+ 'tests': 1,
+ 'failures': 0,
+ 'errors': 0,
+ 'run_time': run_time,
+ 'start_time': re.escape(self._iso_timestamp(start_time),),
+ 'test_name': 'expected_failing_test',
+ 'classname': '__main__.MockTest',
+ 'status': 'run',
+ 'result': 'completed',
+ 'attributes': '',
+ 'message': ''
+ }
+ self._assert_match(re.compile(expected_re, re.DOTALL),
+ self.xml_stream.getvalue())
+
+ def test_with_unexpected_success_error_test(self):
+ start_time = 100
+ end_time = 200
+ result = self._make_result((start_time, start_time, end_time, end_time))
+
+ test = MockTest('__main__.MockTest.unexpectedly_passing_test')
+ result.startTestRun()
+ result.startTest(test)
+ result.addUnexpectedSuccess(test)
+ result.stopTest(test)
+ result.stopTestRun()
+ result.printErrors()
+
+ run_time = end_time - start_time
+ expected_re = OUTPUT_STRING % {
+ 'suite_name': 'MockTest',
+ 'tests': 1,
+ 'failures': 0,
+ 'errors': 1,
+ 'run_time': run_time,
+ 'start_time': re.escape(self._iso_timestamp(start_time),),
+ 'test_name': 'unexpectedly_passing_test',
+ 'classname': '__main__.MockTest',
+ 'status': 'run',
+ 'result': 'completed',
+ 'attributes': '',
+ 'message': UNEXPECTED_SUCCESS_MESSAGE
+ }
+ self._assert_match(expected_re, self.xml_stream.getvalue())
+
+ def test_with_skipped_test(self):
+ start_time = 100
+ end_time = 100
+ result = self._make_result((start_time, start_time, end_time, end_time))
+
+ test = MockTest('__main__.MockTest.skipped_test_with_reason')
+ result.startTestRun()
+ result.startTest(test)
+ result.addSkip(test, 'b"r')
+ result.stopTest(test)
+ result.stopTestRun()
+ result.printErrors()
+
+ run_time = end_time - start_time
+ expected_re = OUTPUT_STRING % {
+ 'suite_name': 'MockTest',
+ 'tests': 1,
+ 'failures': 0,
+ 'errors': 0,
+ 'run_time': run_time,
+ 'start_time': re.escape(self._iso_timestamp(start_time),),
+ 'test_name': 'skipped_test_with_reason',
+ 'classname': '__main__.MockTest',
+ 'status': 'notrun',
+ 'result': 'suppressed',
+ 'message': ''
+ }
+ self._assert_match(expected_re, self.xml_stream.getvalue())
+
+ def test_suite_time(self):
+ start_time1 = 100
+ end_time1 = 200
+ start_time2 = 400
+ end_time2 = 700
+ name = '__main__.MockTest.failing_test'
+ result = self._make_result((start_time1, start_time1, end_time1,
+ start_time2, end_time2, end_time2))
+
+ test = MockTest('%s1' % name)
+ result.startTestRun()
+ result.startTest(test)
+ result.addSuccess(test)
+ result.stopTest(test)
+
+ test = MockTest('%s2' % name)
+ result.startTest(test)
+ result.addSuccess(test)
+ result.stopTest(test)
+ result.stopTestRun()
+ result.printErrors()
+
+ run_time = max(end_time1, end_time2) - min(start_time1, start_time2)
+ timestamp = self._iso_timestamp(start_time1)
+ expected_prefix = """<?xml version="1.0"?>
+<testsuites name="" tests="2" failures="0" errors="0" time="%.1f" timestamp="%s">
+<testsuite name="MockTest" tests="2" failures="0" errors="0" time="%.1f" timestamp="%s">
+""" % (run_time, timestamp, run_time, timestamp)
+ xml_output = self.xml_stream.getvalue()
+ self.assertTrue(
+ xml_output.startswith(expected_prefix),
+ '%s not found in %s' % (expected_prefix, xml_output))
+
+ def test_with_no_suite_name(self):
+ start_time = 1000
+ end_time = 1200
+ result = self._make_result((start_time, start_time, end_time, end_time))
+
+ test = MockTest('__main__.MockTest.bad_name')
+ result.startTestRun()
+ result.startTest(test)
+ result.addSuccess(test)
+ result.stopTest(test)
+ result.stopTestRun()
+ result.printErrors()
+
+ run_time = end_time - start_time
+ expected_re = OUTPUT_STRING % {
+ 'suite_name': 'MockTest',
+ 'tests': 1,
+ 'failures': 0,
+ 'errors': 0,
+ 'run_time': run_time,
+ 'start_time': re.escape(self._iso_timestamp(start_time),),
+ 'test_name': 'bad_name',
+ 'classname': '__main__.MockTest',
+ 'status': 'run',
+ 'result': 'completed',
+ 'attributes': '',
+ 'message': ''
+ }
+ self._assert_match(expected_re, self.xml_stream.getvalue())
+
+ def test_unnamed_parameterized_testcase(self):
+ """Test unnamed parameterized test cases.
+
+ Unnamed parameterized test cases might have non-alphanumeric characters in
+ their test method names. This test ensures xml_reporter handles them
+ correctly.
+ """
+
+ class ParameterizedTest(parameterized.TestCase):
+
+ @parameterized.parameters(('a (b.c)',))
+ def test_prefix(self, case):
+ self.assertTrue(case.startswith('a'))
+
+ start_time = 1000
+ end_time = 1200
+ result = self._make_result((start_time, start_time, end_time, end_time))
+ test = ParameterizedTest(methodName='test_prefix0')
+ result.startTestRun()
+ result.startTest(test)
+ result.addSuccess(test)
+ result.stopTest(test)
+ result.stopTestRun()
+ result.printErrors()
+
+ run_time = end_time - start_time
+ classname = xml_reporter._escape_xml_attr(
+ unittest.util.strclass(test.__class__))
+ expected_re = OUTPUT_STRING % {
+ 'suite_name': 'ParameterizedTest',
+ 'tests': 1,
+ 'failures': 0,
+ 'errors': 0,
+ 'run_time': run_time,
+ 'start_time': re.escape(self._iso_timestamp(start_time),),
+ 'test_name': re.escape('test_prefix0&#x20;(&apos;a&#x20;(b.c)&apos;)'),
+ 'classname': classname,
+ 'status': 'run',
+ 'result': 'completed',
+ 'attributes': '',
+ 'message': ''
+ }
+ self._assert_match(expected_re, self.xml_stream.getvalue())
+
+ def teststop_test_without_pending_test(self):
+ end_time = 1200
+ result = self._make_result((end_time,))
+
+ test = MockTest('__main__.MockTest.bad_name')
+ result.stopTest(test)
+ result.stopTestRun()
+ # Just verify that this doesn't crash
+
+ def test_text_and_xmltest_runner(self):
+ runner = xml_reporter.TextAndXMLTestRunner(self.xml_stream, self.stream,
+ 'foo', 1)
+ result1 = runner._makeResult()
+ result2 = xml_reporter._TextAndXMLTestResult(None, None, None, 0, None)
+ self.failUnless(type(result1) is type(result2))
+
+ def test_timing_with_time_stub(self):
+ """Make sure that timing is correct even if time.time is stubbed out."""
+ try:
+ saved_time = time.time
+ time.time = lambda: -1
+ reporter = xml_reporter._TextAndXMLTestResult(self.xml_stream,
+ self.stream,
+ 'foo', 0)
+ test = MockTest('bar')
+ reporter.startTest(test)
+ self.failIf(reporter.start_time == -1)
+ finally:
+ time.time = saved_time
+
+ def test_concurrent_add_and_delete_pending_test_case_result(self):
+ """Make sure adding/deleting pending test case results are thread safe."""
+ result = xml_reporter._TextAndXMLTestResult(None, self.stream, None, 0,
+ None)
+ def add_and_delete_pending_test_case_result(test_name):
+ test = MockTest(test_name)
+ result.addSuccess(test)
+ result.delete_pending_test_case_result(test)
+
+ for i in range(50):
+ add_and_delete_pending_test_case_result('add_and_delete_test%s' % i)
+ self.assertEqual(result.pending_test_case_results, {})
+
+ def test_concurrent_test_runs(self):
+ """Make sure concurrent test runs do not race each other."""
+ num_passing_tests = 20
+ num_failing_tests = 20
+ num_error_tests = 20
+ total_num_tests = num_passing_tests + num_failing_tests + num_error_tests
+
+ times = [0] + [i for i in range(2 * total_num_tests)
+ ] + [2 * total_num_tests - 1]
+ result = self._make_result(times)
+ threads = []
+ names = []
+ result.startTestRun()
+ for i in range(num_passing_tests):
+ name = 'passing_concurrent_test_%s' % i
+ names.append(name)
+ test_name = '__main__.MockTest.%s' % name
+ # xml_reporter uses id(test) as the test identifier.
+ # In a real testing scenario, all the test instances are created before
+ # running them. So all ids will be unique.
+ # We must do the same here: create test instance beforehand.
+ test = MockTest(test_name)
+ threads.append(threading.Thread(
+ target=self._simulate_passing_test, args=(test, result)))
+ for i in range(num_failing_tests):
+ name = 'failing_concurrent_test_%s' % i
+ names.append(name)
+ test_name = '__main__.MockTest.%s' % name
+ test = MockTest(test_name)
+ threads.append(threading.Thread(
+ target=self._simulate_failing_test, args=(test, result)))
+ for i in range(num_error_tests):
+ name = 'error_concurrent_test_%s' % i
+ names.append(name)
+ test_name = '__main__.MockTest.%s' % name
+ test = MockTest(test_name)
+ threads.append(threading.Thread(
+ target=self._simulate_error_test, args=(test, result)))
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ result.stopTestRun()
+ result.printErrors()
+ tests_not_in_xml = []
+ for tn in names:
+ if tn not in self.xml_stream.getvalue():
+ tests_not_in_xml.append(tn)
+ msg = ('Expected xml_stream to contain all test %s results, but %s tests '
+ 'are missing. List of missing tests: %s' % (
+ total_num_tests, len(tests_not_in_xml), tests_not_in_xml))
+ self.assertEqual([], tests_not_in_xml, msg)
+
+ def test_add_failure_during_stop_test(self):
+ """Tests an addFailure() call from within a stopTest() call stack."""
+ result = self._make_result((0, 2))
+ test = MockTest('__main__.MockTest.failing_test')
+ result.startTestRun()
+ result.startTest(test)
+
+ # Replace parent stopTest method from unittest.TextTestResult with
+ # a version that calls self.addFailure().
+ with mock.patch.object(
+ unittest.TextTestResult,
+ 'stopTest',
+ side_effect=lambda t: result.addFailure(t, self.get_sample_failure())):
+ # Run stopTest in a separate thread since we are looking to verify that
+ # it does not deadlock, and would otherwise prevent the test from
+ # completing.
+ stop_test_thread = threading.Thread(target=result.stopTest, args=(test,))
+ stop_test_thread.daemon = True
+ stop_test_thread.start()
+
+ stop_test_thread.join(10.0)
+ self.assertFalse(stop_test_thread.is_alive(),
+ 'result.stopTest(test) call failed to complete')
+
+
+class XMLTest(absltest.TestCase):
+
+ def test_escape_xml(self):
+ self.assertEqual(xml_reporter._escape_xml_attr('"Hi" <\'>\t\r\n'),
+ '&quot;Hi&quot;&#x20;&lt;&apos;&gt;&#x9;&#xD;&#xA;')
+
+
+class XmlReporterFixtureTest(absltest.TestCase):
+
+ def _get_helper(self):
+ binary_name = 'absl/testing/tests/xml_reporter_helper_test'
+ return _bazelize_command.get_executable_path(binary_name)
+
+ def _run_test_and_get_xml(self, flag):
+ """Runs xml_reporter_helper_test and returns an Element instance.
+
+ Runs xml_reporter_helper_test in a new process so that it can
+ exercise the entire test infrastructure, and easily test issues in
+ the test fixture.
+
+ Args:
+ flag: flag to pass to xml_reporter_helper_test
+
+ Returns:
+ The Element instance of the XML output.
+ """
+
+ xml_fhandle, xml_fname = tempfile.mkstemp()
+ os.close(xml_fhandle)
+
+ try:
+ binary = self._get_helper()
+ args = [binary, flag, '--xml_output_file=%s' % xml_fname]
+ ret = subprocess.call(args)
+ self.assertEqual(ret, 0)
+
+ xml = ElementTree.parse(xml_fname).getroot()
+ finally:
+ os.remove(xml_fname)
+
+ return xml
+
+ def _run_test(self, flag, num_errors, num_failures, suites):
+ xml_fhandle, xml_fname = tempfile.mkstemp()
+ os.close(xml_fhandle)
+
+ try:
+ binary = self._get_helper()
+ args = [binary, flag, '--xml_output_file=%s' % xml_fname]
+ ret = subprocess.call(args)
+ self.assertNotEqual(ret, 0)
+
+ xml = ElementTree.parse(xml_fname).getroot()
+ logging.info('xml output is:\n%s', ElementTree.tostring(xml))
+ finally:
+ os.remove(xml_fname)
+
+ self.assertEqual(int(xml.attrib['errors']), num_errors)
+ self.assertEqual(int(xml.attrib['failures']), num_failures)
+ self.assertLen(xml, len(suites))
+ actual_suites = sorted(
+ xml.findall('testsuite'), key=lambda x: x.attrib['name'])
+ suites = sorted(suites, key=lambda x: x['name'])
+ for actual_suite, expected_suite in zip(actual_suites, suites):
+ self.assertEqual(actual_suite.attrib['name'], expected_suite['name'])
+ self.assertLen(actual_suite, len(expected_suite['cases']))
+ actual_cases = sorted(actual_suite.findall('testcase'),
+ key=lambda x: x.attrib['name'])
+ expected_cases = sorted(expected_suite['cases'], key=lambda x: x['name'])
+ for actual_case, expected_case in zip(actual_cases, expected_cases):
+ self.assertEqual(actual_case.attrib['name'], expected_case['name'])
+ self.assertEqual(actual_case.attrib['classname'],
+ expected_case['classname'])
+ if 'error' in expected_case:
+ actual_error = actual_case.find('error')
+ self.assertEqual(actual_error.attrib['message'],
+ expected_case['error'])
+ if 'failure' in expected_case:
+ actual_failure = actual_case.find('failure')
+ self.assertEqual(actual_failure.attrib['message'],
+ expected_case['failure'])
+
+ return xml
+
+ def test_set_up_module_error(self):
+ self._run_test(
+ flag='--set_up_module_error',
+ num_errors=1,
+ num_failures=0,
+ suites=[{'name': '__main__',
+ 'cases': [{'name': 'setUpModule',
+ 'classname': '__main__',
+ 'error': 'setUpModule Errored!'}]}])
+
+ def test_tear_down_module_error(self):
+ self._run_test(
+ flag='--tear_down_module_error',
+ num_errors=1,
+ num_failures=0,
+ suites=[{'name': 'FailableTest',
+ 'cases': [{'name': 'test',
+ 'classname': '__main__.FailableTest'}]},
+ {'name': '__main__',
+ 'cases': [{'name': 'tearDownModule',
+ 'classname': '__main__',
+ 'error': 'tearDownModule Errored!'}]}])
+
+ def test_set_up_class_error(self):
+ self._run_test(
+ flag='--set_up_class_error',
+ num_errors=1,
+ num_failures=0,
+ suites=[{'name': 'FailableTest',
+ 'cases': [{'name': 'setUpClass',
+ 'classname': '__main__.FailableTest',
+ 'error': 'setUpClass Errored!'}]}])
+
+ def test_tear_down_class_error(self):
+ self._run_test(
+ flag='--tear_down_class_error',
+ num_errors=1,
+ num_failures=0,
+ suites=[{'name': 'FailableTest',
+ 'cases': [{'name': 'test',
+ 'classname': '__main__.FailableTest'},
+ {'name': 'tearDownClass',
+ 'classname': '__main__.FailableTest',
+ 'error': 'tearDownClass Errored!'}]}])
+
+ def test_set_up_error(self):
+ self._run_test(
+ flag='--set_up_error',
+ num_errors=1,
+ num_failures=0,
+ suites=[{'name': 'FailableTest',
+ 'cases': [{'name': 'test',
+ 'classname': '__main__.FailableTest',
+ 'error': 'setUp Errored!'}]}])
+
+ def test_tear_down_error(self):
+ self._run_test(
+ flag='--tear_down_error',
+ num_errors=1,
+ num_failures=0,
+ suites=[{'name': 'FailableTest',
+ 'cases': [{'name': 'test',
+ 'classname': '__main__.FailableTest',
+ 'error': 'tearDown Errored!'}]}])
+
+ def test_test_error(self):
+ self._run_test(
+ flag='--test_error',
+ num_errors=1,
+ num_failures=0,
+ suites=[{'name': 'FailableTest',
+ 'cases': [{'name': 'test',
+ 'classname': '__main__.FailableTest',
+ 'error': 'test Errored!'}]}])
+
+ def test_set_up_failure(self):
+ self._run_test(
+ flag='--set_up_fail',
+ num_errors=0,
+ num_failures=1,
+ suites=[{'name': 'FailableTest',
+ 'cases': [{'name': 'test',
+ 'classname': '__main__.FailableTest',
+ 'failure': 'setUp Failed!'}]}])
+
+ def test_tear_down_failure(self):
+ self._run_test(
+ flag='--tear_down_fail',
+ num_errors=0,
+ num_failures=1,
+ suites=[{'name': 'FailableTest',
+ 'cases': [{'name': 'test',
+ 'classname': '__main__.FailableTest',
+ 'failure': 'tearDown Failed!'}]}])
+
+ def test_test_fail(self):
+ self._run_test(
+ flag='--test_fail',
+ num_errors=0,
+ num_failures=1,
+ suites=[{'name': 'FailableTest',
+ 'cases': [{'name': 'test',
+ 'classname': '__main__.FailableTest',
+ 'failure': 'test Failed!'}]}])
+
+ def test_test_randomization_seed_logging(self):
+ # We expect the resulting XML to start as follows:
+ # <testsuites ...>
+ # <properties>
+ # <property name="test_randomize_ordering_seed" value="17" />
+ # ...
+ #
+ # which we validate here.
+ out = self._run_test_and_get_xml('--test_randomize_ordering_seed=17')
+ expected_attrib = {'name': 'test_randomize_ordering_seed', 'value': '17'}
+ property_attributes = [
+ prop.attrib for prop in out.findall('./properties/property')]
+ self.assertIn(expected_attrib, property_attributes)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/testing/xml_reporter.py b/absl/testing/xml_reporter.py
new file mode 100644
index 0000000..da56e39
--- /dev/null
+++ b/absl/testing/xml_reporter.py
@@ -0,0 +1,562 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A Python test reporter that generates test reports in JUnit XML format."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import datetime
+import re
+import sys
+import threading
+import time
+import traceback
+import unittest
+from xml.sax import saxutils
+from absl.testing import _pretty_print_reporter
+
+
+# See http://www.w3.org/TR/REC-xml/#NT-Char
+_bad_control_character_codes = set(range(0, 0x20)) - {0x9, 0xA, 0xD}
+
+
+_control_character_conversions = {
+ chr(i): '\\x{:02x}'.format(i) for i in _bad_control_character_codes}
+
+
+_escape_xml_attr_conversions = {
+ '"': '&quot;',
+ "'": '&apos;',
+ '\n': '&#xA;',
+ '\t': '&#x9;',
+ '\r': '&#xD;',
+ ' ': '&#x20;'}
+_escape_xml_attr_conversions.update(_control_character_conversions)
+
+
+# When class or module level function fails, unittest/suite.py adds a
+# _ErrorHolder instance instead of a real TestCase, and it has a description
+# like "setUpClass (__main__.MyTestCase)".
+_CLASS_OR_MODULE_LEVEL_TEST_DESC_REGEX = re.compile(r'^(\w+) \((\S+)\)$')
+
+
+# NOTE: while saxutils.quoteattr() theoretically does the same thing; it
+# seems to often end up being too smart for it's own good not escaping properly.
+# This function is much more reliable.
+def _escape_xml_attr(content):
+ """Escapes xml attributes."""
+ # Note: saxutils doesn't escape the quotes.
+ return saxutils.escape(content, _escape_xml_attr_conversions)
+
+
+def _escape_cdata(s):
+ """Escapes a string to be used as XML CDATA.
+
+ CDATA characters are treated strictly as character data, not as XML markup,
+ but there are still certain restrictions on them.
+
+ Args:
+ s: the string to be escaped.
+ Returns:
+ An escaped version of the input string.
+ """
+ for char, escaped in _control_character_conversions.items():
+ s = s.replace(char, escaped)
+ return s.replace(']]>', ']] >')
+
+
+def _iso8601_timestamp(timestamp):
+ """Produces an ISO8601 datetime.
+
+ Args:
+ timestamp: an Epoch based timestamp in seconds.
+
+ Returns:
+ A iso8601 format timestamp if the input is a valid timestamp, None otherwise
+ """
+ if timestamp is None or timestamp < 0:
+ return None
+ return datetime.datetime.fromtimestamp(
+ timestamp, tz=datetime.timezone.utc).isoformat()
+
+
+def _print_xml_element_header(element, attributes, stream, indentation=''):
+ """Prints an XML header of an arbitrary element.
+
+ Args:
+ element: element name (testsuites, testsuite, testcase)
+ attributes: 2-tuple list with (attributes, values) already escaped
+ stream: output stream to write test report XML to
+ indentation: indentation added to the element header
+ """
+ stream.write('%s<%s' % (indentation, element))
+ for attribute in attributes:
+ if (len(attribute) == 2 and attribute[0] is not None and
+ attribute[1] is not None):
+ stream.write(' %s="%s"' % (attribute[0], attribute[1]))
+ stream.write('>\n')
+
+# Copy time.time which ensures the real time is used internally.
+# This prevents bad interactions with tests that stub out time.
+_time_copy = time.time
+
+if hasattr(traceback, '_some_str'):
+ # Use the traceback module str function to format safely.
+ _safe_str = traceback._some_str
+else:
+ _safe_str = str # pylint: disable=invalid-name
+
+
+class _TestCaseResult(object):
+ """Private helper for _TextAndXMLTestResult that represents a test result.
+
+ Attributes:
+ test: A TestCase instance of an individual test method.
+ name: The name of the individual test method.
+ full_class_name: The full name of the test class.
+ run_time: The duration (in seconds) it took to run the test.
+ start_time: Epoch relative timestamp of when test started (in seconds)
+ errors: A list of error 4-tuples. Error tuple entries are
+ 1) a string identifier of either "failure" or "error"
+ 2) an exception_type
+ 3) an exception_message
+ 4) a string version of a sys.exc_info()-style tuple of values
+ ('error', err[0], err[1], self._exc_info_to_string(err))
+ If the length of errors is 0, then the test is either passed or
+ skipped.
+ skip_reason: A string explaining why the test was skipped.
+ """
+
+ def __init__(self, test):
+ self.run_time = -1
+ self.start_time = -1
+ self.skip_reason = None
+ self.errors = []
+ self.test = test
+
+ # Parse the test id to get its test name and full class path.
+ # Unfortunately there is no better way of knowning the test and class.
+ # Worse, unittest uses _ErrorHandler instances to represent class / module
+ # level failures.
+ test_desc = test.id() or str(test)
+ # Check if it's something like "setUpClass (__main__.TestCase)".
+ match = _CLASS_OR_MODULE_LEVEL_TEST_DESC_REGEX.match(test_desc)
+ if match:
+ name = match.group(1)
+ full_class_name = match.group(2)
+ else:
+ class_name = unittest.util.strclass(test.__class__)
+ if isinstance(test, unittest.case._SubTest):
+ # If the test case is a _SubTest, the real TestCase instance is
+ # available as _SubTest.test_case.
+ class_name = unittest.util.strclass(test.test_case.__class__)
+ if test_desc.startswith(class_name + '.'):
+ # In a typical unittest.TestCase scenario, test.id() returns with
+ # a class name formatted using unittest.util.strclass.
+ name = test_desc[len(class_name)+1:]
+ full_class_name = class_name
+ else:
+ # Otherwise make a best effort to guess the test name and full class
+ # path.
+ parts = test_desc.rsplit('.', 1)
+ name = parts[-1]
+ full_class_name = parts[0] if len(parts) == 2 else ''
+ self.name = _escape_xml_attr(name)
+ self.full_class_name = _escape_xml_attr(full_class_name)
+
+ def set_run_time(self, time_in_secs):
+ self.run_time = time_in_secs
+
+ def set_start_time(self, time_in_secs):
+ self.start_time = time_in_secs
+
+ def print_xml_summary(self, stream):
+ """Prints an XML Summary of a TestCase.
+
+ Status and result are populated as per JUnit XML test result reporter.
+ A test that has been skipped will always have a skip reason,
+ as every skip method in Python's unittest requires the reason arg to be
+ passed.
+
+ Args:
+ stream: output stream to write test report XML to
+ """
+
+ if self.skip_reason is None:
+ status = 'run'
+ result = 'completed'
+ else:
+ status = 'notrun'
+ result = 'suppressed'
+
+ test_case_attributes = [
+ ('name', '%s' % self.name),
+ ('status', '%s' % status),
+ ('result', '%s' % result),
+ ('time', '%.1f' % self.run_time),
+ ('classname', self.full_class_name),
+ ('timestamp', _iso8601_timestamp(self.start_time)),
+ ]
+ _print_xml_element_header('testcase', test_case_attributes, stream, ' ')
+ self._print_testcase_details(stream)
+ stream.write(' </testcase>\n')
+
+ def _print_testcase_details(self, stream):
+ for error in self.errors:
+ outcome, exception_type, message, error_msg = error # pylint: disable=unpacking-non-sequence
+ message = _escape_xml_attr(_safe_str(message))
+ exception_type = _escape_xml_attr(str(exception_type))
+ error_msg = _escape_cdata(error_msg)
+ stream.write(' <%s message="%s" type="%s"><![CDATA[%s]]></%s>\n'
+ % (outcome, message, exception_type, error_msg, outcome))
+
+
+class _TestSuiteResult(object):
+ """Private helper for _TextAndXMLTestResult."""
+
+ def __init__(self):
+ self.suites = {}
+ self.failure_counts = {}
+ self.error_counts = {}
+ self.overall_start_time = -1
+ self.overall_end_time = -1
+ self._testsuites_properties = {}
+
+ def add_test_case_result(self, test_case_result):
+ suite_name = type(test_case_result.test).__name__
+ if suite_name == '_ErrorHolder':
+ # _ErrorHolder is a special case created by unittest for class / module
+ # level functions.
+ suite_name = test_case_result.full_class_name.rsplit('.')[-1]
+ if isinstance(test_case_result.test, unittest.case._SubTest):
+ # If the test case is a _SubTest, the real TestCase instance is
+ # available as _SubTest.test_case.
+ suite_name = type(test_case_result.test.test_case).__name__
+
+ self._setup_test_suite(suite_name)
+ self.suites[suite_name].append(test_case_result)
+ for error in test_case_result.errors:
+ # Only count the first failure or error so that the sum is equal to the
+ # total number of *testcases* that have failures or errors.
+ if error[0] == 'failure':
+ self.failure_counts[suite_name] += 1
+ break
+ elif error[0] == 'error':
+ self.error_counts[suite_name] += 1
+ break
+
+ def print_xml_summary(self, stream):
+ overall_test_count = sum(len(x) for x in self.suites.values())
+ overall_failures = sum(self.failure_counts.values())
+ overall_errors = sum(self.error_counts.values())
+ overall_attributes = [
+ ('name', ''),
+ ('tests', '%d' % overall_test_count),
+ ('failures', '%d' % overall_failures),
+ ('errors', '%d' % overall_errors),
+ ('time', '%.1f' % (self.overall_end_time - self.overall_start_time)),
+ ('timestamp', _iso8601_timestamp(self.overall_start_time)),
+ ]
+ _print_xml_element_header('testsuites', overall_attributes, stream)
+ if self._testsuites_properties:
+ stream.write(' <properties>\n')
+ for name, value in sorted(self._testsuites_properties.items()):
+ stream.write(' <property name="%s" value="%s"></property>\n' %
+ (_escape_xml_attr(name), _escape_xml_attr(str(value))))
+ stream.write(' </properties>\n')
+
+ for suite_name in self.suites:
+ suite = self.suites[suite_name]
+ suite_end_time = max(x.start_time + x.run_time for x in suite)
+ suite_start_time = min(x.start_time for x in suite)
+ failures = self.failure_counts[suite_name]
+ errors = self.error_counts[suite_name]
+ suite_attributes = [
+ ('name', '%s' % suite_name),
+ ('tests', '%d' % len(suite)),
+ ('failures', '%d' % failures),
+ ('errors', '%d' % errors),
+ ('time', '%.1f' % (suite_end_time - suite_start_time)),
+ ('timestamp', _iso8601_timestamp(suite_start_time)),
+ ]
+ _print_xml_element_header('testsuite', suite_attributes, stream)
+
+ for test_case_result in suite:
+ test_case_result.print_xml_summary(stream)
+ stream.write('</testsuite>\n')
+ stream.write('</testsuites>\n')
+
+ def _setup_test_suite(self, suite_name):
+ """Adds a test suite to the set of suites tracked by this test run.
+
+ Args:
+ suite_name: string, The name of the test suite being initialized.
+ """
+ if suite_name in self.suites:
+ return
+ self.suites[suite_name] = []
+ self.failure_counts[suite_name] = 0
+ self.error_counts[suite_name] = 0
+
+ def set_end_time(self, timestamp_in_secs):
+ """Sets the start timestamp of this test suite.
+
+ Args:
+ timestamp_in_secs: timestamp in seconds since epoch
+ """
+ self.overall_end_time = timestamp_in_secs
+
+ def set_start_time(self, timestamp_in_secs):
+ """Sets the end timestamp of this test suite.
+
+ Args:
+ timestamp_in_secs: timestamp in seconds since epoch
+ """
+ self.overall_start_time = timestamp_in_secs
+
+
+class _TextAndXMLTestResult(_pretty_print_reporter.TextTestResult):
+ """Private TestResult class that produces both formatted text results and XML.
+
+ Used by TextAndXMLTestRunner.
+ """
+
+ _TEST_SUITE_RESULT_CLASS = _TestSuiteResult
+ _TEST_CASE_RESULT_CLASS = _TestCaseResult
+
+ def __init__(self, xml_stream, stream, descriptions, verbosity,
+ time_getter=_time_copy, testsuites_properties=None):
+ super(_TextAndXMLTestResult, self).__init__(stream, descriptions, verbosity)
+ self.xml_stream = xml_stream
+ self.pending_test_case_results = {}
+ self.suite = self._TEST_SUITE_RESULT_CLASS()
+ if testsuites_properties:
+ self.suite._testsuites_properties = testsuites_properties
+ self.time_getter = time_getter
+
+ # This lock guards any mutations on pending_test_case_results.
+ self._pending_test_case_results_lock = threading.RLock()
+
+ def startTest(self, test):
+ self.start_time = self.time_getter()
+ super(_TextAndXMLTestResult, self).startTest(test)
+
+ def stopTest(self, test):
+ # Grabbing the write lock to avoid conflicting with stopTestRun.
+ with self._pending_test_case_results_lock:
+ super(_TextAndXMLTestResult, self).stopTest(test)
+ result = self.get_pending_test_case_result(test)
+ if not result:
+ test_name = test.id() or str(test)
+ sys.stderr.write('No pending test case: %s\n' % test_name)
+ return
+ test_id = id(test)
+ run_time = self.time_getter() - self.start_time
+ result.set_run_time(run_time)
+ result.set_start_time(self.start_time)
+ self.suite.add_test_case_result(result)
+ del self.pending_test_case_results[test_id]
+
+ def startTestRun(self):
+ self.suite.set_start_time(self.time_getter())
+ super(_TextAndXMLTestResult, self).startTestRun()
+
+ def stopTestRun(self):
+ self.suite.set_end_time(self.time_getter())
+ # All pending_test_case_results will be added to the suite and removed from
+ # the pending_test_case_results dictionary. Grabbing the write lock to avoid
+ # results from being added during this process to avoid duplicating adds or
+ # accidentally erasing newly appended pending results.
+ with self._pending_test_case_results_lock:
+ # Errors in the test fixture (setUpModule, tearDownModule,
+ # setUpClass, tearDownClass) can leave a pending result which
+ # never gets added to the suite. The runner calls stopTestRun
+ # which gives us an opportunity to add these errors for
+ # reporting here.
+ for test_id in self.pending_test_case_results:
+ result = self.pending_test_case_results[test_id]
+ if hasattr(self, 'start_time'):
+ run_time = self.suite.overall_end_time - self.start_time
+ result.set_run_time(run_time)
+ result.set_start_time(self.start_time)
+ self.suite.add_test_case_result(result)
+ self.pending_test_case_results.clear()
+
+ def _exc_info_to_string(self, err, test=None):
+ """Converts a sys.exc_info()-style tuple of values into a string.
+
+ This method must be overridden because the method signature in
+ unittest.TestResult changed between Python 2.2 and 2.4.
+
+ Args:
+ err: A sys.exc_info() tuple of values for an error.
+ test: The test method.
+
+ Returns:
+ A formatted exception string.
+ """
+ if test:
+ return super(_TextAndXMLTestResult, self)._exc_info_to_string(err, test)
+ return ''.join(traceback.format_exception(*err))
+
+ def add_pending_test_case_result(self, test, error_summary=None,
+ skip_reason=None):
+ """Adds result information to a test case result which may still be running.
+
+ If a result entry for the test already exists, add_pending_test_case_result
+ will add error summary tuples and/or overwrite skip_reason for the result.
+ If it does not yet exist, a result entry will be created.
+ Note that a test result is considered to have been run and passed
+ only if there are no errors or skip_reason.
+
+ Args:
+ test: A test method as defined by unittest
+ error_summary: A 4-tuple with the following entries:
+ 1) a string identifier of either "failure" or "error"
+ 2) an exception_type
+ 3) an exception_message
+ 4) a string version of a sys.exc_info()-style tuple of values
+ ('error', err[0], err[1], self._exc_info_to_string(err))
+ If the length of errors is 0, then the test is either passed or
+ skipped.
+ skip_reason: a string explaining why the test was skipped
+ """
+ with self._pending_test_case_results_lock:
+ test_id = id(test)
+ if test_id not in self.pending_test_case_results:
+ self.pending_test_case_results[test_id] = self._TEST_CASE_RESULT_CLASS(
+ test)
+ if error_summary:
+ self.pending_test_case_results[test_id].errors.append(error_summary)
+ if skip_reason:
+ self.pending_test_case_results[test_id].skip_reason = skip_reason
+
+ def delete_pending_test_case_result(self, test):
+ with self._pending_test_case_results_lock:
+ test_id = id(test)
+ del self.pending_test_case_results[test_id]
+
+ def get_pending_test_case_result(self, test):
+ test_id = id(test)
+ return self.pending_test_case_results.get(test_id, None)
+
+ def addSuccess(self, test):
+ super(_TextAndXMLTestResult, self).addSuccess(test)
+ self.add_pending_test_case_result(test)
+
+ def addError(self, test, err):
+ super(_TextAndXMLTestResult, self).addError(test, err)
+ error_summary = ('error', err[0], err[1],
+ self._exc_info_to_string(err, test=test))
+ self.add_pending_test_case_result(test, error_summary=error_summary)
+
+ def addFailure(self, test, err):
+ super(_TextAndXMLTestResult, self).addFailure(test, err)
+ error_summary = ('failure', err[0], err[1],
+ self._exc_info_to_string(err, test=test))
+ self.add_pending_test_case_result(test, error_summary=error_summary)
+
+ def addSkip(self, test, reason):
+ super(_TextAndXMLTestResult, self).addSkip(test, reason)
+ self.add_pending_test_case_result(test, skip_reason=reason)
+
+ def addExpectedFailure(self, test, err):
+ super(_TextAndXMLTestResult, self).addExpectedFailure(test, err)
+ if callable(getattr(test, 'recordProperty', None)):
+ test.recordProperty('EXPECTED_FAILURE',
+ self._exc_info_to_string(err, test=test))
+ self.add_pending_test_case_result(test)
+
+ def addUnexpectedSuccess(self, test):
+ super(_TextAndXMLTestResult, self).addUnexpectedSuccess(test)
+ test_name = test.id() or str(test)
+ error_summary = ('error', '', '',
+ 'Test case %s should have failed, but passed.'
+ % (test_name))
+ self.add_pending_test_case_result(test, error_summary=error_summary)
+
+ def addSubTest(self, test, subtest, err): # pylint: disable=invalid-name
+ super(_TextAndXMLTestResult, self).addSubTest(test, subtest, err)
+ if err is not None:
+ if issubclass(err[0], test.failureException):
+ error_summary = ('failure', err[0], err[1],
+ self._exc_info_to_string(err, test=test))
+ else:
+ error_summary = ('error', err[0], err[1],
+ self._exc_info_to_string(err, test=test))
+ else:
+ error_summary = None
+ self.add_pending_test_case_result(subtest, error_summary=error_summary)
+
+ def printErrors(self):
+ super(_TextAndXMLTestResult, self).printErrors()
+ self.xml_stream.write('<?xml version="1.0"?>\n')
+ self.suite.print_xml_summary(self.xml_stream)
+
+
+class TextAndXMLTestRunner(unittest.TextTestRunner):
+ """A test runner that produces both formatted text results and XML.
+
+ It prints out the names of tests as they are run, errors as they
+ occur, and a summary of the results at the end of the test run.
+ """
+
+ _TEST_RESULT_CLASS = _TextAndXMLTestResult
+
+ _xml_stream = None
+ _testsuites_properties = {}
+
+ def __init__(self, xml_stream=None, *args, **kwargs):
+ """Initialize a TextAndXMLTestRunner.
+
+ Args:
+ xml_stream: file-like or None; XML-formatted test results are output
+ via this object's write() method. If None (the default), the
+ new instance behaves as described in the set_default_xml_stream method
+ documentation below.
+ *args: passed unmodified to unittest.TextTestRunner.__init__.
+ **kwargs: passed unmodified to unittest.TextTestRunner.__init__.
+ """
+ super(TextAndXMLTestRunner, self).__init__(*args, **kwargs)
+ if xml_stream is not None:
+ self._xml_stream = xml_stream
+ # else, do not set self._xml_stream to None -- this allows implicit fallback
+ # to the class attribute's value.
+
+ @classmethod
+ def set_default_xml_stream(cls, xml_stream):
+ """Sets the default XML stream for the class.
+
+ Args:
+ xml_stream: file-like or None; used for instances when xml_stream is None
+ or not passed to their constructors. If None is passed, instances
+ created with xml_stream=None will act as ordinary TextTestRunner
+ instances; this is the default state before any calls to this method
+ have been made.
+ """
+ cls._xml_stream = xml_stream
+
+ def _makeResult(self):
+ if self._xml_stream is None:
+ return super(TextAndXMLTestRunner, self)._makeResult()
+ else:
+ return self._TEST_RESULT_CLASS(
+ self._xml_stream, self.stream, self.descriptions, self.verbosity,
+ testsuites_properties=self._testsuites_properties)
+
+ @classmethod
+ def set_testsuites_property(cls, key, value):
+ cls._testsuites_properties[key] = value
diff --git a/absl/tests/__init__.py b/absl/tests/__init__.py
new file mode 100644
index 0000000..a3bd1cd
--- /dev/null
+++ b/absl/tests/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/absl/tests/app_test.py b/absl/tests/app_test.py
new file mode 100644
index 0000000..1d8b764
--- /dev/null
+++ b/absl/tests/app_test.py
@@ -0,0 +1,359 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for app.py."""
+
+import contextlib
+import copy
+import enum
+import io
+import os
+import re
+import subprocess
+import sys
+import tempfile
+from unittest import mock
+
+from absl import app
+from absl import flags
+from absl.testing import _bazelize_command
+from absl.testing import absltest
+from absl.testing import flagsaver
+from absl.tests import app_test_helper
+
+
+FLAGS = flags.FLAGS
+
+
+_newline_regex = re.compile('(\r\n)|\r')
+
+
+@contextlib.contextmanager
+def patch_main_module_docstring(docstring):
+ old_doc = sys.modules['__main__'].__doc__
+ sys.modules['__main__'].__doc__ = docstring
+ yield
+ sys.modules['__main__'].__doc__ = old_doc
+
+
+def _normalize_newlines(s):
+ return re.sub('(\r\n)|\r', '\n', s)
+
+
+class UnitTests(absltest.TestCase):
+
+ def test_install_exception_handler(self):
+ with self.assertRaises(TypeError):
+ app.install_exception_handler(1)
+
+ def test_usage(self):
+ with mock.patch.object(
+ sys, 'stderr', new=io.StringIO()) as mock_stderr:
+ app.usage()
+ self.assertIn(__doc__, mock_stderr.getvalue())
+ # Assert that flags are written to stderr.
+ self.assertIn('\n --[no]helpfull:', mock_stderr.getvalue())
+
+ def test_usage_shorthelp(self):
+ with mock.patch.object(
+ sys, 'stderr', new=io.StringIO()) as mock_stderr:
+ app.usage(shorthelp=True)
+ # Assert that flags are NOT written to stderr.
+ self.assertNotIn(' --', mock_stderr.getvalue())
+
+ def test_usage_writeto_stderr(self):
+ with mock.patch.object(
+ sys, 'stdout', new=io.StringIO()) as mock_stdout:
+ app.usage(writeto_stdout=True)
+ self.assertIn(__doc__, mock_stdout.getvalue())
+
+ def test_usage_detailed_error(self):
+ with mock.patch.object(
+ sys, 'stderr', new=io.StringIO()) as mock_stderr:
+ app.usage(detailed_error='BAZBAZ')
+ self.assertIn('BAZBAZ', mock_stderr.getvalue())
+
+ def test_usage_exitcode(self):
+ with mock.patch.object(sys, 'stderr', new=sys.stderr):
+ try:
+ app.usage(exitcode=2)
+ self.fail('app.usage(exitcode=1) should raise SystemExit')
+ except SystemExit as e:
+ self.assertEqual(2, e.code)
+
+ def test_usage_expands_docstring(self):
+ with patch_main_module_docstring('Name: %s, %%s'):
+ with mock.patch.object(
+ sys, 'stderr', new=io.StringIO()) as mock_stderr:
+ app.usage()
+ self.assertIn('Name: {}, %s'.format(sys.argv[0]),
+ mock_stderr.getvalue())
+
+ def test_usage_does_not_expand_bad_docstring(self):
+ with patch_main_module_docstring('Name: %s, %%s, %@'):
+ with mock.patch.object(
+ sys, 'stderr', new=io.StringIO()) as mock_stderr:
+ app.usage()
+ self.assertIn('Name: %s, %%s, %@', mock_stderr.getvalue())
+
+ @flagsaver.flagsaver
+ def test_register_and_parse_flags_with_usage_exits_on_only_check_args(self):
+ done = app._register_and_parse_flags_with_usage.done
+ try:
+ app._register_and_parse_flags_with_usage.done = False
+ with self.assertRaises(SystemExit):
+ app._register_and_parse_flags_with_usage(
+ argv=['./program', '--only_check_args'])
+ finally:
+ app._register_and_parse_flags_with_usage.done = done
+
+ def test_register_and_parse_flags_with_usage_exits_on_second_run(self):
+ with self.assertRaises(SystemError):
+ app._register_and_parse_flags_with_usage()
+
+
+class FunctionalTests(absltest.TestCase):
+ """Functional tests that use runs app_test_helper."""
+
+ helper_type = 'pure_python'
+
+ def run_helper(self, expect_success,
+ expected_stdout_substring=None, expected_stderr_substring=None,
+ arguments=(),
+ env_overrides=None):
+ env = os.environ.copy()
+ env['APP_TEST_HELPER_TYPE'] = self.helper_type
+ env['PYTHONIOENCODING'] = 'utf8'
+ if env_overrides:
+ env.update(env_overrides)
+
+ helper = 'absl/tests/app_test_helper_{}'.format(self.helper_type)
+ process = subprocess.Popen(
+ [_bazelize_command.get_executable_path(helper)] + list(arguments),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE, env=env, universal_newlines=False)
+ stdout, stderr = process.communicate()
+ # In Python 2, we can't control the encoding used by universal_newline
+ # mode, which can cause UnicodeDecodeErrors when subprocess tries to
+ # convert the bytes to unicode, so we have to decode it manually.
+ stdout = _normalize_newlines(stdout.decode('utf8'))
+ stderr = _normalize_newlines(stderr.decode('utf8'))
+
+ message = (u'Command: {command}\n'
+ 'Exit Code: {exitcode}\n'
+ '===== stdout =====\n{stdout}'
+ '===== stderr =====\n{stderr}'
+ '=================='.format(
+ command=' '.join([helper] + list(arguments)),
+ exitcode=process.returncode,
+ stdout=stdout or '<no output>\n',
+ stderr=stderr or '<no output>\n'))
+ if expect_success:
+ self.assertEqual(0, process.returncode, msg=message)
+ else:
+ self.assertNotEqual(0, process.returncode, msg=message)
+
+ if expected_stdout_substring:
+ self.assertIn(expected_stdout_substring, stdout, message)
+ if expected_stderr_substring:
+ self.assertIn(expected_stderr_substring, stderr, message)
+
+ return process.returncode, stdout, stderr
+
+ def test_help(self):
+ _, _, stderr = self.run_helper(
+ False,
+ arguments=['--help'],
+ expected_stdout_substring=app_test_helper.__doc__)
+ self.assertNotIn('--', stderr)
+
+ def test_helpfull_basic(self):
+ self.run_helper(
+ False,
+ arguments=['--helpfull'],
+ # --logtostderr is from absl.logging module.
+ expected_stdout_substring='--[no]logtostderr')
+
+ def test_helpfull_unicode_flag_help(self):
+ _, stdout, _ = self.run_helper(
+ False,
+ arguments=['--helpfull'],
+ expected_stdout_substring='str_flag_with_unicode_args')
+
+ self.assertIn(u'smile:\U0001F604', stdout)
+
+ self.assertIn(u'thumb:\U0001F44D', stdout)
+
+ def test_helpshort(self):
+ _, _, stderr = self.run_helper(
+ False,
+ arguments=['--helpshort'],
+ expected_stdout_substring=app_test_helper.__doc__)
+ self.assertNotIn('--', stderr)
+
+ def test_custom_main(self):
+ self.run_helper(
+ True,
+ env_overrides={'APP_TEST_CUSTOM_MAIN_FUNC': 'custom_main'},
+ expected_stdout_substring='Function called: custom_main.')
+
+ def test_custom_argv(self):
+ self.run_helper(
+ True,
+ expected_stdout_substring='argv: ./program pos_arg1',
+ env_overrides={
+ 'APP_TEST_CUSTOM_ARGV': './program --noraise_exception pos_arg1',
+ 'APP_TEST_PRINT_ARGV': '1',
+ })
+
+ def test_gwq_status_file_on_exception(self):
+ if self.helper_type == 'pure_python':
+ # Pure python binary does not write to GWQ Status.
+ return
+
+ tmpdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
+ self.run_helper(
+ False,
+ arguments=['--raise_exception'],
+ env_overrides={'GOOGLE_STATUS_DIR': tmpdir})
+ with open(os.path.join(tmpdir, 'STATUS')) as status_file:
+ self.assertIn('MyException:', status_file.read())
+
+ def test_faulthandler_dumps_stack_on_sigsegv(self):
+ return_code, _, _ = self.run_helper(
+ False,
+ expected_stderr_substring='app_test_helper.py", line',
+ arguments=['--faulthandler_sigsegv'])
+ # sigsegv returns 3 on Windows, and -11 on LINUX/macOS.
+ expected_return_code = 3 if os.name == 'nt' else -11
+ self.assertEqual(expected_return_code, return_code)
+
+ def test_top_level_exception(self):
+ self.run_helper(
+ False,
+ arguments=['--raise_exception'],
+ expected_stderr_substring='MyException')
+
+ def test_only_check_args(self):
+ self.run_helper(
+ True,
+ arguments=['--only_check_args', '--raise_exception'])
+
+ def test_only_check_args_failure(self):
+ self.run_helper(
+ False,
+ arguments=['--only_check_args', '--banana'],
+ expected_stderr_substring='FATAL Flags parsing error')
+
+ def test_usage_error(self):
+ exitcode, _, _ = self.run_helper(
+ False,
+ arguments=['--raise_usage_error'],
+ expected_stderr_substring=app_test_helper.__doc__)
+ self.assertEqual(1, exitcode)
+
+ def test_usage_error_exitcode(self):
+ exitcode, _, _ = self.run_helper(
+ False,
+ arguments=['--raise_usage_error', '--usage_error_exitcode=88'],
+ expected_stderr_substring=app_test_helper.__doc__)
+ self.assertEqual(88, exitcode)
+
+ def test_exception_handler(self):
+ exception_handler_messages = (
+ 'MyExceptionHandler: first\nMyExceptionHandler: second\n')
+ self.run_helper(
+ False,
+ arguments=['--raise_exception'],
+ expected_stdout_substring=exception_handler_messages)
+
+ def test_exception_handler_not_called(self):
+ _, _, stdout = self.run_helper(True)
+ self.assertNotIn('MyExceptionHandler', stdout)
+
+ def test_print_init_callbacks(self):
+ _, stdout, _ = self.run_helper(
+ expect_success=True, arguments=['--print_init_callbacks'])
+ self.assertIn('before app.run', stdout)
+ self.assertIn('during real_main', stdout)
+
+
+class FlagDeepCopyTest(absltest.TestCase):
+ """Make sure absl flags are copy.deepcopy() compatible."""
+
+ def test_deepcopyable(self):
+ copy.deepcopy(FLAGS)
+ # Nothing to assert
+
+
+class FlagValuesExternalizationTest(absltest.TestCase):
+ """Test to make sure FLAGS can be serialized out and parsed back in."""
+
+ @flagsaver.flagsaver
+ def test_nohelp_doesnt_show_help(self):
+ with self.assertRaisesWithPredicateMatch(SystemExit,
+ lambda e: e.code == 1):
+ app.run(
+ len,
+ argv=[
+ './program', '--nohelp', '--helpshort=false', '--helpfull=0',
+ '--helpxml=f'
+ ])
+
+ @flagsaver.flagsaver
+ def test_serialize_roundtrip(self):
+ # Use the global 'FLAGS' as the source, to ensure all the framework defined
+ # flags will go through the round trip process.
+ flags.DEFINE_string('testflag', 'testval', 'help', flag_values=FLAGS)
+
+ flags.DEFINE_multi_enum('test_multi_enum_flag',
+ ['x', 'y'], ['x', 'y', 'z'],
+ 'Multi enum help.',
+ flag_values=FLAGS)
+
+ class Fruit(enum.Enum):
+ APPLE = 1
+ ORANGE = 2
+ TOMATO = 3
+ flags.DEFINE_multi_enum_class('test_multi_enum_class_flag',
+ ['APPLE', 'TOMATO'], Fruit,
+ 'Fruit help.',
+ flag_values=FLAGS)
+
+ new_flag_values = flags.FlagValues()
+ new_flag_values.append_flag_values(FLAGS)
+
+ FLAGS.testflag = 'roundtrip_me'
+ FLAGS.test_multi_enum_flag = ['y', 'z']
+ FLAGS.test_multi_enum_class_flag = [Fruit.ORANGE, Fruit.APPLE]
+ argv = ['binary_name'] + FLAGS.flags_into_string().splitlines()
+
+ self.assertNotEqual(new_flag_values['testflag'], FLAGS.testflag)
+ self.assertNotEqual(new_flag_values['test_multi_enum_flag'],
+ FLAGS.test_multi_enum_flag)
+ self.assertNotEqual(new_flag_values['test_multi_enum_class_flag'],
+ FLAGS.test_multi_enum_class_flag)
+ new_flag_values(argv)
+ self.assertEqual(new_flag_values.testflag, FLAGS.testflag)
+ self.assertEqual(new_flag_values.test_multi_enum_flag,
+ FLAGS.test_multi_enum_flag)
+ self.assertEqual(new_flag_values.test_multi_enum_class_flag,
+ FLAGS.test_multi_enum_class_flag)
+ del FLAGS.testflag
+ del FLAGS.test_multi_enum_flag
+ del FLAGS.test_multi_enum_class_flag
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/tests/app_test_helper.py b/absl/tests/app_test_helper.py
new file mode 100644
index 0000000..6bd1a89
--- /dev/null
+++ b/absl/tests/app_test_helper.py
@@ -0,0 +1,151 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Helper script used by app_test.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+
+try:
+ import faulthandler
+except ImportError:
+ faulthandler = None
+
+from absl import app
+from absl import flags
+
+FLAGS = flags.FLAGS
+flags.DEFINE_boolean('faulthandler_sigsegv', False, 'raise SIGSEGV')
+flags.DEFINE_boolean('raise_exception', False, 'Raise MyException from main.')
+flags.DEFINE_boolean(
+ 'raise_usage_error', False, 'Raise app.UsageError from main.')
+flags.DEFINE_integer(
+ 'usage_error_exitcode', None, 'The exitcode if app.UsageError if raised.')
+flags.DEFINE_string(
+ 'str_flag_with_unicode_args', u'thumb:\U0001F44D', u'smile:\U0001F604')
+flags.DEFINE_boolean('print_init_callbacks', False,
+ 'print init callbacks and exit')
+
+
+class MyException(Exception):
+ pass
+
+
+class MyExceptionHandler(app.ExceptionHandler):
+
+ def __init__(self, message):
+ self.message = message
+
+ def handle(self, exc):
+ sys.stdout.write('MyExceptionHandler: {}\n'.format(self.message))
+
+
+def real_main(argv):
+ """The main function."""
+ if os.environ.get('APP_TEST_PRINT_ARGV', False):
+ sys.stdout.write('argv: {}\n'.format(' '.join(argv)))
+
+ if FLAGS.raise_exception:
+ raise MyException
+
+ if FLAGS.raise_usage_error:
+ if FLAGS.usage_error_exitcode is not None:
+ raise app.UsageError('Error!', FLAGS.usage_error_exitcode)
+ else:
+ raise app.UsageError('Error!')
+
+ if FLAGS.faulthandler_sigsegv:
+ faulthandler._sigsegv() # pylint: disable=protected-access
+ sys.exit(1) # Should not reach here.
+
+ if FLAGS.print_init_callbacks:
+ app.call_after_init(lambda: _callback_results.append('during real_main'))
+ for value in _callback_results:
+ print('callback: {}'.format(value))
+ sys.exit(0)
+
+ # Ensure that we have a random C++ flag in flags.FLAGS; this shows
+ # us that app.run() did the right thing in conjunction with C++ flags.
+ helper_type = os.environ['APP_TEST_HELPER_TYPE']
+ if helper_type == 'clif':
+ if 'heap_check_before_constructors' in flags.FLAGS:
+ print('PASS: C++ flag present and helper_type is {}'.format(helper_type))
+ sys.exit(0)
+ else:
+ print('FAILED: C++ flag absent but helper_type is {}'.format(helper_type))
+ sys.exit(1)
+ elif helper_type == 'pure_python':
+ if 'heap_check_before_constructors' in flags.FLAGS:
+ print('FAILED: C++ flag present but helper_type is pure_python')
+ sys.exit(1)
+ else:
+ print('PASS: C++ flag absent and helper_type is pure_python')
+ sys.exit(0)
+ else:
+ print('Unexpected helper_type "{}"'.format(helper_type))
+ sys.exit(1)
+
+
+def custom_main(argv):
+ print('Function called: custom_main.')
+ real_main(argv)
+
+
+def main(argv):
+ print('Function called: main.')
+ real_main(argv)
+
+
+flags_parser_argv_sentinel = object()
+
+
+def flags_parser_main(argv):
+ print('Function called: main_with_flags_parser.')
+ if argv is not flags_parser_argv_sentinel:
+ sys.exit(
+ 'FAILED: main function should be called with the return value of '
+ 'flags_parser, but found: {}'.format(argv))
+
+
+def flags_parser(argv):
+ print('Function called: flags_parser.')
+ if os.environ.get('APP_TEST_FLAGS_PARSER_PARSE_FLAGS', None):
+ FLAGS(argv)
+ return flags_parser_argv_sentinel
+
+
+# Holds results from callbacks triggered by `app.run_after_init`.
+_callback_results = []
+
+if __name__ == '__main__':
+ kwargs = {'main': main}
+ main_function_name = os.environ.get('APP_TEST_CUSTOM_MAIN_FUNC', None)
+ if main_function_name:
+ kwargs['main'] = globals()[main_function_name]
+ custom_argv = os.environ.get('APP_TEST_CUSTOM_ARGV', None)
+ if custom_argv:
+ kwargs['argv'] = custom_argv.split(' ')
+ if os.environ.get('APP_TEST_USE_CUSTOM_PARSER', None):
+ kwargs['flags_parser'] = flags_parser
+
+ app.call_after_init(lambda: _callback_results.append('before app.run'))
+ app.install_exception_handler(MyExceptionHandler('first'))
+ app.install_exception_handler(MyExceptionHandler('second'))
+ app.run(**kwargs)
+
+ sys.exit('This is not reachable.')
diff --git a/absl/tests/command_name_test.py b/absl/tests/command_name_test.py
new file mode 100644
index 0000000..2679521
--- /dev/null
+++ b/absl/tests/command_name_test.py
@@ -0,0 +1,108 @@
+# -*- coding=utf-8 -*-
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for absl.command_name."""
+
+import ctypes
+import errno
+import os
+import unittest
+from unittest import mock
+
+from absl import command_name
+from absl.testing import absltest
+
+
+def _get_kernel_process_name():
+ """Returns the Kernel's name for our process or an empty string."""
+ try:
+ with open('/proc/self/status', 'rt') as status_file:
+ for line in status_file:
+ if line.startswith('Name:'):
+ return line.split(':', 2)[1].strip().encode('ascii', 'replace')
+ return b''
+ except IOError:
+ return b''
+
+
+def _is_prctl_syscall_available():
+ try:
+ libc = ctypes.CDLL('libc.so.6', use_errno=True)
+ except OSError:
+ return False
+ zero = ctypes.c_ulong(0)
+ try:
+ status = libc.prctl(zero, zero, zero, zero, zero)
+ except AttributeError:
+ return False
+ if status < 0 and errno.ENOSYS == ctypes.get_errno():
+ return False
+ return True
+
+
+@unittest.skipIf(not _get_kernel_process_name(),
+ '_get_kernel_process_name() fails.')
+class CommandNameTest(absltest.TestCase):
+
+ def assertProcessNameSimilarTo(self, new_name):
+ if not isinstance(new_name, bytes):
+ new_name = new_name.encode('ascii', 'replace')
+ actual_name = _get_kernel_process_name()
+ self.assertTrue(actual_name)
+ self.assertTrue(new_name.startswith(actual_name),
+ msg='set {!r} vs found {!r}'.format(new_name, actual_name))
+
+ @unittest.skipIf(not os.access('/proc/self/comm', os.W_OK),
+ '/proc/self/comm is not writeable.')
+ def test_set_kernel_process_name(self):
+ new_name = u'ProcessNam0123456789abcdefghijklmnöp'
+ command_name.set_kernel_process_name(new_name)
+ self.assertProcessNameSimilarTo(new_name)
+
+ @unittest.skipIf(not _is_prctl_syscall_available(),
+ 'prctl() system call missing from libc.so.6.')
+ def test_set_kernel_process_name_no_proc_file(self):
+ new_name = b'NoProcFile0123456789abcdefghijklmnop'
+ mock_open = mock.mock_open()
+ with mock.patch.object(command_name, 'open', mock_open, create=True):
+ mock_open.side_effect = IOError('mock open that raises.')
+ command_name.set_kernel_process_name(new_name)
+ mock_open.assert_called_with('/proc/self/comm', mock.ANY)
+ self.assertProcessNameSimilarTo(new_name)
+
+ def test_set_kernel_process_name_failure(self):
+ starting_name = _get_kernel_process_name()
+ new_name = b'NameTest'
+ mock_open = mock.mock_open()
+ mock_ctypes_cdll = mock.patch('ctypes.CDLL')
+ with mock.patch.object(command_name, 'open', mock_open, create=True):
+ with mock.patch('ctypes.CDLL') as mock_ctypes_cdll:
+ mock_open.side_effect = IOError('mock open that raises.')
+ mock_libc = mock.Mock(['prctl'])
+ mock_ctypes_cdll.return_value = mock_libc
+ command_name.set_kernel_process_name(new_name)
+ mock_open.assert_called_with('/proc/self/comm', mock.ANY)
+ self.assertEqual(1, mock_libc.prctl.call_count)
+ self.assertEqual(starting_name, _get_kernel_process_name()) # No change.
+
+ def test_make_process_name_useful(self):
+ test_name = 'hello.from.test'
+ with mock.patch('sys.argv', [test_name]):
+ command_name.make_process_name_useful()
+ self.assertProcessNameSimilarTo(test_name)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/absl/tests/python_version_test.py b/absl/tests/python_version_test.py
new file mode 100644
index 0000000..eebfff2
--- /dev/null
+++ b/absl/tests/python_version_test.py
@@ -0,0 +1,40 @@
+# Copyright 2021 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test that verifies the Python version used in bazel is expected."""
+
+import sys
+from absl import flags
+from absl.testing import absltest
+
+_EXPECTED_VERSION = flags.DEFINE_string(
+ 'expected_version',
+ None,
+ 'The expected Python SemVer version, '
+ 'can be major.minor or major.minor.patch.',
+)
+
+
+class PythonVersionTest(absltest.TestCase):
+
+ def test_version(self):
+ version = _EXPECTED_VERSION.value
+ if not version:
+ self.skipTest(
+ 'Skipping version test since --expected_version is not specified')
+ num_parts = len(version.split('.'))
+ self.assertEqual('.'.join(map(str, sys.version_info[:num_parts])), version)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..13b2353
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,77 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Abseil setup configuration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+
+try:
+ import setuptools
+except ImportError:
+ from ez_setup import use_setuptools
+ use_setuptools()
+ import setuptools
+
+if sys.version_info < (3, 6):
+ raise RuntimeError('Python version 3.6+ is required.')
+
+setuptools_version = tuple(
+ int(x) for x in setuptools.__version__.split('.')[:2])
+
+additional_kwargs = {}
+if setuptools_version >= (24, 2):
+ # `python_requires` was added in 24.2, see
+ # https://packaging.python.org/guides/distributing-packages-using-setuptools/#python-requires
+ additional_kwargs['python_requires'] = '>=3.6'
+
+_README_PATH = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)), 'README.md')
+with open(_README_PATH, 'rb') as fp:
+ LONG_DESCRIPTION = fp.read().decode('utf-8')
+
+setuptools.setup(
+ name='absl-py',
+ version='1.1.0',
+ description=(
+ 'Abseil Python Common Libraries, '
+ 'see https://github.com/abseil/abseil-py.'),
+ long_description=LONG_DESCRIPTION,
+ long_description_content_type='text/markdown',
+ author='The Abseil Authors',
+ url='https://github.com/abseil/abseil-py',
+ packages=setuptools.find_packages(exclude=[
+ '*.tests', '*.tests.*', 'tests.*', 'tests',
+ ]),
+ include_package_data=True,
+ license='Apache 2.0',
+ classifiers=[
+ 'Programming Language :: Python',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.6',
+ 'Programming Language :: Python :: 3.7',
+ 'Programming Language :: Python :: 3.8',
+ 'Programming Language :: Python :: 3.9',
+ 'Programming Language :: Python :: 3.10',
+ 'Intended Audience :: Developers',
+ 'Topic :: Software Development :: Libraries :: Python Modules',
+ 'License :: OSI Approved :: Apache Software License',
+ 'Operating System :: OS Independent',
+ ],
+ **additional_kwargs,
+)
diff --git a/smoke_tests/sample_app.py b/smoke_tests/sample_app.py
new file mode 100644
index 0000000..532a11e
--- /dev/null
+++ b/smoke_tests/sample_app.py
@@ -0,0 +1,41 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test helper for smoke_test.sh."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+from absl import app
+from absl import flags
+from absl import logging
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('echo', None, 'Text to echo.')
+
+
+def main(argv):
+ del argv # Unused.
+
+ print('Running under Python {0[0]}.{0[1]}.{0[2]}'.format(sys.version_info),
+ file=sys.stderr)
+ logging.info('echo is %s.', FLAGS.echo)
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/smoke_tests/sample_test.py b/smoke_tests/sample_test.py
new file mode 100644
index 0000000..713677a
--- /dev/null
+++ b/smoke_tests/sample_test.py
@@ -0,0 +1,33 @@
+# Copyright 2018 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test helper for smoke_test.sh."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import absltest
+
+
+class SampleTest(absltest.TestCase):
+
+ def test_subtest(self):
+ for i in (1, 2):
+ with self.subTest(i=i):
+ self.assertEqual(i, i)
+ print('msg_for_test')
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/smoke_tests/smoke_test.sh b/smoke_tests/smoke_test.sh
new file mode 100755
index 0000000..99307a4
--- /dev/null
+++ b/smoke_tests/smoke_test.sh
@@ -0,0 +1,70 @@
+#!/bin/bash
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Smoke test to verify setup.py works as expected.
+# Note on Windows, this must run via msys.
+
+# Fail on any error. Treat unset variables an error. Print commands as executed.
+set -eux
+
+if [[ "$#" -ne "2" ]]; then
+ echo 'Must specify the Python interpreter and virtualenv path.'
+ echo 'Usage:'
+ echo ' smoke_tests/smoke_test.sh [Python interpreter path] [virtualenv Path]'
+ exit 1
+fi
+
+ABSL_PYTHON=$1
+ABSL_VIRTUALENV=$2
+TMP_DIR=$(mktemp -d)
+trap "{ rm -rf ${TMP_DIR}; }" EXIT
+# Do not bootstrap pip/setuptools, they are manually installed with get-pip.py
+# inside the virtualenv.
+if ${ABSL_VIRTUALENV} --help | grep '\--no-site-packages'; then
+ no_site_packages_flag="--no-site-packages"
+else
+ # --no-site-packages becomes the default in version 20 and is no longer a
+ # flag.
+ no_site_packages_flag=""
+fi
+${ABSL_VIRTUALENV} ${no_site_packages_flag} --no-pip --no-setuptools --no-wheel \
+ -p ${ABSL_PYTHON} ${TMP_DIR}
+
+# Temporarily disable unbound variable errors to activate virtualenv.
+set +u
+if [[ $(uname -s) == MSYS* ]]; then
+ source ${TMP_DIR}/Scripts/activate
+else
+ source ${TMP_DIR}/bin/activate
+fi
+set -u
+
+trap 'deactivate' EXIT
+
+# When running macOS <= 10.12, pip 9.0.3 is required to connect to PyPI.
+# So we need to manually use the latest pip to install absl-py. See:
+# https://mail.python.org/pipermail/distutils-sig/2018-April/032114.html
+if [[ "$(python -c "import sys; print(sys.version_info.major, sys.version_info.minor)")" == "3 6" ]]; then
+ # Latest get-pip.py no longer supports Python 3.6.
+ curl https://bootstrap.pypa.io/pip/3.6/get-pip.py | python
+else
+ curl https://bootstrap.pypa.io/get-pip.py | python
+fi
+pip --version
+
+python --version
+python setup.py install
+python smoke_tests/sample_app.py --echo smoke 2>&1 |grep 'echo is smoke.'
+python smoke_tests/sample_test.py 2>&1 | grep 'msg_for_test'