diff options
author | Yifan Hong <elsk@google.com> | 2022-06-16 17:11:11 -0700 |
---|---|---|
committer | Yifan Hong <elsk@google.com> | 2022-06-16 17:13:28 -0700 |
commit | 41ad18dfefd9f807e798ffee08323c84eaa8a9e7 (patch) | |
tree | faf738843747b43d4bc5f9b46ec8ea33436a62ad | |
parent | b6bca275ddb60f54428b9f462c8eec6c714f5543 (diff) | |
parent | 58ead8c22230a2493006fa0ab9f76776b6e7280f (diff) | |
download | absl-py-41ad18dfefd9f807e798ffee08323c84eaa8a9e7.tar.gz |
Merge tag 'v1.1.0' into master
Test: none
Bug: 227705464
Bug: 236118730
Change-Id: I28557fa16390caefaac5d21bbc5c0881ad511797
Signed-off-by: Yifan Hong <elsk@google.com>
89 files changed, 30385 insertions, 0 deletions
@@ -0,0 +1,7 @@ +# This is the list of Abseil authors for copyright purposes. +# +# This does not necessarily list everyone who has contributed code, since in +# some cases, their employer may be the copyright holder. To see the full list +# of contributors, see the revision history in source control. + +Google Inc. diff --git a/BUILD.bazel b/BUILD.bazel new file mode 100644 index 0000000..d72b220 --- /dev/null +++ b/BUILD.bazel @@ -0,0 +1,21 @@ +# Copyright 2021 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +exports_files([ + "LICENSE", +]) diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..3c36751 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,302 @@ +# Python Absl Changelog + +All notable changes to Python Absl are recorded here. + +The format is based on [Keep a Changelog](https://keepachangelog.com). + +## Unreleased + +Nothing notable unreleased. + +## 1.1.0 (2022-06-01) + +* `Flag` instances now raise an error if used in a bool context. This prevents + the occasional mistake of testing an instance for truthiness rather than + testing `flag.value`. +* `absl-py` no longer depends on `six`. + +## 1.0.0 (2021-11-09) + +### Changed + +* `absl-py` no longer supports Python 2.7, 3.4, 3.5. All versions have reached + end-of-life for more than a year now. +* New releases will be tagged as `vX.Y.Z` instead of `pypi-vX.Y.Z` in the git + repo going forward. + +## 0.15.0 (2021-10-19) + +### Changed + +* (testing) #128: When running bazel with its `--test_filter=` flag, it now + treats the filters as `unittest`'s `-k` flag in Python 3.7+. + +## 0.14.1 (2021-09-30) + +### Fixed + +* Top-level `LICENSE` file is now exported in bazel. + +## 0.14.0 (2021-09-21) + +### Fixed + +* #171: Creating `argparse_flags.ArgumentParser` with `argument_default=` no + longer raises an exception when other `absl.flags` flags are defined. +* #173: `absltest` now correctly sets up test filtering and fail fast flags + when an explicit `argv=` parameter is passed to `absltest.main`. + +## 0.13.0 (2021-06-14) + +### Added + +* (app) Type annotations for public `app` interfaces. +* (testing) Added new decorator `@absltest.skipThisClass` to indicate a class + contains shared functionality to be used as a base class for other + TestCases, and therefore should be skipped. + +### Changed + +* (app) Annotated the `flag_parser` paramteter of `run` as keyword-only. This + keyword-only constraint will be enforced at runtime in a future release. +* (app, flags) Flag validations now include all errors from disjoint flag + sets, instead of fail fast upon first error from all validators. Multiple + validators on the same flag still fails fast. + +## 0.12.0 (2021-03-08) + +### Added + +* (flags) Made `EnumClassSerializer` and `EnumClassListSerializer` public. +* (flags) Added a `required: Optional[bool] = False` parameter to `DEFINE_*` + functions. +* (testing) flagsaver overrides can now be specified in terms of FlagHolder. +* (testing) `parameterized.product`: Allows testing a method over cartesian + product of parameters values, specified as a sequences of values for each + parameter or as kwargs-like dicts of parameter values. +* (testing) Added public flag holders for `--test_srcdir` and `--test_tmpdir`. + Users should use `absltest.TEST_SRCDIR.value` and + `absltest.TEST_TMPDIR.value` instead of `FLAGS.test_srcdir` and + `FLAGS.test_tmpdir`. + +### Fixed + +* (flags) Made `CsvListSerializer` respect its delimiter argument. + +## 0.11.0 (2020-10-27) + +### Changed + +* (testing) Surplus entries in AssertionError stack traces from absltest are + now suppressed and no longer reported in the xml_reporter. +* (logging) An exception is now raised instead of `logging.fatal` when logging + directories cannot be found. +* (testing) Multiple flags are now set together before their validators run. + This resolves an issue where multi-flag validators rely on specific flag + combinations. +* (flags) As a deterrent for misuse, FlagHolder objects will now raise a + TypeError exception when used in a conditional statement or equality + expression. + +## 0.10.0 (2020-08-19) + +### Added + +* (testing) `_TempDir` and `_TempFile` now implement `__fspath__` to satisfy + `os.PathLike` +* (logging) `--logger_levels`: allows specifying the log levels of loggers. +* (flags) `FLAGS.validate_all_flags`: a new method that validates all flags + and raises an exception if one fails. +* (flags) `FLAGS.get_flags_for_module`: Allows fetching the flags a module + defines. +* (testing) `parameterized.TestCase`: Supports async test definitions. +* (testing,app) Added `--pdb` flag: When true, uncaught exceptions will be + handled by `pdb.post_mortem`. This is an alias for `--pdb_post_mortem`. + +### Changed + +* (testing) Failed tests output a copy/pastable test id to make it easier to + copy the failing test to the command line. +* (testing) `@parameterized.parameters` now treats a single `abc.Mapping` as a + single test case, consistent with `named_parameters`. Previously the + `abc.Mapping` is treated as if only its keys are passed as a list of test + cases. If you were relying on the old inconsistent behavior, explicitly + convert the `abc.Mapping` to a `list`. +* (flags) `DEFINE_enum_class` and `DEFINE_mutlti_enum_class` accept a + `case_sensitive` argument. When `False` (the default), strings are mapped to + enum member names without case sensitivity, and member names are serialized + in lowercase form. Flag definitions for enums whose members include + duplicates when case is ignored must now explicitly pass + `case_sensitive=True`. + +### Fixed + +* (flags) Defining an alias no longer marks the aliased flag as always present + on the command line. +* (flags) Aliasing a multi flag no longer causes the default value to be + appended to. +* (flags) Alias default values now matched the aliased default value. +* (flags) Alias `present` counter now correctly reflects command line usage. + +## 0.9.0 (2019-12-17) + +### Added + +* (testing) `TestCase.enter_context`: Allows using context managers in setUp + and having them automatically exited when a test finishes. + +### Fixed + +* #126: calling `logging.debug(msg, stack_info=...)` no longer throws an + exception in Python 3.8. + +## 0.8.1 (2019-10-08) + +### Fixed + +* (testing) `absl.testing`'s pretty print reporter no longer buffers + RUN/OK/FAILED messages. +* (testing) `create_tempfile` will overwrite pre-existing read-only files. + +## 0.8.0 (2019-08-26) + +### Added + +* (testing) `absltest.expectedFailureIf`: a variant of + `unittest.expectedFailure` that allows a condition to be given. + +### Changed + +* (bazel) Tests now pass when bazel + `--incompatible_allow_python_version_transitions=true` is set. +* (bazel) Both Python 2 and Python 3 versions of tests are now created. To + only run one major Python version, use `bazel test + --test_tag_filters=-python[23]` to ignore the other version. +* (testing) `assertTotallyOrdered` no longer requires objects to implement + `__hash__`. +* (testing) `absltest` now integrates better with `--pdb_post_mortem`. +* (testing) `xml_reporter` now includes timestamps to testcases, test_suite, + test_suites elements. + +### Fixed + +* #99: `absl.logging` no longer registers itself to `logging.root` at import + time. +* #108: Tests now pass with Bazel 0.28.0 on macOS. + +## 0.7.1 (2019-03-12) + +### Added + +* (flags) `flags.mark_bool_flags_as_mutual_exclusive`: convenience function to + check that only one, or at most one, flag among a set of boolean flags are + True. + +### Changed + +* (bazel) Bazel 0.23+ or 0.22+ is now required for building/testing. + Specifically, a Bazel version that supports + `@bazel_tools//tools/python:python_version` for selecting the Python + version. + +### Fixed + +* #94: LICENSE files are now included in sdist. +* #93: Change log added. + +## 0.7.0 (2019-01-11) + +### Added + +* (bazel) testonly=1 has been removed from the testing libraries, which allows + their use outside of testing contexts. +* (flags) Multi-flags now accept any Iterable type for the default value + instead of only lists. Strings are still special cased as before. This + allows sets, generators, views, etc to be used naturally. +* (flags) DEFINE_multi_enum_class: a multi flag variant of enum_class. +* (testing) Most of absltest is now type-annotated. +* (testing) Made AbslTest.assertRegex available under Python 2. This allows + Python 2 code to write more natural Python 3 compatible code. (Note: this + was actually released in 0.6.1, but unannounced) +* (logging) logging.vlog_is_on: helper to tell if a vlog() call will actually + log anything. This allows avoiding computing expansive inputs to a logging + call when logging isn't enabled for that level. + +### Fixed + +* (flags) Pickling flags now raises an clear error instead of a cryptic one. + Pickling flags isn't supported; instead use flags_into_string to serialize + flags. +* (flags) Flags serialization works better: the resulting serialized value, + when deserialized, won't cause --help to be invoked, thus ending the + process. +* (flags) Several flag fixes to make them behave more like the Absl C++ flags: + empty --flagfile is allowed; --nohelp and --help=false don't display help +* (flags) An empty --flagfile value (e.g. "--flagfile=" or "--flagfile=''" + doesn't raise an error; its not just ignored. This matches Abseil C++ + behavior. +* (bazel) Building with Bazel 0.2.0 works without extra incompatibility disable + build flags. + +### Changed + +* (flags) Flag serialization is now deterministic: this improves Bazel build + caching for tools that are affected by flag serialization. + +## 0.6.0 (2018-10-22) + +### Added + +* Tempfile management APIs for tests: read/write/manage tempfiles for test + purposes easily and correctly. See TestCase.create_temp{file/dir} and the + corresponding commit for more info. + +## 0.5.0 (2018-09-17) + +### Added + +* Flags enum support: flags.DEFINE_enum_class allows using an `Enum` derived + class to define the allowed values for a flag. + +## 0.4.1 (2018-08-28) + +### Fixed + +* Flags no long allow spaces in their names + +### Changed + +* XML test output is written at the end of all test execution. +* If the current user's username can't be gotten, fallback to uid, else fall + back to a generic 'unknown' string. + +## 0.4.0 (2018-08-14) + +### Added + +* argparse integration: absl-registered flags can now be accessed via argparse + using absl.flags.argparse_flags: see that module for more information. +* TestCase.assertSameStructure now allows mixed set types. + +### Changed + +* Test output now includes start/end markers for each test ran. This is to + help distinguish output from tests clearly. + +## 0.3.0 (2018-07-25) + +### Added + +* `app.call_after_init`: Register functions to be called after app.run() is + called. Useful for program-wide initialization that library code may need. +* `logging.log_every_n_seconds`: like log_every_n, but based on elapsed time + between logging calls. +* `absltest.mock`: alias to unittest.mock (PY3) for better unittest drop-in + replacement. For PY2, it will be available if mock is importable. + +### Fixed + +* `ABSLLogger.findCaller()`: allow stack_info arg and return value for PY2 +* Make stopTest locking reentrant: this prevents deadlocks for test frameworks + that customize unittest.TextTestResult.stopTest. +* Make --helpfull work with unicode flag help strings. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..5134aff --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,69 @@ +# How to Contribute + +We'd love to accept your patches and contributions to this project. There are +just a few small guidelines you need to follow. + +NOTE: If you are new to GitHub, please start by reading the [Pull Request +howto](https://help.github.com/articles/about-pull-requests/). + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution, +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to <https://cla.developers.google.com/> to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Coding Style + +To keep the source consistent, readable, diffable and easy to merge, we use a +fairly rigid coding style, as defined by the +[google-styleguide](https://github.com/google/styleguide) project. All patches +will be expected to conform to the Python style outlined +[here](https://google.github.io/styleguide/pyguide.html). + +## Guidelines for Pull Requests + +* Create **small PRs** that are narrowly focused on **addressing a single + concern**. We often receive PRs that are trying to fix several things at a + time, but if only one fix is considered acceptable, nothing gets merged and + both author's & review's time is wasted. Create more PRs to address + different concerns and everyone will be happy. + +* For speculative changes, consider opening an + [issue](https://github.com/abseil/abseil-py/issues) and discussing it first. + +* Provide a good **PR description** as a record of **what** change is being + made and **why** it was made. Link to a GitHub issue if it exists. + +* Don't fix code style and formatting unless you are already changing that + line to address an issue. PRs with irrelevant changes won't be merged. If + you do want to fix formatting or style, do that in a separate PR. + +* Unless your PR is trivial, you should expect there will be reviewer comments + that you'll need to address before merging. We expect you to be reasonably + responsive to those comments, otherwise the PR will be closed after 2-3 + weeks of inactivity. + +* Maintain **clean commit history** and use **meaningful commit messages**. + PRs with messy commit history are difficult to review and won't be merged. + Use `rebase -i upstream/main` to curate your commit history and/or to + bring in latest changes from main (but avoid rebasing in the middle of a + code review). + +* Keep your PR up to date with upstream/main (if there are merge conflicts, + we can't really merge your change). + +* **All tests need to be passing** before your change can be merged. We + recommend you **run tests locally** (see + [Running Tests](README.md#running-tests)). + +* Exceptions to the rules can be made if there's a compelling reason for doing + so. That is - the rules are here to serve us, not the other way around, and + the rules need to be serving their intended purpose to be valuable. + +* All submissions, including submissions by project members, require review. @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..1aba38f --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include LICENSE diff --git a/README.md b/README.md new file mode 100644 index 0000000..5ab2365 --- /dev/null +++ b/README.md @@ -0,0 +1,60 @@ +# Abseil Python Common Libraries + +This repository is a collection of Python library code for building Python +applications. The code is collected from Google's own Python code base, and has +been extensively tested and used in production. + +## Features + +* Simple application startup +* Distributed commandline flags system +* Custom logging module with additional features +* Testing utilities + +## Getting Started + +### Installation + +To install the package, simply run: + +```bash +pip install absl-py +``` + +Or install from source: + +```bash +python setup.py install +``` + +### Running Tests + +To run Abseil tests, you can clone the git repo and run +[bazel](https://bazel.build/): + +```bash +git clone https://github.com/abseil/abseil-py.git +cd abseil-py +bazel test absl/... +``` + +### Example Code + +Please refer to +[smoke_tests/sample_app.py](https://github.com/abseil/abseil-py/blob/main/smoke_tests/sample_app.py) +as an example to get started. + +## Documentation + +See the [Abseil Python Developer Guide](https://abseil.io/docs/python/). + +## Future Releases + +The current repository includes an initial set of libraries for early adoption. +More components and interoperability with Abseil C++ Common Libraries +will come in future releases. + +## License + +The Abseil Python library is licensed under the terms of the Apache +license. See [LICENSE](LICENSE) for more information. diff --git a/WORKSPACE b/WORKSPACE new file mode 100644 index 0000000..a964e21 --- /dev/null +++ b/WORKSPACE @@ -0,0 +1,14 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +workspace(name = "io_abseil_py") diff --git a/absl/BUILD b/absl/BUILD new file mode 100644 index 0000000..4e747ea --- /dev/null +++ b/absl/BUILD @@ -0,0 +1,84 @@ +licenses(["notice"]) + +py_library( + name = "app", + srcs = [ + "app.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":command_name", + "//absl/flags", + "//absl/logging", + ], +) + +py_library( + name = "command_name", + srcs = ["command_name.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], +) + +py_library( + name = "tests/app_test_helper", + testonly = 1, + srcs = ["tests/app_test_helper.py"], + srcs_version = "PY2AND3", + deps = [ + ":app", + "//absl/flags", + ], +) + +py_binary( + name = "tests/app_test_helper_pure_python", + testonly = 1, + srcs = ["tests/app_test_helper.py"], + main = "tests/app_test_helper.py", + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":app", + "//absl/flags", + ], +) + +py_test( + name = "tests/app_test", + srcs = ["tests/app_test.py"], + data = [":tests/app_test_helper_pure_python"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":app", + ":tests/app_test_helper", + "//absl/flags", + "//absl/testing:_bazelize_command", + "//absl/testing:absltest", + "//absl/testing:flagsaver", + ], +) + +py_test( + name = "tests/command_name_test", + srcs = ["tests/command_name_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":command_name", + "//absl/testing:absltest", + ], +) + +py_test( + name = "tests/python_version_test", + srcs = ["tests/python_version_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + "//absl/flags", + "//absl/testing:absltest", + ], +) diff --git a/absl/__init__.py b/absl/__init__.py new file mode 100644 index 0000000..a3bd1cd --- /dev/null +++ b/absl/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/absl/app.py b/absl/app.py new file mode 100644 index 0000000..037f75c --- /dev/null +++ b/absl/app.py @@ -0,0 +1,484 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic entry point for Abseil Python applications. + +To use this module, define a 'main' function with a single 'argv' argument and +call app.run(main). For example: + + def main(argv): + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + + if __name__ == '__main__': + app.run(main) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import errno +import os +import pdb +import sys +import textwrap +import traceback + +from absl import command_name +from absl import flags +from absl import logging + +try: + import faulthandler +except ImportError: + faulthandler = None + +FLAGS = flags.FLAGS + +flags.DEFINE_boolean('run_with_pdb', False, 'Set to true for PDB debug mode') +flags.DEFINE_boolean('pdb_post_mortem', False, + 'Set to true to handle uncaught exceptions with PDB ' + 'post mortem.') +flags.DEFINE_alias('pdb', 'pdb_post_mortem') +flags.DEFINE_boolean('run_with_profiling', False, + 'Set to true for profiling the script. ' + 'Execution will be slower, and the output format might ' + 'change over time.') +flags.DEFINE_string('profile_file', None, + 'Dump profile information to a file (for python -m ' + 'pstats). Implies --run_with_profiling.') +flags.DEFINE_boolean('use_cprofile_for_profiling', True, + 'Use cProfile instead of the profile module for ' + 'profiling. This has no effect unless ' + '--run_with_profiling is set.') +flags.DEFINE_boolean('only_check_args', False, + 'Set to true to validate args and exit.', + allow_hide_cpp=True) + + +# If main() exits via an abnormal exception, call into these +# handlers before exiting. +EXCEPTION_HANDLERS = [] + + +class Error(Exception): + pass + + +class UsageError(Error): + """Exception raised when the arguments supplied by the user are invalid. + + Raise this when the arguments supplied are invalid from the point of + view of the application. For example when two mutually exclusive + flags have been supplied or when there are not enough non-flag + arguments. It is distinct from flags.Error which covers the lower + level of parsing and validating individual flags. + """ + + def __init__(self, message, exitcode=1): + super(UsageError, self).__init__(message) + self.exitcode = exitcode + + +class HelpFlag(flags.BooleanFlag): + """Special boolean flag that displays usage and raises SystemExit.""" + NAME = 'help' + SHORT_NAME = '?' + + def __init__(self): + super(HelpFlag, self).__init__( + self.NAME, False, 'show this help', + short_name=self.SHORT_NAME, allow_hide_cpp=True) + + def parse(self, arg): + if self._parse(arg): + usage(shorthelp=True, writeto_stdout=True) + # Advertise --helpfull on stdout, since usage() was on stdout. + print() + print('Try --helpfull to get a list of all flags.') + sys.exit(1) + + +class HelpshortFlag(HelpFlag): + """--helpshort is an alias for --help.""" + NAME = 'helpshort' + SHORT_NAME = None + + +class HelpfullFlag(flags.BooleanFlag): + """Display help for flags in the main module and all dependent modules.""" + + def __init__(self): + super(HelpfullFlag, self).__init__( + 'helpfull', False, 'show full help', allow_hide_cpp=True) + + def parse(self, arg): + if self._parse(arg): + usage(writeto_stdout=True) + sys.exit(1) + + +class HelpXMLFlag(flags.BooleanFlag): + """Similar to HelpfullFlag, but generates output in XML format.""" + + def __init__(self): + super(HelpXMLFlag, self).__init__( + 'helpxml', False, 'like --helpfull, but generates XML output', + allow_hide_cpp=True) + + def parse(self, arg): + if self._parse(arg): + flags.FLAGS.write_help_in_xml_format(sys.stdout) + sys.exit(1) + + +def parse_flags_with_usage(args): + """Tries to parse the flags, print usage, and exit if unparsable. + + Args: + args: [str], a non-empty list of the command line arguments including + program name. + + Returns: + [str], a non-empty list of remaining command line arguments after parsing + flags, including program name. + """ + try: + return FLAGS(args) + except flags.Error as error: + message = str(error) + if '\n' in message: + final_message = 'FATAL Flags parsing error:\n%s\n' % textwrap.indent( + message, ' ') + else: + final_message = 'FATAL Flags parsing error: %s\n' % message + sys.stderr.write(final_message) + sys.stderr.write('Pass --helpshort or --helpfull to see help on flags.\n') + sys.exit(1) + + +_define_help_flags_called = False + + +def define_help_flags(): + """Registers help flags. Idempotent.""" + # Use a global to ensure idempotence. + global _define_help_flags_called + + if not _define_help_flags_called: + flags.DEFINE_flag(HelpFlag()) + flags.DEFINE_flag(HelpshortFlag()) # alias for --help + flags.DEFINE_flag(HelpfullFlag()) + flags.DEFINE_flag(HelpXMLFlag()) + _define_help_flags_called = True + + +def _register_and_parse_flags_with_usage( + argv=None, + flags_parser=parse_flags_with_usage, +): + """Registers help flags, parses arguments and shows usage if appropriate. + + This also calls sys.exit(0) if flag --only_check_args is True. + + Args: + argv: [str], a non-empty list of the command line arguments including + program name, sys.argv is used if None. + flags_parser: Callable[[List[Text]], Any], the function used to parse flags. + The return value of this function is passed to `main` untouched. + It must guarantee FLAGS is parsed after this function is called. + + Returns: + The return value of `flags_parser`. When using the default `flags_parser`, + it returns the following: + [str], a non-empty list of remaining command line arguments after parsing + flags, including program name. + + Raises: + Error: Raised when flags_parser is called, but FLAGS is not parsed. + SystemError: Raised when it's called more than once. + """ + if _register_and_parse_flags_with_usage.done: + raise SystemError('Flag registration can be done only once.') + + define_help_flags() + + original_argv = sys.argv if argv is None else argv + args_to_main = flags_parser(original_argv) + if not FLAGS.is_parsed(): + raise Error('FLAGS must be parsed after flags_parser is called.') + + # Exit when told so. + if FLAGS.only_check_args: + sys.exit(0) + # Immediately after flags are parsed, bump verbosity to INFO if the flag has + # not been set. + if FLAGS['verbosity'].using_default_value: + FLAGS.verbosity = 0 + _register_and_parse_flags_with_usage.done = True + + return args_to_main + +_register_and_parse_flags_with_usage.done = False + + +def _run_main(main, argv): + """Calls main, optionally with pdb or profiler.""" + if FLAGS.run_with_pdb: + sys.exit(pdb.runcall(main, argv)) + elif FLAGS.run_with_profiling or FLAGS.profile_file: + # Avoid import overhead since most apps (including performance-sensitive + # ones) won't be run with profiling. + import atexit + if FLAGS.use_cprofile_for_profiling: + import cProfile as profile + else: + import profile + profiler = profile.Profile() + if FLAGS.profile_file: + atexit.register(profiler.dump_stats, FLAGS.profile_file) + else: + atexit.register(profiler.print_stats) + retval = profiler.runcall(main, argv) + sys.exit(retval) + else: + sys.exit(main(argv)) + + +def _call_exception_handlers(exception): + """Calls any installed exception handlers.""" + for handler in EXCEPTION_HANDLERS: + try: + if handler.wants(exception): + handler.handle(exception) + except: # pylint: disable=bare-except + try: + # We don't want to stop for exceptions in the exception handlers but + # we shouldn't hide them either. + logging.error(traceback.format_exc()) + except: # pylint: disable=bare-except + # In case even the logging statement fails, ignore. + pass + + +def run( + main, + argv=None, + flags_parser=parse_flags_with_usage, +): + """Begins executing the program. + + Args: + main: The main function to execute. It takes an single argument "argv", + which is a list of command line arguments with parsed flags removed. + The return value is passed to `sys.exit`, and so for example + a return value of 0 or None results in a successful termination, whereas + a return value of 1 results in abnormal termination. + For more details, see https://docs.python.org/3/library/sys#sys.exit + argv: A non-empty list of the command line arguments including program name, + sys.argv is used if None. + flags_parser: Callable[[List[Text]], Any], the function used to parse flags. + The return value of this function is passed to `main` untouched. + It must guarantee FLAGS is parsed after this function is called. + Should be passed as a keyword-only arg which will become mandatory in a + future release. + - Parses command line flags with the flag module. + - If there are any errors, prints usage(). + - Calls main() with the remaining arguments. + - If main() raises a UsageError, prints usage and the error message. + """ + try: + args = _run_init( + sys.argv if argv is None else argv, + flags_parser, + ) + while _init_callbacks: + callback = _init_callbacks.popleft() + callback() + try: + _run_main(main, args) + except UsageError as error: + usage(shorthelp=True, detailed_error=error, exitcode=error.exitcode) + except: + exc = sys.exc_info()[1] + # Don't try to post-mortem debug successful SystemExits, since those + # mean there wasn't actually an error. In particular, the test framework + # raises SystemExit(False) even if all tests passed. + if isinstance(exc, SystemExit) and not exc.code: + raise + + # Check the tty so that we don't hang waiting for input in an + # non-interactive scenario. + if FLAGS.pdb_post_mortem and sys.stdout.isatty(): + traceback.print_exc() + print() + print(' *** Entering post-mortem debugging ***') + print() + pdb.post_mortem() + raise + except Exception as e: + _call_exception_handlers(e) + raise + +# Callbacks which have been deferred until after _run_init has been called. +_init_callbacks = collections.deque() + + +def call_after_init(callback): + """Calls the given callback only once ABSL has finished initialization. + + If ABSL has already finished initialization when `call_after_init` is + called then the callback is executed immediately, otherwise `callback` is + stored to be executed after `app.run` has finished initializing (aka. just + before the main function is called). + + If called after `app.run`, this is equivalent to calling `callback()` in the + caller thread. If called before `app.run`, callbacks are run sequentially (in + an undefined order) in the same thread as `app.run`. + + Args: + callback: a callable to be called once ABSL has finished initialization. + This may be immediate if initialization has already finished. It + takes no arguments and returns nothing. + """ + if _run_init.done: + callback() + else: + _init_callbacks.append(callback) + + +def _run_init( + argv, + flags_parser, +): + """Does one-time initialization and re-parses flags on rerun.""" + if _run_init.done: + return flags_parser(argv) + command_name.make_process_name_useful() + # Set up absl logging handler. + logging.use_absl_handler() + args = _register_and_parse_flags_with_usage( + argv=argv, + flags_parser=flags_parser, + ) + if faulthandler: + try: + faulthandler.enable() + except Exception: # pylint: disable=broad-except + # Some tests verify stderr output very closely, so don't print anything. + # Disabled faulthandler is a low-impact error. + pass + _run_init.done = True + return args + + +_run_init.done = False + + +def usage(shorthelp=False, writeto_stdout=False, detailed_error=None, + exitcode=None): + """Writes __main__'s docstring to stderr with some help text. + + Args: + shorthelp: bool, if True, prints only flags from the main module, + rather than all flags. + writeto_stdout: bool, if True, writes help message to stdout, + rather than to stderr. + detailed_error: str, additional detail about why usage info was presented. + exitcode: optional integer, if set, exits with this status code after + writing help. + """ + if writeto_stdout: + stdfile = sys.stdout + else: + stdfile = sys.stderr + + doc = sys.modules['__main__'].__doc__ + if not doc: + doc = '\nUSAGE: %s [flags]\n' % sys.argv[0] + doc = flags.text_wrap(doc, indent=' ', firstline_indent='') + else: + # Replace all '%s' with sys.argv[0], and all '%%' with '%'. + num_specifiers = doc.count('%') - 2 * doc.count('%%') + try: + doc %= (sys.argv[0],) * num_specifiers + except (OverflowError, TypeError, ValueError): + # Just display the docstring as-is. + pass + if shorthelp: + flag_str = FLAGS.main_module_help() + else: + flag_str = FLAGS.get_help() + try: + stdfile.write(doc) + if flag_str: + stdfile.write('\nflags:\n') + stdfile.write(flag_str) + stdfile.write('\n') + if detailed_error is not None: + stdfile.write('\n%s\n' % detailed_error) + except IOError as e: + # We avoid printing a huge backtrace if we get EPIPE, because + # "foo.par --help | less" is a frequent use case. + if e.errno != errno.EPIPE: + raise + if exitcode is not None: + sys.exit(exitcode) + + +class ExceptionHandler(object): + """Base exception handler from which other may inherit.""" + + def wants(self, exc): + """Returns whether this handler wants to handle the exception or not. + + This base class returns True for all exceptions by default. Override in + subclass if it wants to be more selective. + + Args: + exc: Exception, the current exception. + """ + del exc # Unused. + return True + + def handle(self, exc): + """Do something with the current exception. + + Args: + exc: Exception, the current exception + + This method must be overridden. + """ + raise NotImplementedError() + + +def install_exception_handler(handler): + """Installs an exception handler. + + Args: + handler: ExceptionHandler, the exception handler to install. + + Raises: + TypeError: Raised when the handler was not of the correct type. + + All installed exception handlers will be called if main() exits via + an abnormal exception, i.e. not one of SystemExit, KeyboardInterrupt, + FlagsError or UsageError. + """ + if not isinstance(handler, ExceptionHandler): + raise TypeError('handler of type %s does not inherit from ExceptionHandler' + % type(handler)) + EXCEPTION_HANDLERS.append(handler) diff --git a/absl/app.pyi b/absl/app.pyi new file mode 100644 index 0000000..fe5e448 --- /dev/null +++ b/absl/app.pyi @@ -0,0 +1,99 @@ + +from typing import Any, Callable, Collection, Iterable, List, NoReturn, Optional, Text, TypeVar, Union, overload + +from absl.flags import _flag + + +_MainArgs = TypeVar('_MainArgs') +_Exc = TypeVar('_Exc', bound=Exception) + + +class ExceptionHandler(): + + def wants(self, exc: _Exc) -> bool: + ... + + def handle(self, exc: _Exc): + ... + + +EXCEPTION_HANDLERS: List[ExceptionHandler] = ... + + +class HelpFlag(_flag.BooleanFlag): + def __init__(self): + ... + + +class HelpshortFlag(HelpFlag): + ... + + +class HelpfullFlag(_flag.BooleanFlag): + def __init__(self): + ... + + +class HelpXMLFlag(_flag.BooleanFlag): + def __init__(self): + ... + + +def define_help_flags() -> None: + ... + + +@overload +def usage(shorthelp: Union[bool, int] = ..., + writeto_stdout: Union[bool, int] = ..., + detailed_error: Optional[Any] = ..., + exitcode: None = ...) -> None: + ... + + +@overload +def usage(shorthelp: Union[bool, int] = ..., + writeto_stdout: Union[bool, int] = ..., + detailed_error: Optional[Any] = ..., + exitcode: int = ...) -> NoReturn: + ... + + +def install_exception_handler(handler: ExceptionHandler) -> None: + ... + + +class Error(Exception): + ... + + +class UsageError(Error): + exitcode: int + + +def parse_flags_with_usage(args: List[Text]) -> List[Text]: + ... + + +def call_after_init(callback: Callable[[], Any]) -> None: + ... + + +# Without the flag_parser argument, `main` should require a List[Text]. +@overload +def run( + main: Callable[[List[Text]], Any], + argv: Optional[List[Text]] = ..., + *, +) -> NoReturn: + ... + + +@overload +def run( + main: Callable[[_MainArgs], Any], + argv: Optional[List[Text]] = ..., + *, + flags_parser: Callable[[List[Text]], _MainArgs], +) -> NoReturn: + ... diff --git a/absl/command_name.py b/absl/command_name.py new file mode 100644 index 0000000..3bf9fad --- /dev/null +++ b/absl/command_name.py @@ -0,0 +1,67 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A tiny stand alone library to change the kernel process name on Linux.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +# This library must be kept small and stand alone. It is used by small things +# that require no extension modules. + + +def make_process_name_useful(): + """Sets the process name to something better than 'python' if possible.""" + set_kernel_process_name(os.path.basename(sys.argv[0])) + + +def set_kernel_process_name(name): + """Changes the Kernel's /proc/self/status process name on Linux. + + The kernel name is NOT what will be shown by the ps or top command. + It is a 15 character string stored in the kernel's process table that + is included in the kernel log when a process is OOM killed. + The first 15 bytes of name are used. Non-ASCII unicode is replaced with '?'. + + Does nothing if /proc/self/comm cannot be written or prctl() fails. + + Args: + name: bytes|unicode, the Linux kernel's command name to set. + """ + if not isinstance(name, bytes): + name = name.encode('ascii', 'replace') + try: + # This is preferred to using ctypes to try and call prctl() when possible. + with open('/proc/self/comm', 'wb') as proc_comm: + proc_comm.write(name[:15]) + except EnvironmentError: + try: + import ctypes + except ImportError: + return # No ctypes. + try: + libc = ctypes.CDLL('libc.so.6') + except EnvironmentError: + return # No libc.so.6. + pr_set_name = ctypes.c_ulong(15) # linux/prctl.h PR_SET_NAME value. + zero = ctypes.c_ulong(0) + try: + libc.prctl(pr_set_name, name, zero, zero, zero) + # Ignore the prctl return value. Nothing we can do if it errored. + except AttributeError: + return # No prctl. diff --git a/absl/flags/BUILD b/absl/flags/BUILD new file mode 100644 index 0000000..33a7b07 --- /dev/null +++ b/absl/flags/BUILD @@ -0,0 +1,314 @@ +licenses(["notice"]) + +py_library( + name = "flags", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":_argument_parser", + ":_defines", + ":_exceptions", + ":_flag", + ":_flagvalues", + ":_helpers", + ":_validators", + ], +) + +py_library( + name = "argparse_flags", + srcs = ["argparse_flags.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [":flags"], +) + +py_library( + name = "_argument_parser", + srcs = ["_argument_parser.py"], + srcs_version = "PY2AND3", + deps = [ + ":_helpers", + ], +) + +py_library( + name = "_defines", + srcs = ["_defines.py"], + srcs_version = "PY2AND3", + deps = [ + ":_argument_parser", + ":_exceptions", + ":_flag", + ":_flagvalues", + ":_helpers", + ":_validators", + ], +) + +py_library( + name = "_exceptions", + srcs = ["_exceptions.py"], + srcs_version = "PY2AND3", + deps = [ + ":_helpers", + ], +) + +py_library( + name = "_flag", + srcs = ["_flag.py"], + srcs_version = "PY2AND3", + deps = [ + ":_argument_parser", + ":_exceptions", + ":_helpers", + ], +) + +py_library( + name = "_flagvalues", + srcs = ["_flagvalues.py"], + srcs_version = "PY2AND3", + deps = [ + ":_exceptions", + ":_flag", + ":_helpers", + ":_validators_classes", + ], +) + +py_library( + name = "_helpers", + srcs = ["_helpers.py"], + srcs_version = "PY2AND3", +) + +py_library( + name = "_validators", + srcs = [ + "_validators.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":_exceptions", + ":_flagvalues", + ":_validators_classes", + ], +) + +py_library( + name = "_validators_classes", + srcs = [ + "_validators_classes.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":_exceptions", + ], +) + +py_test( + name = "tests/_argument_parser_test", + srcs = ["tests/_argument_parser_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":_argument_parser", + "//absl/testing:absltest", + "//absl/testing:parameterized", + ], +) + +py_test( + name = "tests/_flag_test", + srcs = ["tests/_flag_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":_argument_parser", + ":_exceptions", + ":_flag", + "//absl/testing:absltest", + "//absl/testing:parameterized", + ], +) + +py_test( + name = "tests/_flagvalues_test", + size = "small", + srcs = ["tests/_flagvalues_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":_defines", + ":_exceptions", + ":_flagvalues", + ":_helpers", + ":_validators", + ":tests/module_foo", + "//absl/logging", + "//absl/testing:absltest", + "//absl/testing:parameterized", + ], +) + +py_test( + name = "tests/_helpers_test", + size = "small", + srcs = ["tests/_helpers_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":_helpers", + ":tests/module_bar", + ":tests/module_foo", + "//absl/testing:absltest", + ], +) + +py_test( + name = "tests/_validators_test", + size = "small", + srcs = ["tests/_validators_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":_defines", + ":_exceptions", + ":_flagvalues", + ":_validators", + "//absl/testing:absltest", + ], +) + +py_test( + name = "tests/argparse_flags_test", + size = "small", + srcs = ["tests/argparse_flags_test.py"], + data = [":tests/argparse_flags_test_helper"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":argparse_flags", + ":flags", + "//absl/logging", + "//absl/testing:_bazelize_command", + "//absl/testing:absltest", + "//absl/testing:parameterized", + ], +) + +py_binary( + name = "tests/argparse_flags_test_helper", + testonly = 1, + srcs = ["tests/argparse_flags_test_helper.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":argparse_flags", + ":flags", + "//absl:app", + ], +) + +py_test( + name = "tests/flags_formatting_test", + size = "small", + srcs = ["tests/flags_formatting_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":_helpers", + ":flags", + "//absl/testing:absltest", + ], +) + +py_test( + name = "tests/flags_helpxml_test", + size = "small", + srcs = ["tests/flags_helpxml_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":_helpers", + ":flags", + ":tests/module_bar", + "//absl/testing:absltest", + ], +) + +py_test( + name = "tests/flags_numeric_bounds_test", + size = "small", + srcs = ["tests/flags_numeric_bounds_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":_validators", + ":flags", + "//absl/testing:absltest", + ], +) + +py_test( + name = "tests/flags_test", + size = "small", + srcs = ["tests/flags_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":_exceptions", + ":_helpers", + ":flags", + ":tests/module_bar", + ":tests/module_baz", + ":tests/module_foo", + "//absl/testing:absltest", + ], +) + +py_test( + name = "tests/flags_unicode_literals_test", + size = "small", + srcs = ["tests/flags_unicode_literals_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":flags", + "//absl/testing:absltest", + ], +) + +py_library( + name = "tests/module_bar", + testonly = 1, + srcs = ["tests/module_bar.py"], + srcs_version = "PY2AND3", + deps = [ + ":_helpers", + ":flags", + ], +) + +py_library( + name = "tests/module_baz", + testonly = 1, + srcs = ["tests/module_baz.py"], + srcs_version = "PY2AND3", + deps = [":flags"], +) + +py_library( + name = "tests/module_foo", + testonly = 1, + srcs = ["tests/module_foo.py"], + srcs_version = "PY2AND3", + deps = [ + ":_helpers", + ":flags", + ":tests/module_bar", + ], +) diff --git a/absl/flags/__init__.py b/absl/flags/__init__.py new file mode 100644 index 0000000..e6014a6 --- /dev/null +++ b/absl/flags/__init__.py @@ -0,0 +1,145 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This package is used to define and parse command line flags. + +This package defines a *distributed* flag-definition policy: rather than +an application having to define all flags in or near main(), each Python +module defines flags that are useful to it. When one Python module +imports another, it gains access to the other's flags. (This is +implemented by having all modules share a common, global registry object +containing all the flag information.) + +Flags are defined through the use of one of the DEFINE_xxx functions. +The specific function used determines how the flag is parsed, checked, +and optionally type-converted, when it's seen on the command line. +""" + +import getopt +import os +import re +import sys +import types +import warnings + +from absl.flags import _argument_parser +from absl.flags import _defines +from absl.flags import _exceptions +from absl.flags import _flag +from absl.flags import _flagvalues +from absl.flags import _helpers +from absl.flags import _validators + +# Initialize the FLAGS_MODULE as early as possible. +# It's only used by adopt_module_key_flags to take SPECIAL_FLAGS into account. +_helpers.FLAGS_MODULE = sys.modules[__name__] + +# Add current module to disclaimed module ids. +_helpers.disclaim_module_ids.add(id(sys.modules[__name__])) + +# DEFINE functions. They are explained in more details in the module doc string. +# pylint: disable=invalid-name +DEFINE = _defines.DEFINE +DEFINE_flag = _defines.DEFINE_flag +DEFINE_string = _defines.DEFINE_string +DEFINE_boolean = _defines.DEFINE_boolean +DEFINE_bool = DEFINE_boolean # Match C++ API. +DEFINE_float = _defines.DEFINE_float +DEFINE_integer = _defines.DEFINE_integer +DEFINE_enum = _defines.DEFINE_enum +DEFINE_enum_class = _defines.DEFINE_enum_class +DEFINE_list = _defines.DEFINE_list +DEFINE_spaceseplist = _defines.DEFINE_spaceseplist +DEFINE_multi = _defines.DEFINE_multi +DEFINE_multi_string = _defines.DEFINE_multi_string +DEFINE_multi_integer = _defines.DEFINE_multi_integer +DEFINE_multi_float = _defines.DEFINE_multi_float +DEFINE_multi_enum = _defines.DEFINE_multi_enum +DEFINE_multi_enum_class = _defines.DEFINE_multi_enum_class +DEFINE_alias = _defines.DEFINE_alias +# pylint: enable=invalid-name + +# Flag validators. +register_validator = _validators.register_validator +validator = _validators.validator +register_multi_flags_validator = _validators.register_multi_flags_validator +multi_flags_validator = _validators.multi_flags_validator +mark_flag_as_required = _validators.mark_flag_as_required +mark_flags_as_required = _validators.mark_flags_as_required +mark_flags_as_mutual_exclusive = _validators.mark_flags_as_mutual_exclusive +mark_bool_flags_as_mutual_exclusive = _validators.mark_bool_flags_as_mutual_exclusive + +# Key flag related functions. +declare_key_flag = _defines.declare_key_flag +adopt_module_key_flags = _defines.adopt_module_key_flags +disclaim_key_flags = _defines.disclaim_key_flags + +# Module exceptions. +# pylint: disable=invalid-name +Error = _exceptions.Error +CantOpenFlagFileError = _exceptions.CantOpenFlagFileError +DuplicateFlagError = _exceptions.DuplicateFlagError +IllegalFlagValueError = _exceptions.IllegalFlagValueError +UnrecognizedFlagError = _exceptions.UnrecognizedFlagError +UnparsedFlagAccessError = _exceptions.UnparsedFlagAccessError +ValidationError = _exceptions.ValidationError +FlagNameConflictsWithMethodError = _exceptions.FlagNameConflictsWithMethodError + +# Public classes. +Flag = _flag.Flag +BooleanFlag = _flag.BooleanFlag +EnumFlag = _flag.EnumFlag +EnumClassFlag = _flag.EnumClassFlag +MultiFlag = _flag.MultiFlag +MultiEnumClassFlag = _flag.MultiEnumClassFlag +FlagHolder = _flagvalues.FlagHolder +FlagValues = _flagvalues.FlagValues +ArgumentParser = _argument_parser.ArgumentParser +BooleanParser = _argument_parser.BooleanParser +EnumParser = _argument_parser.EnumParser +EnumClassParser = _argument_parser.EnumClassParser +ArgumentSerializer = _argument_parser.ArgumentSerializer +FloatParser = _argument_parser.FloatParser +IntegerParser = _argument_parser.IntegerParser +BaseListParser = _argument_parser.BaseListParser +ListParser = _argument_parser.ListParser +ListSerializer = _argument_parser.ListSerializer +EnumClassListSerializer = _argument_parser.EnumClassListSerializer +CsvListSerializer = _argument_parser.CsvListSerializer +WhitespaceSeparatedListParser = _argument_parser.WhitespaceSeparatedListParser +EnumClassSerializer = _argument_parser.EnumClassSerializer +# pylint: enable=invalid-name + +# Helper functions. +get_help_width = _helpers.get_help_width +text_wrap = _helpers.text_wrap +flag_dict_to_args = _helpers.flag_dict_to_args +doc_to_help = _helpers.doc_to_help + +# Special flags. +_helpers.SPECIAL_FLAGS = FlagValues() + +DEFINE_string( + 'flagfile', '', + 'Insert flag definitions from the given file into the command line.', + _helpers.SPECIAL_FLAGS) # pytype: disable=wrong-arg-types + +DEFINE_string('undefok', '', + 'comma-separated list of flag names that it is okay to specify ' + 'on the command line even if the program does not define a flag ' + 'with that name. IMPORTANT: flags in this list that have ' + 'arguments MUST use the --flag=value format.', + _helpers.SPECIAL_FLAGS) # pytype: disable=wrong-arg-types + +# The global FlagValues instance. +FLAGS = _flagvalues.FLAGS diff --git a/absl/flags/__init__.pyi b/absl/flags/__init__.pyi new file mode 100644 index 0000000..4eee59e --- /dev/null +++ b/absl/flags/__init__.pyi @@ -0,0 +1,103 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.flags import _argument_parser +from absl.flags import _defines +from absl.flags import _exceptions +from absl.flags import _flag +from absl.flags import _flagvalues +from absl.flags import _helpers +from absl.flags import _validators + +# DEFINE functions. They are explained in more details in the module doc string. +# pylint: disable=invalid-name +DEFINE = _defines.DEFINE +DEFINE_flag = _defines.DEFINE_flag +DEFINE_string = _defines.DEFINE_string +DEFINE_boolean = _defines.DEFINE_boolean +DEFINE_bool = DEFINE_boolean # Match C++ API. +DEFINE_float = _defines.DEFINE_float +DEFINE_integer = _defines.DEFINE_integer +DEFINE_enum = _defines.DEFINE_enum +DEFINE_enum_class = _defines.DEFINE_enum_class +DEFINE_list = _defines.DEFINE_list +DEFINE_spaceseplist = _defines.DEFINE_spaceseplist +DEFINE_multi = _defines.DEFINE_multi +DEFINE_multi_string = _defines.DEFINE_multi_string +DEFINE_multi_integer = _defines.DEFINE_multi_integer +DEFINE_multi_float = _defines.DEFINE_multi_float +DEFINE_multi_enum = _defines.DEFINE_multi_enum +DEFINE_multi_enum_class = _defines.DEFINE_multi_enum_class +DEFINE_alias = _defines.DEFINE_alias +# pylint: enable=invalid-name + +# Flag validators. +register_validator = _validators.register_validator +validator = _validators.validator +register_multi_flags_validator = _validators.register_multi_flags_validator +multi_flags_validator = _validators.multi_flags_validator +mark_flag_as_required = _validators.mark_flag_as_required +mark_flags_as_required = _validators.mark_flags_as_required +mark_flags_as_mutual_exclusive = _validators.mark_flags_as_mutual_exclusive +mark_bool_flags_as_mutual_exclusive = _validators.mark_bool_flags_as_mutual_exclusive + +# Key flag related functions. +declare_key_flag = _defines.declare_key_flag +adopt_module_key_flags = _defines.adopt_module_key_flags +disclaim_key_flags = _defines.disclaim_key_flags + +# Module exceptions. +# pylint: disable=invalid-name +Error = _exceptions.Error +CantOpenFlagFileError = _exceptions.CantOpenFlagFileError +DuplicateFlagError = _exceptions.DuplicateFlagError +IllegalFlagValueError = _exceptions.IllegalFlagValueError +UnrecognizedFlagError = _exceptions.UnrecognizedFlagError +UnparsedFlagAccessError = _exceptions.UnparsedFlagAccessError +ValidationError = _exceptions.ValidationError +FlagNameConflictsWithMethodError = _exceptions.FlagNameConflictsWithMethodError + +# Public classes. +Flag = _flag.Flag +BooleanFlag = _flag.BooleanFlag +EnumFlag = _flag.EnumFlag +EnumClassFlag = _flag.EnumClassFlag +MultiFlag = _flag.MultiFlag +MultiEnumClassFlag = _flag.MultiEnumClassFlag +FlagHolder = _flagvalues.FlagHolder +FlagValues = _flagvalues.FlagValues +ArgumentParser = _argument_parser.ArgumentParser +BooleanParser = _argument_parser.BooleanParser +EnumParser = _argument_parser.EnumParser +EnumClassParser = _argument_parser.EnumClassParser +ArgumentSerializer = _argument_parser.ArgumentSerializer +FloatParser = _argument_parser.FloatParser +IntegerParser = _argument_parser.IntegerParser +BaseListParser = _argument_parser.BaseListParser +ListParser = _argument_parser.ListParser +ListSerializer = _argument_parser.ListSerializer +CsvListSerializer = _argument_parser.CsvListSerializer +WhitespaceSeparatedListParser = _argument_parser.WhitespaceSeparatedListParser +EnumClassSerializer = _argument_parser.EnumClassSerializer +# pylint: enable=invalid-name + +# Helper functions. +get_help_width = _helpers.get_help_width +text_wrap = _helpers.text_wrap +flag_dict_to_args = _helpers.flag_dict_to_args +doc_to_help = _helpers.doc_to_help + +# The global FlagValues instance. +FLAGS = _flagvalues.FLAGS + diff --git a/absl/flags/_argument_parser.py b/absl/flags/_argument_parser.py new file mode 100644 index 0000000..9c6c8c6 --- /dev/null +++ b/absl/flags/_argument_parser.py @@ -0,0 +1,629 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains base classes used to parse and convert arguments. + +Do NOT import this module directly. Import the flags package and use the +aliases defined at the package level instead. +""" + +import collections +import csv +import io +import string + +from absl.flags import _helpers + + +def _is_integer_type(instance): + """Returns True if instance is an integer, and not a bool.""" + return (isinstance(instance, int) and + not isinstance(instance, bool)) + + +class _ArgumentParserCache(type): + """Metaclass used to cache and share argument parsers among flags.""" + + _instances = {} + + def __call__(cls, *args, **kwargs): + """Returns an instance of the argument parser cls. + + This method overrides behavior of the __new__ methods in + all subclasses of ArgumentParser (inclusive). If an instance + for cls with the same set of arguments exists, this instance is + returned, otherwise a new instance is created. + + If any keyword arguments are defined, or the values in args + are not hashable, this method always returns a new instance of + cls. + + Args: + *args: Positional initializer arguments. + **kwargs: Initializer keyword arguments. + + Returns: + An instance of cls, shared or new. + """ + if kwargs: + return type.__call__(cls, *args, **kwargs) + else: + instances = cls._instances + key = (cls,) + tuple(args) + try: + return instances[key] + except KeyError: + # No cache entry for key exists, create a new one. + return instances.setdefault(key, type.__call__(cls, *args)) + except TypeError: + # An object in args cannot be hashed, always return + # a new instance. + return type.__call__(cls, *args) + + +# NOTE about Genericity and Metaclass of ArgumentParser. +# (1) In the .py source (this file) +# - is not declared as Generic +# - has _ArgumentParserCache as a metaclass +# (2) In the .pyi source (type stub) +# - is declared as Generic +# - doesn't have a metaclass +# The reason we need this is due to Generic having a different metaclass +# (for python versions <= 3.7) and a class can have only one metaclass. +# +# * Lack of metaclass in .pyi is not a deal breaker, since the metaclass +# doesn't affect any type information. Also type checkers can check the type +# parameters. +# * However, not declaring ArgumentParser as Generic in the source affects +# runtime annotation processing. In particular this means, subclasses should +# inherit from `ArgumentParser` and not `ArgumentParser[SomeType]`. +# The corresponding DEFINE_someType method (the public API) can be annotated +# to return FlagHolder[SomeType]. +class ArgumentParser(metaclass=_ArgumentParserCache): + """Base class used to parse and convert arguments. + + The parse() method checks to make sure that the string argument is a + legal value and convert it to a native type. If the value cannot be + converted, it should throw a 'ValueError' exception with a human + readable explanation of why the value is illegal. + + Subclasses should also define a syntactic_help string which may be + presented to the user to describe the form of the legal values. + + Argument parser classes must be stateless, since instances are cached + and shared between flags. Initializer arguments are allowed, but all + member variables must be derived from initializer arguments only. + """ + + syntactic_help = '' + + def parse(self, argument): + """Parses the string argument and returns the native value. + + By default it returns its argument unmodified. + + Args: + argument: string argument passed in the commandline. + + Raises: + ValueError: Raised when it fails to parse the argument. + TypeError: Raised when the argument has the wrong type. + + Returns: + The parsed value in native type. + """ + if not isinstance(argument, str): + raise TypeError('flag value must be a string, found "{}"'.format( + type(argument))) + return argument + + def flag_type(self): + """Returns a string representing the type of the flag.""" + return 'string' + + def _custom_xml_dom_elements(self, doc): + """Returns a list of minidom.Element to add additional flag information. + + Args: + doc: minidom.Document, the DOM document it should create nodes from. + """ + del doc # Unused. + return [] + + +class ArgumentSerializer(object): + """Base class for generating string representations of a flag value.""" + + def serialize(self, value): + """Returns a serialized string of the value.""" + return _helpers.str_or_unicode(value) + + +class NumericParser(ArgumentParser): + """Parser of numeric values. + + Parsed value may be bounded to a given upper and lower bound. + """ + + def is_outside_bounds(self, val): + """Returns whether the value is outside the bounds or not.""" + return ((self.lower_bound is not None and val < self.lower_bound) or + (self.upper_bound is not None and val > self.upper_bound)) + + def parse(self, argument): + """See base class.""" + val = self.convert(argument) + if self.is_outside_bounds(val): + raise ValueError('%s is not %s' % (val, self.syntactic_help)) + return val + + def _custom_xml_dom_elements(self, doc): + elements = [] + if self.lower_bound is not None: + elements.append(_helpers.create_xml_dom_element( + doc, 'lower_bound', self.lower_bound)) + if self.upper_bound is not None: + elements.append(_helpers.create_xml_dom_element( + doc, 'upper_bound', self.upper_bound)) + return elements + + def convert(self, argument): + """Returns the correct numeric value of argument. + + Subclass must implement this method, and raise TypeError if argument is not + string or has the right numeric type. + + Args: + argument: string argument passed in the commandline, or the numeric type. + + Raises: + TypeError: Raised when argument is not a string or the right numeric type. + ValueError: Raised when failed to convert argument to the numeric value. + """ + raise NotImplementedError + + +class FloatParser(NumericParser): + """Parser of floating point values. + + Parsed value may be bounded to a given upper and lower bound. + """ + number_article = 'a' + number_name = 'number' + syntactic_help = ' '.join((number_article, number_name)) + + def __init__(self, lower_bound=None, upper_bound=None): + super(FloatParser, self).__init__() + self.lower_bound = lower_bound + self.upper_bound = upper_bound + sh = self.syntactic_help + if lower_bound is not None and upper_bound is not None: + sh = ('%s in the range [%s, %s]' % (sh, lower_bound, upper_bound)) + elif lower_bound == 0: + sh = 'a non-negative %s' % self.number_name + elif upper_bound == 0: + sh = 'a non-positive %s' % self.number_name + elif upper_bound is not None: + sh = '%s <= %s' % (self.number_name, upper_bound) + elif lower_bound is not None: + sh = '%s >= %s' % (self.number_name, lower_bound) + self.syntactic_help = sh + + def convert(self, argument): + """Returns the float value of argument.""" + if (_is_integer_type(argument) or isinstance(argument, float) or + isinstance(argument, str)): + return float(argument) + else: + raise TypeError( + 'Expect argument to be a string, int, or float, found {}'.format( + type(argument))) + + def flag_type(self): + """See base class.""" + return 'float' + + +class IntegerParser(NumericParser): + """Parser of an integer value. + + Parsed value may be bounded to a given upper and lower bound. + """ + number_article = 'an' + number_name = 'integer' + syntactic_help = ' '.join((number_article, number_name)) + + def __init__(self, lower_bound=None, upper_bound=None): + super(IntegerParser, self).__init__() + self.lower_bound = lower_bound + self.upper_bound = upper_bound + sh = self.syntactic_help + if lower_bound is not None and upper_bound is not None: + sh = ('%s in the range [%s, %s]' % (sh, lower_bound, upper_bound)) + elif lower_bound == 1: + sh = 'a positive %s' % self.number_name + elif upper_bound == -1: + sh = 'a negative %s' % self.number_name + elif lower_bound == 0: + sh = 'a non-negative %s' % self.number_name + elif upper_bound == 0: + sh = 'a non-positive %s' % self.number_name + elif upper_bound is not None: + sh = '%s <= %s' % (self.number_name, upper_bound) + elif lower_bound is not None: + sh = '%s >= %s' % (self.number_name, lower_bound) + self.syntactic_help = sh + + def convert(self, argument): + """Returns the int value of argument.""" + if _is_integer_type(argument): + return argument + elif isinstance(argument, str): + base = 10 + if len(argument) > 2 and argument[0] == '0': + if argument[1] == 'o': + base = 8 + elif argument[1] == 'x': + base = 16 + return int(argument, base) + else: + raise TypeError('Expect argument to be a string or int, found {}'.format( + type(argument))) + + def flag_type(self): + """See base class.""" + return 'int' + + +class BooleanParser(ArgumentParser): + """Parser of boolean values.""" + + def parse(self, argument): + """See base class.""" + if isinstance(argument, str): + if argument.lower() in ('true', 't', '1'): + return True + elif argument.lower() in ('false', 'f', '0'): + return False + else: + raise ValueError('Non-boolean argument to boolean flag', argument) + elif isinstance(argument, int): + # Only allow bool or integer 0, 1. + # Note that float 1.0 == True, 0.0 == False. + bool_value = bool(argument) + if argument == bool_value: + return bool_value + else: + raise ValueError('Non-boolean argument to boolean flag', argument) + + raise TypeError('Non-boolean argument to boolean flag', argument) + + def flag_type(self): + """See base class.""" + return 'bool' + + +class EnumParser(ArgumentParser): + """Parser of a string enum value (a string value from a given set).""" + + def __init__(self, enum_values, case_sensitive=True): + """Initializes EnumParser. + + Args: + enum_values: [str], a non-empty list of string values in the enum. + case_sensitive: bool, whether or not the enum is to be case-sensitive. + + Raises: + ValueError: When enum_values is empty. + """ + if not enum_values: + raise ValueError( + 'enum_values cannot be empty, found "{}"'.format(enum_values)) + super(EnumParser, self).__init__() + self.enum_values = enum_values + self.case_sensitive = case_sensitive + + def parse(self, argument): + """Determines validity of argument and returns the correct element of enum. + + Args: + argument: str, the supplied flag value. + + Returns: + The first matching element from enum_values. + + Raises: + ValueError: Raised when argument didn't match anything in enum. + """ + if self.case_sensitive: + if argument not in self.enum_values: + raise ValueError('value should be one of <%s>' % + '|'.join(self.enum_values)) + else: + return argument + else: + if argument.upper() not in [value.upper() for value in self.enum_values]: + raise ValueError('value should be one of <%s>' % + '|'.join(self.enum_values)) + else: + return [value for value in self.enum_values + if value.upper() == argument.upper()][0] + + def flag_type(self): + """See base class.""" + return 'string enum' + + +class EnumClassParser(ArgumentParser): + """Parser of an Enum class member.""" + + def __init__(self, enum_class, case_sensitive=True): + """Initializes EnumParser. + + Args: + enum_class: class, the Enum class with all possible flag values. + case_sensitive: bool, whether or not the enum is to be case-sensitive. If + False, all member names must be unique when case is ignored. + + Raises: + TypeError: When enum_class is not a subclass of Enum. + ValueError: When enum_class is empty. + """ + # Users must have an Enum class defined before using EnumClass flag. + # Therefore this dependency is guaranteed. + import enum + + if not issubclass(enum_class, enum.Enum): + raise TypeError('{} is not a subclass of Enum.'.format(enum_class)) + if not enum_class.__members__: + raise ValueError('enum_class cannot be empty, but "{}" is empty.' + .format(enum_class)) + if not case_sensitive: + members = collections.Counter( + name.lower() for name in enum_class.__members__) + duplicate_keys = { + member for member, count in members.items() if count > 1 + } + if duplicate_keys: + raise ValueError( + 'Duplicate enum values for {} using case_sensitive=False'.format( + duplicate_keys)) + + super(EnumClassParser, self).__init__() + self.enum_class = enum_class + self._case_sensitive = case_sensitive + if case_sensitive: + self._member_names = tuple(enum_class.__members__) + else: + self._member_names = tuple( + name.lower() for name in enum_class.__members__) + + @property + def member_names(self): + """The accepted enum names, in lowercase if not case sensitive.""" + return self._member_names + + def parse(self, argument): + """Determines validity of argument and returns the correct element of enum. + + Args: + argument: str or Enum class member, the supplied flag value. + + Returns: + The first matching Enum class member in Enum class. + + Raises: + ValueError: Raised when argument didn't match anything in enum. + """ + if isinstance(argument, self.enum_class): + return argument + elif not isinstance(argument, str): + raise ValueError( + '{} is not an enum member or a name of a member in {}'.format( + argument, self.enum_class)) + key = EnumParser( + self._member_names, case_sensitive=self._case_sensitive).parse(argument) + if self._case_sensitive: + return self.enum_class[key] + else: + # If EnumParser.parse() return a value, we're guaranteed to find it + # as a member of the class + return next(value for name, value in self.enum_class.__members__.items() + if name.lower() == key.lower()) + + def flag_type(self): + """See base class.""" + return 'enum class' + + +class ListSerializer(ArgumentSerializer): + + def __init__(self, list_sep): + self.list_sep = list_sep + + def serialize(self, value): + """See base class.""" + return self.list_sep.join([_helpers.str_or_unicode(x) for x in value]) + + +class EnumClassListSerializer(ListSerializer): + """A serializer for MultiEnumClass flags. + + This serializer simply joins the output of `EnumClassSerializer` using a + provided separator. + """ + + def __init__(self, list_sep, **kwargs): + """Initializes EnumClassListSerializer. + + Args: + list_sep: String to be used as a separator when serializing + **kwargs: Keyword arguments to the `EnumClassSerializer` used to serialize + individual values. + """ + super(EnumClassListSerializer, self).__init__(list_sep) + self._element_serializer = EnumClassSerializer(**kwargs) + + def serialize(self, value): + """See base class.""" + if isinstance(value, list): + return self.list_sep.join( + self._element_serializer.serialize(x) for x in value) + else: + return self._element_serializer.serialize(value) + + +class CsvListSerializer(ArgumentSerializer): + + def __init__(self, list_sep): + self.list_sep = list_sep + + def serialize(self, value): + """Serializes a list as a CSV string or unicode.""" + output = io.StringIO() + writer = csv.writer(output, delimiter=self.list_sep) + writer.writerow([str(x) for x in value]) + serialized_value = output.getvalue().strip() + + # We need the returned value to be pure ascii or Unicodes so that + # when the xml help is generated they are usefully encodable. + return _helpers.str_or_unicode(serialized_value) + + +class EnumClassSerializer(ArgumentSerializer): + """Class for generating string representations of an enum class flag value.""" + + def __init__(self, lowercase): + """Initializes EnumClassSerializer. + + Args: + lowercase: If True, enum member names are lowercased during serialization. + """ + self._lowercase = lowercase + + def serialize(self, value): + """Returns a serialized string of the Enum class value.""" + as_string = _helpers.str_or_unicode(value.name) + return as_string.lower() if self._lowercase else as_string + + +class BaseListParser(ArgumentParser): + """Base class for a parser of lists of strings. + + To extend, inherit from this class; from the subclass __init__, call + + BaseListParser.__init__(self, token, name) + + where token is a character used to tokenize, and name is a description + of the separator. + """ + + def __init__(self, token=None, name=None): + assert name + super(BaseListParser, self).__init__() + self._token = token + self._name = name + self.syntactic_help = 'a %s separated list' % self._name + + def parse(self, argument): + """See base class.""" + if isinstance(argument, list): + return argument + elif not argument: + return [] + else: + return [s.strip() for s in argument.split(self._token)] + + def flag_type(self): + """See base class.""" + return '%s separated list of strings' % self._name + + +class ListParser(BaseListParser): + """Parser for a comma-separated list of strings.""" + + def __init__(self): + super(ListParser, self).__init__(',', 'comma') + + def parse(self, argument): + """Parses argument as comma-separated list of strings.""" + if isinstance(argument, list): + return argument + elif not argument: + return [] + else: + try: + return [s.strip() for s in list(csv.reader([argument], strict=True))[0]] + except csv.Error as e: + # Provide a helpful report for case like + # --listflag="$(printf 'hello,\nworld')" + # IOW, list flag values containing naked newlines. This error + # was previously "reported" by allowing csv.Error to + # propagate. + raise ValueError('Unable to parse the value %r as a %s: %s' + % (argument, self.flag_type(), e)) + + def _custom_xml_dom_elements(self, doc): + elements = super(ListParser, self)._custom_xml_dom_elements(doc) + elements.append(_helpers.create_xml_dom_element( + doc, 'list_separator', repr(','))) + return elements + + +class WhitespaceSeparatedListParser(BaseListParser): + """Parser for a whitespace-separated list of strings.""" + + def __init__(self, comma_compat=False): + """Initializer. + + Args: + comma_compat: bool, whether to support comma as an additional separator. + If False then only whitespace is supported. This is intended only for + backwards compatibility with flags that used to be comma-separated. + """ + self._comma_compat = comma_compat + name = 'whitespace or comma' if self._comma_compat else 'whitespace' + super(WhitespaceSeparatedListParser, self).__init__(None, name) + + def parse(self, argument): + """Parses argument as whitespace-separated list of strings. + + It also parses argument as comma-separated list of strings if requested. + + Args: + argument: string argument passed in the commandline. + + Returns: + [str], the parsed flag value. + """ + if isinstance(argument, list): + return argument + elif not argument: + return [] + else: + if self._comma_compat: + argument = argument.replace(',', ' ') + return argument.split() + + def _custom_xml_dom_elements(self, doc): + elements = super(WhitespaceSeparatedListParser, self + )._custom_xml_dom_elements(doc) + separators = list(string.whitespace) + if self._comma_compat: + separators.append(',') + separators.sort() + for sep_char in separators: + elements.append(_helpers.create_xml_dom_element( + doc, 'list_separator', repr(sep_char))) + return elements diff --git a/absl/flags/_argument_parser.pyi b/absl/flags/_argument_parser.pyi new file mode 100644 index 0000000..7e78d7d --- /dev/null +++ b/absl/flags/_argument_parser.pyi @@ -0,0 +1,127 @@ +# Copyright 2020 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains type annotations for _argument_parser.py.""" + + +from typing import Text, TypeVar, Generic, Iterable, Type, List, Optional, Sequence, Any + +import enum + +_T = TypeVar('_T') +_ET = TypeVar('_ET', bound=enum.Enum) + + +class ArgumentSerializer(Generic[_T]): + def serialize(self, value: _T) -> Text: ... + + +# The metaclass of ArgumentParser is not reflected here, because it does not +# affect the provided API. +class ArgumentParser(Generic[_T]): + + syntactic_help: Text + + def parse(self, argument: Text) -> Optional[_T]: ... + + def flag_type(self) -> Text: ... + + +# Using bound=numbers.Number results in an error: b/153268436 +_N = TypeVar('_N', int, float) + + +class NumericParser(ArgumentParser[_N]): + + def is_outside_bounds(self, val: _N) -> bool: ... + + def parse(self, argument: Text) -> _N: ... + + def convert(self, argument: Text) -> _N: ... + + +class FloatParser(NumericParser[float]): + + def __init__(self, lower_bound:Optional[float]=None, + upper_bound:Optional[float]=None) -> None: + ... + + +class IntegerParser(NumericParser[int]): + + def __init__(self, lower_bound:Optional[int]=None, + upper_bound:Optional[int]=None) -> None: + ... + + +class BooleanParser(ArgumentParser[bool]): + ... + + +class EnumParser(ArgumentParser[Text]): + def __init__(self, enum_values: Sequence[Text], case_sensitive: bool=...) -> None: + ... + + + +class EnumClassParser(ArgumentParser[_ET]): + + def __init__(self, enum_class: Type[_ET], case_sensitive: bool=...) -> None: + ... + + @property + def member_names(self) -> Sequence[Text]: ... + + +class BaseListParser(ArgumentParser[List[Text]]): + def __init__(self, token: Text, name:Text) -> None: ... + + # Unlike baseclass BaseListParser never returns None. + def parse(self, argument: Text) -> List[Text]: ... + + + +class ListParser(BaseListParser): + def __init__(self) -> None: + ... + + + +class WhitespaceSeparatedListParser(BaseListParser): + def __init__(self, comma_compat: bool=False) -> None: + ... + + + +class ListSerializer(ArgumentSerializer[List[Text]]): + list_sep = ... # type: Text + + def __init__(self, list_sep: Text) -> None: + ... + + +class EnumClassListSerializer(ArgumentSerializer[List[Text]]): + def __init__(self, list_sep: Text, **kwargs: Any) -> None: + ... + + +class CsvListSerializer(ArgumentSerializer[List[Any]]): + + def __init__(self, list_sep: Text) -> None: + ... + + +class EnumClassSerializer(ArgumentSerializer[_ET]): + def __init__(self, lowercase: bool) -> None: + ... diff --git a/absl/flags/_defines.py b/absl/flags/_defines.py new file mode 100644 index 0000000..4494c3b --- /dev/null +++ b/absl/flags/_defines.py @@ -0,0 +1,912 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This modules contains flags DEFINE functions. + +Do NOT import this module directly. Import the flags package and use the +aliases defined at the package level instead. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys +import types + +from absl.flags import _argument_parser +from absl.flags import _exceptions +from absl.flags import _flag +from absl.flags import _flagvalues +from absl.flags import _helpers +from absl.flags import _validators + +# pylint: disable=unused-import +try: + from typing import Text, List, Any +except ImportError: + pass + +try: + import enum +except ImportError: + pass +# pylint: enable=unused-import + +_helpers.disclaim_module_ids.add(id(sys.modules[__name__])) + + +def _register_bounds_validator_if_needed(parser, name, flag_values): + """Enforces lower and upper bounds for numeric flags. + + Args: + parser: NumericParser (either FloatParser or IntegerParser), provides lower + and upper bounds, and help text to display. + name: str, name of the flag + flag_values: FlagValues. + """ + if parser.lower_bound is not None or parser.upper_bound is not None: + + def checker(value): + if value is not None and parser.is_outside_bounds(value): + message = '%s is not %s' % (value, parser.syntactic_help) + raise _exceptions.ValidationError(message) + return True + + _validators.register_validator(name, checker, flag_values=flag_values) + + +def DEFINE( # pylint: disable=invalid-name + parser, + name, + default, + help, # pylint: disable=redefined-builtin + flag_values=_flagvalues.FLAGS, + serializer=None, + module_name=None, + required=False, + **args): + """Registers a generic Flag object. + + NOTE: in the docstrings of all DEFINE* functions, "registers" is short + for "creates a new flag and registers it". + + Auxiliary function: clients should use the specialized DEFINE_<type> + function instead. + + Args: + parser: ArgumentParser, used to parse the flag arguments. + name: str, the flag name. + default: The default value of the flag. + help: str, the help message. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + serializer: ArgumentSerializer, the flag serializer instance. + module_name: str, the name of the Python module declaring this flag. If not + provided, it will be computed using the stack trace of this call. + required: bool, is this a required flag. This must be used as a keyword + argument. + **args: dict, the extra keyword args that are passed to Flag __init__. + + Returns: + a handle to defined flag. + """ + return DEFINE_flag( + _flag.Flag(parser, serializer, name, default, help, **args), flag_values, + module_name, required) + + +def DEFINE_flag( # pylint: disable=invalid-name + flag, + flag_values=_flagvalues.FLAGS, + module_name=None, + required=False): + """Registers a 'Flag' object with a 'FlagValues' object. + + By default, the global FLAGS 'FlagValue' object is used. + + Typical users will use one of the more specialized DEFINE_xxx + functions, such as DEFINE_string or DEFINE_integer. But developers + who need to create Flag objects themselves should use this function + to register their flags. + + Args: + flag: Flag, a flag that is key to the module. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + module_name: str, the name of the Python module declaring this flag. If not + provided, it will be computed using the stack trace of this call. + required: bool, is this a required flag. This must be used as a keyword + argument. + + Returns: + a handle to defined flag. + """ + if required and flag.default is not None: + raise ValueError('Required flag --%s cannot have a non-None default' % + flag.name) + # Copying the reference to flag_values prevents pychecker warnings. + fv = flag_values + fv[flag.name] = flag + # Tell flag_values who's defining the flag. + if module_name: + module = sys.modules.get(module_name) + else: + module, module_name = _helpers.get_calling_module_object_and_name() + flag_values.register_flag_by_module(module_name, flag) + flag_values.register_flag_by_module_id(id(module), flag) + if required: + _validators.mark_flag_as_required(flag.name, fv) + ensure_non_none_value = (flag.default is not None) or required + return _flagvalues.FlagHolder( + fv, flag, ensure_non_none_value=ensure_non_none_value) + + +def _internal_declare_key_flags(flag_names, + flag_values=_flagvalues.FLAGS, + key_flag_values=None): + """Declares a flag as key for the calling module. + + Internal function. User code should call declare_key_flag or + adopt_module_key_flags instead. + + Args: + flag_names: [str], a list of strings that are names of already-registered + Flag objects. + flag_values: FlagValues, the FlagValues instance with which the flags listed + in flag_names have registered (the value of the flag_values argument from + the DEFINE_* calls that defined those flags). This should almost never + need to be overridden. + key_flag_values: FlagValues, the FlagValues instance that (among possibly + many other things) keeps track of the key flags for each module. Default + None means "same as flag_values". This should almost never need to be + overridden. + + Raises: + UnrecognizedFlagError: Raised when the flag is not defined. + """ + key_flag_values = key_flag_values or flag_values + + module = _helpers.get_calling_module() + + for flag_name in flag_names: + flag = flag_values[flag_name] + key_flag_values.register_key_flag_for_module(module, flag) + + +def declare_key_flag(flag_name, flag_values=_flagvalues.FLAGS): + """Declares one flag as key to the current module. + + Key flags are flags that are deemed really important for a module. + They are important when listing help messages; e.g., if the + --helpshort command-line flag is used, then only the key flags of the + main module are listed (instead of all flags, as in the case of + --helpfull). + + Sample usage: + + flags.declare_key_flag('flag_1') + + Args: + flag_name: str, the name of an already declared flag. (Redeclaring flags as + key, including flags implicitly key because they were declared in this + module, is a no-op.) + flag_values: FlagValues, the FlagValues instance in which the flag will be + declared as a key flag. This should almost never need to be overridden. + + Raises: + ValueError: Raised if flag_name not defined as a Python flag. + """ + if flag_name in _helpers.SPECIAL_FLAGS: + # Take care of the special flags, e.g., --flagfile, --undefok. + # These flags are defined in SPECIAL_FLAGS, and are treated + # specially during flag parsing, taking precedence over the + # user-defined flags. + _internal_declare_key_flags([flag_name], + flag_values=_helpers.SPECIAL_FLAGS, + key_flag_values=flag_values) + return + try: + _internal_declare_key_flags([flag_name], flag_values=flag_values) + except KeyError: + raise ValueError('Flag --%s is undefined. To set a flag as a key flag ' + 'first define it in Python.' % flag_name) + + +def adopt_module_key_flags(module, flag_values=_flagvalues.FLAGS): + """Declares that all flags key to a module are key to the current module. + + Args: + module: module, the module object from which all key flags will be declared + as key flags to the current module. + flag_values: FlagValues, the FlagValues instance in which the flags will be + declared as key flags. This should almost never need to be overridden. + + Raises: + Error: Raised when given an argument that is a module name (a string), + instead of a module object. + """ + if not isinstance(module, types.ModuleType): + raise _exceptions.Error('Expected a module object, not %r.' % (module,)) + _internal_declare_key_flags( + [f.name for f in flag_values.get_key_flags_for_module(module.__name__)], + flag_values=flag_values) + # If module is this flag module, take _helpers.SPECIAL_FLAGS into account. + if module == _helpers.FLAGS_MODULE: + _internal_declare_key_flags( + # As we associate flags with get_calling_module_object_and_name(), the + # special flags defined in this module are incorrectly registered with + # a different module. So, we can't use get_key_flags_for_module. + # Instead, we take all flags from _helpers.SPECIAL_FLAGS (a private + # FlagValues, where no other module should register flags). + [_helpers.SPECIAL_FLAGS[name].name for name in _helpers.SPECIAL_FLAGS], + flag_values=_helpers.SPECIAL_FLAGS, + key_flag_values=flag_values) + + +def disclaim_key_flags(): + """Declares that the current module will not define any more key flags. + + Normally, the module that calls the DEFINE_xxx functions claims the + flag to be its key flag. This is undesirable for modules that + define additional DEFINE_yyy functions with its own flag parsers and + serializers, since that module will accidentally claim flags defined + by DEFINE_yyy as its key flags. After calling this function, the + module disclaims flag definitions thereafter, so the key flags will + be correctly attributed to the caller of DEFINE_yyy. + + After calling this function, the module will not be able to define + any more flags. This function will affect all FlagValues objects. + """ + globals_for_caller = sys._getframe(1).f_globals # pylint: disable=protected-access + module, _ = _helpers.get_module_object_and_name(globals_for_caller) + _helpers.disclaim_module_ids.add(id(module)) + + +def DEFINE_string( # pylint: disable=invalid-name,redefined-builtin + name, + default, + help, + flag_values=_flagvalues.FLAGS, + required=False, + **args): + """Registers a flag whose value can be any string.""" + parser = _argument_parser.ArgumentParser() + serializer = _argument_parser.ArgumentSerializer() + return DEFINE( + parser, + name, + default, + help, + flag_values, + serializer, + required=required, + **args) + + +def DEFINE_boolean( # pylint: disable=invalid-name,redefined-builtin + name, + default, + help, + flag_values=_flagvalues.FLAGS, + module_name=None, + required=False, + **args): + """Registers a boolean flag. + + Such a boolean flag does not take an argument. If a user wants to + specify a false value explicitly, the long option beginning with 'no' + must be used: i.e. --noflag + + This flag will have a value of None, True or False. None is possible + if default=None and the user does not specify the flag on the command + line. + + Args: + name: str, the flag name. + default: bool|str|None, the default value of the flag. + help: str, the help message. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + module_name: str, the name of the Python module declaring this flag. If not + provided, it will be computed using the stack trace of this call. + required: bool, is this a required flag. This must be used as a keyword + argument. + **args: dict, the extra keyword args that are passed to Flag __init__. + + Returns: + a handle to defined flag. + """ + return DEFINE_flag( + _flag.BooleanFlag(name, default, help, **args), flag_values, module_name, + required) + + +def DEFINE_float( # pylint: disable=invalid-name,redefined-builtin + name, + default, + help, + lower_bound=None, + upper_bound=None, + flag_values=_flagvalues.FLAGS, + required=False, + **args): + """Registers a flag whose value must be a float. + + If lower_bound or upper_bound are set, then this flag must be + within the given range. + + Args: + name: str, the flag name. + default: float|str|None, the default value of the flag. + help: str, the help message. + lower_bound: float, min value of the flag. + upper_bound: float, max value of the flag. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + required: bool, is this a required flag. This must be used as a keyword + argument. + **args: dict, the extra keyword args that are passed to DEFINE. + + Returns: + a handle to defined flag. + """ + parser = _argument_parser.FloatParser(lower_bound, upper_bound) + serializer = _argument_parser.ArgumentSerializer() + result = DEFINE( + parser, + name, + default, + help, + flag_values, + serializer, + required=required, + **args) + _register_bounds_validator_if_needed(parser, name, flag_values=flag_values) + return result + + +def DEFINE_integer( # pylint: disable=invalid-name,redefined-builtin + name, + default, + help, + lower_bound=None, + upper_bound=None, + flag_values=_flagvalues.FLAGS, + required=False, + **args): + """Registers a flag whose value must be an integer. + + If lower_bound, or upper_bound are set, then this flag must be + within the given range. + + Args: + name: str, the flag name. + default: int|str|None, the default value of the flag. + help: str, the help message. + lower_bound: int, min value of the flag. + upper_bound: int, max value of the flag. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + required: bool, is this a required flag. This must be used as a keyword + argument. + **args: dict, the extra keyword args that are passed to DEFINE. + + Returns: + a handle to defined flag. + """ + parser = _argument_parser.IntegerParser(lower_bound, upper_bound) + serializer = _argument_parser.ArgumentSerializer() + result = DEFINE( + parser, + name, + default, + help, + flag_values, + serializer, + required=required, + **args) + _register_bounds_validator_if_needed(parser, name, flag_values=flag_values) + return result + + +def DEFINE_enum( # pylint: disable=invalid-name,redefined-builtin + name, + default, + enum_values, + help, + flag_values=_flagvalues.FLAGS, + module_name=None, + required=False, + **args): + """Registers a flag whose value can be any string from enum_values. + + Instead of a string enum, prefer `DEFINE_enum_class`, which allows + defining enums from an `enum.Enum` class. + + Args: + name: str, the flag name. + default: str|None, the default value of the flag. + enum_values: [str], a non-empty list of strings with the possible values for + the flag. + help: str, the help message. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + module_name: str, the name of the Python module declaring this flag. If not + provided, it will be computed using the stack trace of this call. + required: bool, is this a required flag. This must be used as a keyword + argument. + **args: dict, the extra keyword args that are passed to Flag __init__. + + Returns: + a handle to defined flag. + """ + return DEFINE_flag( + _flag.EnumFlag(name, default, help, enum_values, **args), flag_values, + module_name, required) + + +def DEFINE_enum_class( # pylint: disable=invalid-name,redefined-builtin + name, + default, + enum_class, + help, + flag_values=_flagvalues.FLAGS, + module_name=None, + case_sensitive=False, + required=False, + **args): + """Registers a flag whose value can be the name of enum members. + + Args: + name: str, the flag name. + default: Enum|str|None, the default value of the flag. + enum_class: class, the Enum class with all the possible values for the flag. + help: str, the help message. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + module_name: str, the name of the Python module declaring this flag. If not + provided, it will be computed using the stack trace of this call. + case_sensitive: bool, whether to map strings to members of the enum_class + without considering case. + required: bool, is this a required flag. This must be used as a keyword + argument. + **args: dict, the extra keyword args that are passed to Flag __init__. + + Returns: + a handle to defined flag. + """ + return DEFINE_flag( + _flag.EnumClassFlag( + name, + default, + help, + enum_class, + case_sensitive=case_sensitive, + **args), flag_values, module_name, required) + + +def DEFINE_list( # pylint: disable=invalid-name,redefined-builtin + name, + default, + help, + flag_values=_flagvalues.FLAGS, + required=False, + **args): + """Registers a flag whose value is a comma-separated list of strings. + + The flag value is parsed with a CSV parser. + + Args: + name: str, the flag name. + default: list|str|None, the default value of the flag. + help: str, the help message. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + required: bool, is this a required flag. This must be used as a keyword + argument. + **args: Dictionary with extra keyword args that are passed to the Flag + __init__. + + Returns: + a handle to defined flag. + """ + parser = _argument_parser.ListParser() + serializer = _argument_parser.CsvListSerializer(',') + return DEFINE( + parser, + name, + default, + help, + flag_values, + serializer, + required=required, + **args) + + +def DEFINE_spaceseplist( # pylint: disable=invalid-name,redefined-builtin + name, + default, + help, + comma_compat=False, + flag_values=_flagvalues.FLAGS, + required=False, + **args): + """Registers a flag whose value is a whitespace-separated list of strings. + + Any whitespace can be used as a separator. + + Args: + name: str, the flag name. + default: list|str|None, the default value of the flag. + help: str, the help message. + comma_compat: bool - Whether to support comma as an additional separator. If + false then only whitespace is supported. This is intended only for + backwards compatibility with flags that used to be comma-separated. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + required: bool, is this a required flag. This must be used as a keyword + argument. + **args: Dictionary with extra keyword args that are passed to the Flag + __init__. + + Returns: + a handle to defined flag. + """ + parser = _argument_parser.WhitespaceSeparatedListParser( + comma_compat=comma_compat) + serializer = _argument_parser.ListSerializer(' ') + return DEFINE( + parser, + name, + default, + help, + flag_values, + serializer, + required=required, + **args) + + +def DEFINE_multi( # pylint: disable=invalid-name,redefined-builtin + parser, + serializer, + name, + default, + help, + flag_values=_flagvalues.FLAGS, + module_name=None, + required=False, + **args): + """Registers a generic MultiFlag that parses its args with a given parser. + + Auxiliary function. Normal users should NOT use it directly. + + Developers who need to create their own 'Parser' classes for options + which can appear multiple times can call this module function to + register their flags. + + Args: + parser: ArgumentParser, used to parse the flag arguments. + serializer: ArgumentSerializer, the flag serializer instance. + name: str, the flag name. + default: Union[Iterable[T], Text, None], the default value of the flag. If + the value is text, it will be parsed as if it was provided from the + command line. If the value is a non-string iterable, it will be iterated + over to create a shallow copy of the values. If it is None, it is left + as-is. + help: str, the help message. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + module_name: A string, the name of the Python module declaring this flag. If + not provided, it will be computed using the stack trace of this call. + required: bool, is this a required flag. This must be used as a keyword + argument. + **args: Dictionary with extra keyword args that are passed to the Flag + __init__. + + Returns: + a handle to defined flag. + """ + return DEFINE_flag( + _flag.MultiFlag(parser, serializer, name, default, help, **args), + flag_values, module_name, required) + + +def DEFINE_multi_string( # pylint: disable=invalid-name,redefined-builtin + name, + default, + help, + flag_values=_flagvalues.FLAGS, + required=False, + **args): + """Registers a flag whose value can be a list of any strings. + + Use the flag on the command line multiple times to place multiple + string values into the list. The 'default' may be a single string + (which will be converted into a single-element list) or a list of + strings. + + + Args: + name: str, the flag name. + default: Union[Iterable[Text], Text, None], the default value of the flag; + see `DEFINE_multi`. + help: str, the help message. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + required: bool, is this a required flag. This must be used as a keyword + argument. + **args: Dictionary with extra keyword args that are passed to the Flag + __init__. + + Returns: + a handle to defined flag. + """ + parser = _argument_parser.ArgumentParser() + serializer = _argument_parser.ArgumentSerializer() + return DEFINE_multi( + parser, + serializer, + name, + default, + help, + flag_values, + required=required, + **args) + + +def DEFINE_multi_integer( # pylint: disable=invalid-name,redefined-builtin + name, + default, + help, + lower_bound=None, + upper_bound=None, + flag_values=_flagvalues.FLAGS, + required=False, + **args): + """Registers a flag whose value can be a list of arbitrary integers. + + Use the flag on the command line multiple times to place multiple + integer values into the list. The 'default' may be a single integer + (which will be converted into a single-element list) or a list of + integers. + + Args: + name: str, the flag name. + default: Union[Iterable[int], Text, None], the default value of the flag; + see `DEFINE_multi`. + help: str, the help message. + lower_bound: int, min values of the flag. + upper_bound: int, max values of the flag. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + required: bool, is this a required flag. This must be used as a keyword + argument. + **args: Dictionary with extra keyword args that are passed to the Flag + __init__. + + Returns: + a handle to defined flag. + """ + parser = _argument_parser.IntegerParser(lower_bound, upper_bound) + serializer = _argument_parser.ArgumentSerializer() + return DEFINE_multi( + parser, + serializer, + name, + default, + help, + flag_values, + required=required, + **args) + + +def DEFINE_multi_float( # pylint: disable=invalid-name,redefined-builtin + name, + default, + help, + lower_bound=None, + upper_bound=None, + flag_values=_flagvalues.FLAGS, + required=False, + **args): + """Registers a flag whose value can be a list of arbitrary floats. + + Use the flag on the command line multiple times to place multiple + float values into the list. The 'default' may be a single float + (which will be converted into a single-element list) or a list of + floats. + + Args: + name: str, the flag name. + default: Union[Iterable[float], Text, None], the default value of the flag; + see `DEFINE_multi`. + help: str, the help message. + lower_bound: float, min values of the flag. + upper_bound: float, max values of the flag. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + required: bool, is this a required flag. This must be used as a keyword + argument. + **args: Dictionary with extra keyword args that are passed to the Flag + __init__. + + Returns: + a handle to defined flag. + """ + parser = _argument_parser.FloatParser(lower_bound, upper_bound) + serializer = _argument_parser.ArgumentSerializer() + return DEFINE_multi( + parser, + serializer, + name, + default, + help, + flag_values, + required=required, + **args) + + +def DEFINE_multi_enum( # pylint: disable=invalid-name,redefined-builtin + name, + default, + enum_values, + help, + flag_values=_flagvalues.FLAGS, + case_sensitive=True, + required=False, + **args): + """Registers a flag whose value can be a list strings from enum_values. + + Use the flag on the command line multiple times to place multiple + enum values into the list. The 'default' may be a single string + (which will be converted into a single-element list) or a list of + strings. + + Args: + name: str, the flag name. + default: Union[Iterable[Text], Text, None], the default value of the flag; + see `DEFINE_multi`. + enum_values: [str], a non-empty list of strings with the possible values for + the flag. + help: str, the help message. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + case_sensitive: Whether or not the enum is to be case-sensitive. + required: bool, is this a required flag. This must be used as a keyword + argument. + **args: Dictionary with extra keyword args that are passed to the Flag + __init__. + + Returns: + a handle to defined flag. + """ + parser = _argument_parser.EnumParser(enum_values, case_sensitive) + serializer = _argument_parser.ArgumentSerializer() + return DEFINE_multi( + parser, + serializer, + name, + default, + '<%s>: %s' % ('|'.join(enum_values), help), + flag_values, + required=required, + **args) + + +def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin + name, + default, + enum_class, + help, + flag_values=_flagvalues.FLAGS, + module_name=None, + case_sensitive=False, + required=False, + **args): + """Registers a flag whose value can be a list of enum members. + + Use the flag on the command line multiple times to place multiple + enum values into the list. + + Args: + name: str, the flag name. + default: Union[Iterable[Enum], Iterable[Text], Enum, Text, None], the + default value of the flag; see `DEFINE_multi`; only differences are + documented here. If the value is a single Enum, it is treated as a + single-item list of that Enum value. If it is an iterable, text values + within the iterable will be converted to the equivalent Enum objects. + enum_class: class, the Enum class with all the possible values for the flag. + help: str, the help message. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + module_name: A string, the name of the Python module declaring this flag. If + not provided, it will be computed using the stack trace of this call. + case_sensitive: bool, whether to map strings to members of the enum_class + without considering case. + required: bool, is this a required flag. This must be used as a keyword + argument. + **args: Dictionary with extra keyword args that are passed to the Flag + __init__. + + Returns: + a handle to defined flag. + """ + return DEFINE_flag( + _flag.MultiEnumClassFlag( + name, default, help, enum_class, case_sensitive=case_sensitive), + flag_values, + module_name, + required=required, + **args) + + +def DEFINE_alias( # pylint: disable=invalid-name + name, + original_name, + flag_values=_flagvalues.FLAGS, + module_name=None): + """Defines an alias flag for an existing one. + + Args: + name: str, the flag name. + original_name: str, the original flag name. + flag_values: FlagValues, the FlagValues instance with which the flag will be + registered. This should almost never need to be overridden. + module_name: A string, the name of the module that defines this flag. + + Returns: + a handle to defined flag. + + Raises: + flags.FlagError: + UnrecognizedFlagError: if the referenced flag doesn't exist. + DuplicateFlagError: if the alias name has been used by some existing flag. + """ + if original_name not in flag_values: + raise _exceptions.UnrecognizedFlagError(original_name) + flag = flag_values[original_name] + + class _FlagAlias(_flag.Flag): + """Overrides Flag class so alias value is copy of original flag value.""" + + def parse(self, argument): + flag.parse(argument) + self.present += 1 + + def _parse_from_default(self, value): + # The value was already parsed by the aliased flag, so there is no + # need to call the parser on it a second time. + # Additionally, because of how MultiFlag parses and merges values, + # it isn't possible to delegate to the aliased flag and still get + # the correct values. + return value + + @property + def value(self): + return flag.value + + @value.setter + def value(self, value): + flag.value = value + + help_msg = 'Alias for --%s.' % flag.name + # If alias_name has been used, flags.DuplicatedFlag will be raised. + return DEFINE_flag( + _FlagAlias( + flag.parser, + flag.serializer, + name, + flag.default, + help_msg, + boolean=flag.boolean), flag_values, module_name) diff --git a/absl/flags/_defines.pyi b/absl/flags/_defines.pyi new file mode 100644 index 0000000..1f482a2 --- /dev/null +++ b/absl/flags/_defines.pyi @@ -0,0 +1,637 @@ +# Copyright 2020 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This modules contains type annotated stubs for DEFINE functions.""" + + +from absl.flags import _argument_parser +from absl.flags import _flag +from absl.flags import _flagvalues + +import enum + +from typing import Text, List, Any, TypeVar, Optional, Union, Type, Iterable, overload, Literal + +_T = TypeVar('_T') +_ET = TypeVar('_ET', bound=enum.Enum) + + +@overload +def DEFINE( + parser: _argument_parser.ArgumentParser[_T], + name: Text, + default: Any, + help: Optional[Text], + flag_values : _flagvalues.FlagValues = ..., + serializer: Optional[_argument_parser.ArgumentSerializer[_T]] = ..., + module_name: Optional[Text] = ..., + required: Literal[True] = ..., + **args: Any) -> _flagvalues.FlagHolder[_T]: + ... + + +@overload +def DEFINE( + parser: _argument_parser.ArgumentParser[_T], + name: Text, + default: Any, + help: Optional[Text], + flag_values : _flagvalues.FlagValues = ..., + serializer: Optional[_argument_parser.ArgumentSerializer[_T]] = ..., + module_name: Optional[Text] = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[Optional[_T]]: + ... + + +@overload +def DEFINE_flag( + flag: _flag.Flag[_T], + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: Literal[True] = ... +) -> _flagvalues.FlagHolder[_T]: + ... + +@overload +def DEFINE_flag( + flag: _flag.Flag[_T], + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: bool = ...) -> _flagvalues.FlagHolder[Optional[_T]]: + ... + +# typing overloads for DEFINE_* methods... +# +# - DEFINE_* method return FlagHolder[Optional[T]] or FlagHolder[T] depending +# on the arguments. +# - If the flag value is guaranteed to be not None, the return type is +# FlagHolder[T]. +# - If the flag is required OR has a non-None default, the flag value i +# guaranteed to be not None after flag parsing has finished. +# The information above is captured with three overloads as follows. +# +# (if required=True and passed in as a keyword argument, +# return type is FlagHolder[Y]) +# @overload +# def DEFINE_xxx( +# ... arguments... +# default: Union[None, X] = ..., +# *, +# required: Literal[True]) -> _flagvalues.FlagHolder[Y]: +# ... +# +# (if default=None, return type is FlagHolder[Optional[Y]]) +# @overload +# def DEFINE_xxx( +# ... arguments... +# default: None, +# required: bool = ...) -> _flagvalues.FlagHolder[Optional[Y]]: +# ... +# +# (if default!=None, return type is FlagHolder[Y]): +# @overload +# def DEFINE_xxx( +# ... arguments... +# default: X, +# required: bool = ...) -> _flagvalues.FlagHolder[Y]: +# ... +# +# where X = type of non-None default values for the flag +# and Y = non-None type for flag value + +@overload +def DEFINE_string( + name: Text, + default: Optional[Text], + help: Optional[Text], + flag_values: _flagvalues.FlagValues = ..., + *, + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[Text]: + ... + +@overload +def DEFINE_string( + name: Text, + default: None, + help: Optional[Text], + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[Optional[Text]]: + ... + +@overload +def DEFINE_string( + name: Text, + default: Text, + help: Optional[Text], + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[Text]: + ... + +@overload +def DEFINE_boolean( + name : Text, + default: Union[None, Text, bool, int], + help: Optional[Text], + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + *, + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[bool]: + ... + +@overload +def DEFINE_boolean( + name : Text, + default: None, + help: Optional[Text], + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[Optional[bool]]: + ... + +@overload +def DEFINE_boolean( + name : Text, + default: Union[Text, bool, int], + help: Optional[Text], + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[bool]: + ... + +@overload +def DEFINE_float( + name: Text, + default: Union[None, float, Text], + help: Optional[Text], + lower_bound: Optional[float] = ..., + upper_bound: Optional[float] = ..., + flag_values: _flagvalues.FlagValues = ..., + *, + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[float]: + ... + +@overload +def DEFINE_float( + name: Text, + default: None, + help: Optional[Text], + lower_bound: Optional[float] = ..., + upper_bound: Optional[float] = ..., + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[Optional[float]]: + ... + +@overload +def DEFINE_float( + name: Text, + default: Union[float, Text], + help: Optional[Text], + lower_bound: Optional[float] = ..., + upper_bound: Optional[float] = ..., + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[float]: + ... + + +@overload +def DEFINE_integer( + name: Text, + default: Union[None, int, Text], + help: Optional[Text], + lower_bound: Optional[int] = ..., + upper_bound: Optional[int] = ..., + flag_values: _flagvalues.FlagValues = ..., + *, + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[int]: + ... + +@overload +def DEFINE_integer( + name: Text, + default: None, + help: Optional[Text], + lower_bound: Optional[int] = ..., + upper_bound: Optional[int] = ..., + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[Optional[int]]: + ... + +@overload +def DEFINE_integer( + name: Text, + default: Union[int, Text], + help: Optional[Text], + lower_bound: Optional[int] = ..., + upper_bound: Optional[int] = ..., + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[int]: + ... + +@overload +def DEFINE_enum( + name : Text, + default: Optional[Text], + enum_values: Iterable[Text], + help: Optional[Text], + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + *, + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[Text]: + ... + +@overload +def DEFINE_enum( + name : Text, + default: None, + enum_values: Iterable[Text], + help: Optional[Text], + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[Optional[Text]]: + ... + +@overload +def DEFINE_enum( + name : Text, + default: Text, + enum_values: Iterable[Text], + help: Optional[Text], + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[Text]: + ... + +@overload +def DEFINE_enum_class( + name: Text, + default: Union[None, _ET, Text], + enum_class: Type[_ET], + help: Optional[Text], + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + case_sensitive: bool = ..., + *, + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[_ET]: + ... + +@overload +def DEFINE_enum_class( + name: Text, + default: None, + enum_class: Type[_ET], + help: Optional[Text], + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + case_sensitive: bool = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[Optional[_ET]]: + ... + +@overload +def DEFINE_enum_class( + name: Text, + default: Union[_ET, Text], + enum_class: Type[_ET], + help: Optional[Text], + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + case_sensitive: bool = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[_ET]: + ... + + +@overload +def DEFINE_list( + name: Text, + default: Union[None, Iterable[Text], Text], + help: Text, + flag_values: _flagvalues.FlagValues = ..., + *, + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[List[Text]]: + ... + +@overload +def DEFINE_list( + name: Text, + default: None, + help: Text, + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[Optional[List[Text]]]: + ... + +@overload +def DEFINE_list( + name: Text, + default: Union[Iterable[Text], Text], + help: Text, + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[List[Text]]: + ... + +@overload +def DEFINE_spaceseplist( + name: Text, + default: Union[None, Iterable[Text], Text], + help: Text, + comma_compat: bool = ..., + flag_values: _flagvalues.FlagValues = ..., + *, + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[List[Text]]: + ... + +@overload +def DEFINE_spaceseplist( + name: Text, + default: None, + help: Text, + comma_compat: bool = ..., + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[Optional[List[Text]]]: + ... + +@overload +def DEFINE_spaceseplist( + name: Text, + default: Union[Iterable[Text], Text], + help: Text, + comma_compat: bool = ..., + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[List[Text]]: + ... + +@overload +def DEFINE_multi( + parser : _argument_parser.ArgumentParser[_T], + serializer: _argument_parser.ArgumentSerializer[_T], + name: Text, + default: Union[None, Iterable[_T], _T, Text], + help: Text, + flag_values:_flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + *, + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[List[_T]]: + ... + +@overload +def DEFINE_multi( + parser : _argument_parser.ArgumentParser[_T], + serializer: _argument_parser.ArgumentSerializer[_T], + name: Text, + default: None, + help: Text, + flag_values:_flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[Optional[List[_T]]]: + ... + +@overload +def DEFINE_multi( + parser : _argument_parser.ArgumentParser[_T], + serializer: _argument_parser.ArgumentSerializer[_T], + name: Text, + default: Union[Iterable[_T], _T, Text], + help: Text, + flag_values:_flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[List[_T]]: + ... + +@overload +def DEFINE_multi_string( + name: Text, + default: Union[None, Iterable[Text], Text], + help: Text, + flag_values: _flagvalues.FlagValues = ..., + *, + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[List[Text]]: + ... + +@overload +def DEFINE_multi_string( + name: Text, + default: None, + help: Text, + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[Optional[List[Text]]]: + ... + +@overload +def DEFINE_multi_string( + name: Text, + default: Union[Iterable[Text], Text], + help: Text, + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[List[Text]]: + ... + +@overload +def DEFINE_multi_integer( + name: Text, + default: Union[None, Iterable[int], int, Text], + help: Text, + lower_bound: Optional[int] = ..., + upper_bound: Optional[int] = ..., + flag_values: _flagvalues.FlagValues = ..., + *, + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[List[int]]: + ... + +@overload +def DEFINE_multi_integer( + name: Text, + default: None, + help: Text, + lower_bound: Optional[int] = ..., + upper_bound: Optional[int] = ..., + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[Optional[List[int]]]: + ... + +@overload +def DEFINE_multi_integer( + name: Text, + default: Union[Iterable[int], int, Text], + help: Text, + lower_bound: Optional[int] = ..., + upper_bound: Optional[int] = ..., + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[List[int]]: + ... + +@overload +def DEFINE_multi_float( + name: Text, + default: Union[None, Iterable[float], float, Text], + help: Text, + lower_bound: Optional[float] = ..., + upper_bound: Optional[float] = ..., + flag_values: _flagvalues.FlagValues = ..., + *, + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[List[float]]: + ... + +@overload +def DEFINE_multi_float( + name: Text, + default: None, + help: Text, + lower_bound: Optional[float] = ..., + upper_bound: Optional[float] = ..., + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[Optional[List[float]]]: + ... + +@overload +def DEFINE_multi_float( + name: Text, + default: Union[Iterable[float], float, Text], + help: Text, + lower_bound: Optional[float] = ..., + upper_bound: Optional[float] = ..., + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[List[float]]: + ... + + +@overload +def DEFINE_multi_enum( + name: Text, + default: Union[None, Iterable[Text], Text], + enum_values: Iterable[Text], + help: Text, + flag_values: _flagvalues.FlagValues = ..., + *, + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[List[Text]]: + ... + +@overload +def DEFINE_multi_enum( + name: Text, + default: None, + enum_values: Iterable[Text], + help: Text, + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[Optional[List[Text]]]: + ... + +@overload +def DEFINE_multi_enum( + name: Text, + default: Union[Iterable[Text], Text], + enum_values: Iterable[Text], + help: Text, + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[List[Text]]: + ... + +@overload +def DEFINE_multi_enum_class( + name: Text, + default: Union[None, Iterable[_ET], _ET, Text], + enum_class: Type[_ET], + help: Text, + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + *, + required: Literal[True], + **args: Any) -> _flagvalues.FlagHolder[List[_ET]]: + ... + +@overload +def DEFINE_multi_enum_class( + name: Text, + default: None, + enum_class: Type[_ET], + help: Text, + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[Optional[List[_ET]]]: + ... + +@overload +def DEFINE_multi_enum_class( + name: Text, + default: Union[Iterable[_ET], _ET, Text], + enum_class: Type[_ET], + help: Text, + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: bool = ..., + **args: Any) -> _flagvalues.FlagHolder[List[_ET]]: + ... + + + +def DEFINE_alias( + name: Text, + original_name: Text, + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ...) -> _flagvalues.FlagHolder[Any]: + ... + + + +def declare_key_flag(flag_name: Text, + flag_values: _flagvalues.FlagValues = ...) -> None: + ... + + + +def adopt_module_key_flags(module: Any, + flag_values: _flagvalues.FlagValues = ...) -> None: + ... + + + +def disclaim_key_flags() -> None: + ... diff --git a/absl/flags/_exceptions.py b/absl/flags/_exceptions.py new file mode 100644 index 0000000..254eb9b --- /dev/null +++ b/absl/flags/_exceptions.py @@ -0,0 +1,112 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Exception classes in ABSL flags library. + +Do NOT import this module directly. Import the flags package and use the +aliases defined at the package level instead. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from absl.flags import _helpers + + +_helpers.disclaim_module_ids.add(id(sys.modules[__name__])) + + +class Error(Exception): + """The base class for all flags errors.""" + + +class CantOpenFlagFileError(Error): + """Raised when flagfile fails to open. + + E.g. the file doesn't exist, or has wrong permissions. + """ + + +class DuplicateFlagError(Error): + """Raised if there is a flag naming conflict.""" + + @classmethod + def from_flag(cls, flagname, flag_values, other_flag_values=None): + """Creates a DuplicateFlagError by providing flag name and values. + + Args: + flagname: str, the name of the flag being redefined. + flag_values: FlagValues, the FlagValues instance containing the first + definition of flagname. + other_flag_values: FlagValues, if it is not None, it should be the + FlagValues object where the second definition of flagname occurs. + If it is None, we assume that we're being called when attempting + to create the flag a second time, and we use the module calling + this one as the source of the second definition. + + Returns: + An instance of DuplicateFlagError. + """ + first_module = flag_values.find_module_defining_flag( + flagname, default='<unknown>') + if other_flag_values is None: + second_module = _helpers.get_calling_module() + else: + second_module = other_flag_values.find_module_defining_flag( + flagname, default='<unknown>') + flag_summary = flag_values[flagname].help + msg = ("The flag '%s' is defined twice. First from %s, Second from %s. " + "Description from first occurrence: %s") % ( + flagname, first_module, second_module, flag_summary) + return cls(msg) + + +class IllegalFlagValueError(Error): + """Raised when the flag command line argument is illegal.""" + + +class UnrecognizedFlagError(Error): + """Raised when a flag is unrecognized. + + Attributes: + flagname: str, the name of the unrecognized flag. + flagvalue: The value of the flag, empty if the flag is not defined. + """ + + def __init__(self, flagname, flagvalue='', suggestions=None): + self.flagname = flagname + self.flagvalue = flagvalue + if suggestions: + # Space before the question mark is intentional to not include it in the + # selection when copy-pasting the suggestion from (some) terminals. + tip = '. Did you mean: %s ?' % ', '.join(suggestions) + else: + tip = '' + super(UnrecognizedFlagError, self).__init__( + 'Unknown command line flag \'%s\'%s' % (flagname, tip)) + + +class UnparsedFlagAccessError(Error): + """Raised when accessing the flag value from unparsed FlagValues.""" + + +class ValidationError(Error): + """Raised when flag validator constraint is not satisfied.""" + + +class FlagNameConflictsWithMethodError(Error): + """Raised when a flag name conflicts with FlagValues methods.""" diff --git a/absl/flags/_flag.py b/absl/flags/_flag.py new file mode 100644 index 0000000..d7ad944 --- /dev/null +++ b/absl/flags/_flag.py @@ -0,0 +1,485 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains Flag class - information about single command-line flag. + +Do NOT import this module directly. Import the flags package and use the +aliases defined at the package level instead. +""" + +from collections import abc +import copy +import functools + +from absl.flags import _argument_parser +from absl.flags import _exceptions +from absl.flags import _helpers + + +@functools.total_ordering +class Flag(object): + """Information about a command-line flag. + + 'Flag' objects define the following fields: + .name - the name for this flag; + .default - the default value for this flag; + .default_unparsed - the unparsed default value for this flag. + .default_as_str - default value as repr'd string, e.g., "'true'" (or None); + .value - the most recent parsed value of this flag; set by parse(); + .help - a help string or None if no help is available; + .short_name - the single letter alias for this flag (or None); + .boolean - if 'true', this flag does not accept arguments; + .present - true if this flag was parsed from command line flags; + .parser - an ArgumentParser object; + .serializer - an ArgumentSerializer object; + .allow_override - the flag may be redefined without raising an error, and + newly defined flag overrides the old one. + .allow_override_cpp - use the flag from C++ if available; the flag + definition is replaced by the C++ flag after init; + .allow_hide_cpp - use the Python flag despite having a C++ flag with + the same name (ignore the C++ flag); + .using_default_value - the flag value has not been set by user; + .allow_overwrite - the flag may be parsed more than once without raising + an error, the last set value will be used; + .allow_using_method_names - whether this flag can be defined even if it has + a name that conflicts with a FlagValues method. + + The only public method of a 'Flag' object is parse(), but it is + typically only called by a 'FlagValues' object. The parse() method is + a thin wrapper around the 'ArgumentParser' parse() method. The parsed + value is saved in .value, and the .present attribute is updated. If + this flag was already present, an Error is raised. + + parse() is also called during __init__ to parse the default value and + initialize the .value attribute. This enables other python modules to + safely use flags even if the __main__ module neglects to parse the + command line arguments. The .present attribute is cleared after + __init__ parsing. If the default value is set to None, then the + __init__ parsing step is skipped and the .value attribute is + initialized to None. + + Note: The default value is also presented to the user in the help + string, so it is important that it be a legal value for this flag. + """ + + def __init__(self, parser, serializer, name, default, help_string, + short_name=None, boolean=False, allow_override=False, + allow_override_cpp=False, allow_hide_cpp=False, + allow_overwrite=True, allow_using_method_names=False): + self.name = name + + if not help_string: + help_string = '(no help available)' + + self.help = help_string + self.short_name = short_name + self.boolean = boolean + self.present = 0 + self.parser = parser + self.serializer = serializer + self.allow_override = allow_override + self.allow_override_cpp = allow_override_cpp + self.allow_hide_cpp = allow_hide_cpp + self.allow_overwrite = allow_overwrite + self.allow_using_method_names = allow_using_method_names + + self.using_default_value = True + self._value = None + self.validators = [] + if self.allow_hide_cpp and self.allow_override_cpp: + raise _exceptions.Error( + "Can't have both allow_hide_cpp (means use Python flag) and " + 'allow_override_cpp (means use C++ flag after InitGoogle)') + + self._set_default(default) + + @property + def value(self): + return self._value + + @value.setter + def value(self, value): + self._value = value + + def __hash__(self): + return hash(id(self)) + + def __eq__(self, other): + return self is other + + def __lt__(self, other): + if isinstance(other, Flag): + return id(self) < id(other) + return NotImplemented + + def __bool__(self): + raise TypeError('A Flag instance would always be True. ' + 'Did you mean to test the `.value` attribute?') + + def __getstate__(self): + raise TypeError("can't pickle Flag objects") + + def __copy__(self): + raise TypeError('%s does not support shallow copies. ' + 'Use copy.deepcopy instead.' % type(self).__name__) + + def __deepcopy__(self, memo): + result = object.__new__(type(self)) + result.__dict__ = copy.deepcopy(self.__dict__, memo) + return result + + def _get_parsed_value_as_string(self, value): + """Returns parsed flag value as string.""" + if value is None: + return None + if self.serializer: + return repr(self.serializer.serialize(value)) + if self.boolean: + if value: + return repr('true') + else: + return repr('false') + return repr(_helpers.str_or_unicode(value)) + + def parse(self, argument): + """Parses string and sets flag value. + + Args: + argument: str or the correct flag value type, argument to be parsed. + """ + if self.present and not self.allow_overwrite: + raise _exceptions.IllegalFlagValueError( + 'flag --%s=%s: already defined as %s' % ( + self.name, argument, self.value)) + self.value = self._parse(argument) + self.present += 1 + + def _parse(self, argument): + """Internal parse function. + + It returns the parsed value, and does not modify class states. + + Args: + argument: str or the correct flag value type, argument to be parsed. + + Returns: + The parsed value. + """ + try: + return self.parser.parse(argument) + except (TypeError, ValueError) as e: # Recast as IllegalFlagValueError. + raise _exceptions.IllegalFlagValueError( + 'flag --%s=%s: %s' % (self.name, argument, e)) + + def unparse(self): + self.value = self.default + self.using_default_value = True + self.present = 0 + + def serialize(self): + """Serializes the flag.""" + return self._serialize(self.value) + + def _serialize(self, value): + """Internal serialize function.""" + if value is None: + return '' + if self.boolean: + if value: + return '--%s' % self.name + else: + return '--no%s' % self.name + else: + if not self.serializer: + raise _exceptions.Error( + 'Serializer not present for flag %s' % self.name) + return '--%s=%s' % (self.name, self.serializer.serialize(value)) + + def _set_default(self, value): + """Changes the default value (and current value too) for this Flag.""" + self.default_unparsed = value + if value is None: + self.default = None + else: + self.default = self._parse_from_default(value) + self.default_as_str = self._get_parsed_value_as_string(self.default) + if self.using_default_value: + self.value = self.default + + # This is split out so that aliases can skip regular parsing of the default + # value. + def _parse_from_default(self, value): + return self._parse(value) + + def flag_type(self): + """Returns a str that describes the type of the flag. + + NOTE: we use strings, and not the types.*Type constants because + our flags can have more exotic types, e.g., 'comma separated list + of strings', 'whitespace separated list of strings', etc. + """ + return self.parser.flag_type() + + def _create_xml_dom_element(self, doc, module_name, is_key=False): + """Returns an XML element that contains this flag's information. + + This is information that is relevant to all flags (e.g., name, + meaning, etc.). If you defined a flag that has some other pieces of + info, then please override _ExtraXMLInfo. + + Please do NOT override this method. + + Args: + doc: minidom.Document, the DOM document it should create nodes from. + module_name: str,, the name of the module that defines this flag. + is_key: boolean, True iff this flag is key for main module. + + Returns: + A minidom.Element instance. + """ + element = doc.createElement('flag') + if is_key: + element.appendChild(_helpers.create_xml_dom_element(doc, 'key', 'yes')) + element.appendChild(_helpers.create_xml_dom_element( + doc, 'file', module_name)) + # Adds flag features that are relevant for all flags. + element.appendChild(_helpers.create_xml_dom_element(doc, 'name', self.name)) + if self.short_name: + element.appendChild(_helpers.create_xml_dom_element( + doc, 'short_name', self.short_name)) + if self.help: + element.appendChild(_helpers.create_xml_dom_element( + doc, 'meaning', self.help)) + # The default flag value can either be represented as a string like on the + # command line, or as a Python object. We serialize this value in the + # latter case in order to remain consistent. + if self.serializer and not isinstance(self.default, str): + if self.default is not None: + default_serialized = self.serializer.serialize(self.default) + else: + default_serialized = '' + else: + default_serialized = self.default + element.appendChild(_helpers.create_xml_dom_element( + doc, 'default', default_serialized)) + value_serialized = self._serialize_value_for_xml(self.value) + element.appendChild(_helpers.create_xml_dom_element( + doc, 'current', value_serialized)) + element.appendChild(_helpers.create_xml_dom_element( + doc, 'type', self.flag_type())) + # Adds extra flag features this flag may have. + for e in self._extra_xml_dom_elements(doc): + element.appendChild(e) + return element + + def _serialize_value_for_xml(self, value): + """Returns the serialized value, for use in an XML help text.""" + return value + + def _extra_xml_dom_elements(self, doc): + """Returns extra info about this flag in XML. + + "Extra" means "not already included by _create_xml_dom_element above." + + Args: + doc: minidom.Document, the DOM document it should create nodes from. + + Returns: + A list of minidom.Element. + """ + # Usually, the parser knows the extra details about the flag, so + # we just forward the call to it. + return self.parser._custom_xml_dom_elements(doc) # pylint: disable=protected-access + + +class BooleanFlag(Flag): + """Basic boolean flag. + + Boolean flags do not take any arguments, and their value is either + True (1) or False (0). The false value is specified on the command + line by prepending the word 'no' to either the long or the short flag + name. + + For example, if a Boolean flag was created whose long name was + 'update' and whose short name was 'x', then this flag could be + explicitly unset through either --noupdate or --nox. + """ + + def __init__(self, name, default, help, short_name=None, **args): # pylint: disable=redefined-builtin + p = _argument_parser.BooleanParser() + super(BooleanFlag, self).__init__( + p, None, name, default, help, short_name, 1, **args) + + +class EnumFlag(Flag): + """Basic enum flag; its value can be any string from list of enum_values.""" + + def __init__(self, name, default, help, enum_values, # pylint: disable=redefined-builtin + short_name=None, case_sensitive=True, **args): + p = _argument_parser.EnumParser(enum_values, case_sensitive) + g = _argument_parser.ArgumentSerializer() + super(EnumFlag, self).__init__( + p, g, name, default, help, short_name, **args) + self.help = '<%s>: %s' % ('|'.join(enum_values), self.help) + + def _extra_xml_dom_elements(self, doc): + elements = [] + for enum_value in self.parser.enum_values: + elements.append(_helpers.create_xml_dom_element( + doc, 'enum_value', enum_value)) + return elements + + +class EnumClassFlag(Flag): + """Basic enum flag; its value is an enum class's member.""" + + def __init__( + self, + name, + default, + help, # pylint: disable=redefined-builtin + enum_class, + short_name=None, + case_sensitive=False, + **args): + p = _argument_parser.EnumClassParser( + enum_class, case_sensitive=case_sensitive) + g = _argument_parser.EnumClassSerializer(lowercase=not case_sensitive) + super(EnumClassFlag, self).__init__( + p, g, name, default, help, short_name, **args) + self.help = '<%s>: %s' % ('|'.join(p.member_names), self.help) + + def _extra_xml_dom_elements(self, doc): + elements = [] + for enum_value in self.parser.enum_class.__members__.keys(): + elements.append(_helpers.create_xml_dom_element( + doc, 'enum_value', enum_value)) + return elements + + +class MultiFlag(Flag): + """A flag that can appear multiple time on the command-line. + + The value of such a flag is a list that contains the individual values + from all the appearances of that flag on the command-line. + + See the __doc__ for Flag for most behavior of this class. Only + differences in behavior are described here: + + * The default value may be either a single value or an iterable of values. + A single value is transformed into a single-item list of that value. + + * The value of the flag is always a list, even if the option was + only supplied once, and even if the default value is a single + value + """ + + def __init__(self, *args, **kwargs): + super(MultiFlag, self).__init__(*args, **kwargs) + self.help += ';\n repeat this option to specify a list of values' + + def parse(self, arguments): + """Parses one or more arguments with the installed parser. + + Args: + arguments: a single argument or a list of arguments (typically a + list of default values); a single argument is converted + internally into a list containing one item. + """ + new_values = self._parse(arguments) + if self.present: + self.value.extend(new_values) + else: + self.value = new_values + self.present += len(new_values) + + def _parse(self, arguments): + if (isinstance(arguments, abc.Iterable) and + not isinstance(arguments, str)): + arguments = list(arguments) + + if not isinstance(arguments, list): + # Default value may be a list of values. Most other arguments + # will not be, so convert them into a single-item list to make + # processing simpler below. + arguments = [arguments] + + return [super(MultiFlag, self)._parse(item) for item in arguments] + + def _serialize(self, value): + """See base class.""" + if not self.serializer: + raise _exceptions.Error( + 'Serializer not present for flag %s' % self.name) + if value is None: + return '' + + serialized_items = [ + super(MultiFlag, self)._serialize(value_item) for value_item in value + ] + + return '\n'.join(serialized_items) + + def flag_type(self): + """See base class.""" + return 'multi ' + self.parser.flag_type() + + def _extra_xml_dom_elements(self, doc): + elements = [] + if hasattr(self.parser, 'enum_values'): + for enum_value in self.parser.enum_values: + elements.append(_helpers.create_xml_dom_element( + doc, 'enum_value', enum_value)) + return elements + + +class MultiEnumClassFlag(MultiFlag): + """A multi_enum_class flag. + + See the __doc__ for MultiFlag for most behaviors of this class. In addition, + this class knows how to handle enum.Enum instances as values for this flag + type. + """ + + def __init__(self, + name, + default, + help_string, + enum_class, + case_sensitive=False, + **args): + p = _argument_parser.EnumClassParser( + enum_class, case_sensitive=case_sensitive) + g = _argument_parser.EnumClassListSerializer( + list_sep=',', lowercase=not case_sensitive) + super(MultiEnumClassFlag, self).__init__( + p, g, name, default, help_string, **args) + self.help = ( + '<%s>: %s;\n repeat this option to specify a list of values' % + ('|'.join(p.member_names), help_string or '(no help available)')) + + def _extra_xml_dom_elements(self, doc): + elements = [] + for enum_value in self.parser.enum_class.__members__.keys(): + elements.append(_helpers.create_xml_dom_element( + doc, 'enum_value', enum_value)) + return elements + + def _serialize_value_for_xml(self, value): + """See base class.""" + if value is not None: + value_serialized = self.serializer.serialize(value) + else: + value_serialized = '' + return value_serialized diff --git a/absl/flags/_flag.pyi b/absl/flags/_flag.pyi new file mode 100644 index 0000000..9b4a3d3 --- /dev/null +++ b/absl/flags/_flag.pyi @@ -0,0 +1,133 @@ +# Copyright 2020 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains type annotations for Flag class.""" + +import copy +import functools + +from absl.flags import _argument_parser +import enum + +from typing import Text, TypeVar, Generic, Iterable, Type, List, Optional, Any, Union, Sequence + +_T = TypeVar('_T') +_ET = TypeVar('_ET', bound=enum.Enum) + + +class Flag(Generic[_T]): + + name = ... # type: Text + default = ... # type: Any + default_unparsed = ... # type: Any + default_as_str = ... # type: Optional[Text] + help = ... # type: Text + short_name = ... # type: Text + boolean = ... # type: bool + present = ... # type: bool + parser = ... # type: _argument_parser.ArgumentParser[_T] + serializer = ... # type: _argument_parser.ArgumentSerializer[_T] + allow_override = ... # type: bool + allow_override_cpp = ... # type: bool + allow_hide_cpp = ... # type: bool + using_default_value = ... # type: bool + allow_overwrite = ... # type: bool + allow_using_method_names = ... # type: bool + + def __init__(self, + parser: _argument_parser.ArgumentParser[_T], + serializer: Optional[_argument_parser.ArgumentSerializer[_T]], + name: Text, + default: Any, + help_string: Optional[Text], + short_name: Optional[Text] = ..., + boolean: bool = ..., + allow_override: bool = ..., + allow_override_cpp: bool = ..., + allow_hide_cpp: bool = ..., + allow_overwrite: bool = ..., + allow_using_method_names: bool = ...) -> None: + ... + + + @property + def value(self) -> Optional[_T]: ... + + def parse(self, argument: Union[_T, Text, None]) -> None: ... + + def unparse(self) -> None: ... + + def _parse(self, argument: Any) -> Any: ... + + def __deepcopy__(self, memo: dict) -> Flag: ... + + def _get_parsed_value_as_string(self, value: Optional[_T]) -> Optional[Text]: + ... + + def serialize(self) -> Text: ... + + def flag_type(self) -> Text: ... + + +class BooleanFlag(Flag[bool]): + def __init__(self, + name: Text, + default: Any, + help: Optional[Text], + short_name: Optional[Text]=None, + **args: Any) -> None: + ... + + + +class EnumFlag(Flag[Text]): + def __init__(self, + name: Text, + default: Any, + help: Optional[Text], + enum_values: Sequence[Text], + short_name: Optional[Text] = ..., + case_sensitive: bool = ..., + **args: Any): + ... + + + +class EnumClassFlag(Flag[_ET]): + + def __init__(self, + name: Text, + default: Any, + help: Optional[Text], + enum_class: Type[_ET], + short_name: Optional[Text]=None, + **args: Any): + ... + + + +class MultiFlag(Flag[List[_T]]): + ... + + +class MultiEnumClassFlag(MultiFlag[_ET]): + def __init__(self, + name: Text, + default: Any, + help_string: Optional[Text], + enum_class: Type[_ET], + **args: Any): + ... + + diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py new file mode 100644 index 0000000..1b54fb3 --- /dev/null +++ b/absl/flags/_flagvalues.py @@ -0,0 +1,1387 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines the FlagValues class - registry of 'Flag' objects. + +Do NOT import this module directly. Import the flags package and use the +aliases defined at the package level instead. +""" + +import copy +import itertools +import logging +import os +import sys +from typing import Generic, TypeVar +from xml.dom import minidom + +from absl.flags import _exceptions +from absl.flags import _flag +from absl.flags import _helpers +from absl.flags import _validators_classes + +# Add flagvalues module to disclaimed module ids. +_helpers.disclaim_module_ids.add(id(sys.modules[__name__])) + +_T = TypeVar('_T') + + +class FlagValues: + """Registry of 'Flag' objects. + + A 'FlagValues' can then scan command line arguments, passing flag + arguments through to the 'Flag' objects that it owns. It also + provides easy access to the flag values. Typically only one + 'FlagValues' object is needed by an application: flags.FLAGS + + This class is heavily overloaded: + + 'Flag' objects are registered via __setitem__: + FLAGS['longname'] = x # register a new flag + + The .value attribute of the registered 'Flag' objects can be accessed + as attributes of this 'FlagValues' object, through __getattr__. Both + the long and short name of the original 'Flag' objects can be used to + access its value: + FLAGS.longname # parsed flag value + FLAGS.x # parsed flag value (short name) + + Command line arguments are scanned and passed to the registered 'Flag' + objects through the __call__ method. Unparsed arguments, including + argv[0] (e.g. the program name) are returned. + argv = FLAGS(sys.argv) # scan command line arguments + + The original registered Flag objects can be retrieved through the use + of the dictionary-like operator, __getitem__: + x = FLAGS['longname'] # access the registered Flag object + + The str() operator of a 'FlagValues' object provides help for all of + the registered 'Flag' objects. + """ + + # A note on collections.abc.Mapping: + # FlagValues defines __getitem__, __iter__, and __len__. It makes perfect + # sense to let it be a collections.abc.Mapping class. However, we are not + # able to do so. The mixin methods, e.g. keys, values, are not uncommon flag + # names. Those flag values would not be accessible via the FLAGS.xxx form. + + def __init__(self): + # Since everything in this class is so heavily overloaded, the only + # way of defining and using fields is to access __dict__ directly. + + # Dictionary: flag name (string) -> Flag object. + self.__dict__['__flags'] = {} + + # Set: name of hidden flag (string). + # Holds flags that should not be directly accessible from Python. + self.__dict__['__hiddenflags'] = set() + + # Dictionary: module name (string) -> list of Flag objects that are defined + # by that module. + self.__dict__['__flags_by_module'] = {} + # Dictionary: module id (int) -> list of Flag objects that are defined by + # that module. + self.__dict__['__flags_by_module_id'] = {} + # Dictionary: module name (string) -> list of Flag objects that are + # key for that module. + self.__dict__['__key_flags_by_module'] = {} + + # Bool: True if flags were parsed. + self.__dict__['__flags_parsed'] = False + + # Bool: True if unparse_flags() was called. + self.__dict__['__unparse_flags_called'] = False + + # None or Method(name, value) to call from __setattr__ for an unknown flag. + self.__dict__['__set_unknown'] = None + + # A set of banned flag names. This is to prevent users from accidentally + # defining a flag that has the same name as a method on this class. + # Users can still allow defining the flag by passing + # allow_using_method_names=True in DEFINE_xxx functions. + self.__dict__['__banned_flag_names'] = frozenset(dir(FlagValues)) + + # Bool: Whether to use GNU style scanning. + self.__dict__['__use_gnu_getopt'] = True + + # Bool: Whether use_gnu_getopt has been explicitly set by the user. + self.__dict__['__use_gnu_getopt_explicitly_set'] = False + + # Function: Takes a flag name as parameter, returns a tuple + # (is_retired, type_is_bool). + self.__dict__['__is_retired_flag_func'] = None + + def set_gnu_getopt(self, gnu_getopt=True): + """Sets whether or not to use GNU style scanning. + + GNU style allows mixing of flag and non-flag arguments. See + http://docs.python.org/library/getopt.html#getopt.gnu_getopt + + Args: + gnu_getopt: bool, whether or not to use GNU style scanning. + """ + self.__dict__['__use_gnu_getopt'] = gnu_getopt + self.__dict__['__use_gnu_getopt_explicitly_set'] = True + + def is_gnu_getopt(self): + return self.__dict__['__use_gnu_getopt'] + + def _flags(self): + return self.__dict__['__flags'] + + def flags_by_module_dict(self): + """Returns the dictionary of module_name -> list of defined flags. + + Returns: + A dictionary. Its keys are module names (strings). Its values + are lists of Flag objects. + """ + return self.__dict__['__flags_by_module'] + + def flags_by_module_id_dict(self): + """Returns the dictionary of module_id -> list of defined flags. + + Returns: + A dictionary. Its keys are module IDs (ints). Its values + are lists of Flag objects. + """ + return self.__dict__['__flags_by_module_id'] + + def key_flags_by_module_dict(self): + """Returns the dictionary of module_name -> list of key flags. + + Returns: + A dictionary. Its keys are module names (strings). Its values + are lists of Flag objects. + """ + return self.__dict__['__key_flags_by_module'] + + def register_flag_by_module(self, module_name, flag): + """Records the module that defines a specific flag. + + We keep track of which flag is defined by which module so that we + can later sort the flags by module. + + Args: + module_name: str, the name of a Python module. + flag: Flag, the Flag instance that is key to the module. + """ + flags_by_module = self.flags_by_module_dict() + flags_by_module.setdefault(module_name, []).append(flag) + + def register_flag_by_module_id(self, module_id, flag): + """Records the module that defines a specific flag. + + Args: + module_id: int, the ID of the Python module. + flag: Flag, the Flag instance that is key to the module. + """ + flags_by_module_id = self.flags_by_module_id_dict() + flags_by_module_id.setdefault(module_id, []).append(flag) + + def register_key_flag_for_module(self, module_name, flag): + """Specifies that a flag is a key flag for a module. + + Args: + module_name: str, the name of a Python module. + flag: Flag, the Flag instance that is key to the module. + """ + key_flags_by_module = self.key_flags_by_module_dict() + # The list of key flags for the module named module_name. + key_flags = key_flags_by_module.setdefault(module_name, []) + # Add flag, but avoid duplicates. + if flag not in key_flags: + key_flags.append(flag) + + def _flag_is_registered(self, flag_obj): + """Checks whether a Flag object is registered under long name or short name. + + Args: + flag_obj: Flag, the Flag instance to check for. + + Returns: + bool, True iff flag_obj is registered under long name or short name. + """ + flag_dict = self._flags() + # Check whether flag_obj is registered under its long name. + name = flag_obj.name + if flag_dict.get(name, None) == flag_obj: + return True + # Check whether flag_obj is registered under its short name. + short_name = flag_obj.short_name + if (short_name is not None and flag_dict.get(short_name, None) == flag_obj): + return True + return False + + def _cleanup_unregistered_flag_from_module_dicts(self, flag_obj): + """Cleans up unregistered flags from all module -> [flags] dictionaries. + + If flag_obj is registered under either its long name or short name, it + won't be removed from the dictionaries. + + Args: + flag_obj: Flag, the Flag instance to clean up for. + """ + if self._flag_is_registered(flag_obj): + return + for flags_by_module_dict in (self.flags_by_module_dict(), + self.flags_by_module_id_dict(), + self.key_flags_by_module_dict()): + for flags_in_module in flags_by_module_dict.values(): + # While (as opposed to if) takes care of multiple occurrences of a + # flag in the list for the same module. + while flag_obj in flags_in_module: + flags_in_module.remove(flag_obj) + + def get_flags_for_module(self, module): + """Returns the list of flags defined by a module. + + Args: + module: module|str, the module to get flags from. + + Returns: + [Flag], a new list of Flag instances. Caller may update this list as + desired: none of those changes will affect the internals of this + FlagValue instance. + """ + if not isinstance(module, str): + module = module.__name__ + if module == '__main__': + module = sys.argv[0] + + return list(self.flags_by_module_dict().get(module, [])) + + def get_key_flags_for_module(self, module): + """Returns the list of key flags for a module. + + Args: + module: module|str, the module to get key flags from. + + Returns: + [Flag], a new list of Flag instances. Caller may update this list as + desired: none of those changes will affect the internals of this + FlagValue instance. + """ + if not isinstance(module, str): + module = module.__name__ + if module == '__main__': + module = sys.argv[0] + + # Any flag is a key flag for the module that defined it. NOTE: + # key_flags is a fresh list: we can update it without affecting the + # internals of this FlagValues object. + key_flags = self.get_flags_for_module(module) + + # Take into account flags explicitly declared as key for a module. + for flag in self.key_flags_by_module_dict().get(module, []): + if flag not in key_flags: + key_flags.append(flag) + return key_flags + + def find_module_defining_flag(self, flagname, default=None): + """Return the name of the module defining this flag, or default. + + Args: + flagname: str, name of the flag to lookup. + default: Value to return if flagname is not defined. Defaults to None. + + Returns: + The name of the module which registered the flag with this name. + If no such module exists (i.e. no flag with this name exists), + we return default. + """ + registered_flag = self._flags().get(flagname) + if registered_flag is None: + return default + for module, flags in self.flags_by_module_dict().items(): + for flag in flags: + # It must compare the flag with the one in _flags. This is because a + # flag might be overridden only for its long name (or short name), + # and only its short name (or long name) is considered registered. + if (flag.name == registered_flag.name and + flag.short_name == registered_flag.short_name): + return module + return default + + def find_module_id_defining_flag(self, flagname, default=None): + """Return the ID of the module defining this flag, or default. + + Args: + flagname: str, name of the flag to lookup. + default: Value to return if flagname is not defined. Defaults to None. + + Returns: + The ID of the module which registered the flag with this name. + If no such module exists (i.e. no flag with this name exists), + we return default. + """ + registered_flag = self._flags().get(flagname) + if registered_flag is None: + return default + for module_id, flags in self.flags_by_module_id_dict().items(): + for flag in flags: + # It must compare the flag with the one in _flags. This is because a + # flag might be overridden only for its long name (or short name), + # and only its short name (or long name) is considered registered. + if (flag.name == registered_flag.name and + flag.short_name == registered_flag.short_name): + return module_id + return default + + def _register_unknown_flag_setter(self, setter): + """Allow set default values for undefined flags. + + Args: + setter: Method(name, value) to call to __setattr__ an unknown flag. Must + raise NameError or ValueError for invalid name/value. + """ + self.__dict__['__set_unknown'] = setter + + def _set_unknown_flag(self, name, value): + """Returns value if setting flag |name| to |value| returned True. + + Args: + name: str, name of the flag to set. + value: Value to set. + + Returns: + Flag value on successful call. + + Raises: + UnrecognizedFlagError + IllegalFlagValueError + """ + setter = self.__dict__['__set_unknown'] + if setter: + try: + setter(name, value) + return value + except (TypeError, ValueError): # Flag value is not valid. + raise _exceptions.IllegalFlagValueError( + '"{1}" is not valid for --{0}'.format(name, value)) + except NameError: # Flag name is not valid. + pass + raise _exceptions.UnrecognizedFlagError(name, value) + + def append_flag_values(self, flag_values): + """Appends flags registered in another FlagValues instance. + + Args: + flag_values: FlagValues, the FlagValues instance from which to copy flags. + """ + for flag_name, flag in flag_values._flags().items(): # pylint: disable=protected-access + # Each flags with short_name appears here twice (once under its + # normal name, and again with its short name). To prevent + # problems (DuplicateFlagError) with double flag registration, we + # perform a check to make sure that the entry we're looking at is + # for its normal name. + if flag_name == flag.name: + try: + self[flag_name] = flag + except _exceptions.DuplicateFlagError: + raise _exceptions.DuplicateFlagError.from_flag( + flag_name, self, other_flag_values=flag_values) + + def remove_flag_values(self, flag_values): + """Remove flags that were previously appended from another FlagValues. + + Args: + flag_values: FlagValues, the FlagValues instance containing flags to + remove. + """ + for flag_name in flag_values: + self.__delattr__(flag_name) + + def __setitem__(self, name, flag): + """Registers a new flag variable.""" + fl = self._flags() + if not isinstance(flag, _flag.Flag): + raise _exceptions.IllegalFlagValueError(flag) + if str is bytes and isinstance(name, unicode): + # When using Python 2 with unicode_literals, allow it but encode it + # into the bytes type we require. + name = name.encode('utf-8') + if not isinstance(name, type('')): + raise _exceptions.Error('Flag name must be a string') + if not name: + raise _exceptions.Error('Flag name cannot be empty') + if ' ' in name: + raise _exceptions.Error('Flag name cannot contain a space') + self._check_method_name_conflicts(name, flag) + if name in fl and not flag.allow_override and not fl[name].allow_override: + module, module_name = _helpers.get_calling_module_object_and_name() + if (self.find_module_defining_flag(name) == module_name and + id(module) != self.find_module_id_defining_flag(name)): + # If the flag has already been defined by a module with the same name, + # but a different ID, we can stop here because it indicates that the + # module is simply being imported a subsequent time. + return + raise _exceptions.DuplicateFlagError.from_flag(name, self) + short_name = flag.short_name + # If a new flag overrides an old one, we need to cleanup the old flag's + # modules if it's not registered. + flags_to_cleanup = set() + if short_name is not None: + if (short_name in fl and not flag.allow_override and + not fl[short_name].allow_override): + raise _exceptions.DuplicateFlagError.from_flag(short_name, self) + if short_name in fl and fl[short_name] != flag: + flags_to_cleanup.add(fl[short_name]) + fl[short_name] = flag + if (name not in fl # new flag + or fl[name].using_default_value or not flag.using_default_value): + if name in fl and fl[name] != flag: + flags_to_cleanup.add(fl[name]) + fl[name] = flag + for f in flags_to_cleanup: + self._cleanup_unregistered_flag_from_module_dicts(f) + + def __dir__(self): + """Returns list of names of all defined flags. + + Useful for TAB-completion in ipython. + + Returns: + [str], a list of names of all defined flags. + """ + return sorted(self.__dict__['__flags']) + + def __getitem__(self, name): + """Returns the Flag object for the flag --name.""" + return self._flags()[name] + + def _hide_flag(self, name): + """Marks the flag --name as hidden.""" + self.__dict__['__hiddenflags'].add(name) + + def __getattr__(self, name): + """Retrieves the 'value' attribute of the flag --name.""" + fl = self._flags() + if name not in fl: + raise AttributeError(name) + if name in self.__dict__['__hiddenflags']: + raise AttributeError(name) + + if self.__dict__['__flags_parsed'] or fl[name].present: + return fl[name].value + else: + raise _exceptions.UnparsedFlagAccessError( + 'Trying to access flag --%s before flags were parsed.' % name) + + def __setattr__(self, name, value): + """Sets the 'value' attribute of the flag --name.""" + self._set_attributes(**{name: value}) + return value + + def _set_attributes(self, **attributes): + """Sets multiple flag values together, triggers validators afterwards.""" + fl = self._flags() + known_flags = set() + for name, value in attributes.items(): + if name in self.__dict__['__hiddenflags']: + raise AttributeError(name) + if name in fl: + fl[name].value = value + known_flags.add(name) + else: + self._set_unknown_flag(name, value) + for name in known_flags: + self._assert_validators(fl[name].validators) + fl[name].using_default_value = False + + def validate_all_flags(self): + """Verifies whether all flags pass validation. + + Raises: + AttributeError: Raised if validators work with a non-existing flag. + IllegalFlagValueError: Raised if validation fails for at least one + validator. + """ + all_validators = set() + for flag in self._flags().values(): + all_validators.update(flag.validators) + self._assert_validators(all_validators) + + def _assert_validators(self, validators): + """Asserts if all validators in the list are satisfied. + + It asserts validators in the order they were created. + + Args: + validators: Iterable(validators.Validator), validators to be verified. + + Raises: + AttributeError: Raised if validators work with a non-existing flag. + IllegalFlagValueError: Raised if validation fails for at least one + validator. + """ + messages = [] + bad_flags = set() + for validator in sorted( + validators, key=lambda validator: validator.insertion_index): + try: + if isinstance(validator, _validators_classes.SingleFlagValidator): + if validator.flag_name in bad_flags: + continue + elif isinstance(validator, _validators_classes.MultiFlagsValidator): + if bad_flags & set(validator.flag_names): + continue + validator.verify(self) + except _exceptions.ValidationError as e: + if isinstance(validator, _validators_classes.SingleFlagValidator): + bad_flags.add(validator.flag_name) + elif isinstance(validator, _validators_classes.MultiFlagsValidator): + bad_flags.update(set(validator.flag_names)) + message = validator.print_flags_with_values(self) + messages.append('%s: %s' % (message, str(e))) + if messages: + raise _exceptions.IllegalFlagValueError('\n'.join(messages)) + + def __delattr__(self, flag_name): + """Deletes a previously-defined flag from a flag object. + + This method makes sure we can delete a flag by using + + del FLAGS.<flag_name> + + E.g., + + flags.DEFINE_integer('foo', 1, 'Integer flag.') + del flags.FLAGS.foo + + If a flag is also registered by its the other name (long name or short + name), the other name won't be deleted. + + Args: + flag_name: str, the name of the flag to be deleted. + + Raises: + AttributeError: Raised when there is no registered flag named flag_name. + """ + fl = self._flags() + if flag_name not in fl: + raise AttributeError(flag_name) + + flag_obj = fl[flag_name] + del fl[flag_name] + + self._cleanup_unregistered_flag_from_module_dicts(flag_obj) + + def set_default(self, name, value): + """Changes the default value of the named flag object. + + The flag's current value is also updated if the flag is currently using + the default value, i.e. not specified in the command line, and not set + by FLAGS.name = value. + + Args: + name: str, the name of the flag to modify. + value: The new default value. + + Raises: + UnrecognizedFlagError: Raised when there is no registered flag named name. + IllegalFlagValueError: Raised when value is not valid. + """ + fl = self._flags() + if name not in fl: + self._set_unknown_flag(name, value) + return + fl[name]._set_default(value) # pylint: disable=protected-access + self._assert_validators(fl[name].validators) + + def __contains__(self, name): + """Returns True if name is a value (flag) in the dict.""" + return name in self._flags() + + def __len__(self): + return len(self.__dict__['__flags']) + + def __iter__(self): + return iter(self._flags()) + + def __call__(self, argv, known_only=False): + """Parses flags from argv; stores parsed flags into this FlagValues object. + + All unparsed arguments are returned. + + Args: + argv: a tuple/list of strings. + known_only: bool, if True, parse and remove known flags; return the rest + untouched. Unknown flags specified by --undefok are not returned. + + Returns: + The list of arguments not parsed as options, including argv[0]. + + Raises: + Error: Raised on any parsing error. + TypeError: Raised on passing wrong type of arguments. + ValueError: Raised on flag value parsing error. + """ + if _helpers.is_bytes_or_string(argv): + raise TypeError( + 'argv should be a tuple/list of strings, not bytes or string.') + if not argv: + raise ValueError( + 'argv cannot be an empty list, and must contain the program name as ' + 'the first element.') + + # This pre parses the argv list for --flagfile=<> options. + program_name = argv[0] + args = self.read_flags_from_files(argv[1:], force_gnu=False) + + # Parse the arguments. + unknown_flags, unparsed_args = self._parse_args(args, known_only) + + # Handle unknown flags by raising UnrecognizedFlagError. + # Note some users depend on us raising this particular error. + for name, value in unknown_flags: + suggestions = _helpers.get_flag_suggestions(name, list(self)) + raise _exceptions.UnrecognizedFlagError( + name, value, suggestions=suggestions) + + self.mark_as_parsed() + self.validate_all_flags() + return [program_name] + unparsed_args + + def __getstate__(self): + raise TypeError("can't pickle FlagValues") + + def __copy__(self): + raise TypeError('FlagValues does not support shallow copies. ' + 'Use absl.testing.flagsaver or copy.deepcopy instead.') + + def __deepcopy__(self, memo): + result = object.__new__(type(self)) + result.__dict__.update(copy.deepcopy(self.__dict__, memo)) + return result + + def _set_is_retired_flag_func(self, is_retired_flag_func): + """Sets a function for checking retired flags. + + Do not use it. This is a private absl API used to check retired flags + registered by the absl C++ flags library. + + Args: + is_retired_flag_func: Callable(str) -> (bool, bool), a function takes flag + name as parameter, returns a tuple (is_retired, type_is_bool). + """ + self.__dict__['__is_retired_flag_func'] = is_retired_flag_func + + def _parse_args(self, args, known_only): + """Helper function to do the main argument parsing. + + This function goes through args and does the bulk of the flag parsing. + It will find the corresponding flag in our flag dictionary, and call its + .parse() method on the flag value. + + Args: + args: [str], a list of strings with the arguments to parse. + known_only: bool, if True, parse and remove known flags; return the rest + untouched. Unknown flags specified by --undefok are not returned. + + Returns: + A tuple with the following: + unknown_flags: List of (flag name, arg) for flags we don't know about. + unparsed_args: List of arguments we did not parse. + + Raises: + Error: Raised on any parsing error. + ValueError: Raised on flag value parsing error. + """ + unparsed_names_and_args = [] # A list of (flag name or None, arg). + undefok = set() + retired_flag_func = self.__dict__['__is_retired_flag_func'] + + flag_dict = self._flags() + args = iter(args) + for arg in args: + value = None + + def get_value(): + # pylint: disable=cell-var-from-loop + try: + return next(args) if value is None else value + except StopIteration: + raise _exceptions.Error('Missing value for flag ' + arg) # pylint: disable=undefined-loop-variable + + if not arg.startswith('-'): + # A non-argument: default is break, GNU is skip. + unparsed_names_and_args.append((None, arg)) + if self.is_gnu_getopt(): + continue + else: + break + + if arg == '--': + if known_only: + unparsed_names_and_args.append((None, arg)) + break + + # At this point, arg must start with '-'. + if arg.startswith('--'): + arg_without_dashes = arg[2:] + else: + arg_without_dashes = arg[1:] + + if '=' in arg_without_dashes: + name, value = arg_without_dashes.split('=', 1) + else: + name, value = arg_without_dashes, None + + if not name: + # The argument is all dashes (including one dash). + unparsed_names_and_args.append((None, arg)) + if self.is_gnu_getopt(): + continue + else: + break + + # --undefok is a special case. + if name == 'undefok': + value = get_value() + undefok.update(v.strip() for v in value.split(',')) + undefok.update('no' + v.strip() for v in value.split(',')) + continue + + flag = flag_dict.get(name) + if flag is not None: + if flag.boolean and value is None: + value = 'true' + else: + value = get_value() + elif name.startswith('no') and len(name) > 2: + # Boolean flags can take the form of --noflag, with no value. + noflag = flag_dict.get(name[2:]) + if noflag is not None and noflag.boolean: + if value is not None: + raise ValueError(arg + ' does not take an argument') + flag = noflag + value = 'false' + + if retired_flag_func and flag is None: + is_retired, is_bool = retired_flag_func(name) + + # If we didn't recognize that flag, but it starts with + # "no" then maybe it was a boolean flag specified in the + # --nofoo form. + if not is_retired and name.startswith('no'): + is_retired, is_bool = retired_flag_func(name[2:]) + is_retired = is_retired and is_bool + + if is_retired: + if not is_bool and value is None: + # This happens when a non-bool retired flag is specified + # in format of "--flag value". + get_value() + logging.error( + 'Flag "%s" is retired and should no longer ' + 'be specified. See go/totw/90.', name) + continue + + if flag is not None: + flag.parse(value) + flag.using_default_value = False + else: + unparsed_names_and_args.append((name, arg)) + + unknown_flags = [] + unparsed_args = [] + for name, arg in unparsed_names_and_args: + if name is None: + # Positional arguments. + unparsed_args.append(arg) + elif name in undefok: + # Remove undefok flags. + continue + else: + # This is an unknown flag. + if known_only: + unparsed_args.append(arg) + else: + unknown_flags.append((name, arg)) + + unparsed_args.extend(list(args)) + return unknown_flags, unparsed_args + + def is_parsed(self): + """Returns whether flags were parsed.""" + return self.__dict__['__flags_parsed'] + + def mark_as_parsed(self): + """Explicitly marks flags as parsed. + + Use this when the caller knows that this FlagValues has been parsed as if + a __call__() invocation has happened. This is only a public method for + use by things like appcommands which do additional command like parsing. + """ + self.__dict__['__flags_parsed'] = True + + def unparse_flags(self): + """Unparses all flags to the point before any FLAGS(argv) was called.""" + for f in self._flags().values(): + f.unparse() + # We log this message before marking flags as unparsed to avoid a + # problem when the logging library causes flags access. + logging.info('unparse_flags() called; flags access will now raise errors.') + self.__dict__['__flags_parsed'] = False + self.__dict__['__unparse_flags_called'] = True + + def flag_values_dict(self): + """Returns a dictionary that maps flag names to flag values.""" + return {name: flag.value for name, flag in self._flags().items()} + + def __str__(self): + """Returns a help string for all known flags.""" + return self.get_help() + + def get_help(self, prefix='', include_special_flags=True): + """Returns a help string for all known flags. + + Args: + prefix: str, per-line output prefix. + include_special_flags: bool, whether to include description of + SPECIAL_FLAGS, i.e. --flagfile and --undefok. + + Returns: + str, formatted help message. + """ + flags_by_module = self.flags_by_module_dict() + if flags_by_module: + modules = sorted(flags_by_module) + # Print the help for the main module first, if possible. + main_module = sys.argv[0] + if main_module in modules: + modules.remove(main_module) + modules = [main_module] + modules + return self._get_help_for_modules(modules, prefix, include_special_flags) + else: + output_lines = [] + # Just print one long list of flags. + values = self._flags().values() + if include_special_flags: + values = itertools.chain( + values, _helpers.SPECIAL_FLAGS._flags().values()) # pylint: disable=protected-access + self._render_flag_list(values, output_lines, prefix) + return '\n'.join(output_lines) + + def _get_help_for_modules(self, modules, prefix, include_special_flags): + """Returns the help string for a list of modules. + + Private to absl.flags package. + + Args: + modules: List[str], a list of modules to get the help string for. + prefix: str, a string that is prepended to each generated help line. + include_special_flags: bool, whether to include description of + SPECIAL_FLAGS, i.e. --flagfile and --undefok. + """ + output_lines = [] + for module in modules: + self._render_our_module_flags(module, output_lines, prefix) + if include_special_flags: + self._render_module_flags( + 'absl.flags', + _helpers.SPECIAL_FLAGS._flags().values(), # pylint: disable=protected-access + output_lines, + prefix) + return '\n'.join(output_lines) + + def _render_module_flags(self, module, flags, output_lines, prefix=''): + """Returns a help string for a given module.""" + if not isinstance(module, str): + module = module.__name__ + output_lines.append('\n%s%s:' % (prefix, module)) + self._render_flag_list(flags, output_lines, prefix + ' ') + + def _render_our_module_flags(self, module, output_lines, prefix=''): + """Returns a help string for a given module.""" + flags = self.get_flags_for_module(module) + if flags: + self._render_module_flags(module, flags, output_lines, prefix) + + def _render_our_module_key_flags(self, module, output_lines, prefix=''): + """Returns a help string for the key flags of a given module. + + Args: + module: module|str, the module to render key flags for. + output_lines: [str], a list of strings. The generated help message lines + will be appended to this list. + prefix: str, a string that is prepended to each generated help line. + """ + key_flags = self.get_key_flags_for_module(module) + if key_flags: + self._render_module_flags(module, key_flags, output_lines, prefix) + + def module_help(self, module): + """Describes the key flags of a module. + + Args: + module: module|str, the module to describe the key flags for. + + Returns: + str, describing the key flags of a module. + """ + helplist = [] + self._render_our_module_key_flags(module, helplist) + return '\n'.join(helplist) + + def main_module_help(self): + """Describes the key flags of the main module. + + Returns: + str, describing the key flags of the main module. + """ + return self.module_help(sys.argv[0]) + + def _render_flag_list(self, flaglist, output_lines, prefix=' '): + fl = self._flags() + special_fl = _helpers.SPECIAL_FLAGS._flags() # pylint: disable=protected-access + flaglist = [(flag.name, flag) for flag in flaglist] + flaglist.sort() + flagset = {} + for (name, flag) in flaglist: + # It's possible this flag got deleted or overridden since being + # registered in the per-module flaglist. Check now against the + # canonical source of current flag information, the _flags. + if fl.get(name, None) != flag and special_fl.get(name, None) != flag: + # a different flag is using this name now + continue + # only print help once + if flag in flagset: + continue + flagset[flag] = 1 + flaghelp = '' + if flag.short_name: + flaghelp += '-%s,' % flag.short_name + if flag.boolean: + flaghelp += '--[no]%s:' % flag.name + else: + flaghelp += '--%s:' % flag.name + flaghelp += ' ' + if flag.help: + flaghelp += flag.help + flaghelp = _helpers.text_wrap( + flaghelp, indent=prefix + ' ', firstline_indent=prefix) + if flag.default_as_str: + flaghelp += '\n' + flaghelp += _helpers.text_wrap( + '(default: %s)' % flag.default_as_str, indent=prefix + ' ') + if flag.parser.syntactic_help: + flaghelp += '\n' + flaghelp += _helpers.text_wrap( + '(%s)' % flag.parser.syntactic_help, indent=prefix + ' ') + output_lines.append(flaghelp) + + def get_flag_value(self, name, default): # pylint: disable=invalid-name + """Returns the value of a flag (if not None) or a default value. + + Args: + name: str, the name of a flag. + default: Default value to use if the flag value is None. + + Returns: + Requested flag value or default. + """ + + value = self.__getattr__(name) + if value is not None: # Can't do if not value, b/c value might be '0' or "" + return value + else: + return default + + def _is_flag_file_directive(self, flag_string): + """Checks whether flag_string contain a --flagfile=<foo> directive.""" + if isinstance(flag_string, type('')): + if flag_string.startswith('--flagfile='): + return 1 + elif flag_string == '--flagfile': + return 1 + elif flag_string.startswith('-flagfile='): + return 1 + elif flag_string == '-flagfile': + return 1 + else: + return 0 + return 0 + + def _extract_filename(self, flagfile_str): + """Returns filename from a flagfile_str of form -[-]flagfile=filename. + + The cases of --flagfile foo and -flagfile foo shouldn't be hitting + this function, as they are dealt with in the level above this + function. + + Args: + flagfile_str: str, the flagfile string. + + Returns: + str, the filename from a flagfile_str of form -[-]flagfile=filename. + + Raises: + Error: Raised when illegal --flagfile is provided. + """ + if flagfile_str.startswith('--flagfile='): + return os.path.expanduser((flagfile_str[(len('--flagfile=')):]).strip()) + elif flagfile_str.startswith('-flagfile='): + return os.path.expanduser((flagfile_str[(len('-flagfile=')):]).strip()) + else: + raise _exceptions.Error('Hit illegal --flagfile type: %s' % flagfile_str) + + def _get_flag_file_lines(self, filename, parsed_file_stack=None): + """Returns the useful (!=comments, etc) lines from a file with flags. + + Args: + filename: str, the name of the flag file. + parsed_file_stack: [str], a list of the names of the files that we have + recursively encountered at the current depth. MUTATED BY THIS FUNCTION + (but the original value is preserved upon successfully returning from + function call). + + Returns: + List of strings. See the note below. + + NOTE(springer): This function checks for a nested --flagfile=<foo> + tag and handles the lower file recursively. It returns a list of + all the lines that _could_ contain command flags. This is + EVERYTHING except whitespace lines and comments (lines starting + with '#' or '//'). + """ + # For consistency with the cpp version, ignore empty values. + if not filename: + return [] + if parsed_file_stack is None: + parsed_file_stack = [] + # We do a little safety check for reparsing a file we've already encountered + # at a previous depth. + if filename in parsed_file_stack: + sys.stderr.write('Warning: Hit circular flagfile dependency. Ignoring' + ' flagfile: %s\n' % (filename,)) + return [] + else: + parsed_file_stack.append(filename) + + line_list = [] # All line from flagfile. + flag_line_list = [] # Subset of lines w/o comments, blanks, flagfile= tags. + try: + file_obj = open(filename, 'r') + except IOError as e_msg: + raise _exceptions.CantOpenFlagFileError( + 'ERROR:: Unable to open flagfile: %s' % e_msg) + + with file_obj: + line_list = file_obj.readlines() + + # This is where we check each line in the file we just read. + for line in line_list: + if line.isspace(): + pass + # Checks for comment (a line that starts with '#'). + elif line.startswith('#') or line.startswith('//'): + pass + # Checks for a nested "--flagfile=<bar>" flag in the current file. + # If we find one, recursively parse down into that file. + elif self._is_flag_file_directive(line): + sub_filename = self._extract_filename(line) + included_flags = self._get_flag_file_lines( + sub_filename, parsed_file_stack=parsed_file_stack) + flag_line_list.extend(included_flags) + else: + # Any line that's not a comment or a nested flagfile should get + # copied into 2nd position. This leaves earlier arguments + # further back in the list, thus giving them higher priority. + flag_line_list.append(line.strip()) + + parsed_file_stack.pop() + return flag_line_list + + def read_flags_from_files(self, argv, force_gnu=True): + """Processes command line args, but also allow args to be read from file. + + Args: + argv: [str], a list of strings, usually sys.argv[1:], which may contain + one or more flagfile directives of the form --flagfile="./filename". + Note that the name of the program (sys.argv[0]) should be omitted. + force_gnu: bool, if False, --flagfile parsing obeys the + FLAGS.is_gnu_getopt() value. If True, ignore the value and always follow + gnu_getopt semantics. + + Returns: + A new list which has the original list combined with what we read + from any flagfile(s). + + Raises: + IllegalFlagValueError: Raised when --flagfile is provided with no + argument. + + This function is called by FLAGS(argv). + It scans the input list for a flag that looks like: + --flagfile=<somefile>. Then it opens <somefile>, reads all valid key + and value pairs and inserts them into the input list in exactly the + place where the --flagfile arg is found. + + Note that your application's flags are still defined the usual way + using absl.flags DEFINE_flag() type functions. + + Notes (assuming we're getting a commandline of some sort as our input): + --> For duplicate flags, the last one we hit should "win". + --> Since flags that appear later win, a flagfile's settings can be "weak" + if the --flagfile comes at the beginning of the argument sequence, + and it can be "strong" if the --flagfile comes at the end. + --> A further "--flagfile=<otherfile.cfg>" CAN be nested in a flagfile. + It will be expanded in exactly the spot where it is found. + --> In a flagfile, a line beginning with # or // is a comment. + --> Entirely blank lines _should_ be ignored. + """ + rest_of_args = argv + new_argv = [] + while rest_of_args: + current_arg = rest_of_args[0] + rest_of_args = rest_of_args[1:] + if self._is_flag_file_directive(current_arg): + # This handles the case of -(-)flagfile foo. In this case the + # next arg really is part of this one. + if current_arg == '--flagfile' or current_arg == '-flagfile': + if not rest_of_args: + raise _exceptions.IllegalFlagValueError( + '--flagfile with no argument') + flag_filename = os.path.expanduser(rest_of_args[0]) + rest_of_args = rest_of_args[1:] + else: + # This handles the case of (-)-flagfile=foo. + flag_filename = self._extract_filename(current_arg) + new_argv.extend(self._get_flag_file_lines(flag_filename)) + else: + new_argv.append(current_arg) + # Stop parsing after '--', like getopt and gnu_getopt. + if current_arg == '--': + break + # Stop parsing after a non-flag, like getopt. + if not current_arg.startswith('-'): + if not force_gnu and not self.__dict__['__use_gnu_getopt']: + break + else: + if ('=' not in current_arg and rest_of_args and + not rest_of_args[0].startswith('-')): + # If this is an occurrence of a legitimate --x y, skip the value + # so that it won't be mistaken for a standalone arg. + fl = self._flags() + name = current_arg.lstrip('-') + if name in fl and not fl[name].boolean: + current_arg = rest_of_args[0] + rest_of_args = rest_of_args[1:] + new_argv.append(current_arg) + + if rest_of_args: + new_argv.extend(rest_of_args) + + return new_argv + + def flags_into_string(self): + """Returns a string with the flags assignments from this FlagValues object. + + This function ignores flags whose value is None. Each flag + assignment is separated by a newline. + + NOTE: MUST mirror the behavior of the C++ CommandlineFlagsIntoString + from https://github.com/gflags/gflags. + + Returns: + str, the string with the flags assignments from this FlagValues object. + The flags are ordered by (module_name, flag_name). + """ + module_flags = sorted(self.flags_by_module_dict().items()) + s = '' + for unused_module_name, flags in module_flags: + flags = sorted(flags, key=lambda f: f.name) + for flag in flags: + if flag.value is not None: + s += flag.serialize() + '\n' + return s + + def append_flags_into_file(self, filename): + """Appends all flags assignments from this FlagInfo object to a file. + + Output will be in the format of a flagfile. + + NOTE: MUST mirror the behavior of the C++ AppendFlagsIntoFile + from https://github.com/gflags/gflags. + + Args: + filename: str, name of the file. + """ + with open(filename, 'a') as out_file: + out_file.write(self.flags_into_string()) + + def write_help_in_xml_format(self, outfile=None): + """Outputs flag documentation in XML format. + + NOTE: We use element names that are consistent with those used by + the C++ command-line flag library, from + https://github.com/gflags/gflags. + We also use a few new elements (e.g., <key>), but we do not + interfere / overlap with existing XML elements used by the C++ + library. Please maintain this consistency. + + Args: + outfile: File object we write to. Default None means sys.stdout. + """ + doc = minidom.Document() + all_flag = doc.createElement('AllFlags') + doc.appendChild(all_flag) + + all_flag.appendChild( + _helpers.create_xml_dom_element(doc, 'program', + os.path.basename(sys.argv[0]))) + + usage_doc = sys.modules['__main__'].__doc__ + if not usage_doc: + usage_doc = '\nUSAGE: %s [flags]\n' % sys.argv[0] + else: + usage_doc = usage_doc.replace('%s', sys.argv[0]) + all_flag.appendChild( + _helpers.create_xml_dom_element(doc, 'usage', usage_doc)) + + # Get list of key flags for the main module. + key_flags = self.get_key_flags_for_module(sys.argv[0]) + + # Sort flags by declaring module name and next by flag name. + flags_by_module = self.flags_by_module_dict() + all_module_names = list(flags_by_module.keys()) + all_module_names.sort() + for module_name in all_module_names: + flag_list = [(f.name, f) for f in flags_by_module[module_name]] + flag_list.sort() + for unused_flag_name, flag in flag_list: + is_key = flag in key_flags + all_flag.appendChild( + flag._create_xml_dom_element( # pylint: disable=protected-access + doc, + module_name, + is_key=is_key)) + + outfile = outfile or sys.stdout + outfile.write( + doc.toprettyxml(indent=' ', encoding='utf-8').decode('utf-8')) + outfile.flush() + + def _check_method_name_conflicts(self, name, flag): + if flag.allow_using_method_names: + return + short_name = flag.short_name + flag_names = {name} if short_name is None else {name, short_name} + for flag_name in flag_names: + if flag_name in self.__dict__['__banned_flag_names']: + raise _exceptions.FlagNameConflictsWithMethodError( + 'Cannot define a flag named "{name}". It conflicts with a method ' + 'on class "{class_name}". To allow defining it, use ' + 'allow_using_method_names and access the flag value with ' + "FLAGS['{name}'].value. FLAGS.{name} returns the method, " + 'not the flag value.'.format( + name=flag_name, class_name=type(self).__name__)) + + +FLAGS = FlagValues() + + +class FlagHolder(Generic[_T]): + """Holds a defined flag. + + This facilitates a cleaner api around global state. Instead of + + ``` + flags.DEFINE_integer('foo', ...) + flags.DEFINE_integer('bar', ...) + ... + def method(): + # prints parsed value of 'bar' flag + print(flags.FLAGS.foo) + # runtime error due to typo or possibly bad coding style. + print(flags.FLAGS.baz) + ``` + + it encourages code like + + ``` + FOO_FLAG = flags.DEFINE_integer('foo', ...) + BAR_FLAG = flags.DEFINE_integer('bar', ...) + ... + def method(): + print(FOO_FLAG.value) + print(BAR_FLAG.value) + ``` + + since the name of the flag appears only once in the source code. + """ + + def __init__(self, flag_values, flag, ensure_non_none_value=False): + """Constructs a FlagHolder instance providing typesafe access to flag. + + Args: + flag_values: The container the flag is registered to. + flag: The flag object for this flag. + ensure_non_none_value: Is the value of the flag allowed to be None. + """ + self._flagvalues = flag_values + # We take the entire flag object, but only keep the name. Why? + # - We want FlagHolder[T] to be generic container + # - flag_values contains all flags, so has no reference to T. + # - typecheckers don't like to see a generic class where none of the ctor + # arguments refer to the generic type. + self._name = flag.name + # We intentionally do NOT check if the default value is None. + # This allows future use of this for "required flags with None default" + self._ensure_non_none_value = ensure_non_none_value + + def __eq__(self, other): + raise TypeError( + "unsupported operand type(s) for ==: '{0}' and '{1}' " + "(did you mean to use '{0}.value' instead?)".format( + type(self).__name__, type(other).__name__)) + + def __bool__(self): + raise TypeError( + "bool() not supported for instances of type '{0}' " + "(did you mean to use '{0}.value' instead?)".format( + type(self).__name__)) + + __nonzero__ = __bool__ + + @property + def name(self): + return self._name + + @property + def value(self): + """Returns the value of the flag. + + If _ensure_non_none_value is True, then return value is not None. + + Raises: + UnparsedFlagAccessError: if flag parsing has not finished. + IllegalFlagValueError: if value is None unexpectedly. + """ + val = getattr(self._flagvalues, self._name) + if self._ensure_non_none_value and val is None: + raise _exceptions.IllegalFlagValueError( + 'Unexpected None value for flag %s' % self._name) + return val + + @property + def default(self): + """Returns the default value of the flag.""" + return self._flagvalues[self._name].default + + @property + def present(self): + """Returns True if the flag was parsed from command-line flags.""" + return bool(self._flagvalues[self._name].present) diff --git a/absl/flags/_flagvalues.pyi b/absl/flags/_flagvalues.pyi new file mode 100644 index 0000000..e25c6dd --- /dev/null +++ b/absl/flags/_flagvalues.pyi @@ -0,0 +1,148 @@ +# Copyright 2020 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines type annotations for _flagvalues.""" + + +from absl.flags import _flag + +from typing import Any, Dict, Generic, Iterable, Iterator, List, Optional, Sequence, Text, Type, TypeVar + + +class FlagValues: + + def __getitem__(self, name: Text) -> _flag.Flag: ... + + def __setitem__(self, name: Text, flag: _flag.Flag) -> None: ... + + def __getattr__(self, name: Text) -> Any: ... + + def __setattr__(self, name: Text, value: Any) -> Any: ... + + def __call__( + self, + argv: Sequence[Text], + known_only: bool = ..., + ) -> List[Text]: ... + + def __contains__(self, name: Text) -> bool: ... + + def __copy__(self) -> Any: ... + + def __deepcopy__(self, memo) -> Any: ... + + def __delattr__(self, flag_name: Text) -> None: ... + + def __dir__(self) -> List[Text]: ... + + def __getstate__(self) -> Any: ... + + def __iter__(self) -> Iterator[Text]: ... + + def __len__(self) -> int: ... + + def get_help(self, + prefix: Text = ..., + include_special_flags: bool = ...) -> Text: + ... + + + def set_gnu_getopt(self, gnu_getopt: bool = ...) -> None: ... + + def is_gnu_getopt(self) -> bool: ... + + def flags_by_module_dict(self) -> Dict[Text, List[_flag.Flag]]: ... + + def flags_by_module_id_dict(self) -> Dict[Text, List[_flag.Flag]]: ... + + def key_flags_by_module_dict(self) -> Dict[Text, List[_flag.Flag]]: ... + + def register_flag_by_module( + self, module_name: Text, flag: _flag.Flag) -> None: ... + + def register_flag_by_module_id( + self, module_id: Text, flag: _flag.Flag) -> None: ... + + def register_key_flag_for_module( + self, module_name: Text, flag: _flag.Flag) -> None: ... + + def get_key_flags_for_module(self, module: Any) -> List[_flag.Flag]: ... + + def find_module_defining_flag( + self, flagname: Text, default: Any = ...) -> Any: + ... + + def find_module_id_defining_flag( + self, flagname: Text, default: Any = ...) -> Any: + ... + + def append_flag_values(self, flag_values: Any) -> None: ... + + def remove_flag_values(self, flag_values: Any) -> None: ... + + def validate_all_flags(self) -> None: ... + + def set_default(self, name: Text, value: Any) -> None: ... + + def is_parsed(self) -> bool: ... + + def mark_as_parsed(self) -> None: ... + + def unparse_flags(self) -> None: ... + + def flag_values_dict(self) -> Dict[Text, Any]: ... + + def module_help(self, module: Any) -> Text: ... + + def main_module_help(self) -> Text: ... + + def get_flag_value(self, name: Text, default: Any) -> Any: ... + + def read_flags_from_files( + self, argv: List[Text], force_gnu: bool = ...) -> List[Text]: ... + + def flags_into_string(self) -> Text: ... + + def append_flags_into_file(self, filename: Text) -> None:... + + # outfile is Optional[fileobject] + def write_help_in_xml_format(self, outfile: Any = ...) -> None: ... + + +FLAGS = ... # type: FlagValues + + +_T = TypeVar('_T') # The type of parsed default value of the flag. + +# We assume that default and value are guaranteed to have the same type. +class FlagHolder(Generic[_T]): + def __init__( + self, + flag_values: FlagValues, + # NOTE: Use Flag instead of Flag[T] is used to work around some superficial + # differences between Flag and FlagHolder typing. + flag: _flag.Flag, + ensure_non_none_value: bool=False) -> None: ... + + @property + def name(self) -> Text: ... + + @property + def value(self) -> _T: ... + + @property + def default(self) -> _T: ... + + @property + def present(self) -> bool: ... diff --git a/absl/flags/_helpers.py b/absl/flags/_helpers.py new file mode 100644 index 0000000..37ae360 --- /dev/null +++ b/absl/flags/_helpers.py @@ -0,0 +1,433 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal helper functions for Abseil Python flags library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import os +import re +import struct +import sys +import textwrap +try: + import fcntl +except ImportError: + fcntl = None +try: + # Importing termios will fail on non-unix platforms. + import termios +except ImportError: + termios = None + + +_DEFAULT_HELP_WIDTH = 80 # Default width of help output. +_MIN_HELP_WIDTH = 40 # Minimal "sane" width of help output. We assume that any + # value below 40 is unreasonable. + +# Define the allowed error rate in an input string to get suggestions. +# +# We lean towards a high threshold because we tend to be matching a phrase, +# and the simple algorithm used here is geared towards correcting word +# spellings. +# +# For manual testing, consider "<command> --list" which produced a large number +# of spurious suggestions when we used "least_errors > 0.5" instead of +# "least_erros >= 0.5". +_SUGGESTION_ERROR_RATE_THRESHOLD = 0.50 + +# Characters that cannot appear or are highly discouraged in an XML 1.0 +# document. (See http://www.w3.org/TR/REC-xml/#charsets or +# https://en.wikipedia.org/wiki/Valid_characters_in_XML#XML_1.0) +_ILLEGAL_XML_CHARS_REGEX = re.compile( + u'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x84\x86-\x9f\ud800-\udfff\ufffe\uffff]') + +# This is a set of module ids for the modules that disclaim key flags. +# This module is explicitly added to this set so that we never consider it to +# define key flag. +disclaim_module_ids = set([id(sys.modules[__name__])]) + + +# Define special flags here so that help may be generated for them. +# NOTE: Please do NOT use SPECIAL_FLAGS from outside flags module. +# Initialized inside flagvalues.py. +SPECIAL_FLAGS = None + + +# This points to the flags module, initialized in flags/__init__.py. +# This should only be used in adopt_module_key_flags to take SPECIAL_FLAGS into +# account. +FLAGS_MODULE = None + + +class _ModuleObjectAndName( + collections.namedtuple('_ModuleObjectAndName', 'module module_name')): + """Module object and name. + + Fields: + - module: object, module object. + - module_name: str, module name. + """ + + +def get_module_object_and_name(globals_dict): + """Returns the module that defines a global environment, and its name. + + Args: + globals_dict: A dictionary that should correspond to an environment + providing the values of the globals. + + Returns: + _ModuleObjectAndName - pair of module object & module name. + Returns (None, None) if the module could not be identified. + """ + name = globals_dict.get('__name__', None) + module = sys.modules.get(name, None) + # Pick a more informative name for the main module. + return _ModuleObjectAndName(module, + (sys.argv[0] if name == '__main__' else name)) + + +def get_calling_module_object_and_name(): + """Returns the module that's calling into this module. + + We generally use this function to get the name of the module calling a + DEFINE_foo... function. + + Returns: + The module object that called into this one. + + Raises: + AssertionError: Raised when no calling module could be identified. + """ + for depth in range(1, sys.getrecursionlimit()): + # sys._getframe is the right thing to use here, as it's the best + # way to walk up the call stack. + globals_for_frame = sys._getframe(depth).f_globals # pylint: disable=protected-access + module, module_name = get_module_object_and_name(globals_for_frame) + if id(module) not in disclaim_module_ids and module_name is not None: + return _ModuleObjectAndName(module, module_name) + raise AssertionError('No module was found') + + +def get_calling_module(): + """Returns the name of the module that's calling into this module.""" + return get_calling_module_object_and_name().module_name + + +def str_or_unicode(value): + """Converts a value to a python string. + + Behavior of this function is intentionally different in Python2/3. + + In Python2, the given value is attempted to convert to a str (byte string). + If it contains non-ASCII characters, it is converted to a unicode instead. + + In Python3, the given value is always converted to a str (unicode string). + + This behavior reflects the (bad) practice in Python2 to try to represent + a string as str as long as it contains ASCII characters only. + + Args: + value: An object to be converted to a string. + + Returns: + A string representation of the given value. See the description above + for its type. + """ + try: + return str(value) + except UnicodeEncodeError: + return unicode(value) # Python3 should never come here + + +def create_xml_dom_element(doc, name, value): + """Returns an XML DOM element with name and text value. + + Args: + doc: minidom.Document, the DOM document it should create nodes from. + name: str, the tag of XML element. + value: object, whose string representation will be used + as the value of the XML element. Illegal or highly discouraged xml 1.0 + characters are stripped. + + Returns: + An instance of minidom.Element. + """ + s = str_or_unicode(value) + if isinstance(value, bool): + # Display boolean values as the C++ flag library does: no caps. + s = s.lower() + # Remove illegal xml characters. + s = _ILLEGAL_XML_CHARS_REGEX.sub(u'', s) + + e = doc.createElement(name) + e.appendChild(doc.createTextNode(s)) + return e + + +def get_help_width(): + """Returns the integer width of help lines that is used in TextWrap.""" + if not sys.stdout.isatty() or termios is None or fcntl is None: + return _DEFAULT_HELP_WIDTH + try: + data = fcntl.ioctl(sys.stdout, termios.TIOCGWINSZ, '1234') + columns = struct.unpack('hh', data)[1] + # Emacs mode returns 0. + # Here we assume that any value below 40 is unreasonable. + if columns >= _MIN_HELP_WIDTH: + return columns + # Returning an int as default is fine, int(int) just return the int. + return int(os.getenv('COLUMNS', _DEFAULT_HELP_WIDTH)) + + except (TypeError, IOError, struct.error): + return _DEFAULT_HELP_WIDTH + + +def get_flag_suggestions(attempt, longopt_list): + """Returns helpful similar matches for an invalid flag.""" + # Don't suggest on very short strings, or if no longopts are specified. + if len(attempt) <= 2 or not longopt_list: + return [] + + option_names = [v.split('=')[0] for v in longopt_list] + + # Find close approximations in flag prefixes. + # This also handles the case where the flag is spelled right but ambiguous. + distances = [(_damerau_levenshtein(attempt, option[0:len(attempt)]), option) + for option in option_names] + # t[0] is distance, and sorting by t[1] allows us to have stable output. + distances.sort() + + least_errors, _ = distances[0] + # Don't suggest excessively bad matches. + if least_errors >= _SUGGESTION_ERROR_RATE_THRESHOLD * len(attempt): + return [] + + suggestions = [] + for errors, name in distances: + if errors == least_errors: + suggestions.append(name) + else: + break + return suggestions + + +def _damerau_levenshtein(a, b): + """Returns Damerau-Levenshtein edit distance from a to b.""" + memo = {} + + def distance(x, y): + """Recursively defined string distance with memoization.""" + if (x, y) in memo: + return memo[x, y] + if not x: + d = len(y) + elif not y: + d = len(x) + else: + d = min( + distance(x[1:], y) + 1, # correct an insertion error + distance(x, y[1:]) + 1, # correct a deletion error + distance(x[1:], y[1:]) + (x[0] != y[0])) # correct a wrong character + if len(x) >= 2 and len(y) >= 2 and x[0] == y[1] and x[1] == y[0]: + # Correct a transposition. + t = distance(x[2:], y[2:]) + 1 + if d > t: + d = t + + memo[x, y] = d + return d + return distance(a, b) + + +def text_wrap(text, length=None, indent='', firstline_indent=None): + """Wraps a given text to a maximum line length and returns it. + + It turns lines that only contain whitespace into empty lines, keeps new lines, + and expands tabs using 4 spaces. + + Args: + text: str, text to wrap. + length: int, maximum length of a line, includes indentation. + If this is None then use get_help_width() + indent: str, indent for all but first line. + firstline_indent: str, indent for first line; if None, fall back to indent. + + Returns: + str, the wrapped text. + + Raises: + ValueError: Raised if indent or firstline_indent not shorter than length. + """ + # Get defaults where callee used None + if length is None: + length = get_help_width() + if indent is None: + indent = '' + if firstline_indent is None: + firstline_indent = indent + + if len(indent) >= length: + raise ValueError('Length of indent exceeds length') + if len(firstline_indent) >= length: + raise ValueError('Length of first line indent exceeds length') + + text = text.expandtabs(4) + + result = [] + # Create one wrapper for the first paragraph and one for subsequent + # paragraphs that does not have the initial wrapping. + wrapper = textwrap.TextWrapper( + width=length, initial_indent=firstline_indent, subsequent_indent=indent) + subsequent_wrapper = textwrap.TextWrapper( + width=length, initial_indent=indent, subsequent_indent=indent) + + # textwrap does not have any special treatment for newlines. From the docs: + # "...newlines may appear in the middle of a line and cause strange output. + # For this reason, text should be split into paragraphs (using + # str.splitlines() or similar) which are wrapped separately." + for paragraph in (p.strip() for p in text.splitlines()): + if paragraph: + result.extend(wrapper.wrap(paragraph)) + else: + result.append('') # Keep empty lines. + # Replace initial wrapper with wrapper for subsequent paragraphs. + wrapper = subsequent_wrapper + + return '\n'.join(result) + + +def flag_dict_to_args(flag_map, multi_flags=None): + """Convert a dict of values into process call parameters. + + This method is used to convert a dictionary into a sequence of parameters + for a binary that parses arguments using this module. + + Args: + flag_map: dict, a mapping where the keys are flag names (strings). + values are treated according to their type: + * If value is None, then only the name is emitted. + * If value is True, then only the name is emitted. + * If value is False, then only the name prepended with 'no' is emitted. + * If value is a string then --name=value is emitted. + * If value is a collection, this will emit --name=value1,value2,value3, + unless the flag name is in multi_flags, in which case this will emit + --name=value1 --name=value2 --name=value3. + * Everything else is converted to string an passed as such. + multi_flags: set, names (strings) of flags that should be treated as + multi-flags. + Yields: + sequence of string suitable for a subprocess execution. + """ + for key, value in flag_map.items(): + if value is None: + yield '--%s' % key + elif isinstance(value, bool): + if value: + yield '--%s' % key + else: + yield '--no%s' % key + elif isinstance(value, (bytes, type(u''))): + # We don't want strings to be handled like python collections. + yield '--%s=%s' % (key, value) + else: + # Now we attempt to deal with collections. + try: + if multi_flags and key in multi_flags: + for item in value: + yield '--%s=%s' % (key, str(item)) + else: + yield '--%s=%s' % (key, ','.join(str(item) for item in value)) + except TypeError: + # Default case. + yield '--%s=%s' % (key, value) + + +def trim_docstring(docstring): + """Removes indentation from triple-quoted strings. + + This is the function specified in PEP 257 to handle docstrings: + https://www.python.org/dev/peps/pep-0257/. + + Args: + docstring: str, a python docstring. + + Returns: + str, docstring with indentation removed. + """ + if not docstring: + return '' + + # If you've got a line longer than this you have other problems... + max_indent = 1 << 29 + + # Convert tabs to spaces (following the normal Python rules) + # and split into a list of lines: + lines = docstring.expandtabs().splitlines() + + # Determine minimum indentation (first line doesn't count): + indent = max_indent + for line in lines[1:]: + stripped = line.lstrip() + if stripped: + indent = min(indent, len(line) - len(stripped)) + # Remove indentation (first line is special): + trimmed = [lines[0].strip()] + if indent < max_indent: + for line in lines[1:]: + trimmed.append(line[indent:].rstrip()) + # Strip off trailing and leading blank lines: + while trimmed and not trimmed[-1]: + trimmed.pop() + while trimmed and not trimmed[0]: + trimmed.pop(0) + # Return a single string: + return '\n'.join(trimmed) + + +def doc_to_help(doc): + """Takes a __doc__ string and reformats it as help.""" + + # Get rid of starting and ending white space. Using lstrip() or even + # strip() could drop more than maximum of first line and right space + # of last line. + doc = doc.strip() + + # Get rid of all empty lines. + whitespace_only_line = re.compile('^[ \t]+$', re.M) + doc = whitespace_only_line.sub('', doc) + + # Cut out common space at line beginnings. + doc = trim_docstring(doc) + + # Just like this module's comment, comments tend to be aligned somehow. + # In other words they all start with the same amount of white space. + # 1) keep double new lines; + # 2) keep ws after new lines if not empty line; + # 3) all other new lines shall be changed to a space; + # Solution: Match new lines between non white space and replace with space. + doc = re.sub(r'(?<=\S)\n(?=\S)', ' ', doc, flags=re.M) + + return doc + + +def is_bytes_or_string(maybe_string): + if str is bytes: + return isinstance(maybe_string, basestring) + else: + return isinstance(maybe_string, (str, bytes)) diff --git a/absl/flags/_validators.py b/absl/flags/_validators.py new file mode 100644 index 0000000..af66050 --- /dev/null +++ b/absl/flags/_validators.py @@ -0,0 +1,313 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module to enforce different constraints on flags. + +Flags validators can be registered using following functions / decorators: + flags.register_validator + @flags.validator + flags.register_multi_flags_validator + @flags.multi_flags_validator + +Three convenience functions are also provided for common flag constraints: + flags.mark_flag_as_required + flags.mark_flags_as_required + flags.mark_flags_as_mutual_exclusive + flags.mark_bool_flags_as_mutual_exclusive + +See their docstring in this module for a usage manual. + +Do NOT import this module directly. Import the flags package and use the +aliases defined at the package level instead. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import warnings + +from absl.flags import _exceptions +from absl.flags import _flagvalues +from absl.flags import _validators_classes + + +def register_validator(flag_name, + checker, + message='Flag validation failed', + flag_values=_flagvalues.FLAGS): + """Adds a constraint, which will be enforced during program execution. + + The constraint is validated when flags are initially parsed, and after each + change of the corresponding flag's value. + Args: + flag_name: str, name of the flag to be checked. + checker: callable, a function to validate the flag. + input - A single positional argument: The value of the corresponding + flag (string, boolean, etc. This value will be passed to checker + by the library). + output - bool, True if validator constraint is satisfied. + If constraint is not satisfied, it should either return False or + raise flags.ValidationError(desired_error_message). + message: str, error text to be shown to the user if checker returns False. + If checker raises flags.ValidationError, message from the raised + error will be shown. + flag_values: flags.FlagValues, optional FlagValues instance to validate + against. + Raises: + AttributeError: Raised when flag_name is not registered as a valid flag + name. + """ + v = _validators_classes.SingleFlagValidator(flag_name, checker, message) + _add_validator(flag_values, v) + + +def validator(flag_name, message='Flag validation failed', + flag_values=_flagvalues.FLAGS): + """A function decorator for defining a flag validator. + + Registers the decorated function as a validator for flag_name, e.g. + + @flags.validator('foo') + def _CheckFoo(foo): + ... + + See register_validator() for the specification of checker function. + + Args: + flag_name: str, name of the flag to be checked. + message: str, error text to be shown to the user if checker returns False. + If checker raises flags.ValidationError, message from the raised + error will be shown. + flag_values: flags.FlagValues, optional FlagValues instance to validate + against. + Returns: + A function decorator that registers its function argument as a validator. + Raises: + AttributeError: Raised when flag_name is not registered as a valid flag + name. + """ + + def decorate(function): + register_validator(flag_name, function, + message=message, + flag_values=flag_values) + return function + return decorate + + +def register_multi_flags_validator(flag_names, + multi_flags_checker, + message='Flags validation failed', + flag_values=_flagvalues.FLAGS): + """Adds a constraint to multiple flags. + + The constraint is validated when flags are initially parsed, and after each + change of the corresponding flag's value. + + Args: + flag_names: [str], a list of the flag names to be checked. + multi_flags_checker: callable, a function to validate the flag. + input - dict, with keys() being flag_names, and value for each key + being the value of the corresponding flag (string, boolean, etc). + output - bool, True if validator constraint is satisfied. + If constraint is not satisfied, it should either return False or + raise flags.ValidationError. + message: str, error text to be shown to the user if checker returns False. + If checker raises flags.ValidationError, message from the raised + error will be shown. + flag_values: flags.FlagValues, optional FlagValues instance to validate + against. + + Raises: + AttributeError: Raised when a flag is not registered as a valid flag name. + """ + v = _validators_classes.MultiFlagsValidator( + flag_names, multi_flags_checker, message) + _add_validator(flag_values, v) + + +def multi_flags_validator(flag_names, + message='Flag validation failed', + flag_values=_flagvalues.FLAGS): + """A function decorator for defining a multi-flag validator. + + Registers the decorated function as a validator for flag_names, e.g. + + @flags.multi_flags_validator(['foo', 'bar']) + def _CheckFooBar(flags_dict): + ... + + See register_multi_flags_validator() for the specification of checker + function. + + Args: + flag_names: [str], a list of the flag names to be checked. + message: str, error text to be shown to the user if checker returns False. + If checker raises flags.ValidationError, message from the raised + error will be shown. + flag_values: flags.FlagValues, optional FlagValues instance to validate + against. + + Returns: + A function decorator that registers its function argument as a validator. + + Raises: + AttributeError: Raised when a flag is not registered as a valid flag name. + """ + + def decorate(function): + register_multi_flags_validator(flag_names, + function, + message=message, + flag_values=flag_values) + return function + + return decorate + + +def mark_flag_as_required(flag_name, flag_values=_flagvalues.FLAGS): + """Ensures that flag is not None during program execution. + + Registers a flag validator, which will follow usual validator rules. + Important note: validator will pass for any non-None value, such as False, + 0 (zero), '' (empty string) and so on. + + If your module might be imported by others, and you only wish to make the flag + required when the module is directly executed, call this method like this: + + if __name__ == '__main__': + flags.mark_flag_as_required('your_flag_name') + app.run() + + Args: + flag_name: str, name of the flag + flag_values: flags.FlagValues, optional FlagValues instance where the flag + is defined. + Raises: + AttributeError: Raised when flag_name is not registered as a valid flag + name. + """ + if flag_values[flag_name].default is not None: + warnings.warn( + 'Flag --%s has a non-None default value; therefore, ' + 'mark_flag_as_required will pass even if flag is not specified in the ' + 'command line!' % flag_name, + stacklevel=2) + register_validator( + flag_name, + lambda value: value is not None, + message='Flag --{} must have a value other than None.'.format(flag_name), + flag_values=flag_values) + + +def mark_flags_as_required(flag_names, flag_values=_flagvalues.FLAGS): + """Ensures that flags are not None during program execution. + + If your module might be imported by others, and you only wish to make the flag + required when the module is directly executed, call this method like this: + + if __name__ == '__main__': + flags.mark_flags_as_required(['flag1', 'flag2', 'flag3']) + app.run() + + Args: + flag_names: Sequence[str], names of the flags. + flag_values: flags.FlagValues, optional FlagValues instance where the flags + are defined. + Raises: + AttributeError: If any of flag name has not already been defined as a flag. + """ + for flag_name in flag_names: + mark_flag_as_required(flag_name, flag_values) + + +def mark_flags_as_mutual_exclusive(flag_names, required=False, + flag_values=_flagvalues.FLAGS): + """Ensures that only one flag among flag_names is not None. + + Important note: This validator checks if flag values are None, and it does not + distinguish between default and explicit values. Therefore, this validator + does not make sense when applied to flags with default values other than None, + including other false values (e.g. False, 0, '', []). That includes multi + flags with a default value of [] instead of None. + + Args: + flag_names: [str], names of the flags. + required: bool. If true, exactly one of the flags must have a value other + than None. Otherwise, at most one of the flags can have a value other + than None, and it is valid for all of the flags to be None. + flag_values: flags.FlagValues, optional FlagValues instance where the flags + are defined. + """ + for flag_name in flag_names: + if flag_values[flag_name].default is not None: + warnings.warn( + 'Flag --{} has a non-None default value. That does not make sense ' + 'with mark_flags_as_mutual_exclusive, which checks whether the ' + 'listed flags have a value other than None.'.format(flag_name), + stacklevel=2) + + def validate_mutual_exclusion(flags_dict): + flag_count = sum(1 for val in flags_dict.values() if val is not None) + if flag_count == 1 or (not required and flag_count == 0): + return True + raise _exceptions.ValidationError( + '{} one of ({}) must have a value other than None.'.format( + 'Exactly' if required else 'At most', ', '.join(flag_names))) + + register_multi_flags_validator( + flag_names, validate_mutual_exclusion, flag_values=flag_values) + + +def mark_bool_flags_as_mutual_exclusive(flag_names, required=False, + flag_values=_flagvalues.FLAGS): + """Ensures that only one flag among flag_names is True. + + Args: + flag_names: [str], names of the flags. + required: bool. If true, exactly one flag must be True. Otherwise, at most + one flag can be True, and it is valid for all flags to be False. + flag_values: flags.FlagValues, optional FlagValues instance where the flags + are defined. + """ + for flag_name in flag_names: + if not flag_values[flag_name].boolean: + raise _exceptions.ValidationError( + 'Flag --{} is not Boolean, which is required for flags used in ' + 'mark_bool_flags_as_mutual_exclusive.'.format(flag_name)) + + def validate_boolean_mutual_exclusion(flags_dict): + flag_count = sum(bool(val) for val in flags_dict.values()) + if flag_count == 1 or (not required and flag_count == 0): + return True + raise _exceptions.ValidationError( + '{} one of ({}) must be True.'.format( + 'Exactly' if required else 'At most', ', '.join(flag_names))) + + register_multi_flags_validator( + flag_names, validate_boolean_mutual_exclusion, flag_values=flag_values) + + +def _add_validator(fv, validator_instance): + """Register new flags validator to be checked. + + Args: + fv: flags.FlagValues, the FlagValues instance to add the validator. + validator_instance: validators.Validator, the validator to add. + Raises: + KeyError: Raised when validators work with a non-existing flag. + """ + for flag_name in validator_instance.get_flags_names(): + fv[flag_name].validators.append(validator_instance) diff --git a/absl/flags/_validators_classes.py b/absl/flags/_validators_classes.py new file mode 100644 index 0000000..d8996e0 --- /dev/null +++ b/absl/flags/_validators_classes.py @@ -0,0 +1,176 @@ +# Copyright 2021 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines *private* classes used for flag validators. + +Do NOT import this module. DO NOT use anything from this module. They are +private APIs. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.flags import _exceptions + + +class Validator(object): + """Base class for flags validators. + + Users should NOT overload these classes, and use flags.Register... + methods instead. + """ + + # Used to assign each validator an unique insertion_index + validators_count = 0 + + def __init__(self, checker, message): + """Constructor to create all validators. + + Args: + checker: function to verify the constraint. + Input of this method varies, see SingleFlagValidator and + multi_flags_validator for a detailed description. + message: str, error message to be shown to the user. + """ + self.checker = checker + self.message = message + Validator.validators_count += 1 + # Used to assert validators in the order they were registered. + self.insertion_index = Validator.validators_count + + def verify(self, flag_values): + """Verifies that constraint is satisfied. + + flags library calls this method to verify Validator's constraint. + + Args: + flag_values: flags.FlagValues, the FlagValues instance to get flags from. + Raises: + Error: Raised if constraint is not satisfied. + """ + param = self._get_input_to_checker_function(flag_values) + if not self.checker(param): + raise _exceptions.ValidationError(self.message) + + def get_flags_names(self): + """Returns the names of the flags checked by this validator. + + Returns: + [string], names of the flags. + """ + raise NotImplementedError('This method should be overloaded') + + def print_flags_with_values(self, flag_values): + raise NotImplementedError('This method should be overloaded') + + def _get_input_to_checker_function(self, flag_values): + """Given flag values, returns the input to be given to checker. + + Args: + flag_values: flags.FlagValues, containing all flags. + Returns: + The input to be given to checker. The return type depends on the specific + validator. + """ + raise NotImplementedError('This method should be overloaded') + + +class SingleFlagValidator(Validator): + """Validator behind register_validator() method. + + Validates that a single flag passes its checker function. The checker function + takes the flag value and returns True (if value looks fine) or, if flag value + is not valid, either returns False or raises an Exception. + """ + + def __init__(self, flag_name, checker, message): + """Constructor. + + Args: + flag_name: string, name of the flag. + checker: function to verify the validator. + input - value of the corresponding flag (string, boolean, etc). + output - bool, True if validator constraint is satisfied. + If constraint is not satisfied, it should either return False or + raise flags.ValidationError(desired_error_message). + message: str, error message to be shown to the user if validator's + condition is not satisfied. + """ + super(SingleFlagValidator, self).__init__(checker, message) + self.flag_name = flag_name + + def get_flags_names(self): + return [self.flag_name] + + def print_flags_with_values(self, flag_values): + return 'flag --%s=%s' % (self.flag_name, flag_values[self.flag_name].value) + + def _get_input_to_checker_function(self, flag_values): + """Given flag values, returns the input to be given to checker. + + Args: + flag_values: flags.FlagValues, the FlagValues instance to get flags from. + Returns: + object, the input to be given to checker. + """ + return flag_values[self.flag_name].value + + +class MultiFlagsValidator(Validator): + """Validator behind register_multi_flags_validator method. + + Validates that flag values pass their common checker function. The checker + function takes flag values and returns True (if values look fine) or, + if values are not valid, either returns False or raises an Exception. + """ + + def __init__(self, flag_names, checker, message): + """Constructor. + + Args: + flag_names: [str], containing names of the flags used by checker. + checker: function to verify the validator. + input - dict, with keys() being flag_names, and value for each + key being the value of the corresponding flag (string, boolean, + etc). + output - bool, True if validator constraint is satisfied. + If constraint is not satisfied, it should either return False or + raise flags.ValidationError(desired_error_message). + message: str, error message to be shown to the user if validator's + condition is not satisfied + """ + super(MultiFlagsValidator, self).__init__(checker, message) + self.flag_names = flag_names + + def _get_input_to_checker_function(self, flag_values): + """Given flag values, returns the input to be given to checker. + + Args: + flag_values: flags.FlagValues, the FlagValues instance to get flags from. + Returns: + dict, with keys() being self.lag_names, and value for each key + being the value of the corresponding flag (string, boolean, etc). + """ + return dict([key, flag_values[key].value] for key in self.flag_names) + + def print_flags_with_values(self, flag_values): + prefix = 'flags ' + flags_with_values = [] + for key in self.flag_names: + flags_with_values.append('%s=%s' % (key, flag_values[key].value)) + return prefix + ', '.join(flags_with_values) + + def get_flags_names(self): + return self.flag_names diff --git a/absl/flags/argparse_flags.py b/absl/flags/argparse_flags.py new file mode 100644 index 0000000..4f78f50 --- /dev/null +++ b/absl/flags/argparse_flags.py @@ -0,0 +1,390 @@ +# Copyright 2018 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module provides argparse integration with absl.flags. + +argparse_flags.ArgumentParser is a drop-in replacement for +argparse.ArgumentParser. It takes care of collecting and defining absl flags +in argparse. + +Here is a simple example: + + # Assume the following absl.flags is defined in another module: + # + # from absl import flags + # flags.DEFINE_string('echo', None, 'The echo message.') + # + parser = argparse_flags.ArgumentParser( + description='A demo of absl.flags and argparse integration.') + parser.add_argument('--header', help='Header message to print.') + + # The parser will also accept the absl flag `--echo`. + # The `header` value is available as `args.header` just like a regular + # argparse flag. The absl flag `--echo` continues to be available via + # `absl.flags.FLAGS` if you want to access it. + args = parser.parse_args() + + # Example usages: + # ./program --echo='A message.' --header='A header' + # ./program --header 'A header' --echo 'A message.' + + +Here is another example demonstrates subparsers: + + parser = argparse_flags.ArgumentParser(description='A subcommands demo.') + parser.add_argument('--header', help='The header message to print.') + + subparsers = parser.add_subparsers(help='The command to execute.') + + roll_dice_parser = subparsers.add_parser( + 'roll_dice', help='Roll a dice.', + # By default, absl flags can also be specified after the sub-command. + # To only allow them before sub-command, pass + # `inherited_absl_flags=None`. + inherited_absl_flags=None) + roll_dice_parser.add_argument('--num_faces', type=int, default=6) + roll_dice_parser.set_defaults(command=roll_dice) + + shuffle_parser = subparsers.add_parser('shuffle', help='Shuffle inputs.') + shuffle_parser.add_argument( + 'inputs', metavar='I', nargs='+', help='Inputs to shuffle.') + shuffle_parser.set_defaults(command=shuffle) + + args = parser.parse_args(argv[1:]) + args.command(args) + + # Example usages: + # ./program --echo='A message.' roll_dice --num_faces=6 + # ./program shuffle --echo='A message.' 1 2 3 4 + + +There are several differences between absl.flags and argparse_flags: + +1. Flags defined with absl.flags are parsed differently when using the + argparse parser. Notably: + + 1) absl.flags allows both single-dash and double-dash for any flag, and + doesn't distinguish them; argparse_flags only allows double-dash for + flag's regular name, and single-dash for flag's `short_name`. + 2) Boolean flags in absl.flags can be specified with `--bool`, `--nobool`, + as well as `--bool=true/false` (though not recommended); + in argparse_flags, it only allows `--bool`, `--nobool`. + +2. Help related flag differences: + 1) absl.flags does not define help flags, absl.app does that; argparse_flags + defines help flags unless passed with `add_help=False`. + 2) absl.app supports `--helpxml`; argparse_flags does not. + 3) argparse_flags supports `-h`; absl.app does not. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +from absl import flags + + +_BUILT_IN_FLAGS = frozenset({ + 'help', + 'helpshort', + 'helpfull', + 'helpxml', + 'flagfile', + 'undefok', +}) + + +class ArgumentParser(argparse.ArgumentParser): + """Custom ArgumentParser class to support special absl flags.""" + + def __init__(self, **kwargs): + """Initializes ArgumentParser. + + Args: + **kwargs: same as argparse.ArgumentParser, except: + 1. It also accepts `inherited_absl_flags`: the absl flags to inherit. + The default is the global absl.flags.FLAGS instance. Pass None to + ignore absl flags. + 2. The `prefix_chars` argument must be the default value '-'. + + Raises: + ValueError: Raised when prefix_chars is not '-'. + """ + prefix_chars = kwargs.get('prefix_chars', '-') + if prefix_chars != '-': + raise ValueError( + 'argparse_flags.ArgumentParser only supports "-" as the prefix ' + 'character, found "{}".'.format(prefix_chars)) + + # Remove inherited_absl_flags before calling super. + self._inherited_absl_flags = kwargs.pop('inherited_absl_flags', flags.FLAGS) + # Now call super to initialize argparse.ArgumentParser before calling + # add_argument in _define_absl_flags. + super(ArgumentParser, self).__init__(**kwargs) + + if self.add_help: + # -h and --help are defined in super. + # Also add the --helpshort and --helpfull flags. + self.add_argument( + # Action 'help' defines a similar flag to -h/--help. + '--helpshort', action='help', + default=argparse.SUPPRESS, help=argparse.SUPPRESS) + self.add_argument( + '--helpfull', action=_HelpFullAction, + default=argparse.SUPPRESS, help='show full help message and exit') + + if self._inherited_absl_flags: + self.add_argument( + '--undefok', default=argparse.SUPPRESS, help=argparse.SUPPRESS) + self._define_absl_flags(self._inherited_absl_flags) + + def parse_known_args(self, args=None, namespace=None): + if args is None: + args = sys.argv[1:] + if self._inherited_absl_flags: + # Handle --flagfile. + # Explicitly specify force_gnu=True, since argparse behaves like + # gnu_getopt: flags can be specified after positional arguments. + args = self._inherited_absl_flags.read_flags_from_files( + args, force_gnu=True) + + undefok_missing = object() + undefok = getattr(namespace, 'undefok', undefok_missing) + + namespace, args = super(ArgumentParser, self).parse_known_args( + args, namespace) + + # For Python <= 2.7.8: https://bugs.python.org/issue9351, a bug where + # sub-parsers don't preserve existing namespace attributes. + # Restore the undefok attribute if a sub-parser dropped it. + if undefok is not undefok_missing: + namespace.undefok = undefok + + if self._inherited_absl_flags: + # Handle --undefok. At this point, `args` only contains unknown flags, + # so it won't strip defined flags that are also specified with --undefok. + # For Python <= 2.7.8: https://bugs.python.org/issue9351, a bug where + # sub-parsers don't preserve existing namespace attributes. The undefok + # attribute might not exist because a subparser dropped it. + if hasattr(namespace, 'undefok'): + args = _strip_undefok_args(namespace.undefok, args) + # absl flags are not exposed in the Namespace object. See Namespace: + # https://docs.python.org/3/library/argparse.html#argparse.Namespace. + del namespace.undefok + self._inherited_absl_flags.mark_as_parsed() + try: + self._inherited_absl_flags.validate_all_flags() + except flags.IllegalFlagValueError as e: + self.error(str(e)) + + return namespace, args + + def _define_absl_flags(self, absl_flags): + """Defines flags from absl_flags.""" + key_flags = set(absl_flags.get_key_flags_for_module(sys.argv[0])) + for name in absl_flags: + if name in _BUILT_IN_FLAGS: + # Do not inherit built-in flags. + continue + flag_instance = absl_flags[name] + # Each flags with short_name appears in FLAGS twice, so only define + # when the dictionary key is equal to the regular name. + if name == flag_instance.name: + # Suppress the flag in the help short message if it's not a main + # module's key flag. + suppress = flag_instance not in key_flags + self._define_absl_flag(flag_instance, suppress) + + def _define_absl_flag(self, flag_instance, suppress): + """Defines a flag from the flag_instance.""" + flag_name = flag_instance.name + short_name = flag_instance.short_name + argument_names = ['--' + flag_name] + if short_name: + argument_names.insert(0, '-' + short_name) + if suppress: + helptext = argparse.SUPPRESS + else: + # argparse help string uses %-formatting. Escape the literal %'s. + helptext = flag_instance.help.replace('%', '%%') + if flag_instance.boolean: + # Only add the `no` form to the long name. + argument_names.append('--no' + flag_name) + self.add_argument( + *argument_names, action=_BooleanFlagAction, help=helptext, + metavar=flag_instance.name.upper(), + flag_instance=flag_instance) + else: + self.add_argument( + *argument_names, action=_FlagAction, help=helptext, + metavar=flag_instance.name.upper(), + flag_instance=flag_instance) + + +class _FlagAction(argparse.Action): + """Action class for Abseil non-boolean flags.""" + + def __init__( + self, + option_strings, + dest, + help, # pylint: disable=redefined-builtin + metavar, + flag_instance, + default=argparse.SUPPRESS): + """Initializes _FlagAction. + + Args: + option_strings: See argparse.Action. + dest: Ignored. The flag is always defined with dest=argparse.SUPPRESS. + help: See argparse.Action. + metavar: See argparse.Action. + flag_instance: absl.flags.Flag, the absl flag instance. + default: Ignored. The flag always uses dest=argparse.SUPPRESS so it + doesn't affect the parsing result. + """ + del dest + self._flag_instance = flag_instance + super(_FlagAction, self).__init__( + option_strings=option_strings, + dest=argparse.SUPPRESS, + help=help, + metavar=metavar) + + def __call__(self, parser, namespace, values, option_string=None): + """See https://docs.python.org/3/library/argparse.html#action-classes.""" + self._flag_instance.parse(values) + self._flag_instance.using_default_value = False + + +class _BooleanFlagAction(argparse.Action): + """Action class for Abseil boolean flags.""" + + def __init__( + self, + option_strings, + dest, + help, # pylint: disable=redefined-builtin + metavar, + flag_instance, + default=argparse.SUPPRESS): + """Initializes _BooleanFlagAction. + + Args: + option_strings: See argparse.Action. + dest: Ignored. The flag is always defined with dest=argparse.SUPPRESS. + help: See argparse.Action. + metavar: See argparse.Action. + flag_instance: absl.flags.Flag, the absl flag instance. + default: Ignored. The flag always uses dest=argparse.SUPPRESS so it + doesn't affect the parsing result. + """ + del dest, default + self._flag_instance = flag_instance + flag_names = [self._flag_instance.name] + if self._flag_instance.short_name: + flag_names.append(self._flag_instance.short_name) + self._flag_names = frozenset(flag_names) + super(_BooleanFlagAction, self).__init__( + option_strings=option_strings, + dest=argparse.SUPPRESS, + nargs=0, # Does not accept values, only `--bool` or `--nobool`. + help=help, + metavar=metavar) + + def __call__(self, parser, namespace, values, option_string=None): + """See https://docs.python.org/3/library/argparse.html#action-classes.""" + if not isinstance(values, list) or values: + raise ValueError('values must be an empty list.') + if option_string.startswith('--'): + option = option_string[2:] + else: + option = option_string[1:] + if option in self._flag_names: + self._flag_instance.parse('true') + else: + if not option.startswith('no') or option[2:] not in self._flag_names: + raise ValueError('invalid option_string: ' + option_string) + self._flag_instance.parse('false') + self._flag_instance.using_default_value = False + + +class _HelpFullAction(argparse.Action): + """Action class for --helpfull flag.""" + + def __init__(self, option_strings, dest, default, help): # pylint: disable=redefined-builtin + """Initializes _HelpFullAction. + + Args: + option_strings: See argparse.Action. + dest: Ignored. The flag is always defined with dest=argparse.SUPPRESS. + default: Ignored. + help: See argparse.Action. + """ + del dest, default + super(_HelpFullAction, self).__init__( + option_strings=option_strings, + dest=argparse.SUPPRESS, + default=argparse.SUPPRESS, + nargs=0, + help=help) + + def __call__(self, parser, namespace, values, option_string=None): + """See https://docs.python.org/3/library/argparse.html#action-classes.""" + # This only prints flags when help is not argparse.SUPPRESS. + # It includes user defined argparse flags, as well as main module's + # key absl flags. Other absl flags use argparse.SUPPRESS, so they aren't + # printed here. + parser.print_help() + + absl_flags = parser._inherited_absl_flags # pylint: disable=protected-access + if absl_flags: + modules = sorted(absl_flags.flags_by_module_dict()) + main_module = sys.argv[0] + if main_module in modules: + # The main module flags are already printed in parser.print_help(). + modules.remove(main_module) + print(absl_flags._get_help_for_modules( # pylint: disable=protected-access + modules, prefix='', include_special_flags=True)) + parser.exit() + + +def _strip_undefok_args(undefok, args): + """Returns a new list of args after removing flags in --undefok.""" + if undefok: + undefok_names = set(name.strip() for name in undefok.split(',')) + undefok_names |= set('no' + name for name in undefok_names) + # Remove undefok flags. + args = [arg for arg in args if not _is_undefok(arg, undefok_names)] + return args + + +def _is_undefok(arg, undefok_names): + """Returns whether we can ignore arg based on a set of undefok flag names.""" + if not arg.startswith('-'): + return False + if arg.startswith('--'): + arg_without_dash = arg[2:] + else: + arg_without_dash = arg[1:] + if '=' in arg_without_dash: + name, _ = arg_without_dash.split('=', 1) + else: + name = arg_without_dash + if name in undefok_names: + return True + return False diff --git a/absl/flags/tests/__init__.py b/absl/flags/tests/__init__.py new file mode 100644 index 0000000..a3bd1cd --- /dev/null +++ b/absl/flags/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/absl/flags/tests/_argument_parser_test.py b/absl/flags/tests/_argument_parser_test.py new file mode 100644 index 0000000..4281c3f --- /dev/null +++ b/absl/flags/tests/_argument_parser_test.py @@ -0,0 +1,214 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Additional tests for flag argument parsers. + +Most of the argument parsers are covered in the flags_test.py. +""" + +import enum + +from absl.flags import _argument_parser +from absl.testing import absltest +from absl.testing import parameterized + + +class ArgumentParserTest(absltest.TestCase): + + def test_instance_cache(self): + parser1 = _argument_parser.FloatParser() + parser2 = _argument_parser.FloatParser() + self.assertIs(parser1, parser2) + + def test_parse_wrong_type(self): + parser = _argument_parser.ArgumentParser() + with self.assertRaises(TypeError): + parser.parse(0) + + if bytes is not str: + # In PY3, it does not accept bytes. + with self.assertRaises(TypeError): + parser.parse(b'') + + +class BooleanParserTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.parser = _argument_parser.BooleanParser() + + def test_parse_bytes(self): + with self.assertRaises(TypeError): + self.parser.parse(b'true') + + def test_parse_str(self): + self.assertTrue(self.parser.parse('true')) + + def test_parse_unicode(self): + self.assertTrue(self.parser.parse(u'true')) + + def test_parse_wrong_type(self): + with self.assertRaises(TypeError): + self.parser.parse(1.234) + + def test_parse_str_false(self): + self.assertFalse(self.parser.parse('false')) + + def test_parse_integer(self): + self.assertTrue(self.parser.parse(1)) + + def test_parse_invalid_integer(self): + with self.assertRaises(ValueError): + self.parser.parse(-1) + + def test_parse_invalid_str(self): + with self.assertRaises(ValueError): + self.parser.parse('nottrue') + + +class FloatParserTest(absltest.TestCase): + + def setUp(self): + self.parser = _argument_parser.FloatParser() + + def test_parse_string(self): + self.assertEqual(1.5, self.parser.parse('1.5')) + + def test_parse_wrong_type(self): + with self.assertRaises(TypeError): + self.parser.parse(False) + + +class IntegerParserTest(absltest.TestCase): + + def setUp(self): + self.parser = _argument_parser.IntegerParser() + + def test_parse_string(self): + self.assertEqual(1, self.parser.parse('1')) + + def test_parse_wrong_type(self): + with self.assertRaises(TypeError): + self.parser.parse(1e2) + with self.assertRaises(TypeError): + self.parser.parse(False) + + +class EnumParserTest(absltest.TestCase): + + def test_empty_values(self): + with self.assertRaises(ValueError): + _argument_parser.EnumParser([]) + + def test_parse(self): + parser = _argument_parser.EnumParser(['apple', 'banana']) + self.assertEqual('apple', parser.parse('apple')) + + def test_parse_not_found(self): + parser = _argument_parser.EnumParser(['apple', 'banana']) + with self.assertRaises(ValueError): + parser.parse('orange') + + +class Fruit(enum.Enum): + APPLE = 1 + BANANA = 2 + + +class EmptyEnum(enum.Enum): + pass + + +class MixedCaseEnum(enum.Enum): + APPLE = 1 + BANANA = 2 + apple = 3 + + +class EnumClassParserTest(parameterized.TestCase): + + def test_requires_enum(self): + with self.assertRaises(TypeError): + _argument_parser.EnumClassParser(['apple', 'banana']) + + def test_requires_non_empty_enum_class(self): + with self.assertRaises(ValueError): + _argument_parser.EnumClassParser(EmptyEnum) + + def test_case_sensitive_rejects_duplicates(self): + unused_normal_parser = _argument_parser.EnumClassParser(MixedCaseEnum) + with self.assertRaisesRegex(ValueError, 'Duplicate.+apple'): + _argument_parser.EnumClassParser(MixedCaseEnum, case_sensitive=False) + + def test_parse_string(self): + parser = _argument_parser.EnumClassParser(Fruit) + self.assertEqual(Fruit.APPLE, parser.parse('APPLE')) + + def test_parse_string_case_sensitive(self): + parser = _argument_parser.EnumClassParser(Fruit) + with self.assertRaises(ValueError): + parser.parse('apple') + + @parameterized.parameters('APPLE', 'apple', 'Apple') + def test_parse_string_case_insensitive(self, value): + parser = _argument_parser.EnumClassParser(Fruit, case_sensitive=False) + self.assertIs(Fruit.APPLE, parser.parse(value)) + + def test_parse_literal(self): + parser = _argument_parser.EnumClassParser(Fruit) + self.assertEqual(Fruit.APPLE, parser.parse(Fruit.APPLE)) + + def test_parse_not_found(self): + parser = _argument_parser.EnumClassParser(Fruit) + with self.assertRaises(ValueError): + parser.parse('ORANGE') + + @parameterized.parameters((Fruit.BANANA, False, 'BANANA'), + (Fruit.BANANA, True, 'banana')) + def test_serialize_parse(self, value, lowercase, expected): + serializer = _argument_parser.EnumClassSerializer(lowercase=lowercase) + parser = _argument_parser.EnumClassParser( + Fruit, case_sensitive=not lowercase) + serialized = serializer.serialize(value) + self.assertEqual(serialized, expected) + self.assertEqual(value, parser.parse(expected)) + + +class SerializerTest(parameterized.TestCase): + + def test_csv_serializer(self): + serializer = _argument_parser.CsvListSerializer('+') + self.assertEqual(serializer.serialize(['foo', 'bar']), 'foo+bar') + + @parameterized.parameters([ + dict(lowercase=False, expected='APPLE+BANANA'), + dict(lowercase=True, expected='apple+banana'), + ]) + def test_enum_class_list_serializer(self, lowercase, expected): + values = [Fruit.APPLE, Fruit.BANANA] + serializer = _argument_parser.EnumClassListSerializer( + list_sep='+', lowercase=lowercase) + serialized = serializer.serialize(values) + self.assertEqual(expected, serialized) + + +class HelperFunctionsTest(absltest.TestCase): + + def test_is_integer_type(self): + self.assertTrue(_argument_parser._is_integer_type(1)) + # Note that isinstance(False, int) == True. + self.assertFalse(_argument_parser._is_integer_type(False)) + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/flags/tests/_flag_test.py b/absl/flags/tests/_flag_test.py new file mode 100644 index 0000000..492f117 --- /dev/null +++ b/absl/flags/tests/_flag_test.py @@ -0,0 +1,240 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Additional tests for Flag classes. + +Most of the Flag classes are covered in the flags_test.py. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import enum +import pickle + +from absl.flags import _argument_parser +from absl.flags import _exceptions +from absl.flags import _flag +from absl.testing import absltest +from absl.testing import parameterized + + +class FlagTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.flag = _flag.Flag( + _argument_parser.ArgumentParser(), + _argument_parser.ArgumentSerializer(), + 'fruit', 'apple', 'help') + + def test_default_unparsed(self): + flag = _flag.Flag( + _argument_parser.ArgumentParser(), + _argument_parser.ArgumentSerializer(), + 'fruit', 'apple', 'help') + self.assertEqual('apple', flag.default_unparsed) + + flag = _flag.Flag( + _argument_parser.IntegerParser(), + _argument_parser.ArgumentSerializer(), + 'number', '1', 'help') + self.assertEqual('1', flag.default_unparsed) + + flag = _flag.Flag( + _argument_parser.IntegerParser(), + _argument_parser.ArgumentSerializer(), + 'number', 1, 'help') + self.assertEqual(1, flag.default_unparsed) + + def test_no_truthiness(self): + with self.assertRaises(TypeError): + if self.flag: + self.fail('Flag instances must raise rather than be truthy.') + + def test_set_default_overrides_current_value(self): + self.assertEqual('apple', self.flag.value) + self.flag._set_default('orange') + self.assertEqual('orange', self.flag.value) + + def test_set_default_overrides_current_value_when_not_using_default(self): + self.flag.using_default_value = False + self.assertEqual('apple', self.flag.value) + self.flag._set_default('orange') + self.assertEqual('apple', self.flag.value) + + def test_pickle(self): + with self.assertRaisesRegex(TypeError, "can't pickle Flag objects"): + pickle.dumps(self.flag) + + def test_copy(self): + self.flag.value = 'orange' + + with self.assertRaisesRegex(TypeError, + 'Flag does not support shallow copies'): + copy.copy(self.flag) + + flag2 = copy.deepcopy(self.flag) + self.assertEqual(flag2.value, 'orange') + + flag2.value = 'mango' + self.assertEqual(flag2.value, 'mango') + self.assertEqual(self.flag.value, 'orange') + + +class BooleanFlagTest(parameterized.TestCase): + + @parameterized.parameters(('', '(no help available)'), + ('Is my test brilliant?', 'Is my test brilliant?')) + def test_help_text(self, helptext_input, helptext_output): + f = _flag.BooleanFlag('a_bool', False, helptext_input) + self.assertEqual(helptext_output, f.help) + + +class EnumFlagTest(parameterized.TestCase): + + @parameterized.parameters( + ('', '<apple|orange>: (no help available)'), + ('Type of fruit.', '<apple|orange>: Type of fruit.')) + def test_help_text(self, helptext_input, helptext_output): + f = _flag.EnumFlag('fruit', 'apple', helptext_input, ['apple', 'orange']) + self.assertEqual(helptext_output, f.help) + + def test_empty_values(self): + with self.assertRaises(ValueError): + _flag.EnumFlag('fruit', None, 'help', []) + + +class Fruit(enum.Enum): + APPLE = 1 + ORANGE = 2 + + +class EmptyEnum(enum.Enum): + pass + + +class EnumClassFlagTest(parameterized.TestCase): + + @parameterized.parameters( + ('', '<apple|orange>: (no help available)'), + ('Type of fruit.', '<apple|orange>: Type of fruit.')) + def test_help_text_case_insensitive(self, helptext_input, helptext_output): + f = _flag.EnumClassFlag('fruit', None, helptext_input, Fruit) + self.assertEqual(helptext_output, f.help) + + @parameterized.parameters( + ('', '<APPLE|ORANGE>: (no help available)'), + ('Type of fruit.', '<APPLE|ORANGE>: Type of fruit.')) + def test_help_text_case_sensitive(self, helptext_input, helptext_output): + f = _flag.EnumClassFlag( + 'fruit', None, helptext_input, Fruit, case_sensitive=True) + self.assertEqual(helptext_output, f.help) + + def test_requires_enum(self): + with self.assertRaises(TypeError): + _flag.EnumClassFlag('fruit', None, 'help', ['apple', 'orange']) + + def test_requires_non_empty_enum_class(self): + with self.assertRaises(ValueError): + _flag.EnumClassFlag('empty', None, 'help', EmptyEnum) + + def test_accepts_literal_default(self): + f = _flag.EnumClassFlag('fruit', Fruit.APPLE, 'A sample enum flag.', Fruit) + self.assertEqual(Fruit.APPLE, f.value) + + def test_accepts_string_default(self): + f = _flag.EnumClassFlag('fruit', 'ORANGE', 'A sample enum flag.', Fruit) + self.assertEqual(Fruit.ORANGE, f.value) + + def test_case_sensitive_rejects_default_with_wrong_case(self): + with self.assertRaises(_exceptions.IllegalFlagValueError): + _flag.EnumClassFlag( + 'fruit', 'oranGe', 'A sample enum flag.', Fruit, case_sensitive=True) + + def test_case_insensitive_accepts_string_default(self): + f = _flag.EnumClassFlag( + 'fruit', 'oranGe', 'A sample enum flag.', Fruit, case_sensitive=False) + self.assertEqual(Fruit.ORANGE, f.value) + + def test_default_value_does_not_exist(self): + with self.assertRaises(_exceptions.IllegalFlagValueError): + _flag.EnumClassFlag('fruit', 'BANANA', 'help', Fruit) + + +class MultiEnumClassFlagTest(parameterized.TestCase): + + @parameterized.named_parameters( + ('NoHelpSupplied', '', '<apple|orange>: (no help available);\n ' + + 'repeat this option to specify a list of values', False), + ('WithHelpSupplied', 'Type of fruit.', + '<APPLE|ORANGE>: Type of fruit.;\n ' + + 'repeat this option to specify a list of values', True)) + def test_help_text(self, helptext_input, helptext_output, case_sensitive): + f = _flag.MultiEnumClassFlag( + 'fruit', None, helptext_input, Fruit, case_sensitive=case_sensitive) + self.assertEqual(helptext_output, f.help) + + def test_requires_enum(self): + with self.assertRaises(TypeError): + _flag.MultiEnumClassFlag('fruit', None, 'help', ['apple', 'orange']) + + def test_requires_non_empty_enum_class(self): + with self.assertRaises(ValueError): + _flag.MultiEnumClassFlag('empty', None, 'help', EmptyEnum) + + def test_rejects_wrong_case_when_case_sensitive(self): + with self.assertRaisesRegex(_exceptions.IllegalFlagValueError, + '<APPLE|ORANGE>'): + _flag.MultiEnumClassFlag( + 'fruit', ['APPLE', 'Orange'], + 'A sample enum flag.', + Fruit, + case_sensitive=True) + + def test_accepts_case_insensitive(self): + f = _flag.MultiEnumClassFlag('fruit', ['apple', 'APPLE'], + 'A sample enum flag.', Fruit) + self.assertListEqual([Fruit.APPLE, Fruit.APPLE], f.value) + + def test_accepts_literal_default(self): + f = _flag.MultiEnumClassFlag('fruit', Fruit.APPLE, 'A sample enum flag.', + Fruit) + self.assertListEqual([Fruit.APPLE], f.value) + + def test_accepts_list_of_literal_default(self): + f = _flag.MultiEnumClassFlag('fruit', [Fruit.APPLE, Fruit.ORANGE], + 'A sample enum flag.', Fruit) + self.assertListEqual([Fruit.APPLE, Fruit.ORANGE], f.value) + + def test_accepts_string_default(self): + f = _flag.MultiEnumClassFlag('fruit', 'ORANGE', 'A sample enum flag.', + Fruit) + self.assertListEqual([Fruit.ORANGE], f.value) + + def test_accepts_list_of_string_default(self): + f = _flag.MultiEnumClassFlag('fruit', ['ORANGE', 'APPLE'], + 'A sample enum flag.', Fruit) + self.assertListEqual([Fruit.ORANGE, Fruit.APPLE], f.value) + + def test_default_value_does_not_exist(self): + with self.assertRaisesRegex(_exceptions.IllegalFlagValueError, + '<apple|banana>'): + _flag.MultiEnumClassFlag('fruit', 'BANANA', 'help', Fruit) + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/flags/tests/_flagvalues_test.py b/absl/flags/tests/_flagvalues_test.py new file mode 100644 index 0000000..46639f2 --- /dev/null +++ b/absl/flags/tests/_flagvalues_test.py @@ -0,0 +1,929 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for flags.FlagValues class.""" + +import collections +import copy +import pickle +import types +from unittest import mock + +from absl import logging +from absl.flags import _defines +from absl.flags import _exceptions +from absl.flags import _flagvalues +from absl.flags import _helpers +from absl.flags import _validators +from absl.flags.tests import module_foo +from absl.testing import absltest +from absl.testing import parameterized + + +class FlagValuesTest(absltest.TestCase): + + def test_bool_flags(self): + for arg, expected in (('--nothing', True), + ('--nothing=true', True), + ('--nothing=false', False), + ('--nonothing', False)): + fv = _flagvalues.FlagValues() + _defines.DEFINE_boolean('nothing', None, '', flag_values=fv) + fv(('./program', arg)) + self.assertIs(expected, fv.nothing) + + for arg in ('--nonothing=true', '--nonothing=false'): + fv = _flagvalues.FlagValues() + _defines.DEFINE_boolean('nothing', None, '', flag_values=fv) + with self.assertRaises(ValueError): + fv(('./program', arg)) + + def test_boolean_flag_parser_gets_string_argument(self): + for arg, expected in (('--nothing', 'true'), + ('--nothing=true', 'true'), + ('--nothing=false', 'false'), + ('--nonothing', 'false')): + fv = _flagvalues.FlagValues() + _defines.DEFINE_boolean('nothing', None, '', flag_values=fv) + with mock.patch.object(fv['nothing'].parser, 'parse') as mock_parse: + fv(('./program', arg)) + mock_parse.assert_called_once_with(expected) + + def test_unregistered_flags_are_cleaned_up(self): + fv = _flagvalues.FlagValues() + module, module_name = _helpers.get_calling_module_object_and_name() + + # Define first flag. + _defines.DEFINE_integer('cores', 4, '', flag_values=fv, short_name='c') + old_cores_flag = fv['cores'] + fv.register_key_flag_for_module(module_name, old_cores_flag) + self.assertEqual(fv.flags_by_module_dict(), + {module_name: [old_cores_flag]}) + self.assertEqual(fv.flags_by_module_id_dict(), + {id(module): [old_cores_flag]}) + self.assertEqual(fv.key_flags_by_module_dict(), + {module_name: [old_cores_flag]}) + + # Redefine the same flag. + _defines.DEFINE_integer( + 'cores', 4, '', flag_values=fv, short_name='c', allow_override=True) + new_cores_flag = fv['cores'] + self.assertNotEqual(old_cores_flag, new_cores_flag) + self.assertEqual(fv.flags_by_module_dict(), + {module_name: [new_cores_flag]}) + self.assertEqual(fv.flags_by_module_id_dict(), + {id(module): [new_cores_flag]}) + # old_cores_flag is removed from key flags, and the new_cores_flag is + # not automatically added because it must be registered explicitly. + self.assertEqual(fv.key_flags_by_module_dict(), {module_name: []}) + + # Define a new flag but with the same short_name. + _defines.DEFINE_integer( + 'changelist', + 0, + '', + flag_values=fv, + short_name='c', + allow_override=True) + old_changelist_flag = fv['changelist'] + fv.register_key_flag_for_module(module_name, old_changelist_flag) + # The short named flag -c is overridden to be the old_changelist_flag. + self.assertEqual(fv['c'], old_changelist_flag) + self.assertNotEqual(fv['c'], new_cores_flag) + self.assertEqual(fv.flags_by_module_dict(), + {module_name: [new_cores_flag, old_changelist_flag]}) + self.assertEqual(fv.flags_by_module_id_dict(), + {id(module): [new_cores_flag, old_changelist_flag]}) + self.assertEqual(fv.key_flags_by_module_dict(), + {module_name: [old_changelist_flag]}) + + # Define a flag only with the same long name. + _defines.DEFINE_integer( + 'changelist', + 0, + '', + flag_values=fv, + short_name='l', + allow_override=True) + new_changelist_flag = fv['changelist'] + self.assertNotEqual(old_changelist_flag, new_changelist_flag) + self.assertEqual(fv.flags_by_module_dict(), + {module_name: [new_cores_flag, + old_changelist_flag, + new_changelist_flag]}) + self.assertEqual(fv.flags_by_module_id_dict(), + {id(module): [new_cores_flag, + old_changelist_flag, + new_changelist_flag]}) + self.assertEqual(fv.key_flags_by_module_dict(), + {module_name: [old_changelist_flag]}) + + # Delete the new changelist's long name, it should still be registered + # because of its short name. + del fv.changelist + self.assertNotIn('changelist', fv) + self.assertEqual(fv.flags_by_module_dict(), + {module_name: [new_cores_flag, + old_changelist_flag, + new_changelist_flag]}) + self.assertEqual(fv.flags_by_module_id_dict(), + {id(module): [new_cores_flag, + old_changelist_flag, + new_changelist_flag]}) + self.assertEqual(fv.key_flags_by_module_dict(), + {module_name: [old_changelist_flag]}) + + # Delete the new changelist's short name, it should be removed. + del fv.l + self.assertNotIn('l', fv) + self.assertEqual(fv.flags_by_module_dict(), + {module_name: [new_cores_flag, + old_changelist_flag]}) + self.assertEqual(fv.flags_by_module_id_dict(), + {id(module): [new_cores_flag, + old_changelist_flag]}) + self.assertEqual(fv.key_flags_by_module_dict(), + {module_name: [old_changelist_flag]}) + + def _test_find_module_or_id_defining_flag(self, test_id): + """Tests for find_module_defining_flag and find_module_id_defining_flag. + + Args: + test_id: True to test find_module_id_defining_flag, False to test + find_module_defining_flag. + """ + fv = _flagvalues.FlagValues() + current_module, current_module_name = ( + _helpers.get_calling_module_object_and_name()) + alt_module_name = _flagvalues.__name__ + + if test_id: + current_module_or_id = id(current_module) + alt_module_or_id = id(_flagvalues) + testing_fn = fv.find_module_id_defining_flag + else: + current_module_or_id = current_module_name + alt_module_or_id = alt_module_name + testing_fn = fv.find_module_defining_flag + + # Define first flag. + _defines.DEFINE_integer('cores', 4, '', flag_values=fv, short_name='c') + module_or_id_cores = testing_fn('cores') + self.assertEqual(module_or_id_cores, current_module_or_id) + module_or_id_c = testing_fn('c') + self.assertEqual(module_or_id_c, current_module_or_id) + + # Redefine the same flag in another module. + _defines.DEFINE_integer( + 'cores', + 4, + '', + flag_values=fv, + module_name=alt_module_name, + short_name='c', + allow_override=True) + module_or_id_cores = testing_fn('cores') + self.assertEqual(module_or_id_cores, alt_module_or_id) + module_or_id_c = testing_fn('c') + self.assertEqual(module_or_id_c, alt_module_or_id) + + # Define a new flag but with the same short_name. + _defines.DEFINE_integer( + 'changelist', + 0, + '', + flag_values=fv, + short_name='c', + allow_override=True) + module_or_id_cores = testing_fn('cores') + self.assertEqual(module_or_id_cores, alt_module_or_id) + module_or_id_changelist = testing_fn('changelist') + self.assertEqual(module_or_id_changelist, current_module_or_id) + module_or_id_c = testing_fn('c') + self.assertEqual(module_or_id_c, current_module_or_id) + + # Define a flag in another module only with the same long name. + _defines.DEFINE_integer( + 'changelist', + 0, + '', + flag_values=fv, + module_name=alt_module_name, + short_name='l', + allow_override=True) + module_or_id_cores = testing_fn('cores') + self.assertEqual(module_or_id_cores, alt_module_or_id) + module_or_id_changelist = testing_fn('changelist') + self.assertEqual(module_or_id_changelist, alt_module_or_id) + module_or_id_c = testing_fn('c') + self.assertEqual(module_or_id_c, current_module_or_id) + module_or_id_l = testing_fn('l') + self.assertEqual(module_or_id_l, alt_module_or_id) + + # Delete the changelist flag, its short name should still be registered. + del fv.changelist + module_or_id_changelist = testing_fn('changelist') + self.assertIsNone(module_or_id_changelist) + module_or_id_c = testing_fn('c') + self.assertEqual(module_or_id_c, current_module_or_id) + module_or_id_l = testing_fn('l') + self.assertEqual(module_or_id_l, alt_module_or_id) + + def test_find_module_defining_flag(self): + self._test_find_module_or_id_defining_flag(test_id=False) + + def test_find_module_id_defining_flag(self): + self._test_find_module_or_id_defining_flag(test_id=True) + + def test_set_default(self): + fv = _flagvalues.FlagValues() + fv.mark_as_parsed() + with self.assertRaises(_exceptions.UnrecognizedFlagError): + fv.set_default('changelist', 1) + _defines.DEFINE_integer('changelist', 0, 'help', flag_values=fv) + self.assertEqual(0, fv.changelist) + fv.set_default('changelist', 2) + self.assertEqual(2, fv.changelist) + + def test_default_gnu_getopt_value(self): + self.assertTrue(_flagvalues.FlagValues().is_gnu_getopt()) + + def test_known_only_flags_in_gnustyle(self): + + def run_test(argv, defined_py_flags, expected_argv): + fv = _flagvalues.FlagValues() + fv.set_gnu_getopt(True) + for f in defined_py_flags: + if f.startswith('b'): + _defines.DEFINE_boolean(f, False, 'help', flag_values=fv) + else: + _defines.DEFINE_string(f, 'default', 'help', flag_values=fv) + output_argv = fv(argv, known_only=True) + self.assertEqual(expected_argv, output_argv) + + run_test( + argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '), + defined_py_flags=[], + expected_argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' ')) + run_test( + argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '), + defined_py_flags=['f1'], + expected_argv='0 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' ')) + run_test( + argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '), + defined_py_flags=['f2'], + expected_argv='0 --f1=v1 cmd --b1 --f3 v3 --nob2'.split(' ')) + run_test( + argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '), + defined_py_flags=['b1'], + expected_argv='0 --f1=v1 cmd --f2 v2 --f3 v3 --nob2'.split(' ')) + run_test( + argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '), + defined_py_flags=['f3'], + expected_argv='0 --f1=v1 cmd --f2 v2 --b1 --nob2'.split(' ')) + run_test( + argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '), + defined_py_flags=['b2'], + expected_argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3'.split(' ')) + run_test( + argv=('0 --f1=v1 cmd --undefok=f1 --f2 v2 --b1 ' + '--f3 v3 --nob2').split(' '), + defined_py_flags=['b2'], + expected_argv='0 cmd --f2 v2 --b1 --f3 v3'.split(' ')) + run_test( + argv=('0 --f1=v1 cmd --undefok f1,f2 --f2 v2 --b1 ' + '--f3 v3 --nob2').split(' '), + defined_py_flags=['b2'], + # Note v2 is preserved here, since undefok requires the flag being + # specified in the form of --flag=value. + expected_argv='0 cmd v2 --b1 --f3 v3'.split(' ')) + + def test_invalid_flag_name(self): + with self.assertRaises(_exceptions.Error): + _defines.DEFINE_boolean('test ', 0, '') + + with self.assertRaises(_exceptions.Error): + _defines.DEFINE_boolean(' test', 0, '') + + with self.assertRaises(_exceptions.Error): + _defines.DEFINE_boolean('te st', 0, '') + + with self.assertRaises(_exceptions.Error): + _defines.DEFINE_boolean('', 0, '') + + with self.assertRaises(_exceptions.Error): + _defines.DEFINE_boolean(1, 0, '') + + def test_len(self): + fv = _flagvalues.FlagValues() + self.assertEmpty(fv) + self.assertFalse(fv) + + _defines.DEFINE_boolean('boolean', False, 'help', flag_values=fv) + self.assertLen(fv, 1) + self.assertTrue(fv) + + _defines.DEFINE_boolean( + 'bool', False, 'help', short_name='b', flag_values=fv) + self.assertLen(fv, 3) + self.assertTrue(fv) + + def test_pickle(self): + fv = _flagvalues.FlagValues() + with self.assertRaisesRegex(TypeError, "can't pickle FlagValues"): + pickle.dumps(fv) + + def test_copy(self): + fv = _flagvalues.FlagValues() + _defines.DEFINE_integer('answer', 0, 'help', flag_values=fv) + fv(['', '--answer=1']) + + with self.assertRaisesRegex(TypeError, + 'FlagValues does not support shallow copies'): + copy.copy(fv) + + fv2 = copy.deepcopy(fv) + self.assertEqual(fv2.answer, 1) + + fv2.answer = 42 + self.assertEqual(fv2.answer, 42) + self.assertEqual(fv.answer, 1) + + def test_conflicting_flags(self): + fv = _flagvalues.FlagValues() + with self.assertRaises(_exceptions.FlagNameConflictsWithMethodError): + _defines.DEFINE_boolean('is_gnu_getopt', False, 'help', flag_values=fv) + _defines.DEFINE_boolean( + 'is_gnu_getopt', + False, + 'help', + flag_values=fv, + allow_using_method_names=True) + self.assertFalse(fv['is_gnu_getopt'].value) + self.assertIsInstance(fv.is_gnu_getopt, types.MethodType) + + def test_get_flags_for_module(self): + fv = _flagvalues.FlagValues() + _defines.DEFINE_string('foo', None, 'help', flag_values=fv) + module_foo.define_flags(fv) + flags = fv.get_flags_for_module('__main__') + + self.assertEqual({'foo'}, {flag.name for flag in flags}) + + flags = fv.get_flags_for_module(module_foo) + self.assertEqual({'tmod_foo_bool', 'tmod_foo_int', 'tmod_foo_str'}, + {flag.name for flag in flags}) + + def test_get_help(self): + fv = _flagvalues.FlagValues() + self.assertMultiLineEqual('''\ +--flagfile: Insert flag definitions from the given file into the command line. + (default: '') +--undefok: comma-separated list of flag names that it is okay to specify on the + command line even if the program does not define a flag with that name. + IMPORTANT: flags in this list that have arguments MUST use the --flag=value + format. + (default: '')''', fv.get_help()) + + module_foo.define_flags(fv) + self.assertMultiLineEqual(''' +absl.flags.tests.module_bar: + --tmod_bar_t: Sample int flag. + (default: '4') + (an integer) + --tmod_bar_u: Sample int flag. + (default: '5') + (an integer) + --tmod_bar_v: Sample int flag. + (default: '6') + (an integer) + --[no]tmod_bar_x: Boolean flag. + (default: 'true') + --tmod_bar_y: String flag. + (default: 'default') + --[no]tmod_bar_z: Another boolean flag from module bar. + (default: 'false') + +absl.flags.tests.module_foo: + --[no]tmod_foo_bool: Boolean flag from module foo. + (default: 'true') + --tmod_foo_int: Sample int flag. + (default: '3') + (an integer) + --tmod_foo_str: String flag. + (default: 'default') + +absl.flags: + --flagfile: Insert flag definitions from the given file into the command line. + (default: '') + --undefok: comma-separated list of flag names that it is okay to specify on + the command line even if the program does not define a flag with that name. + IMPORTANT: flags in this list that have arguments MUST use the --flag=value + format. + (default: '')''', fv.get_help()) + + self.assertMultiLineEqual(''' +xxxxabsl.flags.tests.module_bar: +xxxx --tmod_bar_t: Sample int flag. +xxxx (default: '4') +xxxx (an integer) +xxxx --tmod_bar_u: Sample int flag. +xxxx (default: '5') +xxxx (an integer) +xxxx --tmod_bar_v: Sample int flag. +xxxx (default: '6') +xxxx (an integer) +xxxx --[no]tmod_bar_x: Boolean flag. +xxxx (default: 'true') +xxxx --tmod_bar_y: String flag. +xxxx (default: 'default') +xxxx --[no]tmod_bar_z: Another boolean flag from module bar. +xxxx (default: 'false') + +xxxxabsl.flags.tests.module_foo: +xxxx --[no]tmod_foo_bool: Boolean flag from module foo. +xxxx (default: 'true') +xxxx --tmod_foo_int: Sample int flag. +xxxx (default: '3') +xxxx (an integer) +xxxx --tmod_foo_str: String flag. +xxxx (default: 'default') + +xxxxabsl.flags: +xxxx --flagfile: Insert flag definitions from the given file into the command +xxxx line. +xxxx (default: '') +xxxx --undefok: comma-separated list of flag names that it is okay to specify +xxxx on the command line even if the program does not define a flag with that +xxxx name. IMPORTANT: flags in this list that have arguments MUST use the +xxxx --flag=value format. +xxxx (default: '')''', fv.get_help(prefix='xxxx')) + + self.assertMultiLineEqual(''' +absl.flags.tests.module_bar: + --tmod_bar_t: Sample int flag. + (default: '4') + (an integer) + --tmod_bar_u: Sample int flag. + (default: '5') + (an integer) + --tmod_bar_v: Sample int flag. + (default: '6') + (an integer) + --[no]tmod_bar_x: Boolean flag. + (default: 'true') + --tmod_bar_y: String flag. + (default: 'default') + --[no]tmod_bar_z: Another boolean flag from module bar. + (default: 'false') + +absl.flags.tests.module_foo: + --[no]tmod_foo_bool: Boolean flag from module foo. + (default: 'true') + --tmod_foo_int: Sample int flag. + (default: '3') + (an integer) + --tmod_foo_str: String flag. + (default: 'default')''', fv.get_help(include_special_flags=False)) + + def test_str(self): + fv = _flagvalues.FlagValues() + self.assertEqual(str(fv), fv.get_help()) + module_foo.define_flags(fv) + self.assertEqual(str(fv), fv.get_help()) + + def test_empty_argv(self): + fv = _flagvalues.FlagValues() + with self.assertRaises(ValueError): + fv([]) + + def test_invalid_argv(self): + fv = _flagvalues.FlagValues() + with self.assertRaises(TypeError): + fv('./program') + with self.assertRaises(TypeError): + fv(b'./program') + with self.assertRaises(TypeError): + fv(u'./program') + + def test_flags_dir(self): + flag_values = _flagvalues.FlagValues() + flag_name1 = 'bool_flag' + flag_name2 = 'string_flag' + flag_name3 = 'float_flag' + description = 'Description' + _defines.DEFINE_boolean( + flag_name1, None, description, flag_values=flag_values) + _defines.DEFINE_string( + flag_name2, None, description, flag_values=flag_values) + self.assertEqual(sorted([flag_name1, flag_name2]), dir(flag_values)) + + _defines.DEFINE_float( + flag_name3, None, description, flag_values=flag_values) + self.assertEqual( + sorted([flag_name1, flag_name2, flag_name3]), dir(flag_values)) + + def test_flags_into_string_deterministic(self): + flag_values = _flagvalues.FlagValues() + _defines.DEFINE_string( + 'fa', 'x', '', flag_values=flag_values, module_name='mb') + _defines.DEFINE_string( + 'fb', 'x', '', flag_values=flag_values, module_name='mb') + _defines.DEFINE_string( + 'fc', 'x', '', flag_values=flag_values, module_name='ma') + _defines.DEFINE_string( + 'fd', 'x', '', flag_values=flag_values, module_name='ma') + + expected = ('--fc=x\n' + '--fd=x\n' + '--fa=x\n' + '--fb=x\n') + + flags_by_module_items = sorted( + flag_values.flags_by_module_dict().items(), reverse=True) + for _, module_flags in flags_by_module_items: + module_flags.sort(reverse=True) + + flag_values.__dict__['__flags_by_module'] = collections.OrderedDict( + flags_by_module_items) + + actual = flag_values.flags_into_string() + self.assertEqual(expected, actual) + + def test_validate_all_flags(self): + fv = _flagvalues.FlagValues() + _defines.DEFINE_string('name', None, '', flag_values=fv) + _validators.mark_flag_as_required('name', flag_values=fv) + with self.assertRaises(_exceptions.IllegalFlagValueError): + fv.validate_all_flags() + fv.name = 'test' + fv.validate_all_flags() + + +class FlagValuesLoggingTest(absltest.TestCase): + """Test to make sure logging.* functions won't recurse. + + Logging may and does happen before flags initialization. We need to make + sure that any warnings trown by flagvalues do not result in unlimited + recursion. + """ + + def test_logging_do_not_recurse(self): + logging.info('test info') + try: + raise ValueError('test exception') + except ValueError: + logging.exception('test message') + + +class FlagSubstrMatchingTests(parameterized.TestCase): + """Tests related to flag substring matching.""" + + def _get_test_flag_values(self): + """Get a _flagvalues.FlagValues() instance, set up for tests.""" + flag_values = _flagvalues.FlagValues() + + _defines.DEFINE_string('strf', '', '', flag_values=flag_values) + _defines.DEFINE_boolean('boolf', 0, '', flag_values=flag_values) + + return flag_values + + # Test cases that should always make parsing raise an error. + # Tuples of strings with the argv to use. + FAIL_TEST_CASES = [ + ('./program', '--boo', '0'), + ('./program', '--boo=true', '0'), + ('./program', '--boo=0'), + ('./program', '--noboo'), + ('./program', '--st=blah'), + ('./program', '--st=de'), + ('./program', '--st=blah', '--boo'), + ('./program', '--st=blah', 'unused'), + ('./program', '--st=--blah'), + ('./program', '--st', '--blah'), + ] + + @parameterized.parameters(FAIL_TEST_CASES) + def test_raise(self, *argv): + """Test that raising works.""" + fv = self._get_test_flag_values() + with self.assertRaises(_exceptions.UnrecognizedFlagError): + fv(argv) + + @parameterized.parameters( + FAIL_TEST_CASES + [('./program', 'unused', '--st=blah')]) + def test_gnu_getopt_raise(self, *argv): + """Test that raising works when combined with GNU-style getopt.""" + fv = self._get_test_flag_values() + fv.set_gnu_getopt() + with self.assertRaises(_exceptions.UnrecognizedFlagError): + fv(argv) + + +class SettingUnknownFlagTest(absltest.TestCase): + + def setUp(self): + super(SettingUnknownFlagTest, self).setUp() + self.setter_called = 0 + + def set_undef(self, unused_name, unused_val): + self.setter_called += 1 + + def test_raise_on_undefined(self): + new_flags = _flagvalues.FlagValues() + with self.assertRaises(_exceptions.UnrecognizedFlagError): + new_flags.undefined_flag = 0 + + def test_not_raise(self): + new_flags = _flagvalues.FlagValues() + new_flags._register_unknown_flag_setter(self.set_undef) + new_flags.undefined_flag = 0 + self.assertEqual(self.setter_called, 1) + + def test_not_raise_on_undefined_if_undefok(self): + new_flags = _flagvalues.FlagValues() + args = ['0', '--foo', '--bar=1', '--undefok=foo,bar'] + unparsed = new_flags(args, known_only=True) + self.assertEqual(['0'], unparsed) + + def test_re_raise_undefined(self): + def setter(unused_name, unused_val): + raise NameError() + new_flags = _flagvalues.FlagValues() + new_flags._register_unknown_flag_setter(setter) + with self.assertRaises(_exceptions.UnrecognizedFlagError): + new_flags.undefined_flag = 0 + + def test_re_raise_invalid(self): + def setter(unused_name, unused_val): + raise ValueError() + new_flags = _flagvalues.FlagValues() + new_flags._register_unknown_flag_setter(setter) + with self.assertRaises(_exceptions.IllegalFlagValueError): + new_flags.undefined_flag = 0 + + +class SetAttributesTest(absltest.TestCase): + + def setUp(self): + super(SetAttributesTest, self).setUp() + self.new_flags = _flagvalues.FlagValues() + _defines.DEFINE_boolean( + 'defined_flag', None, '', flag_values=self.new_flags) + _defines.DEFINE_boolean( + 'another_defined_flag', None, '', flag_values=self.new_flags) + self.setter_called = 0 + + def set_undef(self, unused_name, unused_val): + self.setter_called += 1 + + def test_two_defined_flags(self): + self.new_flags._set_attributes( + defined_flag=False, another_defined_flag=False) + self.assertEqual(self.setter_called, 0) + + def test_one_defined_one_undefined_flag(self): + with self.assertRaises(_exceptions.UnrecognizedFlagError): + self.new_flags._set_attributes(defined_flag=False, undefined_flag=0) + + def test_register_unknown_flag_setter(self): + self.new_flags._register_unknown_flag_setter(self.set_undef) + self.new_flags._set_attributes(defined_flag=False, undefined_flag=0) + self.assertEqual(self.setter_called, 1) + + +class FlagsDashSyntaxTest(absltest.TestCase): + + def setUp(self): + super(FlagsDashSyntaxTest, self).setUp() + self.fv = _flagvalues.FlagValues() + _defines.DEFINE_string( + 'long_name', 'default', 'help', flag_values=self.fv, short_name='s') + + def test_long_name_one_dash(self): + self.fv(['./program', '-long_name=new']) + self.assertEqual('new', self.fv.long_name) + + def test_long_name_two_dashes(self): + self.fv(['./program', '--long_name=new']) + self.assertEqual('new', self.fv.long_name) + + def test_long_name_three_dashes(self): + with self.assertRaises(_exceptions.UnrecognizedFlagError): + self.fv(['./program', '---long_name=new']) + + def test_short_name_one_dash(self): + self.fv(['./program', '-s=new']) + self.assertEqual('new', self.fv.s) + + def test_short_name_two_dashes(self): + self.fv(['./program', '--s=new']) + self.assertEqual('new', self.fv.s) + + def test_short_name_three_dashes(self): + with self.assertRaises(_exceptions.UnrecognizedFlagError): + self.fv(['./program', '---s=new']) + + +class UnparseFlagsTest(absltest.TestCase): + + def test_using_default_value_none(self): + fv = _flagvalues.FlagValues() + _defines.DEFINE_string('default_none', None, 'help', flag_values=fv) + self.assertTrue(fv['default_none'].using_default_value) + fv(['', '--default_none=notNone']) + self.assertFalse(fv['default_none'].using_default_value) + fv.unparse_flags() + self.assertTrue(fv['default_none'].using_default_value) + fv(['', '--default_none=alsoNotNone']) + self.assertFalse(fv['default_none'].using_default_value) + fv.unparse_flags() + self.assertTrue(fv['default_none'].using_default_value) + + def test_using_default_value_not_none(self): + fv = _flagvalues.FlagValues() + _defines.DEFINE_string('default_foo', 'foo', 'help', flag_values=fv) + + fv.mark_as_parsed() + self.assertTrue(fv['default_foo'].using_default_value) + + fv(['', '--default_foo=foo']) + self.assertFalse(fv['default_foo'].using_default_value) + + fv(['', '--default_foo=notFoo']) + self.assertFalse(fv['default_foo'].using_default_value) + + fv.unparse_flags() + self.assertTrue(fv['default_foo'].using_default_value) + + fv(['', '--default_foo=alsoNotFoo']) + self.assertFalse(fv['default_foo'].using_default_value) + + def test_allow_overwrite_false(self): + fv = _flagvalues.FlagValues() + _defines.DEFINE_string( + 'default_none', None, 'help', allow_overwrite=False, flag_values=fv) + _defines.DEFINE_string( + 'default_foo', 'foo', 'help', allow_overwrite=False, flag_values=fv) + + fv.mark_as_parsed() + self.assertEqual('foo', fv.default_foo) + self.assertIsNone(fv.default_none) + + fv(['', '--default_foo=notFoo', '--default_none=notNone']) + self.assertEqual('notFoo', fv.default_foo) + self.assertEqual('notNone', fv.default_none) + + fv.unparse_flags() + self.assertEqual('foo', fv['default_foo'].value) + self.assertIsNone(fv['default_none'].value) + + fv(['', '--default_foo=alsoNotFoo', '--default_none=alsoNotNone']) + self.assertEqual('alsoNotFoo', fv.default_foo) + self.assertEqual('alsoNotNone', fv.default_none) + + def test_multi_string_default_none(self): + fv = _flagvalues.FlagValues() + _defines.DEFINE_multi_string('foo', None, 'help', flag_values=fv) + fv.mark_as_parsed() + self.assertIsNone(fv.foo) + fv(['', '--foo=aa']) + self.assertEqual(['aa'], fv.foo) + fv.unparse_flags() + self.assertIsNone(fv['foo'].value) + fv(['', '--foo=bb', '--foo=cc']) + self.assertEqual(['bb', 'cc'], fv.foo) + fv.unparse_flags() + self.assertIsNone(fv['foo'].value) + + def test_multi_string_default_string(self): + fv = _flagvalues.FlagValues() + _defines.DEFINE_multi_string('foo', 'xyz', 'help', flag_values=fv) + expected_default = ['xyz'] + fv.mark_as_parsed() + self.assertEqual(expected_default, fv.foo) + fv(['', '--foo=aa']) + self.assertEqual(['aa'], fv.foo) + fv.unparse_flags() + self.assertEqual(expected_default, fv['foo'].value) + fv(['', '--foo=bb', '--foo=cc']) + self.assertEqual(['bb', 'cc'], fv['foo'].value) + fv.unparse_flags() + self.assertEqual(expected_default, fv['foo'].value) + + def test_multi_string_default_list(self): + fv = _flagvalues.FlagValues() + _defines.DEFINE_multi_string( + 'foo', ['xx', 'yy', 'zz'], 'help', flag_values=fv) + expected_default = ['xx', 'yy', 'zz'] + fv.mark_as_parsed() + self.assertEqual(expected_default, fv.foo) + fv(['', '--foo=aa']) + self.assertEqual(['aa'], fv.foo) + fv.unparse_flags() + self.assertEqual(expected_default, fv['foo'].value) + fv(['', '--foo=bb', '--foo=cc']) + self.assertEqual(['bb', 'cc'], fv.foo) + fv.unparse_flags() + self.assertEqual(expected_default, fv['foo'].value) + + +class UnparsedFlagAccessTest(absltest.TestCase): + + def test_unparsed_flag_access(self): + fv = _flagvalues.FlagValues() + _defines.DEFINE_string('name', 'default', 'help', flag_values=fv) + with self.assertRaises(_exceptions.UnparsedFlagAccessError): + _ = fv.name + + def test_hasattr_raises_in_py3(self): + fv = _flagvalues.FlagValues() + _defines.DEFINE_string('name', 'default', 'help', flag_values=fv) + with self.assertRaises(_exceptions.UnparsedFlagAccessError): + _ = hasattr(fv, 'name') + + def test_unparsed_flags_access_raises_after_unparse_flags(self): + fv = _flagvalues.FlagValues() + _defines.DEFINE_string('a_str', 'default_value', 'help', flag_values=fv) + fv.mark_as_parsed() + self.assertEqual(fv.a_str, 'default_value') + fv.unparse_flags() + with self.assertRaises(_exceptions.UnparsedFlagAccessError): + _ = fv.a_str + + +class FlagHolderTest(absltest.TestCase): + + def setUp(self): + super(FlagHolderTest, self).setUp() + self.fv = _flagvalues.FlagValues() + self.name_flag = _defines.DEFINE_string( + 'name', 'default', 'help', flag_values=self.fv) + + def parse_flags(self, *argv): + self.fv.unparse_flags() + self.fv(['binary_name'] + list(argv)) + + def test_name(self): + self.assertEqual('name', self.name_flag.name) + + def test_value_before_flag_parsing(self): + with self.assertRaises(_exceptions.UnparsedFlagAccessError): + _ = self.name_flag.value + + def test_value_returns_default_value_if_not_explicitly_set(self): + self.parse_flags() + self.assertEqual('default', self.name_flag.value) + + def test_value_returns_explicitly_set_value(self): + self.parse_flags('--name=new_value') + self.assertEqual('new_value', self.name_flag.value) + + def test_present_returns_false_before_flag_parsing(self): + self.assertFalse(self.name_flag.present) + + def test_present_returns_false_if_not_explicitly_set(self): + self.parse_flags() + self.assertFalse(self.name_flag.present) + + def test_present_returns_true_if_explicitly_set(self): + self.parse_flags('--name=new_value') + self.assertTrue(self.name_flag.present) + + def test_allow_override(self): + first = _defines.DEFINE_integer( + 'int_flag', 1, 'help', flag_values=self.fv, allow_override=1) + second = _defines.DEFINE_integer( + 'int_flag', 2, 'help', flag_values=self.fv, allow_override=1) + self.parse_flags('--int_flag=3') + self.assertEqual(3, first.value) + self.assertEqual(3, second.value) + self.assertTrue(first.present) + self.assertTrue(second.present) + + def test_eq(self): + with self.assertRaises(TypeError): + self.name_flag == 'value' # pylint: disable=pointless-statement + + def test_eq_reflection(self): + with self.assertRaises(TypeError): + 'value' == self.name_flag # pylint: disable=pointless-statement + + def test_bool(self): + with self.assertRaises(TypeError): + bool(self.name_flag) + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/flags/tests/_helpers_test.py b/absl/flags/tests/_helpers_test.py new file mode 100644 index 0000000..4746a79 --- /dev/null +++ b/absl/flags/tests/_helpers_test.py @@ -0,0 +1,173 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unittests for helpers module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from absl.flags import _helpers +from absl.flags.tests import module_bar +from absl.flags.tests import module_foo +from absl.testing import absltest + + +class FlagSuggestionTest(absltest.TestCase): + + def setUp(self): + self.longopts = [ + 'fsplit-ivs-in-unroller=', + 'fsplit-wide-types=', + 'fstack-protector=', + 'fstack-protector-all=', + 'fstrict-aliasing=', + 'fstrict-overflow=', + 'fthread-jumps=', + 'ftracer', + 'ftree-bit-ccp', + 'ftree-builtin-call-dce', + 'ftree-ccp', + 'ftree-ch'] + + def test_damerau_levenshtein_id(self): + self.assertEqual(0, _helpers._damerau_levenshtein('asdf', 'asdf')) + + def test_damerau_levenshtein_empty(self): + self.assertEqual(5, _helpers._damerau_levenshtein('', 'kites')) + self.assertEqual(6, _helpers._damerau_levenshtein('kitten', '')) + + def test_damerau_levenshtein_commutative(self): + self.assertEqual(2, _helpers._damerau_levenshtein('kitten', 'kites')) + self.assertEqual(2, _helpers._damerau_levenshtein('kites', 'kitten')) + + def test_damerau_levenshtein_transposition(self): + self.assertEqual(1, _helpers._damerau_levenshtein('kitten', 'ktiten')) + + def test_mispelled_suggestions(self): + suggestions = _helpers.get_flag_suggestions('fstack_protector_all', + self.longopts) + self.assertEqual(['fstack-protector-all'], suggestions) + + def test_ambiguous_prefix_suggestion(self): + suggestions = _helpers.get_flag_suggestions('fstack', self.longopts) + self.assertEqual(['fstack-protector', 'fstack-protector-all'], suggestions) + + def test_misspelled_ambiguous_prefix_suggestion(self): + suggestions = _helpers.get_flag_suggestions('stack', self.longopts) + self.assertEqual(['fstack-protector', 'fstack-protector-all'], suggestions) + + def test_crazy_suggestion(self): + suggestions = _helpers.get_flag_suggestions('asdfasdgasdfa', self.longopts) + self.assertEqual([], suggestions) + + def test_suggestions_are_sorted(self): + sorted_flags = sorted(['aab', 'aac', 'aad']) + misspelt_flag = 'aaa' + suggestions = _helpers.get_flag_suggestions(misspelt_flag, + reversed(sorted_flags)) + self.assertEqual(sorted_flags, suggestions) + + +class GetCallingModuleTest(absltest.TestCase): + """Test whether we correctly determine the module which defines the flag.""" + + def test_get_calling_module(self): + self.assertEqual(_helpers.get_calling_module(), sys.argv[0]) + self.assertEqual(module_foo.get_module_name(), + 'absl.flags.tests.module_foo') + self.assertEqual(module_bar.get_module_name(), + 'absl.flags.tests.module_bar') + + # We execute the following exec statements for their side-effect + # (i.e., not raising an error). They emphasize the case that not + # all code resides in one of the imported modules: Python is a + # really dynamic language, where we can dynamically construct some + # code and execute it. + code = ('from absl.flags import _helpers\n' + 'module_name = _helpers.get_calling_module()') + exec(code) # pylint: disable=exec-used + + # Next two exec statements executes code with a global environment + # that is different from the global environment of any imported + # module. + exec(code, {}) # pylint: disable=exec-used + # vars(self) returns a dictionary corresponding to the symbol + # table of the self object. dict(...) makes a distinct copy of + # this dictionary, such that any new symbol definition by the + # exec-ed code (e.g., import flags, module_name = ...) does not + # affect the symbol table of self. + exec(code, dict(vars(self))) # pylint: disable=exec-used + + # Next test is actually more involved: it checks not only that + # get_calling_module does not crash inside exec code, it also checks + # that it returns the expected value: the code executed via exec + # code is treated as being executed by the current module. We + # check it twice: first time by executing exec from the main + # module, second time by executing it from module_bar. + global_dict = {} + exec(code, global_dict) # pylint: disable=exec-used + self.assertEqual(global_dict['module_name'], + sys.argv[0]) + + global_dict = {} + module_bar.execute_code(code, global_dict) + self.assertEqual(global_dict['module_name'], + 'absl.flags.tests.module_bar') + + def test_get_calling_module_with_iteritems_error(self): + # This test checks that get_calling_module is using + # sys.modules.items(), instead of .iteritems(). + orig_sys_modules = sys.modules + + # Mock sys.modules: simulates error produced by importing a module + # in parallel with our iteration over sys.modules.iteritems(). + class SysModulesMock(dict): + + def __init__(self, original_content): + dict.__init__(self, original_content) + + def iteritems(self): + # Any dictionary method is fine, but not .iteritems(). + raise RuntimeError('dictionary changed size during iteration') + + sys.modules = SysModulesMock(orig_sys_modules) + try: + # _get_calling_module should still work as expected: + self.assertEqual(_helpers.get_calling_module(), sys.argv[0]) + self.assertEqual(module_foo.get_module_name(), + 'absl.flags.tests.module_foo') + finally: + sys.modules = orig_sys_modules + + +class IsBytesOrString(absltest.TestCase): + + def test_bytes(self): + self.assertTrue(_helpers.is_bytes_or_string(b'bytes')) + + def test_str(self): + self.assertTrue(_helpers.is_bytes_or_string('str')) + + def test_unicode(self): + self.assertTrue(_helpers.is_bytes_or_string(u'unicode')) + + def test_list(self): + self.assertFalse(_helpers.is_bytes_or_string(['str'])) + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/flags/tests/_validators_test.py b/absl/flags/tests/_validators_test.py new file mode 100644 index 0000000..f724813 --- /dev/null +++ b/absl/flags/tests/_validators_test.py @@ -0,0 +1,744 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Testing that flags validators framework does work. + +This file tests that each flag validator called when it should be, and that +failed validator will throw an exception, etc. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import warnings + + +from absl.flags import _defines +from absl.flags import _exceptions +from absl.flags import _flagvalues +from absl.flags import _validators +from absl.testing import absltest + + +class SingleFlagValidatorTest(absltest.TestCase): + """Testing _validators.register_validator() method.""" + + def setUp(self): + super(SingleFlagValidatorTest, self).setUp() + self.flag_values = _flagvalues.FlagValues() + self.call_args = [] + + def test_success(self): + def checker(x): + self.call_args.append(x) + return True + _defines.DEFINE_integer( + 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values) + _validators.register_validator( + 'test_flag', + checker, + message='Errors happen', + flag_values=self.flag_values) + + argv = ('./program',) + self.flag_values(argv) + self.assertIsNone(self.flag_values.test_flag) + self.flag_values.test_flag = 2 + self.assertEqual(2, self.flag_values.test_flag) + self.assertEqual([None, 2], self.call_args) + + def test_default_value_not_used_success(self): + def checker(x): + self.call_args.append(x) + return True + _defines.DEFINE_integer( + 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values) + _validators.register_validator( + 'test_flag', + checker, + message='Errors happen', + flag_values=self.flag_values) + + argv = ('./program', '--test_flag=1') + self.flag_values(argv) + self.assertEqual(1, self.flag_values.test_flag) + self.assertEqual([1], self.call_args) + + def test_validator_not_called_when_other_flag_is_changed(self): + def checker(x): + self.call_args.append(x) + return True + _defines.DEFINE_integer( + 'test_flag', 1, 'Usual integer flag', flag_values=self.flag_values) + _defines.DEFINE_integer( + 'other_flag', 2, 'Other integer flag', flag_values=self.flag_values) + _validators.register_validator( + 'test_flag', + checker, + message='Errors happen', + flag_values=self.flag_values) + + argv = ('./program',) + self.flag_values(argv) + self.assertEqual(1, self.flag_values.test_flag) + self.flag_values.other_flag = 3 + self.assertEqual([1], self.call_args) + + def test_exception_raised_if_checker_fails(self): + def checker(x): + self.call_args.append(x) + return x == 1 + _defines.DEFINE_integer( + 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values) + _validators.register_validator( + 'test_flag', + checker, + message='Errors happen', + flag_values=self.flag_values) + + argv = ('./program', '--test_flag=1') + self.flag_values(argv) + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values.test_flag = 2 + self.assertEqual('flag --test_flag=2: Errors happen', str(cm.exception)) + self.assertEqual([1, 2], self.call_args) + + def test_exception_raised_if_checker_raises_exception(self): + def checker(x): + self.call_args.append(x) + if x == 1: + return True + raise _exceptions.ValidationError('Specific message') + + _defines.DEFINE_integer( + 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values) + _validators.register_validator( + 'test_flag', + checker, + message='Errors happen', + flag_values=self.flag_values) + + argv = ('./program', '--test_flag=1') + self.flag_values(argv) + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values.test_flag = 2 + self.assertEqual('flag --test_flag=2: Specific message', str(cm.exception)) + self.assertEqual([1, 2], self.call_args) + + def test_error_message_when_checker_returns_false_on_start(self): + def checker(x): + self.call_args.append(x) + return False + _defines.DEFINE_integer( + 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values) + _validators.register_validator( + 'test_flag', + checker, + message='Errors happen', + flag_values=self.flag_values) + + argv = ('./program', '--test_flag=1') + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values(argv) + self.assertEqual('flag --test_flag=1: Errors happen', str(cm.exception)) + self.assertEqual([1], self.call_args) + + def test_error_message_when_checker_raises_exception_on_start(self): + def checker(x): + self.call_args.append(x) + raise _exceptions.ValidationError('Specific message') + + _defines.DEFINE_integer( + 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values) + _validators.register_validator( + 'test_flag', + checker, + message='Errors happen', + flag_values=self.flag_values) + + argv = ('./program', '--test_flag=1') + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values(argv) + self.assertEqual('flag --test_flag=1: Specific message', str(cm.exception)) + self.assertEqual([1], self.call_args) + + def test_validators_checked_in_order(self): + + def required(x): + self.calls.append('required') + return x is not None + + def even(x): + self.calls.append('even') + return x % 2 == 0 + + self.calls = [] + self._define_flag_and_validators(required, even) + self.assertEqual(['required', 'even'], self.calls) + + self.calls = [] + self._define_flag_and_validators(even, required) + self.assertEqual(['even', 'required'], self.calls) + + def _define_flag_and_validators(self, first_validator, second_validator): + local_flags = _flagvalues.FlagValues() + _defines.DEFINE_integer( + 'test_flag', 2, 'test flag', flag_values=local_flags) + _validators.register_validator( + 'test_flag', first_validator, message='', flag_values=local_flags) + _validators.register_validator( + 'test_flag', second_validator, message='', flag_values=local_flags) + argv = ('./program',) + local_flags(argv) + + def test_validator_as_decorator(self): + _defines.DEFINE_integer( + 'test_flag', None, 'Simple integer flag', flag_values=self.flag_values) + + @_validators.validator('test_flag', flag_values=self.flag_values) + def checker(x): + self.call_args.append(x) + return True + + argv = ('./program',) + self.flag_values(argv) + self.assertIsNone(self.flag_values.test_flag) + self.flag_values.test_flag = 2 + self.assertEqual(2, self.flag_values.test_flag) + self.assertEqual([None, 2], self.call_args) + # Check that 'Checker' is still a function and has not been replaced. + self.assertTrue(checker(3)) + self.assertEqual([None, 2, 3], self.call_args) + + +class MultiFlagsValidatorTest(absltest.TestCase): + """Test flags multi-flag validators.""" + + def setUp(self): + super(MultiFlagsValidatorTest, self).setUp() + self.flag_values = _flagvalues.FlagValues() + self.call_args = [] + _defines.DEFINE_integer( + 'foo', 1, 'Usual integer flag', flag_values=self.flag_values) + _defines.DEFINE_integer( + 'bar', 2, 'Usual integer flag', flag_values=self.flag_values) + + def test_success(self): + def checker(flags_dict): + self.call_args.append(flags_dict) + return True + _validators.register_multi_flags_validator( + ['foo', 'bar'], checker, flag_values=self.flag_values) + + argv = ('./program', '--bar=2') + self.flag_values(argv) + self.assertEqual(1, self.flag_values.foo) + self.assertEqual(2, self.flag_values.bar) + self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args) + self.flag_values.foo = 3 + self.assertEqual(3, self.flag_values.foo) + self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 2}], + self.call_args) + + def test_validator_not_called_when_other_flag_is_changed(self): + def checker(flags_dict): + self.call_args.append(flags_dict) + return True + _defines.DEFINE_integer( + 'other_flag', 3, 'Other integer flag', flag_values=self.flag_values) + _validators.register_multi_flags_validator( + ['foo', 'bar'], checker, flag_values=self.flag_values) + + argv = ('./program',) + self.flag_values(argv) + self.flag_values.other_flag = 3 + self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args) + + def test_exception_raised_if_checker_fails(self): + def checker(flags_dict): + self.call_args.append(flags_dict) + values = flags_dict.values() + # Make sure all the flags have different values. + return len(set(values)) == len(values) + _validators.register_multi_flags_validator( + ['foo', 'bar'], + checker, + message='Errors happen', + flag_values=self.flag_values) + + argv = ('./program',) + self.flag_values(argv) + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values.bar = 1 + self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception)) + self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}], + self.call_args) + + def test_exception_raised_if_checker_raises_exception(self): + def checker(flags_dict): + self.call_args.append(flags_dict) + values = flags_dict.values() + # Make sure all the flags have different values. + if len(set(values)) != len(values): + raise _exceptions.ValidationError('Specific message') + return True + + _validators.register_multi_flags_validator( + ['foo', 'bar'], + checker, + message='Errors happen', + flag_values=self.flag_values) + + argv = ('./program',) + self.flag_values(argv) + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values.bar = 1 + self.assertEqual('flags foo=1, bar=1: Specific message', str(cm.exception)) + self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}], + self.call_args) + + def test_decorator(self): + @_validators.multi_flags_validator( + ['foo', 'bar'], message='Errors happen', flag_values=self.flag_values) + def checker(flags_dict): # pylint: disable=unused-variable + self.call_args.append(flags_dict) + values = flags_dict.values() + # Make sure all the flags have different values. + return len(set(values)) == len(values) + + argv = ('./program',) + self.flag_values(argv) + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values.bar = 1 + self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception)) + self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}], + self.call_args) + + +class MarkFlagsAsMutualExclusiveTest(absltest.TestCase): + + def setUp(self): + super(MarkFlagsAsMutualExclusiveTest, self).setUp() + self.flag_values = _flagvalues.FlagValues() + + _defines.DEFINE_string( + 'flag_one', None, 'flag one', flag_values=self.flag_values) + _defines.DEFINE_string( + 'flag_two', None, 'flag two', flag_values=self.flag_values) + _defines.DEFINE_string( + 'flag_three', None, 'flag three', flag_values=self.flag_values) + _defines.DEFINE_integer( + 'int_flag_one', None, 'int flag one', flag_values=self.flag_values) + _defines.DEFINE_integer( + 'int_flag_two', None, 'int flag two', flag_values=self.flag_values) + _defines.DEFINE_multi_string( + 'multi_flag_one', None, 'multi flag one', flag_values=self.flag_values) + _defines.DEFINE_multi_string( + 'multi_flag_two', None, 'multi flag two', flag_values=self.flag_values) + _defines.DEFINE_boolean( + 'flag_not_none', False, 'false default', flag_values=self.flag_values) + + def _mark_flags_as_mutually_exclusive(self, flag_names, required): + _validators.mark_flags_as_mutual_exclusive( + flag_names, required=required, flag_values=self.flag_values) + + def test_no_flags_present(self): + self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], False) + argv = ('./program',) + + self.flag_values(argv) + self.assertIsNone(self.flag_values.flag_one) + self.assertIsNone(self.flag_values.flag_two) + + def test_no_flags_present_required(self): + self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True) + argv = ('./program',) + expected = ( + 'flags flag_one=None, flag_two=None: ' + 'Exactly one of (flag_one, flag_two) must have a value other than ' + 'None.') + + self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError, + expected, self.flag_values, argv) + + def test_one_flag_present(self): + self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], False) + self.flag_values(('./program', '--flag_one=1')) + self.assertEqual('1', self.flag_values.flag_one) + + def test_one_flag_present_required(self): + self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True) + self.flag_values(('./program', '--flag_two=2')) + self.assertEqual('2', self.flag_values.flag_two) + + def test_one_flag_zero_required(self): + self._mark_flags_as_mutually_exclusive( + ['int_flag_one', 'int_flag_two'], True) + self.flag_values(('./program', '--int_flag_one=0')) + self.assertEqual(0, self.flag_values.int_flag_one) + + def test_mutual_exclusion_with_extra_flags(self): + self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True) + argv = ('./program', '--flag_two=2', '--flag_three=3') + + self.flag_values(argv) + self.assertEqual('2', self.flag_values.flag_two) + self.assertEqual('3', self.flag_values.flag_three) + + def test_mutual_exclusion_with_zero(self): + self._mark_flags_as_mutually_exclusive( + ['int_flag_one', 'int_flag_two'], False) + argv = ('./program', '--int_flag_one=0', '--int_flag_two=0') + expected = ( + 'flags int_flag_one=0, int_flag_two=0: ' + 'At most one of (int_flag_one, int_flag_two) must have a value other ' + 'than None.') + + self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError, + expected, self.flag_values, argv) + + def test_multiple_flags_present(self): + self._mark_flags_as_mutually_exclusive( + ['flag_one', 'flag_two', 'flag_three'], False) + argv = ('./program', '--flag_one=1', '--flag_two=2', '--flag_three=3') + expected = ( + 'flags flag_one=1, flag_two=2, flag_three=3: ' + 'At most one of (flag_one, flag_two, flag_three) must have a value ' + 'other than None.') + + self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError, + expected, self.flag_values, argv) + + def test_multiple_flags_present_required(self): + self._mark_flags_as_mutually_exclusive( + ['flag_one', 'flag_two', 'flag_three'], True) + argv = ('./program', '--flag_one=1', '--flag_two=2', '--flag_three=3') + expected = ( + 'flags flag_one=1, flag_two=2, flag_three=3: ' + 'Exactly one of (flag_one, flag_two, flag_three) must have a value ' + 'other than None.') + + self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError, + expected, self.flag_values, argv) + + def test_no_multiflags_present(self): + self._mark_flags_as_mutually_exclusive( + ['multi_flag_one', 'multi_flag_two'], False) + argv = ('./program',) + self.flag_values(argv) + self.assertIsNone(self.flag_values.multi_flag_one) + self.assertIsNone(self.flag_values.multi_flag_two) + + def test_no_multistring_flags_present_required(self): + self._mark_flags_as_mutually_exclusive( + ['multi_flag_one', 'multi_flag_two'], True) + argv = ('./program',) + expected = ( + 'flags multi_flag_one=None, multi_flag_two=None: ' + 'Exactly one of (multi_flag_one, multi_flag_two) must have a value ' + 'other than None.') + + self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError, + expected, self.flag_values, argv) + + def test_one_multiflag_present(self): + self._mark_flags_as_mutually_exclusive( + ['multi_flag_one', 'multi_flag_two'], True) + self.flag_values(('./program', '--multi_flag_one=1')) + self.assertEqual(['1'], self.flag_values.multi_flag_one) + + def test_one_multiflag_present_repeated(self): + self._mark_flags_as_mutually_exclusive( + ['multi_flag_one', 'multi_flag_two'], True) + self.flag_values(('./program', '--multi_flag_one=1', '--multi_flag_one=1b')) + self.assertEqual(['1', '1b'], self.flag_values.multi_flag_one) + + def test_multiple_multiflags_present(self): + self._mark_flags_as_mutually_exclusive( + ['multi_flag_one', 'multi_flag_two'], False) + argv = ('./program', '--multi_flag_one=1', '--multi_flag_two=2') + expected = ( + "flags multi_flag_one=['1'], multi_flag_two=['2']: " + 'At most one of (multi_flag_one, multi_flag_two) must have a value ' + 'other than None.') + + self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError, + expected, self.flag_values, argv) + + def test_multiple_multiflags_present_required(self): + self._mark_flags_as_mutually_exclusive( + ['multi_flag_one', 'multi_flag_two'], True) + argv = ('./program', '--multi_flag_one=1', '--multi_flag_two=2') + expected = ( + "flags multi_flag_one=['1'], multi_flag_two=['2']: " + 'Exactly one of (multi_flag_one, multi_flag_two) must have a value ' + 'other than None.') + + self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError, + expected, self.flag_values, argv) + + def test_flag_default_not_none_warning(self): + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter('always') + self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_not_none'], + False) + self.assertLen(caught_warnings, 1) + self.assertIn('--flag_not_none has a non-None default value', + str(caught_warnings[0].message)) + + +class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase): + + def setUp(self): + super(MarkBoolFlagsAsMutualExclusiveTest, self).setUp() + self.flag_values = _flagvalues.FlagValues() + + _defines.DEFINE_boolean( + 'false_1', False, 'default false 1', flag_values=self.flag_values) + _defines.DEFINE_boolean( + 'false_2', False, 'default false 2', flag_values=self.flag_values) + _defines.DEFINE_boolean( + 'true_1', True, 'default true 1', flag_values=self.flag_values) + _defines.DEFINE_integer( + 'non_bool', None, 'non bool', flag_values=self.flag_values) + + def _mark_bool_flags_as_mutually_exclusive(self, flag_names, required): + _validators.mark_bool_flags_as_mutual_exclusive( + flag_names, required=required, flag_values=self.flag_values) + + def test_no_flags_present(self): + self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], False) + self.flag_values(('./program',)) + self.assertEqual(False, self.flag_values.false_1) + self.assertEqual(False, self.flag_values.false_2) + + def test_no_flags_present_required(self): + self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], True) + argv = ('./program',) + expected = ( + 'flags false_1=False, false_2=False: ' + 'Exactly one of (false_1, false_2) must be True.') + + self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError, + expected, self.flag_values, argv) + + def test_no_flags_present_with_default_true_required(self): + self._mark_bool_flags_as_mutually_exclusive(['false_1', 'true_1'], True) + self.flag_values(('./program',)) + self.assertEqual(False, self.flag_values.false_1) + self.assertEqual(True, self.flag_values.true_1) + + def test_two_flags_true(self): + self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], False) + argv = ('./program', '--false_1', '--false_2') + expected = ( + 'flags false_1=True, false_2=True: At most one of (false_1, ' + 'false_2) must be True.') + + self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError, + expected, self.flag_values, argv) + + def test_non_bool_flag(self): + expected = ('Flag --non_bool is not Boolean, which is required for flags ' + 'used in mark_bool_flags_as_mutual_exclusive.') + with self.assertRaisesWithLiteralMatch(_exceptions.ValidationError, + expected): + self._mark_bool_flags_as_mutually_exclusive(['false_1', 'non_bool'], + False) + + +class MarkFlagAsRequiredTest(absltest.TestCase): + + def setUp(self): + super(MarkFlagAsRequiredTest, self).setUp() + self.flag_values = _flagvalues.FlagValues() + + def test_success(self): + _defines.DEFINE_string( + 'string_flag', None, 'string flag', flag_values=self.flag_values) + _validators.mark_flag_as_required( + 'string_flag', flag_values=self.flag_values) + argv = ('./program', '--string_flag=value') + self.flag_values(argv) + self.assertEqual('value', self.flag_values.string_flag) + + def test_catch_none_as_default(self): + _defines.DEFINE_string( + 'string_flag', None, 'string flag', flag_values=self.flag_values) + _validators.mark_flag_as_required( + 'string_flag', flag_values=self.flag_values) + argv = ('./program',) + expected = ( + r'flag --string_flag=None: Flag --string_flag must have a value other ' + r'than None\.') + with self.assertRaisesRegex(_exceptions.IllegalFlagValueError, expected): + self.flag_values(argv) + + def test_catch_setting_none_after_program_start(self): + _defines.DEFINE_string( + 'string_flag', 'value', 'string flag', flag_values=self.flag_values) + _validators.mark_flag_as_required( + 'string_flag', flag_values=self.flag_values) + argv = ('./program',) + self.flag_values(argv) + self.assertEqual('value', self.flag_values.string_flag) + expected = ('flag --string_flag=None: Flag --string_flag must have a value ' + 'other than None.') + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values.string_flag = None + self.assertEqual(expected, str(cm.exception)) + + def test_flag_default_not_none_warning(self): + _defines.DEFINE_string( + 'flag_not_none', '', 'empty default', flag_values=self.flag_values) + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter('always') + _validators.mark_flag_as_required( + 'flag_not_none', flag_values=self.flag_values) + + self.assertLen(caught_warnings, 1) + self.assertIn('--flag_not_none has a non-None default value', + str(caught_warnings[0].message)) + + +class MarkFlagsAsRequiredTest(absltest.TestCase): + + def setUp(self): + super(MarkFlagsAsRequiredTest, self).setUp() + self.flag_values = _flagvalues.FlagValues() + + def test_success(self): + _defines.DEFINE_string( + 'string_flag_1', None, 'string flag 1', flag_values=self.flag_values) + _defines.DEFINE_string( + 'string_flag_2', None, 'string flag 2', flag_values=self.flag_values) + flag_names = ['string_flag_1', 'string_flag_2'] + _validators.mark_flags_as_required(flag_names, flag_values=self.flag_values) + argv = ('./program', '--string_flag_1=value_1', '--string_flag_2=value_2') + self.flag_values(argv) + self.assertEqual('value_1', self.flag_values.string_flag_1) + self.assertEqual('value_2', self.flag_values.string_flag_2) + + def test_catch_none_as_default(self): + _defines.DEFINE_string( + 'string_flag_1', None, 'string flag 1', flag_values=self.flag_values) + _defines.DEFINE_string( + 'string_flag_2', None, 'string flag 2', flag_values=self.flag_values) + _validators.mark_flags_as_required( + ['string_flag_1', 'string_flag_2'], flag_values=self.flag_values) + argv = ('./program', '--string_flag_1=value_1') + expected = ( + r'flag --string_flag_2=None: Flag --string_flag_2 must have a value ' + r'other than None\.') + with self.assertRaisesRegex(_exceptions.IllegalFlagValueError, expected): + self.flag_values(argv) + + def test_catch_setting_none_after_program_start(self): + _defines.DEFINE_string( + 'string_flag_1', + 'value_1', + 'string flag 1', + flag_values=self.flag_values) + _defines.DEFINE_string( + 'string_flag_2', + 'value_2', + 'string flag 2', + flag_values=self.flag_values) + _validators.mark_flags_as_required( + ['string_flag_1', 'string_flag_2'], flag_values=self.flag_values) + argv = ('./program', '--string_flag_1=value_1') + self.flag_values(argv) + self.assertEqual('value_1', self.flag_values.string_flag_1) + expected = ( + 'flag --string_flag_1=None: Flag --string_flag_1 must have a value ' + 'other than None.') + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values.string_flag_1 = None + self.assertEqual(expected, str(cm.exception)) + + def test_catch_multiple_flags_as_none_at_program_start(self): + _defines.DEFINE_float( + 'float_flag_1', + None, + 'string flag 1', + flag_values=self.flag_values) + _defines.DEFINE_float( + 'float_flag_2', + None, + 'string flag 2', + flag_values=self.flag_values) + _validators.mark_flags_as_required( + ['float_flag_1', 'float_flag_2'], flag_values=self.flag_values) + argv = ('./program', '') + expected = ( + 'flag --float_flag_1=None: Flag --float_flag_1 must have a value ' + 'other than None.\n' + 'flag --float_flag_2=None: Flag --float_flag_2 must have a value ' + 'other than None.') + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values(argv) + self.assertEqual(expected, str(cm.exception)) + + def test_fail_fast_single_flag_and_skip_remaining_validators(self): + def raise_unexpected_error(x): + del x + raise _exceptions.ValidationError('Should not be raised.') + _defines.DEFINE_float( + 'flag_1', None, 'flag 1', flag_values=self.flag_values) + _defines.DEFINE_float( + 'flag_2', 4.2, 'flag 2', flag_values=self.flag_values) + _validators.mark_flag_as_required('flag_1', flag_values=self.flag_values) + _validators.register_validator( + 'flag_1', raise_unexpected_error, flag_values=self.flag_values) + _validators.register_multi_flags_validator(['flag_2', 'flag_1'], + raise_unexpected_error, + flag_values=self.flag_values) + argv = ('./program', '') + expected = ( + 'flag --flag_1=None: Flag --flag_1 must have a value other than None.') + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values(argv) + self.assertEqual(expected, str(cm.exception)) + + def test_fail_fast_multi_flag_and_skip_remaining_validators(self): + def raise_expected_error(x): + del x + raise _exceptions.ValidationError('Expected error.') + def raise_unexpected_error(x): + del x + raise _exceptions.ValidationError('Got unexpected error.') + _defines.DEFINE_float( + 'flag_1', 5.1, 'flag 1', flag_values=self.flag_values) + _defines.DEFINE_float( + 'flag_2', 10.0, 'flag 2', flag_values=self.flag_values) + _validators.register_multi_flags_validator(['flag_1', 'flag_2'], + raise_expected_error, + flag_values=self.flag_values) + _validators.register_multi_flags_validator(['flag_2', 'flag_1'], + raise_unexpected_error, + flag_values=self.flag_values) + _validators.register_validator( + 'flag_1', raise_unexpected_error, flag_values=self.flag_values) + _validators.register_validator( + 'flag_2', raise_unexpected_error, flag_values=self.flag_values) + argv = ('./program', '') + expected = ('flags flag_1=5.1, flag_2=10.0: Expected error.') + with self.assertRaises(_exceptions.IllegalFlagValueError) as cm: + self.flag_values(argv) + self.assertEqual(expected, str(cm.exception)) + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/flags/tests/argparse_flags_test.py b/absl/flags/tests/argparse_flags_test.py new file mode 100644 index 0000000..5e6f49a --- /dev/null +++ b/absl/flags/tests/argparse_flags_test.py @@ -0,0 +1,447 @@ +# Copyright 2018 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for absl.flags.argparse_flags.""" + +import io +import os +import subprocess +import sys +import tempfile +from unittest import mock + +from absl import flags +from absl import logging +from absl.flags import argparse_flags +from absl.testing import _bazelize_command +from absl.testing import absltest +from absl.testing import parameterized + + +class ArgparseFlagsTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self._absl_flags = flags.FlagValues() + flags.DEFINE_bool( + 'absl_bool', None, 'help for --absl_bool.', + short_name='b', flag_values=self._absl_flags) + # Add a boolean flag that starts with "no", to verify it can correctly + # handle the "no" prefixes in boolean flags. + flags.DEFINE_bool( + 'notice', None, 'help for --notice.', + flag_values=self._absl_flags) + flags.DEFINE_string( + 'absl_string', 'default', 'help for --absl_string=%.', + short_name='s', flag_values=self._absl_flags) + flags.DEFINE_integer( + 'absl_integer', 1, 'help for --absl_integer.', + flag_values=self._absl_flags) + flags.DEFINE_float( + 'absl_float', 1, 'help for --absl_integer.', + flag_values=self._absl_flags) + flags.DEFINE_enum( + 'absl_enum', 'apple', ['apple', 'orange'], 'help for --absl_enum.', + flag_values=self._absl_flags) + + def test_dash_as_prefix_char_only(self): + with self.assertRaises(ValueError): + argparse_flags.ArgumentParser(prefix_chars='/') + + def test_default_inherited_absl_flags_value(self): + parser = argparse_flags.ArgumentParser() + self.assertIs(parser._inherited_absl_flags, flags.FLAGS) + + def test_parse_absl_flags(self): + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags) + self.assertFalse(self._absl_flags.is_parsed()) + self.assertTrue(self._absl_flags['absl_string'].using_default_value) + self.assertTrue(self._absl_flags['absl_integer'].using_default_value) + self.assertTrue(self._absl_flags['absl_float'].using_default_value) + self.assertTrue(self._absl_flags['absl_enum'].using_default_value) + + parser.parse_args( + ['--absl_string=new_string', '--absl_integer', '2']) + self.assertEqual(self._absl_flags.absl_string, 'new_string') + self.assertEqual(self._absl_flags.absl_integer, 2) + self.assertTrue(self._absl_flags.is_parsed()) + self.assertFalse(self._absl_flags['absl_string'].using_default_value) + self.assertFalse(self._absl_flags['absl_integer'].using_default_value) + self.assertTrue(self._absl_flags['absl_float'].using_default_value) + self.assertTrue(self._absl_flags['absl_enum'].using_default_value) + + @parameterized.named_parameters( + ('true', ['--absl_bool'], True), + ('false', ['--noabsl_bool'], False), + ('does_not_accept_equal_value', ['--absl_bool=true'], SystemExit), + ('does_not_accept_space_value', ['--absl_bool', 'true'], SystemExit), + ('long_name_single_dash', ['-absl_bool'], SystemExit), + ('short_name', ['-b'], True), + ('short_name_false', ['-nob'], SystemExit), + ('short_name_double_dash', ['--b'], SystemExit), + ('short_name_double_dash_false', ['--nob'], SystemExit), + ) + def test_parse_boolean_flags(self, args, expected): + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags) + self.assertIsNone(self._absl_flags['absl_bool'].value) + self.assertIsNone(self._absl_flags['b'].value) + if isinstance(expected, bool): + parser.parse_args(args) + self.assertEqual(expected, self._absl_flags.absl_bool) + self.assertEqual(expected, self._absl_flags.b) + else: + with self.assertRaises(expected): + parser.parse_args(args) + + @parameterized.named_parameters( + ('true', ['--notice'], True), + ('false', ['--nonotice'], False), + ) + def test_parse_boolean_existing_no_prefix(self, args, expected): + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags) + self.assertIsNone(self._absl_flags['notice'].value) + parser.parse_args(args) + self.assertEqual(expected, self._absl_flags.notice) + + def test_unrecognized_flag(self): + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags) + with self.assertRaises(SystemExit): + parser.parse_args(['--unknown_flag=what']) + + def test_absl_validators(self): + + @flags.validator('absl_integer', flag_values=self._absl_flags) + def ensure_positive(value): + return value > 0 + + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags) + with self.assertRaises(SystemExit): + parser.parse_args(['--absl_integer', '-2']) + + del ensure_positive + + @parameterized.named_parameters( + ('regular_name_double_dash', '--absl_string=new_string', 'new_string'), + ('regular_name_single_dash', '-absl_string=new_string', SystemExit), + ('short_name_double_dash', '--s=new_string', SystemExit), + ('short_name_single_dash', '-s=new_string', 'new_string'), + ) + def test_dashes(self, argument, expected): + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags) + if isinstance(expected, str): + parser.parse_args([argument]) + self.assertEqual(self._absl_flags.absl_string, expected) + else: + with self.assertRaises(expected): + parser.parse_args([argument]) + + def test_absl_flags_not_added_to_namespace(self): + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags) + args = parser.parse_args(['--absl_string=new_string']) + self.assertIsNone(getattr(args, 'absl_string', None)) + + def test_mixed_flags_and_positional(self): + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags) + parser.add_argument('--header', help='Header message to print.') + parser.add_argument('integers', metavar='N', type=int, nargs='+', + help='an integer for the accumulator') + + args = parser.parse_args( + ['--absl_string=new_string', '--header=HEADER', '--absl_integer', + '2', '3', '4']) + self.assertEqual(self._absl_flags.absl_string, 'new_string') + self.assertEqual(self._absl_flags.absl_integer, 2) + self.assertEqual(args.header, 'HEADER') + self.assertListEqual(args.integers, [3, 4]) + + def test_subparsers(self): + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags) + parser.add_argument('--header', help='Header message to print.') + subparsers = parser.add_subparsers(help='The command to execute.') + + sub_parser = subparsers.add_parser( + 'sub_cmd', help='Sub command.', inherited_absl_flags=self._absl_flags) + sub_parser.add_argument('--sub_flag', help='Sub command flag.') + + def sub_command_func(): + pass + + sub_parser.set_defaults(command=sub_command_func) + + args = parser.parse_args([ + '--header=HEADER', '--absl_string=new_value', 'sub_cmd', + '--absl_integer=2', '--sub_flag=new_sub_flag_value']) + + self.assertEqual(args.header, 'HEADER') + self.assertEqual(self._absl_flags.absl_string, 'new_value') + self.assertEqual(args.command, sub_command_func) + self.assertEqual(self._absl_flags.absl_integer, 2) + self.assertEqual(args.sub_flag, 'new_sub_flag_value') + + def test_subparsers_no_inherit_in_subparser(self): + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags) + subparsers = parser.add_subparsers(help='The command to execute.') + + subparsers.add_parser( + 'sub_cmd', help='Sub command.', + # Do not inherit absl flags in the subparser. + # This is the behavior that this test exercises. + inherited_absl_flags=None) + + with self.assertRaises(SystemExit): + parser.parse_args(['sub_cmd', '--absl_string=new_value']) + + def test_help_main_module_flags(self): + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags) + help_message = parser.format_help() + + # Only the short name is shown in the usage string. + self.assertIn('[-s ABSL_STRING]', help_message) + # Both names are included in the options section. + self.assertIn('-s ABSL_STRING, --absl_string ABSL_STRING', help_message) + # Verify help messages. + self.assertIn('help for --absl_string=%.', help_message) + self.assertIn('<apple|orange>: help for --absl_enum.', help_message) + + def test_help_non_main_module_flags(self): + flags.DEFINE_string( + 'non_main_module_flag', 'default', 'help', + module_name='other.module', flag_values=self._absl_flags) + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags) + help_message = parser.format_help() + + # Non main module key flags are not printed in the help message. + self.assertNotIn('non_main_module_flag', help_message) + + def test_help_non_main_module_key_flags(self): + flags.DEFINE_string( + 'non_main_module_flag', 'default', 'help', + module_name='other.module', flag_values=self._absl_flags) + flags.declare_key_flag('non_main_module_flag', flag_values=self._absl_flags) + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags) + help_message = parser.format_help() + + # Main module key fags are printed in the help message, even if the flag + # is defined in another module. + self.assertIn('non_main_module_flag', help_message) + + @parameterized.named_parameters( + ('h', ['-h']), + ('help', ['--help']), + ('helpshort', ['--helpshort']), + ('helpfull', ['--helpfull']), + ) + def test_help_flags(self, args): + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags) + with self.assertRaises(SystemExit): + parser.parse_args(args) + + @parameterized.named_parameters( + ('h', ['-h']), + ('help', ['--help']), + ('helpshort', ['--helpshort']), + ('helpfull', ['--helpfull']), + ) + def test_no_help_flags(self, args): + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags, add_help=False) + with mock.patch.object(parser, 'print_help'): + with self.assertRaises(SystemExit): + parser.parse_args(args) + parser.print_help.assert_not_called() + + def test_helpfull_message(self): + flags.DEFINE_string( + 'non_main_module_flag', 'default', 'help', + module_name='other.module', flag_values=self._absl_flags) + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags) + with self.assertRaises(SystemExit),\ + mock.patch.object(sys, 'stdout', new=io.StringIO()) as mock_stdout: + parser.parse_args(['--helpfull']) + stdout_message = mock_stdout.getvalue() + logging.info('captured stdout message:\n%s', stdout_message) + self.assertIn('--non_main_module_flag', stdout_message) + self.assertIn('other.module', stdout_message) + # Make sure the main module is not included. + self.assertNotIn(sys.argv[0], stdout_message) + # Special flags defined in absl.flags. + self.assertIn('absl.flags:', stdout_message) + self.assertIn('--flagfile', stdout_message) + self.assertIn('--undefok', stdout_message) + + @parameterized.named_parameters( + ('at_end', + ('1', '--absl_string=value_from_cmd', '--flagfile='), + 'value_from_file'), + ('at_beginning', + ('--flagfile=', '1', '--absl_string=value_from_cmd'), + 'value_from_cmd'), + ) + def test_flagfile(self, cmd_args, expected_absl_string_value): + # Set gnu_getopt to False, to verify it's ignored by argparse_flags. + self._absl_flags.set_gnu_getopt(False) + + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags) + parser.add_argument('--header', help='Header message to print.') + parser.add_argument('integers', metavar='N', type=int, nargs='+', + help='an integer for the accumulator') + flagfile = tempfile.NamedTemporaryFile( + dir=absltest.TEST_TMPDIR.value, delete=False) + self.addCleanup(os.unlink, flagfile.name) + with flagfile: + flagfile.write(b''' +# The flag file. +--absl_string=value_from_file +--absl_integer=1 +--header=header_from_file +''') + + expand_flagfile = lambda x: x + flagfile.name if x == '--flagfile=' else x + cmd_args = [expand_flagfile(x) for x in cmd_args] + args = parser.parse_args(cmd_args) + + self.assertEqual([1], args.integers) + self.assertEqual('header_from_file', args.header) + self.assertEqual(expected_absl_string_value, self._absl_flags.absl_string) + + @parameterized.parameters( + ('positional', {'positional'}, False), + ('--not_existed', {'existed'}, False), + ('--empty', set(), False), + ('-single_dash', {'single_dash'}, True), + ('--double_dash', {'double_dash'}, True), + ('--with_value=value', {'with_value'}, True), + ) + def test_is_undefok(self, arg, undefok_names, is_undefok): + self.assertEqual(is_undefok, argparse_flags._is_undefok(arg, undefok_names)) + + @parameterized.named_parameters( + ('single', 'single', ['--single'], []), + ('multiple', 'first,second', ['--first', '--second'], []), + ('single_dash', 'dash', ['-dash'], []), + ('mixed_dash', 'mixed', ['-mixed', '--mixed'], []), + ('value', 'name', ['--name=value'], []), + ('boolean_positive', 'bool', ['--bool'], []), + ('boolean_negative', 'bool', ['--nobool'], []), + ('left_over', 'strip', ['--first', '--strip', '--last'], + ['--first', '--last']), + ) + def test_strip_undefok_args(self, undefok, args, expected_args): + actual_args = argparse_flags._strip_undefok_args(undefok, args) + self.assertListEqual(expected_args, actual_args) + + @parameterized.named_parameters( + ('at_end', ['--unknown', '--undefok=unknown']), + ('at_beginning', ['--undefok=unknown', '--unknown']), + ('multiple', ['--unknown', '--undefok=unknown,another_unknown']), + ('with_value', ['--unknown=value', '--undefok=unknown']), + ('maybe_boolean', ['--nounknown', '--undefok=unknown']), + ('with_space', ['--unknown', '--undefok', 'unknown']), + ) + def test_undefok_flag_correct_use(self, cmd_args): + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags) + args = parser.parse_args(cmd_args) # Make sure it doesn't raise. + # Make sure `undefok` is not exposed in namespace. + sentinel = object() + self.assertIs(sentinel, getattr(args, 'undefok', sentinel)) + + def test_undefok_flag_existing(self): + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags) + parser.parse_args( + ['--absl_string=new_value', '--undefok=absl_string']) + self.assertEqual('new_value', self._absl_flags.absl_string) + + @parameterized.named_parameters( + ('no_equal', ['--unknown', 'value', '--undefok=unknown']), + ('single_dash', ['--unknown', '-undefok=unknown']), + ) + def test_undefok_flag_incorrect_use(self, cmd_args): + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags) + with self.assertRaises(SystemExit): + parser.parse_args(cmd_args) + + def test_argument_default(self): + # Regression test for https://github.com/abseil/abseil-py/issues/171. + parser = argparse_flags.ArgumentParser( + inherited_absl_flags=self._absl_flags, argument_default=23) + parser.add_argument( + '--magic_number', type=int, help='The magic number to use.') + args = parser.parse_args([]) + self.assertEqual(args.magic_number, 23) + + +class ArgparseWithAppRunTest(parameterized.TestCase): + + @parameterized.named_parameters( + ('simple', + 'main_simple', 'parse_flags_simple', + ['--argparse_echo=I am argparse.', '--absl_echo=I am absl.'], + ['I am argparse.', 'I am absl.']), + ('subcommand_roll_dice', + 'main_subcommands', 'parse_flags_subcommands', + ['--argparse_echo=I am argparse.', '--absl_echo=I am absl.', + 'roll_dice', '--num_faces=12'], + ['I am argparse.', 'I am absl.', 'Rolled a dice: ']), + ('subcommand_shuffle', + 'main_subcommands', 'parse_flags_subcommands', + ['--argparse_echo=I am argparse.', '--absl_echo=I am absl.', + 'shuffle', 'a', 'b', 'c'], + ['I am argparse.', 'I am absl.', 'Shuffled: ']), + ) + def test_argparse_with_app_run( + self, main_func_name, flags_parser_func_name, args, output_strings): + env = os.environ.copy() + env['MAIN_FUNC'] = main_func_name + env['FLAGS_PARSER_FUNC'] = flags_parser_func_name + helper = _bazelize_command.get_executable_path( + 'absl/flags/tests/argparse_flags_test_helper') + try: + stdout = subprocess.check_output( + [helper] + args, env=env, universal_newlines=True) + except subprocess.CalledProcessError as e: + error_info = ('ERROR: argparse_helper failed\n' + 'Command: {}\n' + 'Exit code: {}\n' + '----- output -----\n{}' + '------------------') + error_info = error_info.format(e.cmd, e.returncode, + e.output + '\n' if e.output else '<empty>') + print(error_info, file=sys.stderr) + raise + + for output_string in output_strings: + self.assertIn(output_string, stdout) + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/flags/tests/argparse_flags_test_helper.py b/absl/flags/tests/argparse_flags_test_helper.py new file mode 100644 index 0000000..8cf42e6 --- /dev/null +++ b/absl/flags/tests/argparse_flags_test_helper.py @@ -0,0 +1,89 @@ +# Copyright 2018 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test helper for argparse_flags_test.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import random + +from absl import app +from absl import flags +from absl.flags import argparse_flags + +FLAGS = flags.FLAGS + +flags.DEFINE_string('absl_echo', None, 'The echo message from absl.flags.') + + +def parse_flags_simple(argv): + """Simple example for absl.flags + argparse.""" + parser = argparse_flags.ArgumentParser( + description='A simple example of argparse_flags.') + parser.add_argument( + '--argparse_echo', help='The echo message from argparse_flags') + return parser.parse_args(argv[1:]) + + +def main_simple(args): + print('--absl_echo is', FLAGS.absl_echo) + print('--argparse_echo is', args.argparse_echo) + + +def roll_dice(args): + print('Rolled a dice:', random.randint(1, args.num_faces)) + + +def shuffle(args): + inputs = list(args.inputs) + random.shuffle(inputs) + print('Shuffled:', ' '.join(inputs)) + + +def parse_flags_subcommands(argv): + """Subcommands example for absl.flags + argparse.""" + parser = argparse_flags.ArgumentParser( + description='A subcommands example of argparse_flags.') + parser.add_argument('--argparse_echo', + help='The echo message from argparse_flags') + + subparsers = parser.add_subparsers(help='The command to execute.') + + roll_dice_parser = subparsers.add_parser( + 'roll_dice', help='Roll a dice.') + roll_dice_parser.add_argument('--num_faces', type=int, default=6) + roll_dice_parser.set_defaults(command=roll_dice) + + shuffle_parser = subparsers.add_parser( + 'shuffle', help='Shuffle inputs.') + shuffle_parser.add_argument( + 'inputs', metavar='I', nargs='+', help='Inputs to shuffle.') + shuffle_parser.set_defaults(command=shuffle) + + return parser.parse_args(argv[1:]) + + +def main_subcommands(args): + main_simple(args) + args.command(args) + + +if __name__ == '__main__': + main_func_name = os.environ['MAIN_FUNC'] + flags_parser_func_name = os.environ['FLAGS_PARSER_FUNC'] + app.run(main=globals()[main_func_name], + flags_parser=globals()[flags_parser_func_name]) diff --git a/absl/flags/tests/flags_formatting_test.py b/absl/flags/tests/flags_formatting_test.py new file mode 100644 index 0000000..bb547ce --- /dev/null +++ b/absl/flags/tests/flags_formatting_test.py @@ -0,0 +1,217 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import flags +from absl.flags import _helpers +from absl.testing import absltest + +FLAGS = flags.FLAGS + + +class FlagsUnitTest(absltest.TestCase): + """Flags formatting Unit Test.""" + + def test_get_help_width(self): + """Verify that get_help_width() reflects _help_width.""" + default_help_width = _helpers._DEFAULT_HELP_WIDTH # Save. + self.assertEqual(80, _helpers._DEFAULT_HELP_WIDTH) + self.assertEqual(_helpers._DEFAULT_HELP_WIDTH, flags.get_help_width()) + _helpers._DEFAULT_HELP_WIDTH = 10 + self.assertEqual(_helpers._DEFAULT_HELP_WIDTH, flags.get_help_width()) + _helpers._DEFAULT_HELP_WIDTH = default_help_width # restore + + def test_text_wrap(self): + """Test that wrapping works as expected. + + Also tests that it is using global flags._help_width by default. + """ + default_help_width = _helpers._DEFAULT_HELP_WIDTH + _helpers._DEFAULT_HELP_WIDTH = 10 + + # Generate a string with length 40, no spaces + text = '' + expect = [] + for n in range(4): + line = str(n) + line += '123456789' + text += line + expect.append(line) + + # Verify we still break + wrapped = flags.text_wrap(text).split('\n') + self.assertEqual(4, len(wrapped)) + self.assertEqual(expect, wrapped) + + wrapped = flags.text_wrap(text, 80).split('\n') + self.assertEqual(1, len(wrapped)) + self.assertEqual([text], wrapped) + + # Normal case, breaking at word boundaries and rewriting new lines + input_value = 'a b c d e f g h' + expect = {1: ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'], + 2: ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'], + 3: ['a b', 'c d', 'e f', 'g h'], + 4: ['a b', 'c d', 'e f', 'g h'], + 5: ['a b c', 'd e f', 'g h'], + 6: ['a b c', 'd e f', 'g h'], + 7: ['a b c d', 'e f g h'], + 8: ['a b c d', 'e f g h'], + 9: ['a b c d e', 'f g h'], + 10: ['a b c d e', 'f g h'], + 11: ['a b c d e f', 'g h'], + 12: ['a b c d e f', 'g h'], + 13: ['a b c d e f g', 'h'], + 14: ['a b c d e f g', 'h'], + 15: ['a b c d e f g h']} + for width, exp in expect.items(): + self.assertEqual(exp, flags.text_wrap(input_value, width).split('\n')) + + # We turn lines with only whitespace into empty lines + # We strip from the right up to the first new line + self.assertEqual('', flags.text_wrap(' ')) + self.assertEqual('\n', flags.text_wrap(' \n ')) + self.assertEqual('\n', flags.text_wrap('\n\n')) + self.assertEqual('\n\n', flags.text_wrap('\n\n\n')) + self.assertEqual('\n', flags.text_wrap('\n ')) + self.assertEqual('a\n\nb', flags.text_wrap('a\n \nb')) + self.assertEqual('a\n\n\nb', flags.text_wrap('a\n \n \nb')) + self.assertEqual('a\nb', flags.text_wrap(' a\nb ')) + self.assertEqual('\na\nb', flags.text_wrap('\na\nb\n')) + self.assertEqual('\na\nb\n', flags.text_wrap(' \na\nb\n ')) + self.assertEqual('\na\nb\n', flags.text_wrap(' \na\nb\n\n')) + + # Double newline. + self.assertEqual('a\n\nb', flags.text_wrap(' a\n\n b')) + + # We respect prefix + self.assertEqual(' a\n b\n c', flags.text_wrap('a\nb\nc', 80, ' ')) + self.assertEqual('a\n b\n c', flags.text_wrap('a\nb\nc', 80, ' ', '')) + + # tabs + self.assertEqual('a\n b c', + flags.text_wrap('a\nb\tc', 80, ' ', '')) + self.assertEqual('a\n bb c', + flags.text_wrap('a\nbb\tc', 80, ' ', '')) + self.assertEqual('a\n bbb c', + flags.text_wrap('a\nbbb\tc', 80, ' ', '')) + self.assertEqual('a\n bbbb c', + flags.text_wrap('a\nbbbb\tc', 80, ' ', '')) + self.assertEqual('a\n b\n c\n d', + flags.text_wrap('a\nb\tc\td', 3, ' ', '')) + self.assertEqual('a\n b\n c\n d', + flags.text_wrap('a\nb\tc\td', 4, ' ', '')) + self.assertEqual('a\n b\n c\n d', + flags.text_wrap('a\nb\tc\td', 5, ' ', '')) + self.assertEqual('a\n b c\n d', + flags.text_wrap('a\nb\tc\td', 6, ' ', '')) + self.assertEqual('a\n b c\n d', + flags.text_wrap('a\nb\tc\td', 7, ' ', '')) + self.assertEqual('a\n b c\n d', + flags.text_wrap('a\nb\tc\td', 8, ' ', '')) + self.assertEqual('a\n b c\n d', + flags.text_wrap('a\nb\tc\td', 9, ' ', '')) + self.assertEqual('a\n b c d', + flags.text_wrap('a\nb\tc\td', 10, ' ', '')) + + # multiple tabs + self.assertEqual('a c', + flags.text_wrap('a\t\tc', 80, ' ', '')) + + _helpers._DEFAULT_HELP_WIDTH = default_help_width # restore + + def test_doc_to_help(self): + self.assertEqual('', flags.doc_to_help(' ')) + self.assertEqual('', flags.doc_to_help(' \n ')) + self.assertEqual('a\n\nb', flags.doc_to_help('a\n \nb')) + self.assertEqual('a\n\n\nb', flags.doc_to_help('a\n \n \nb')) + self.assertEqual('a b', flags.doc_to_help(' a\nb ')) + self.assertEqual('a b', flags.doc_to_help('\na\nb\n')) + self.assertEqual('a\n\nb', flags.doc_to_help('\na\n\nb\n')) + self.assertEqual('a b', flags.doc_to_help(' \na\nb\n ')) + # Different first line, one line empty - erm double new line. + self.assertEqual('a b c\n\nd', flags.doc_to_help('a\n b\n c\n\n d')) + self.assertEqual('a b\n c d', flags.doc_to_help('a\n b\n \tc\n d')) + self.assertEqual('a b\n c\n d', + flags.doc_to_help('a\n b\n \tc\n \td')) + + def test_doc_to_help_flag_values(self): + # !!!!!!!!!!!!!!!!!!!! + # The following doc string is taken as is directly from flags.py:FlagValues + # The intention of this test is to verify 'live' performance + # !!!!!!!!!!!!!!!!!!!! + """Used as a registry for 'Flag' objects. + + A 'FlagValues' can then scan command line arguments, passing flag + arguments through to the 'Flag' objects that it owns. It also + provides easy access to the flag values. Typically only one + 'FlagValues' object is needed by an application: flags.FLAGS + + This class is heavily overloaded: + + 'Flag' objects are registered via __setitem__: + FLAGS['longname'] = x # register a new flag + + The .value member of the registered 'Flag' objects can be accessed as + members of this 'FlagValues' object, through __getattr__. Both the + long and short name of the original 'Flag' objects can be used to + access its value: + FLAGS.longname # parsed flag value + FLAGS.x # parsed flag value (short name) + + Command line arguments are scanned and passed to the registered 'Flag' + objects through the __call__ method. Unparsed arguments, including + argv[0] (e.g. the program name) are returned. + argv = FLAGS(sys.argv) # scan command line arguments + + The original registered Flag objects can be retrieved through the use + """ + doc = flags.doc_to_help(self.test_doc_to_help_flag_values.__doc__) + # Test the general outline of the converted docs + lines = doc.splitlines() + self.assertEqual(17, len(lines)) + empty_lines = [index for index in range(len(lines)) if not lines[index]] + self.assertEqual([1, 3, 5, 8, 12, 15], empty_lines) + # test that some starting prefix is kept + flags_lines = [index for index in range(len(lines)) + if lines[index].startswith(' FLAGS')] + self.assertEqual([7, 10, 11], flags_lines) + # but other, especially common space has been removed + space_lines = [index for index in range(len(lines)) + if lines[index] and lines[index][0].isspace()] + self.assertEqual([7, 10, 11, 14], space_lines) + # No right space was kept + rspace_lines = [index for index in range(len(lines)) + if lines[index] != lines[index].rstrip()] + self.assertEqual([], rspace_lines) + # test double spaces are kept + self.assertEqual(True, lines[2].endswith('application: flags.FLAGS')) + + def test_text_wrap_raises_on_excessive_indent(self): + """Ensure an indent longer than line length raises.""" + self.assertRaises(ValueError, + flags.text_wrap, 'dummy', length=10, indent=' ' * 10) + + def test_text_wrap_raises_on_excessive_first_line(self): + """Ensure a first line indent longer than line length raises.""" + self.assertRaises( + ValueError, + flags.text_wrap, 'dummy', length=80, firstline_indent=' ' * 80) + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/flags/tests/flags_helpxml_test.py b/absl/flags/tests/flags_helpxml_test.py new file mode 100644 index 0000000..e2168ac --- /dev/null +++ b/absl/flags/tests/flags_helpxml_test.py @@ -0,0 +1,659 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the XML-format help generated by the flags.py module.""" + +import enum +import io +import os +import string +import sys +import xml.dom.minidom +import xml.sax.saxutils + +from absl import flags +from absl.flags import _helpers +from absl.flags.tests import module_bar +from absl.testing import absltest + + +class CreateXMLDOMElement(absltest.TestCase): + + def _check(self, name, value, expected_output): + doc = xml.dom.minidom.Document() + node = _helpers.create_xml_dom_element(doc, name, value) + output = node.toprettyxml(' ', encoding='utf-8') + self.assertEqual(expected_output, output) + + def test_create_xml_dom_element(self): + self._check('tag', '', b'<tag></tag>\n') + self._check('tag', 'plain text', b'<tag>plain text</tag>\n') + self._check('tag', '(x < y) && (a >= b)', + b'<tag>(x < y) && (a >= b)</tag>\n') + + # If the value is bytes with invalid unicode: + bytes_with_invalid_unicodes = b'\x81\xff' + # In python 3 the string representation is "b'\x81\xff'" so they are kept + # as "b'\x81\xff'". + self._check('tag', bytes_with_invalid_unicodes, + b"<tag>b'\\x81\\xff'</tag>\n") + + # Some unicode chars are illegal in xml + # (http://www.w3.org/TR/REC-xml/#charsets): + self._check('tag', u'\x0b\x02\x08\ufffe', b'<tag></tag>\n') + + # Valid unicode will be encoded: + self._check('tag', u'\xff', b'<tag>\xc3\xbf</tag>\n') + + +def _list_separators_in_xmlformat(separators, indent=''): + """Generates XML encoding of a list of list separators. + + Args: + separators: A list of list separators. Usually, this should be a + string whose characters are the valid list separators, e.g., ',' + means that both comma (',') and space (' ') are valid list + separators. + indent: A string that is added at the beginning of each generated + XML element. + + Returns: + A string. + """ + result = '' + separators = list(separators) + separators.sort() + for sep_char in separators: + result += ('%s<list_separator>%s</list_separator>\n' % + (indent, repr(sep_char))) + return result + + +class FlagCreateXMLDOMElement(absltest.TestCase): + """Test the create_xml_dom_element method for a single flag at a time. + + There is one test* method for each kind of DEFINE_* declaration. + """ + + def setUp(self): + # self.fv is a FlagValues object, just like flags.FLAGS. Each + # test registers one flag with this FlagValues. + self.fv = flags.FlagValues() + + def _check_flag_help_in_xml(self, flag_name, module_name, + expected_output, is_key=False): + flag_obj = self.fv[flag_name] + doc = xml.dom.minidom.Document() + element = flag_obj._create_xml_dom_element(doc, module_name, is_key=is_key) + output = element.toprettyxml(indent=' ') + self.assertMultiLineEqual(expected_output, output) + + def test_flag_help_in_xml_int(self): + flags.DEFINE_integer('index', 17, 'An integer flag', flag_values=self.fv) + expected_output_pattern = ( + '<flag>\n' + ' <file>module.name</file>\n' + ' <name>index</name>\n' + ' <meaning>An integer flag</meaning>\n' + ' <default>17</default>\n' + ' <current>%d</current>\n' + ' <type>int</type>\n' + '</flag>\n') + self._check_flag_help_in_xml('index', 'module.name', + expected_output_pattern % 17) + # Check that the output is correct even when the current value of + # a flag is different from the default one. + self.fv['index'].value = 20 + self._check_flag_help_in_xml('index', 'module.name', + expected_output_pattern % 20) + + def test_flag_help_in_xml_int_with_bounds(self): + flags.DEFINE_integer('nb_iters', 17, 'An integer flag', + lower_bound=5, upper_bound=27, + flag_values=self.fv) + expected_output = ( + '<flag>\n' + ' <key>yes</key>\n' + ' <file>module.name</file>\n' + ' <name>nb_iters</name>\n' + ' <meaning>An integer flag</meaning>\n' + ' <default>17</default>\n' + ' <current>17</current>\n' + ' <type>int</type>\n' + ' <lower_bound>5</lower_bound>\n' + ' <upper_bound>27</upper_bound>\n' + '</flag>\n') + self._check_flag_help_in_xml('nb_iters', 'module.name', expected_output, + is_key=True) + + def test_flag_help_in_xml_string(self): + flags.DEFINE_string('file_path', '/path/to/my/dir', 'A test string flag.', + flag_values=self.fv) + expected_output = ( + '<flag>\n' + ' <file>simple_module</file>\n' + ' <name>file_path</name>\n' + ' <meaning>A test string flag.</meaning>\n' + ' <default>/path/to/my/dir</default>\n' + ' <current>/path/to/my/dir</current>\n' + ' <type>string</type>\n' + '</flag>\n') + self._check_flag_help_in_xml('file_path', 'simple_module', expected_output) + + def test_flag_help_in_xml_string_with_xmlillegal_chars(self): + flags.DEFINE_string('file_path', '/path/to/\x08my/dir', + 'A test string flag.', flag_values=self.fv) + # '\x08' is not a legal character in XML 1.0 documents. Our + # current code purges such characters from the generated XML. + expected_output = ( + '<flag>\n' + ' <file>simple_module</file>\n' + ' <name>file_path</name>\n' + ' <meaning>A test string flag.</meaning>\n' + ' <default>/path/to/my/dir</default>\n' + ' <current>/path/to/my/dir</current>\n' + ' <type>string</type>\n' + '</flag>\n') + self._check_flag_help_in_xml('file_path', 'simple_module', expected_output) + + def test_flag_help_in_xml_boolean(self): + flags.DEFINE_boolean('use_gpu', False, 'Use gpu for performance.', + flag_values=self.fv) + expected_output = ( + '<flag>\n' + ' <key>yes</key>\n' + ' <file>a_module</file>\n' + ' <name>use_gpu</name>\n' + ' <meaning>Use gpu for performance.</meaning>\n' + ' <default>false</default>\n' + ' <current>false</current>\n' + ' <type>bool</type>\n' + '</flag>\n') + self._check_flag_help_in_xml('use_gpu', 'a_module', expected_output, + is_key=True) + + def test_flag_help_in_xml_enum(self): + flags.DEFINE_enum('cc_version', 'stable', ['stable', 'experimental'], + 'Compiler version to use.', flag_values=self.fv) + expected_output = ( + '<flag>\n' + ' <file>tool</file>\n' + ' <name>cc_version</name>\n' + ' <meaning><stable|experimental>: ' + 'Compiler version to use.</meaning>\n' + ' <default>stable</default>\n' + ' <current>stable</current>\n' + ' <type>string enum</type>\n' + ' <enum_value>stable</enum_value>\n' + ' <enum_value>experimental</enum_value>\n' + '</flag>\n') + self._check_flag_help_in_xml('cc_version', 'tool', expected_output) + + def test_flag_help_in_xml_enum_class(self): + class Version(enum.Enum): + STABLE = 0 + EXPERIMENTAL = 1 + + flags.DEFINE_enum_class('cc_version', 'STABLE', Version, + 'Compiler version to use.', flag_values=self.fv) + expected_output = ('<flag>\n' + ' <file>tool</file>\n' + ' <name>cc_version</name>\n' + ' <meaning><stable|experimental>: ' + 'Compiler version to use.</meaning>\n' + ' <default>stable</default>\n' + ' <current>Version.STABLE</current>\n' + ' <type>enum class</type>\n' + ' <enum_value>STABLE</enum_value>\n' + ' <enum_value>EXPERIMENTAL</enum_value>\n' + '</flag>\n') + self._check_flag_help_in_xml('cc_version', 'tool', expected_output) + + def test_flag_help_in_xml_comma_separated_list(self): + flags.DEFINE_list('files', 'a.cc,a.h,archive/old.zip', + 'Files to process.', flag_values=self.fv) + expected_output = ( + '<flag>\n' + ' <file>tool</file>\n' + ' <name>files</name>\n' + ' <meaning>Files to process.</meaning>\n' + ' <default>a.cc,a.h,archive/old.zip</default>\n' + ' <current>[\'a.cc\', \'a.h\', \'archive/old.zip\']</current>\n' + ' <type>comma separated list of strings</type>\n' + ' <list_separator>\',\'</list_separator>\n' + '</flag>\n') + self._check_flag_help_in_xml('files', 'tool', expected_output) + + def test_list_as_default_argument_comma_separated_list(self): + flags.DEFINE_list('allow_users', ['alice', 'bob'], + 'Users with access.', flag_values=self.fv) + expected_output = ( + '<flag>\n' + ' <file>tool</file>\n' + ' <name>allow_users</name>\n' + ' <meaning>Users with access.</meaning>\n' + ' <default>alice,bob</default>\n' + ' <current>[\'alice\', \'bob\']</current>\n' + ' <type>comma separated list of strings</type>\n' + ' <list_separator>\',\'</list_separator>\n' + '</flag>\n') + self._check_flag_help_in_xml('allow_users', 'tool', expected_output) + + def test_none_as_default_arguments_comma_separated_list(self): + flags.DEFINE_list('allow_users', None, + 'Users with access.', flag_values=self.fv) + expected_output = ( + '<flag>\n' + ' <file>tool</file>\n' + ' <name>allow_users</name>\n' + ' <meaning>Users with access.</meaning>\n' + ' <default></default>\n' + ' <current>None</current>\n' + ' <type>comma separated list of strings</type>\n' + ' <list_separator>\',\'</list_separator>\n' + '</flag>\n') + self._check_flag_help_in_xml('allow_users', 'tool', expected_output) + + def test_flag_help_in_xml_space_separated_list(self): + flags.DEFINE_spaceseplist('dirs', 'src libs bin', + 'Directories to search.', flag_values=self.fv) + expected_separators = sorted(string.whitespace) + expected_output = ( + '<flag>\n' + ' <file>tool</file>\n' + ' <name>dirs</name>\n' + ' <meaning>Directories to search.</meaning>\n' + ' <default>src libs bin</default>\n' + ' <current>[\'src\', \'libs\', \'bin\']</current>\n' + ' <type>whitespace separated list of strings</type>\n' + 'LIST_SEPARATORS' + '</flag>\n').replace('LIST_SEPARATORS', + _list_separators_in_xmlformat(expected_separators, + indent=' ')) + self._check_flag_help_in_xml('dirs', 'tool', expected_output) + + def test_flag_help_in_xml_space_separated_list_with_comma_compat(self): + flags.DEFINE_spaceseplist('dirs', 'src libs,bin', + 'Directories to search.', comma_compat=True, + flag_values=self.fv) + expected_separators = sorted(string.whitespace + ',') + expected_output = ( + '<flag>\n' + ' <file>tool</file>\n' + ' <name>dirs</name>\n' + ' <meaning>Directories to search.</meaning>\n' + ' <default>src libs bin</default>\n' + ' <current>[\'src\', \'libs\', \'bin\']</current>\n' + ' <type>whitespace or comma separated list of strings</type>\n' + 'LIST_SEPARATORS' + '</flag>\n').replace('LIST_SEPARATORS', + _list_separators_in_xmlformat(expected_separators, + indent=' ')) + self._check_flag_help_in_xml('dirs', 'tool', expected_output) + + def test_flag_help_in_xml_multi_string(self): + flags.DEFINE_multi_string('to_delete', ['a.cc', 'b.h'], + 'Files to delete', flag_values=self.fv) + expected_output = ( + '<flag>\n' + ' <file>tool</file>\n' + ' <name>to_delete</name>\n' + ' <meaning>Files to delete;\n' + ' repeat this option to specify a list of values</meaning>\n' + ' <default>[\'a.cc\', \'b.h\']</default>\n' + ' <current>[\'a.cc\', \'b.h\']</current>\n' + ' <type>multi string</type>\n' + '</flag>\n') + self._check_flag_help_in_xml('to_delete', 'tool', expected_output) + + def test_flag_help_in_xml_multi_int(self): + flags.DEFINE_multi_integer('cols', [5, 7, 23], + 'Columns to select', flag_values=self.fv) + expected_output = ( + '<flag>\n' + ' <file>tool</file>\n' + ' <name>cols</name>\n' + ' <meaning>Columns to select;\n ' + 'repeat this option to specify a list of values</meaning>\n' + ' <default>[5, 7, 23]</default>\n' + ' <current>[5, 7, 23]</current>\n' + ' <type>multi int</type>\n' + '</flag>\n') + self._check_flag_help_in_xml('cols', 'tool', expected_output) + + def test_flag_help_in_xml_multi_enum(self): + flags.DEFINE_multi_enum('flavours', ['APPLE', 'BANANA'], + ['APPLE', 'BANANA', 'CHERRY'], + 'Compilation flavour.', flag_values=self.fv) + expected_output = ( + '<flag>\n' + ' <file>tool</file>\n' + ' <name>flavours</name>\n' + ' <meaning><APPLE|BANANA|CHERRY>: Compilation flavour.;\n' + ' repeat this option to specify a list of values</meaning>\n' + ' <default>[\'APPLE\', \'BANANA\']</default>\n' + ' <current>[\'APPLE\', \'BANANA\']</current>\n' + ' <type>multi string enum</type>\n' + ' <enum_value>APPLE</enum_value>\n' + ' <enum_value>BANANA</enum_value>\n' + ' <enum_value>CHERRY</enum_value>\n' + '</flag>\n') + self._check_flag_help_in_xml('flavours', 'tool', expected_output) + + def test_flag_help_in_xml_multi_enum_class_singleton_default(self): + class Fruit(enum.Enum): + ORANGE = 0 + BANANA = 1 + + flags.DEFINE_multi_enum_class('fruit', ['ORANGE'], + Fruit, + 'The fruit flag.', flag_values=self.fv) + expected_output = ( + '<flag>\n' + ' <file>tool</file>\n' + ' <name>fruit</name>\n' + ' <meaning><orange|banana>: The fruit flag.;\n' + ' repeat this option to specify a list of values</meaning>\n' + ' <default>orange</default>\n' + ' <current>orange</current>\n' + ' <type>multi enum class</type>\n' + ' <enum_value>ORANGE</enum_value>\n' + ' <enum_value>BANANA</enum_value>\n' + '</flag>\n') + self._check_flag_help_in_xml('fruit', 'tool', expected_output) + + def test_flag_help_in_xml_multi_enum_class_list_default(self): + class Fruit(enum.Enum): + ORANGE = 0 + BANANA = 1 + + flags.DEFINE_multi_enum_class('fruit', ['ORANGE', 'BANANA'], + Fruit, + 'The fruit flag.', flag_values=self.fv) + expected_output = ( + '<flag>\n' + ' <file>tool</file>\n' + ' <name>fruit</name>\n' + ' <meaning><orange|banana>: The fruit flag.;\n' + ' repeat this option to specify a list of values</meaning>\n' + ' <default>orange,banana</default>\n' + ' <current>orange,banana</current>\n' + ' <type>multi enum class</type>\n' + ' <enum_value>ORANGE</enum_value>\n' + ' <enum_value>BANANA</enum_value>\n' + '</flag>\n') + self._check_flag_help_in_xml('fruit', 'tool', expected_output) + +# The next EXPECTED_HELP_XML_* constants are parts of a template for +# the expected XML output from WriteHelpInXMLFormatTest below. When +# we assemble these parts into a single big string, we'll take into +# account the ordering between the name of the main module and the +# name of module_bar. Next, we'll fill in the docstring for this +# module (%(usage_doc)s), the name of the main module +# (%(main_module_name)s) and the name of the module module_bar +# (%(module_bar_name)s). See WriteHelpInXMLFormatTest below. +EXPECTED_HELP_XML_START = """\ +<?xml version="1.0" encoding="utf-8"?> +<AllFlags> + <program>%(basename_of_argv0)s</program> + <usage>%(usage_doc)s</usage> +""" + +EXPECTED_HELP_XML_FOR_FLAGS_FROM_MAIN_MODULE = """\ + <flag> + <key>yes</key> + <file>%(main_module_name)s</file> + <name>allow_users</name> + <meaning>Users with access.</meaning> + <default>alice,bob</default> + <current>['alice', 'bob']</current> + <type>comma separated list of strings</type> + <list_separator>','</list_separator> + </flag> + <flag> + <key>yes</key> + <file>%(main_module_name)s</file> + <name>cc_version</name> + <meaning><stable|experimental>: Compiler version to use.</meaning> + <default>stable</default> + <current>stable</current> + <type>string enum</type> + <enum_value>stable</enum_value> + <enum_value>experimental</enum_value> + </flag> + <flag> + <key>yes</key> + <file>%(main_module_name)s</file> + <name>cols</name> + <meaning>Columns to select; + repeat this option to specify a list of values</meaning> + <default>[5, 7, 23]</default> + <current>[5, 7, 23]</current> + <type>multi int</type> + </flag> + <flag> + <key>yes</key> + <file>%(main_module_name)s</file> + <name>dirs</name> + <meaning>Directories to create.</meaning> + <default>src libs bins</default> + <current>['src', 'libs', 'bins']</current> + <type>whitespace separated list of strings</type> +%(whitespace_separators)s </flag> + <flag> + <key>yes</key> + <file>%(main_module_name)s</file> + <name>file_path</name> + <meaning>A test string flag.</meaning> + <default>/path/to/my/dir</default> + <current>/path/to/my/dir</current> + <type>string</type> + </flag> + <flag> + <key>yes</key> + <file>%(main_module_name)s</file> + <name>files</name> + <meaning>Files to process.</meaning> + <default>a.cc,a.h,archive/old.zip</default> + <current>['a.cc', 'a.h', 'archive/old.zip']</current> + <type>comma separated list of strings</type> + <list_separator>\',\'</list_separator> + </flag> + <flag> + <key>yes</key> + <file>%(main_module_name)s</file> + <name>flavours</name> + <meaning><APPLE|BANANA|CHERRY>: Compilation flavour.; + repeat this option to specify a list of values</meaning> + <default>['APPLE', 'BANANA']</default> + <current>['APPLE', 'BANANA']</current> + <type>multi string enum</type> + <enum_value>APPLE</enum_value> + <enum_value>BANANA</enum_value> + <enum_value>CHERRY</enum_value> + </flag> + <flag> + <key>yes</key> + <file>%(main_module_name)s</file> + <name>index</name> + <meaning>An integer flag</meaning> + <default>17</default> + <current>17</current> + <type>int</type> + </flag> + <flag> + <key>yes</key> + <file>%(main_module_name)s</file> + <name>nb_iters</name> + <meaning>An integer flag</meaning> + <default>17</default> + <current>17</current> + <type>int</type> + <lower_bound>5</lower_bound> + <upper_bound>27</upper_bound> + </flag> + <flag> + <key>yes</key> + <file>%(main_module_name)s</file> + <name>to_delete</name> + <meaning>Files to delete; + repeat this option to specify a list of values</meaning> + <default>['a.cc', 'b.h']</default> + <current>['a.cc', 'b.h']</current> + <type>multi string</type> + </flag> + <flag> + <key>yes</key> + <file>%(main_module_name)s</file> + <name>use_gpu</name> + <meaning>Use gpu for performance.</meaning> + <default>false</default> + <current>false</current> + <type>bool</type> + </flag> +""" + +EXPECTED_HELP_XML_FOR_FLAGS_FROM_MODULE_BAR = """\ + <flag> + <file>%(module_bar_name)s</file> + <name>tmod_bar_t</name> + <meaning>Sample int flag.</meaning> + <default>4</default> + <current>4</current> + <type>int</type> + </flag> + <flag> + <key>yes</key> + <file>%(module_bar_name)s</file> + <name>tmod_bar_u</name> + <meaning>Sample int flag.</meaning> + <default>5</default> + <current>5</current> + <type>int</type> + </flag> + <flag> + <file>%(module_bar_name)s</file> + <name>tmod_bar_v</name> + <meaning>Sample int flag.</meaning> + <default>6</default> + <current>6</current> + <type>int</type> + </flag> + <flag> + <file>%(module_bar_name)s</file> + <name>tmod_bar_x</name> + <meaning>Boolean flag.</meaning> + <default>true</default> + <current>true</current> + <type>bool</type> + </flag> + <flag> + <file>%(module_bar_name)s</file> + <name>tmod_bar_y</name> + <meaning>String flag.</meaning> + <default>default</default> + <current>default</current> + <type>string</type> + </flag> + <flag> + <key>yes</key> + <file>%(module_bar_name)s</file> + <name>tmod_bar_z</name> + <meaning>Another boolean flag from module bar.</meaning> + <default>false</default> + <current>false</current> + <type>bool</type> + </flag> +""" + +EXPECTED_HELP_XML_END = """\ +</AllFlags> +""" + + +class WriteHelpInXMLFormatTest(absltest.TestCase): + """Big test of FlagValues.write_help_in_xml_format, with several flags.""" + + def test_write_help_in_xmlformat(self): + fv = flags.FlagValues() + # Since these flags are defined by the top module, they are all key. + flags.DEFINE_integer('index', 17, 'An integer flag', flag_values=fv) + flags.DEFINE_integer('nb_iters', 17, 'An integer flag', + lower_bound=5, upper_bound=27, flag_values=fv) + flags.DEFINE_string('file_path', '/path/to/my/dir', 'A test string flag.', + flag_values=fv) + flags.DEFINE_boolean('use_gpu', False, 'Use gpu for performance.', + flag_values=fv) + flags.DEFINE_enum('cc_version', 'stable', ['stable', 'experimental'], + 'Compiler version to use.', flag_values=fv) + flags.DEFINE_list('files', 'a.cc,a.h,archive/old.zip', + 'Files to process.', flag_values=fv) + flags.DEFINE_list('allow_users', ['alice', 'bob'], + 'Users with access.', flag_values=fv) + flags.DEFINE_spaceseplist('dirs', 'src libs bins', + 'Directories to create.', flag_values=fv) + flags.DEFINE_multi_string('to_delete', ['a.cc', 'b.h'], + 'Files to delete', flag_values=fv) + flags.DEFINE_multi_integer('cols', [5, 7, 23], + 'Columns to select', flag_values=fv) + flags.DEFINE_multi_enum('flavours', ['APPLE', 'BANANA'], + ['APPLE', 'BANANA', 'CHERRY'], + 'Compilation flavour.', flag_values=fv) + # Define a few flags in a different module. + module_bar.define_flags(flag_values=fv) + # And declare only a few of them to be key. This way, we have + # different kinds of flags, defined in different modules, and not + # all of them are key flags. + flags.declare_key_flag('tmod_bar_z', flag_values=fv) + flags.declare_key_flag('tmod_bar_u', flag_values=fv) + + # Generate flag help in XML format in the StringIO sio. + sio = io.StringIO() + fv.write_help_in_xml_format(sio) + + # Check that we got the expected result. + expected_output_template = EXPECTED_HELP_XML_START + main_module_name = sys.argv[0] + module_bar_name = module_bar.__name__ + + if main_module_name < module_bar_name: + expected_output_template += EXPECTED_HELP_XML_FOR_FLAGS_FROM_MAIN_MODULE + expected_output_template += EXPECTED_HELP_XML_FOR_FLAGS_FROM_MODULE_BAR + else: + expected_output_template += EXPECTED_HELP_XML_FOR_FLAGS_FROM_MODULE_BAR + expected_output_template += EXPECTED_HELP_XML_FOR_FLAGS_FROM_MAIN_MODULE + + expected_output_template += EXPECTED_HELP_XML_END + + # XML representation of the whitespace list separators. + whitespace_separators = _list_separators_in_xmlformat(string.whitespace, + indent=' ') + expected_output = ( + expected_output_template % + {'basename_of_argv0': os.path.basename(sys.argv[0]), + 'usage_doc': sys.modules['__main__'].__doc__, + 'main_module_name': main_module_name, + 'module_bar_name': module_bar_name, + 'whitespace_separators': whitespace_separators}) + + actual_output = sio.getvalue() + self.assertMultiLineEqual(expected_output, actual_output) + + # Also check that our result is valid XML. minidom.parseString + # throws an xml.parsers.expat.ExpatError in case of an error. + xml.dom.minidom.parseString(actual_output) + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/flags/tests/flags_numeric_bounds_test.py b/absl/flags/tests/flags_numeric_bounds_test.py new file mode 100644 index 0000000..d3c2a95 --- /dev/null +++ b/absl/flags/tests/flags_numeric_bounds_test.py @@ -0,0 +1,105 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for lower/upper bounds validators for numeric flags.""" + +from unittest import mock +from absl import flags +from absl.flags import _validators +from absl.testing import absltest + + +class NumericFlagBoundsTest(absltest.TestCase): + + def setUp(self): + super(NumericFlagBoundsTest, self).setUp() + self.flag_values = flags.FlagValues() + + def test_no_validator_if_no_bounds(self): + """Validator is not registered if lower and upper bound are None.""" + with mock.patch.object(_validators, 'register_validator' + ) as register_validator: + flags.DEFINE_integer('positive_flag', None, 'positive int', + lower_bound=0, flag_values=self.flag_values) + register_validator.assert_called_once_with( + 'positive_flag', mock.ANY, flag_values=self.flag_values) + with mock.patch.object(_validators, 'register_validator' + ) as register_validator: + flags.DEFINE_integer('int_flag', None, 'just int', + flag_values=self.flag_values) + register_validator.assert_not_called() + + def test_success(self): + flags.DEFINE_integer('int_flag', 5, 'Just integer', + flag_values=self.flag_values) + argv = ('./program', '--int_flag=13') + self.flag_values(argv) + self.assertEqual(13, self.flag_values.int_flag) + self.flag_values.int_flag = 25 + self.assertEqual(25, self.flag_values.int_flag) + + def test_success_if_none(self): + flags.DEFINE_integer('int_flag', None, '', + lower_bound=0, upper_bound=5, + flag_values=self.flag_values) + argv = ('./program',) + self.flag_values(argv) + self.assertIsNone(self.flag_values.int_flag) + + def test_success_if_exactly_equals(self): + flags.DEFINE_float('float_flag', None, '', + lower_bound=1, upper_bound=1, + flag_values=self.flag_values) + argv = ('./program', '--float_flag=1') + self.flag_values(argv) + self.assertEqual(1, self.flag_values.float_flag) + + def test_exception_if_smaller(self): + flags.DEFINE_integer('int_flag', None, '', + lower_bound=0, upper_bound=5, + flag_values=self.flag_values) + argv = ('./program', '--int_flag=-1') + try: + self.flag_values(argv) + except flags.IllegalFlagValueError as e: + text = 'flag --int_flag=-1: -1 is not an integer in the range [0, 5]' + self.assertEqual(text, str(e)) + + +class SettingFlagAfterStartTest(absltest.TestCase): + + def setUp(self): + self.flag_values = flags.FlagValues() + + def test_success(self): + flags.DEFINE_integer('int_flag', None, 'Just integer', + flag_values=self.flag_values) + argv = ('./program', '--int_flag=13') + self.flag_values(argv) + self.assertEqual(13, self.flag_values.int_flag) + self.flag_values.int_flag = 25 + self.assertEqual(25, self.flag_values.int_flag) + + def test_exception_if_setting_integer_flag_outside_bounds(self): + flags.DEFINE_integer('int_flag', None, 'Just integer', lower_bound=0, + flag_values=self.flag_values) + argv = ('./program', '--int_flag=13') + self.flag_values(argv) + self.assertEqual(13, self.flag_values.int_flag) + with self.assertRaises(flags.IllegalFlagValueError): + self.flag_values.int_flag = -2 + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/flags/tests/flags_test.py b/absl/flags/tests/flags_test.py new file mode 100644 index 0000000..8a42bc9 --- /dev/null +++ b/absl/flags/tests/flags_test.py @@ -0,0 +1,2922 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for absl.flags used as a package.""" + +import contextlib +import enum +import io +import os +import shutil +import sys +import tempfile +import unittest + +from absl import flags +from absl.flags import _exceptions +from absl.flags import _helpers +from absl.flags.tests import module_bar +from absl.flags.tests import module_baz +from absl.flags.tests import module_foo +from absl.testing import absltest + +FLAGS = flags.FLAGS + + +@contextlib.contextmanager +def _use_gnu_getopt(flag_values, use_gnu_get_opt): + old_use_gnu_get_opt = flag_values.is_gnu_getopt() + flag_values.set_gnu_getopt(use_gnu_get_opt) + yield + flag_values.set_gnu_getopt(old_use_gnu_get_opt) + + +class FlagDictToArgsTest(absltest.TestCase): + + def test_flatten_google_flag_map(self): + arg_dict = { + 'week-end': None, + 'estudia': False, + 'trabaja': False, + 'party': True, + 'monday': 'party', + 'score': 42, + 'loadthatstuff': [42, 'hello', 'goodbye'], + } + self.assertSameElements( + ('--week-end', '--noestudia', '--notrabaja', '--party', + '--monday=party', '--score=42', '--loadthatstuff=42,hello,goodbye'), + flags.flag_dict_to_args(arg_dict)) + + def test_flatten_google_flag_map_with_multi_flag(self): + arg_dict = { + 'some_list': ['value1', 'value2'], + 'some_multi_string': ['value3', 'value4'], + } + self.assertSameElements( + ('--some_list=value1,value2', '--some_multi_string=value3', + '--some_multi_string=value4'), + flags.flag_dict_to_args(arg_dict, multi_flags={'some_multi_string'})) + + +class Fruit(enum.Enum): + APPLE = object() + ORANGE = object() + + +class CaseSensitiveFruit(enum.Enum): + apple = 1 + orange = 2 + APPLE = 3 + + +class EmptyEnum(enum.Enum): + pass + + +class AliasFlagsTest(absltest.TestCase): + + def setUp(self): + super(AliasFlagsTest, self).setUp() + self.flags = flags.FlagValues() + + @property + def alias(self): + return self.flags['alias'] + + @property + def aliased(self): + return self.flags['aliased'] + + def define_alias(self, *args, **kwargs): + flags.DEFINE_alias(*args, flag_values=self.flags, **kwargs) + + def define_integer(self, *args, **kwargs): + flags.DEFINE_integer(*args, flag_values=self.flags, **kwargs) + + def define_multi_integer(self, *args, **kwargs): + flags.DEFINE_multi_integer(*args, flag_values=self.flags, **kwargs) + + def define_string(self, *args, **kwargs): + flags.DEFINE_string(*args, flag_values=self.flags, **kwargs) + + def assert_alias_mirrors_aliased(self, alias, aliased, ignore_due_to_bug=()): + # A few sanity checks to avoid false success + self.assertIn('FlagAlias', alias.__class__.__qualname__) + self.assertIsNot(alias, aliased) + self.assertNotEqual(aliased.name, alias.name) + + alias_state = {} + aliased_state = {} + attrs = { + 'allow_hide_cpp', + 'allow_override', + 'allow_override_cpp', + 'allow_overwrite', + 'allow_using_method_names', + 'boolean', + 'default', + 'default_as_str', + 'default_unparsed', + # TODO(rlevasseur): This should match, but a bug prevents it from being + # in sync. + # 'using_default_value', + 'value', + } + attrs.difference_update(ignore_due_to_bug) + + for attr in attrs: + alias_state[attr] = getattr(alias, attr) + aliased_state[attr] = getattr(aliased, attr) + + self.assertEqual(aliased_state, alias_state, 'LHS is aliased; RHS is alias') + + def test_serialize_multi(self): + self.define_multi_integer('aliased', [0, 1], '') + self.define_alias('alias', 'aliased') + + actual = self.alias.serialize() + # TODO(rlevasseur): This should check for --alias=0\n--alias=1, but + # a bug causes it to serialize incorrectly. + self.assertEqual('--alias=[0, 1]', actual) + + def test_allow_overwrite_false(self): + self.define_integer('aliased', None, 'help', allow_overwrite=False) + self.define_alias('alias', 'aliased') + + with self.assertRaisesRegex(flags.IllegalFlagValueError, 'already defined'): + self.flags(['./program', '--alias=1', '--aliased=2']) + + self.assertEqual(1, self.alias.value) + self.assertEqual(1, self.aliased.value) + + def test_aliasing_multi_no_default(self): + + def define_flags(): + self.flags = flags.FlagValues() + self.define_multi_integer('aliased', None, 'help') + self.define_alias('alias', 'aliased') + + with self.subTest('after defining'): + define_flags() + self.assert_alias_mirrors_aliased(self.alias, self.aliased) + self.assertIsNone(self.alias.value) + + with self.subTest('set alias'): + define_flags() + self.flags(['./program', '--alias=1', '--alias=2']) + self.assertEqual([1, 2], self.alias.value) + self.assert_alias_mirrors_aliased(self.alias, self.aliased) + + with self.subTest('set aliased'): + define_flags() + self.flags(['./program', '--aliased=1', '--aliased=2']) + self.assertEqual([1, 2], self.alias.value) + self.assert_alias_mirrors_aliased(self.alias, self.aliased) + + with self.subTest('not setting anything'): + define_flags() + self.flags(['./program']) + self.assertEqual(None, self.alias.value) + self.assert_alias_mirrors_aliased(self.alias, self.aliased) + + def test_aliasing_multi_with_default(self): + + def define_flags(): + self.flags = flags.FlagValues() + self.define_multi_integer('aliased', [0], 'help') + self.define_alias('alias', 'aliased') + + with self.subTest('after defining'): + define_flags() + self.assertEqual([0], self.alias.default) + self.assert_alias_mirrors_aliased(self.alias, self.aliased) + + with self.subTest('set alias'): + define_flags() + self.flags(['./program', '--alias=1', '--alias=2']) + self.assertEqual([1, 2], self.alias.value) + self.assert_alias_mirrors_aliased(self.alias, self.aliased) + + self.assertEqual(2, self.alias.present) + # TODO(rlevasseur): This should assert 0, but a bug with aliases and + # MultiFlag causes the alias to increment aliased's present counter. + self.assertEqual(2, self.aliased.present) + + with self.subTest('set aliased'): + define_flags() + self.flags(['./program', '--aliased=1', '--aliased=2']) + self.assertEqual([1, 2], self.alias.value) + self.assert_alias_mirrors_aliased(self.alias, self.aliased) + self.assertEqual(0, self.alias.present) + + # TODO(rlevasseur): This should assert 0, but a bug with aliases and + # MultiFlag causes the alias to increment aliased present counter. + self.assertEqual(2, self.aliased.present) + + with self.subTest('not setting anything'): + define_flags() + self.flags(['./program']) + self.assertEqual([0], self.alias.value) + self.assert_alias_mirrors_aliased(self.alias, self.aliased) + self.assertEqual(0, self.alias.present) + self.assertEqual(0, self.aliased.present) + + def test_aliasing_regular(self): + + def define_flags(): + self.flags = flags.FlagValues() + self.define_string('aliased', '', 'help') + self.define_alias('alias', 'aliased') + + define_flags() + self.assert_alias_mirrors_aliased(self.alias, self.aliased) + + self.flags(['./program', '--alias=1']) + self.assertEqual('1', self.alias.value) + self.assert_alias_mirrors_aliased(self.alias, self.aliased) + self.assertEqual(1, self.alias.present) + self.assertEqual('--alias=1', self.alias.serialize()) + self.assertEqual(1, self.aliased.present) + + define_flags() + self.flags(['./program', '--aliased=2']) + self.assertEqual('2', self.alias.value) + self.assert_alias_mirrors_aliased(self.alias, self.aliased) + self.assertEqual(0, self.alias.present) + self.assertEqual('--alias=2', self.alias.serialize()) + self.assertEqual(1, self.aliased.present) + + def test_defining_alias_doesnt_affect_aliased_state_regular(self): + self.define_string('aliased', 'default', 'help') + self.define_alias('alias', 'aliased') + + self.assertEqual(0, self.aliased.present) + self.assertEqual(0, self.alias.present) + + def test_defining_alias_doesnt_affect_aliased_state_multi(self): + self.define_multi_integer('aliased', [0], 'help') + self.define_alias('alias', 'aliased') + + self.assertEqual([0], self.aliased.value) + self.assertEqual([0], self.aliased.default) + self.assertEqual(0, self.aliased.present) + + self.assertEqual([0], self.aliased.value) + self.assertEqual([0], self.aliased.default) + self.assertEqual(0, self.alias.present) + + +class FlagsUnitTest(absltest.TestCase): + """Flags Unit Test.""" + + maxDiff = None + + def test_flags(self): + """Test normal usage with no (expected) errors.""" + # Define flags + number_test_framework_flags = len(FLAGS) + repeat_help = 'how many times to repeat (0-5)' + flags.DEFINE_integer( + 'repeat', 4, repeat_help, lower_bound=0, short_name='r') + flags.DEFINE_string('name', 'Bob', 'namehelp') + flags.DEFINE_boolean('debug', 0, 'debughelp') + flags.DEFINE_boolean('q', 1, 'quiet mode') + flags.DEFINE_boolean('quack', 0, "superstring of 'q'") + flags.DEFINE_boolean('noexec', 1, 'boolean flag with no as prefix') + flags.DEFINE_float('float', 3.14, 'using floats') + flags.DEFINE_integer('octal', '0o666', 'using octals') + flags.DEFINE_integer('decimal', '666', 'using decimals') + flags.DEFINE_integer('hexadecimal', '0x666', 'using hexadecimals') + flags.DEFINE_integer('x', 3, 'how eXtreme to be') + flags.DEFINE_integer('l', 0x7fffffff00000000, 'how long to be') + flags.DEFINE_list('args', 'v=1,"vmodule=a=0,b=2"', 'a list of arguments') + flags.DEFINE_list('letters', 'a,b,c', 'a list of letters') + flags.DEFINE_list('numbers', [1, 2, 3], 'a list of numbers') + flags.DEFINE_enum('kwery', None, ['who', 'what', 'Why', 'where', 'when'], + '?') + flags.DEFINE_enum( + 'sense', None, ['Case', 'case', 'CASE'], '?', case_sensitive=True) + flags.DEFINE_enum( + 'cases', + None, ['UPPER', 'lower', 'Initial', 'Ot_HeR'], + '?', + case_sensitive=False) + flags.DEFINE_enum( + 'funny', + None, ['Joke', 'ha', 'ha', 'ha', 'ha'], + '?', + case_sensitive=True) + flags.DEFINE_enum( + 'blah', + None, ['bla', 'Blah', 'BLAH', 'blah'], + '?', + case_sensitive=False) + flags.DEFINE_string( + 'only_once', None, 'test only sets this once', allow_overwrite=False) + flags.DEFINE_string( + 'universe', + None, + 'test tries to set this three times', + allow_overwrite=False) + + # Specify number of flags defined above. The short_name defined + # for 'repeat' counts as an extra flag. + number_defined_flags = 22 + 1 + self.assertLen(FLAGS, number_defined_flags + number_test_framework_flags) + + self.assertEqual(FLAGS.repeat, 4) + self.assertEqual(FLAGS.name, 'Bob') + self.assertEqual(FLAGS.debug, 0) + self.assertEqual(FLAGS.q, 1) + self.assertEqual(FLAGS.octal, 0o666) + self.assertEqual(FLAGS.decimal, 666) + self.assertEqual(FLAGS.hexadecimal, 0x666) + self.assertEqual(FLAGS.x, 3) + self.assertEqual(FLAGS.l, 0x7fffffff00000000) + self.assertEqual(FLAGS.args, ['v=1', 'vmodule=a=0,b=2']) + self.assertEqual(FLAGS.letters, ['a', 'b', 'c']) + self.assertEqual(FLAGS.numbers, [1, 2, 3]) + self.assertIsNone(FLAGS.kwery) + self.assertIsNone(FLAGS.sense) + self.assertIsNone(FLAGS.cases) + self.assertIsNone(FLAGS.funny) + self.assertIsNone(FLAGS.blah) + + flag_values = FLAGS.flag_values_dict() + self.assertEqual(flag_values['repeat'], 4) + self.assertEqual(flag_values['name'], 'Bob') + self.assertEqual(flag_values['debug'], 0) + self.assertEqual(flag_values['r'], 4) # Short for repeat. + self.assertEqual(flag_values['q'], 1) + self.assertEqual(flag_values['quack'], 0) + self.assertEqual(flag_values['x'], 3) + self.assertEqual(flag_values['l'], 0x7fffffff00000000) + self.assertEqual(flag_values['args'], ['v=1', 'vmodule=a=0,b=2']) + self.assertEqual(flag_values['letters'], ['a', 'b', 'c']) + self.assertEqual(flag_values['numbers'], [1, 2, 3]) + self.assertIsNone(flag_values['kwery']) + self.assertIsNone(flag_values['sense']) + self.assertIsNone(flag_values['cases']) + self.assertIsNone(flag_values['funny']) + self.assertIsNone(flag_values['blah']) + + # Verify string form of defaults + self.assertEqual(FLAGS['repeat'].default_as_str, "'4'") + self.assertEqual(FLAGS['name'].default_as_str, "'Bob'") + self.assertEqual(FLAGS['debug'].default_as_str, "'false'") + self.assertEqual(FLAGS['q'].default_as_str, "'true'") + self.assertEqual(FLAGS['quack'].default_as_str, "'false'") + self.assertEqual(FLAGS['noexec'].default_as_str, "'true'") + self.assertEqual(FLAGS['x'].default_as_str, "'3'") + self.assertEqual(FLAGS['l'].default_as_str, "'9223372032559808512'") + self.assertEqual(FLAGS['args'].default_as_str, '\'v=1,"vmodule=a=0,b=2"\'') + self.assertEqual(FLAGS['letters'].default_as_str, "'a,b,c'") + self.assertEqual(FLAGS['numbers'].default_as_str, "'1,2,3'") + + # Verify that the iterator for flags yields all the keys + keys = list(FLAGS) + keys.sort() + reg_flags = list(FLAGS._flags()) + reg_flags.sort() + self.assertEqual(keys, reg_flags) + + # Parse flags + # .. empty command line + argv = ('./program',) + argv = FLAGS(argv) + self.assertLen(argv, 1, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + + # .. non-empty command line + argv = ('./program', '--debug', '--name=Bob', '-q', '--x=8') + argv = FLAGS(argv) + self.assertLen(argv, 1, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + self.assertEqual(FLAGS['debug'].present, 1) + FLAGS['debug'].present = 0 # Reset + self.assertEqual(FLAGS['name'].present, 1) + FLAGS['name'].present = 0 # Reset + self.assertEqual(FLAGS['q'].present, 1) + FLAGS['q'].present = 0 # Reset + self.assertEqual(FLAGS['x'].present, 1) + FLAGS['x'].present = 0 # Reset + + # Flags list. + self.assertLen(FLAGS, number_defined_flags + number_test_framework_flags) + self.assertIn('name', FLAGS) + self.assertIn('debug', FLAGS) + self.assertIn('repeat', FLAGS) + self.assertIn('r', FLAGS) + self.assertIn('q', FLAGS) + self.assertIn('quack', FLAGS) + self.assertIn('x', FLAGS) + self.assertIn('l', FLAGS) + self.assertIn('args', FLAGS) + self.assertIn('letters', FLAGS) + self.assertIn('numbers', FLAGS) + + # __contains__ + self.assertIn('name', FLAGS) + self.assertNotIn('name2', FLAGS) + + # try deleting a flag + del FLAGS.r + self.assertLen(FLAGS, + number_defined_flags - 1 + number_test_framework_flags) + self.assertNotIn('r', FLAGS) + + # .. command line with extra stuff + argv = ('./program', '--debug', '--name=Bob', 'extra') + argv = FLAGS(argv) + self.assertLen(argv, 2, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + self.assertEqual(argv[1], 'extra', 'extra argument not preserved') + self.assertEqual(FLAGS['debug'].present, 1) + FLAGS['debug'].present = 0 # Reset + self.assertEqual(FLAGS['name'].present, 1) + FLAGS['name'].present = 0 # Reset + + # Test reset + argv = ('./program', '--debug') + argv = FLAGS(argv) + self.assertLen(argv, 1, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + self.assertEqual(FLAGS['debug'].present, 1) + self.assertTrue(FLAGS['debug'].value) + FLAGS.unparse_flags() + self.assertEqual(FLAGS['debug'].present, 0) + self.assertFalse(FLAGS['debug'].value) + + # Test that reset restores default value when default value is None. + argv = ('./program', '--kwery=who') + argv = FLAGS(argv) + self.assertLen(argv, 1, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + self.assertEqual(FLAGS['kwery'].present, 1) + self.assertEqual(FLAGS['kwery'].value, 'who') + FLAGS.unparse_flags() + argv = ('./program', '--kwery=Why') + argv = FLAGS(argv) + self.assertLen(argv, 1, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + self.assertEqual(FLAGS['kwery'].present, 1) + self.assertEqual(FLAGS['kwery'].value, 'Why') + FLAGS.unparse_flags() + self.assertEqual(FLAGS['kwery'].present, 0) + self.assertIsNone(FLAGS['kwery'].value) + + # Test case sensitive enum. + argv = ('./program', '--sense=CASE') + argv = FLAGS(argv) + self.assertLen(argv, 1, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + self.assertEqual(FLAGS['sense'].present, 1) + self.assertEqual(FLAGS['sense'].value, 'CASE') + FLAGS.unparse_flags() + argv = ('./program', '--sense=Case') + argv = FLAGS(argv) + self.assertLen(argv, 1, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + self.assertEqual(FLAGS['sense'].present, 1) + self.assertEqual(FLAGS['sense'].value, 'Case') + FLAGS.unparse_flags() + + # Test case insensitive enum. + argv = ('./program', '--cases=upper') + argv = FLAGS(argv) + self.assertLen(argv, 1, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + self.assertEqual(FLAGS['cases'].present, 1) + self.assertEqual(FLAGS['cases'].value, 'UPPER') + FLAGS.unparse_flags() + + # Test case sensitive enum with duplicates. + argv = ('./program', '--funny=ha') + argv = FLAGS(argv) + self.assertLen(argv, 1, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + self.assertEqual(FLAGS['funny'].present, 1) + self.assertEqual(FLAGS['funny'].value, 'ha') + FLAGS.unparse_flags() + + # Test case insensitive enum with duplicates. + argv = ('./program', '--blah=bLah') + argv = FLAGS(argv) + self.assertLen(argv, 1, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + self.assertEqual(FLAGS['blah'].present, 1) + self.assertEqual(FLAGS['blah'].value, 'Blah') + FLAGS.unparse_flags() + argv = ('./program', '--blah=BLAH') + argv = FLAGS(argv) + self.assertLen(argv, 1, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + self.assertEqual(FLAGS['blah'].present, 1) + self.assertEqual(FLAGS['blah'].value, 'Blah') + FLAGS.unparse_flags() + + # Test integer argument passing + argv = ('./program', '--x', '0x12345') + argv = FLAGS(argv) + self.assertEqual(FLAGS.x, 0x12345) + self.assertEqual(type(FLAGS.x), int) + + argv = ('./program', '--x', '0x1234567890ABCDEF1234567890ABCDEF') + argv = FLAGS(argv) + self.assertEqual(FLAGS.x, 0x1234567890ABCDEF1234567890ABCDEF) + self.assertIsInstance(FLAGS.x, int) + + argv = ('./program', '--x', '0o12345') + argv = FLAGS(argv) + self.assertEqual(FLAGS.x, 0o12345) + self.assertEqual(type(FLAGS.x), int) + + # Treat 0-prefixed parameters as base-10, not base-8 + argv = ('./program', '--x', '012345') + argv = FLAGS(argv) + self.assertEqual(FLAGS.x, 12345) + self.assertEqual(type(FLAGS.x), int) + + argv = ('./program', '--x', '0123459') + argv = FLAGS(argv) + self.assertEqual(FLAGS.x, 123459) + self.assertEqual(type(FLAGS.x), int) + + argv = ('./program', '--x', '0x123efg') + with self.assertRaises(flags.IllegalFlagValueError): + argv = FLAGS(argv) + + # Test boolean argument parsing + flags.DEFINE_boolean('test0', None, 'test boolean parsing') + argv = ('./program', '--notest0') + argv = FLAGS(argv) + self.assertEqual(FLAGS.test0, 0) + + flags.DEFINE_boolean('test1', None, 'test boolean parsing') + argv = ('./program', '--test1') + argv = FLAGS(argv) + self.assertEqual(FLAGS.test1, 1) + + FLAGS.test0 = None + argv = ('./program', '--test0=false') + argv = FLAGS(argv) + self.assertEqual(FLAGS.test0, 0) + + FLAGS.test1 = None + argv = ('./program', '--test1=true') + argv = FLAGS(argv) + self.assertEqual(FLAGS.test1, 1) + + FLAGS.test0 = None + argv = ('./program', '--test0=0') + argv = FLAGS(argv) + self.assertEqual(FLAGS.test0, 0) + + FLAGS.test1 = None + argv = ('./program', '--test1=1') + argv = FLAGS(argv) + self.assertEqual(FLAGS.test1, 1) + + # Test booleans that already have 'no' as a prefix + FLAGS.noexec = None + argv = ('./program', '--nonoexec', '--name', 'Bob') + argv = FLAGS(argv) + self.assertEqual(FLAGS.noexec, 0) + + FLAGS.noexec = None + argv = ('./program', '--name', 'Bob', '--noexec') + argv = FLAGS(argv) + self.assertEqual(FLAGS.noexec, 1) + + # Test unassigned booleans + flags.DEFINE_boolean('testnone', None, 'test boolean parsing') + argv = ('./program',) + argv = FLAGS(argv) + self.assertIsNone(FLAGS.testnone) + + # Test get with default + flags.DEFINE_boolean('testget1', None, 'test parsing with defaults') + flags.DEFINE_boolean('testget2', None, 'test parsing with defaults') + flags.DEFINE_boolean('testget3', None, 'test parsing with defaults') + flags.DEFINE_integer('testget4', None, 'test parsing with defaults') + argv = ('./program', '--testget1', '--notestget2') + argv = FLAGS(argv) + self.assertEqual(FLAGS.get_flag_value('testget1', 'foo'), 1) + self.assertEqual(FLAGS.get_flag_value('testget2', 'foo'), 0) + self.assertEqual(FLAGS.get_flag_value('testget3', 'foo'), 'foo') + self.assertEqual(FLAGS.get_flag_value('testget4', 'foo'), 'foo') + + # test list code + lists = [['hello', 'moo', 'boo', '1'], []] + + flags.DEFINE_list('testcomma_list', '', 'test comma list parsing') + flags.DEFINE_spaceseplist('testspace_list', '', 'tests space list parsing') + flags.DEFINE_spaceseplist( + 'testspace_or_comma_list', + '', + 'tests space list parsing with comma compatibility', + comma_compat=True) + + for name, sep in (('testcomma_list', ','), ('testspace_list', + ' '), ('testspace_list', '\n'), + ('testspace_or_comma_list', + ' '), ('testspace_or_comma_list', + '\n'), ('testspace_or_comma_list', ',')): + for lst in lists: + argv = ('./program', '--%s=%s' % (name, sep.join(lst))) + argv = FLAGS(argv) + self.assertEqual(getattr(FLAGS, name), lst) + + # Test help text + flags_help = str(FLAGS) + self.assertNotEqual( + flags_help.find('repeat'), -1, 'cannot find flag in help') + self.assertNotEqual( + flags_help.find(repeat_help), -1, 'cannot find help string in help') + + # Test flag specified twice + argv = ('./program', '--repeat=4', '--repeat=2', '--debug', '--nodebug') + argv = FLAGS(argv) + self.assertEqual(FLAGS.get_flag_value('repeat', None), 2) + self.assertEqual(FLAGS.get_flag_value('debug', None), 0) + + # Test MultiFlag with single default value + flags.DEFINE_multi_string( + 's_str', + 'sing1', + 'string option that can occur multiple times', + short_name='s') + self.assertEqual(FLAGS.get_flag_value('s_str', None), ['sing1']) + + # Test MultiFlag with list of default values + multi_string_defs = ['def1', 'def2'] + flags.DEFINE_multi_string( + 'm_str', + multi_string_defs, + 'string option that can occur multiple times', + short_name='m') + self.assertEqual(FLAGS.get_flag_value('m_str', None), multi_string_defs) + + # Test flag specified multiple times with a MultiFlag + argv = ('./program', '--m_str=str1', '-m', 'str2') + argv = FLAGS(argv) + self.assertEqual(FLAGS.get_flag_value('m_str', None), ['str1', 'str2']) + + # A flag with allow_overwrite set to False should behave normally when it + # is only specified once + argv = ('./program', '--only_once=singlevalue') + argv = FLAGS(argv) + self.assertEqual(FLAGS.get_flag_value('only_once', None), 'singlevalue') + + # A flag with allow_overwrite set to False should complain when it is + # specified more than once + argv = ('./program', '--universe=ptolemaic', '--universe=copernicean', + '--universe=euclidean') + self.assertRaisesWithLiteralMatch( + flags.IllegalFlagValueError, + 'flag --universe=copernicean: already defined as ptolemaic', FLAGS, + argv) + + # Test single-letter flags; should support both single and double dash + argv = ('./program', '-q') + argv = FLAGS(argv) + self.assertEqual(FLAGS.get_flag_value('q', None), 1) + + argv = ('./program', '--q', '--x', '9', '--noquack') + argv = FLAGS(argv) + self.assertEqual(FLAGS.get_flag_value('q', None), 1) + self.assertEqual(FLAGS.get_flag_value('x', None), 9) + self.assertEqual(FLAGS.get_flag_value('quack', None), 0) + + argv = ('./program', '--noq', '--x=10', '--quack') + argv = FLAGS(argv) + self.assertEqual(FLAGS.get_flag_value('q', None), 0) + self.assertEqual(FLAGS.get_flag_value('x', None), 10) + self.assertEqual(FLAGS.get_flag_value('quack', None), 1) + + #################################### + # Test flag serialization code: + + old_testcomma_list = FLAGS.testcomma_list + old_testspace_list = FLAGS.testspace_list + old_testspace_or_comma_list = FLAGS.testspace_or_comma_list + + argv = ('./program', FLAGS['test0'].serialize(), FLAGS['test1'].serialize(), + FLAGS['s_str'].serialize()) + + argv = FLAGS(argv) + self.assertEqual(FLAGS['test0'].serialize(), '--notest0') + self.assertEqual(FLAGS['test1'].serialize(), '--test1') + self.assertEqual(FLAGS['s_str'].serialize(), '--s_str=sing1') + + self.assertEqual(FLAGS['testnone'].serialize(), '') + + testcomma_list1 = ['aa', 'bb'] + testspace_list1 = ['aa', 'bb', 'cc'] + testspace_or_comma_list1 = ['aa', 'bb', 'cc', 'dd'] + FLAGS.testcomma_list = list(testcomma_list1) + FLAGS.testspace_list = list(testspace_list1) + FLAGS.testspace_or_comma_list = list(testspace_or_comma_list1) + argv = ('./program', FLAGS['testcomma_list'].serialize(), + FLAGS['testspace_list'].serialize(), + FLAGS['testspace_or_comma_list'].serialize()) + argv = FLAGS(argv) + self.assertEqual(FLAGS.testcomma_list, testcomma_list1) + self.assertEqual(FLAGS.testspace_list, testspace_list1) + self.assertEqual(FLAGS.testspace_or_comma_list, testspace_or_comma_list1) + + testcomma_list1 = ['aa some spaces', 'bb'] + testspace_list1 = ['aa', 'bb,some,commas,', 'cc'] + testspace_or_comma_list1 = ['aa', 'bb,some,commas,', 'cc'] + FLAGS.testcomma_list = list(testcomma_list1) + FLAGS.testspace_list = list(testspace_list1) + FLAGS.testspace_or_comma_list = list(testspace_or_comma_list1) + argv = ('./program', FLAGS['testcomma_list'].serialize(), + FLAGS['testspace_list'].serialize(), + FLAGS['testspace_or_comma_list'].serialize()) + argv = FLAGS(argv) + self.assertEqual(FLAGS.testcomma_list, testcomma_list1) + self.assertEqual(FLAGS.testspace_list, testspace_list1) + # We don't expect idempotency when commas are placed in an item value and + # comma_compat is enabled. + self.assertEqual(FLAGS.testspace_or_comma_list, + ['aa', 'bb', 'some', 'commas', 'cc']) + + FLAGS.testcomma_list = old_testcomma_list + FLAGS.testspace_list = old_testspace_list + FLAGS.testspace_or_comma_list = old_testspace_or_comma_list + + #################################### + # Test flag-update: + + def args_list(): + # Exclude flags that have different default values based on the + # environment. + flags_to_exclude = {'log_dir', 'test_srcdir', 'test_tmpdir'} + flagnames = set(FLAGS) - flags_to_exclude + + nonbool_flags = [] + truebool_flags = [] + falsebool_flags = [] + for name in flagnames: + flag_value = FLAGS.get_flag_value(name, None) + if not isinstance(FLAGS[name], flags.BooleanFlag): + nonbool_flags.append('--%s %s' % (name, flag_value)) + elif flag_value: + truebool_flags.append('--%s' % name) + else: + falsebool_flags.append('--no%s' % name) + all_flags = nonbool_flags + truebool_flags + falsebool_flags + all_flags.sort() + return all_flags + + argv = ('./program', '--repeat=3', '--name=giants', '--nodebug') + + FLAGS(argv) + self.assertEqual(FLAGS.get_flag_value('repeat', None), 3) + self.assertEqual(FLAGS.get_flag_value('name', None), 'giants') + self.assertEqual(FLAGS.get_flag_value('debug', None), 0) + self.assertListEqual([ + '--alsologtostderr', + "--args ['v=1', 'vmodule=a=0,b=2']", + '--blah None', + '--cases None', + '--decimal 666', + '--float 3.14', + '--funny None', + '--hexadecimal 1638', + '--kwery None', + '--l 9223372032559808512', + "--letters ['a', 'b', 'c']", + '--logger_levels {}', + "--m ['str1', 'str2']", + "--m_str ['str1', 'str2']", + '--name giants', + '--no?', + '--nodebug', + '--noexec', + '--nohelp', + '--nohelpfull', + '--nohelpshort', + '--nohelpxml', + '--nologtostderr', + '--noonly_check_args', + '--nopdb_post_mortem', + '--noq', + '--norun_with_pdb', + '--norun_with_profiling', + '--notest0', + '--notestget2', + '--notestget3', + '--notestnone', + '--numbers [1, 2, 3]', + '--octal 438', + '--only_once singlevalue', + '--pdb False', + '--profile_file None', + '--quack', + '--repeat 3', + "--s ['sing1']", + "--s_str ['sing1']", + '--sense None', + '--showprefixforinfo', + '--stderrthreshold fatal', + '--test1', + '--test_random_seed 301', + '--test_randomize_ordering_seed ', + '--testcomma_list []', + '--testget1', + '--testget4 None', + '--testspace_list []', + '--testspace_or_comma_list []', + '--tmod_baz_x', + '--universe ptolemaic', + '--use_cprofile_for_profiling', + '--v -1', + '--verbosity -1', + '--x 10', + '--xml_output_file ', + ], args_list()) + + argv = ('./program', '--debug', '--m_str=upd1', '-s', 'upd2') + FLAGS(argv) + self.assertEqual(FLAGS.get_flag_value('repeat', None), 3) + self.assertEqual(FLAGS.get_flag_value('name', None), 'giants') + self.assertEqual(FLAGS.get_flag_value('debug', None), 1) + + # items appended to existing non-default value lists for --m/--m_str + # new value overwrites default value (not appended to it) for --s/--s_str + self.assertListEqual([ + '--alsologtostderr', + "--args ['v=1', 'vmodule=a=0,b=2']", + '--blah None', + '--cases None', + '--debug', + '--decimal 666', + '--float 3.14', + '--funny None', + '--hexadecimal 1638', + '--kwery None', + '--l 9223372032559808512', + "--letters ['a', 'b', 'c']", + '--logger_levels {}', + "--m ['str1', 'str2', 'upd1']", + "--m_str ['str1', 'str2', 'upd1']", + '--name giants', + '--no?', + '--noexec', + '--nohelp', + '--nohelpfull', + '--nohelpshort', + '--nohelpxml', + '--nologtostderr', + '--noonly_check_args', + '--nopdb_post_mortem', + '--noq', + '--norun_with_pdb', + '--norun_with_profiling', + '--notest0', + '--notestget2', + '--notestget3', + '--notestnone', + '--numbers [1, 2, 3]', + '--octal 438', + '--only_once singlevalue', + '--pdb False', + '--profile_file None', + '--quack', + '--repeat 3', + "--s ['sing1', 'upd2']", + "--s_str ['sing1', 'upd2']", + '--sense None', + '--showprefixforinfo', + '--stderrthreshold fatal', + '--test1', + '--test_random_seed 301', + '--test_randomize_ordering_seed ', + '--testcomma_list []', + '--testget1', + '--testget4 None', + '--testspace_list []', + '--testspace_or_comma_list []', + '--tmod_baz_x', + '--universe ptolemaic', + '--use_cprofile_for_profiling', + '--v -1', + '--verbosity -1', + '--x 10', + '--xml_output_file ', + ], args_list()) + + #################################### + # Test all kind of error conditions. + + # Argument not in enum exception + argv = ('./program', '--kwery=WHEN') + self.assertRaises(flags.IllegalFlagValueError, FLAGS, argv) + argv = ('./program', '--kwery=why') + self.assertRaises(flags.IllegalFlagValueError, FLAGS, argv) + + # Duplicate flag detection + with self.assertRaises(flags.DuplicateFlagError): + flags.DEFINE_boolean('run', 0, 'runhelp', short_name='q') + + # Duplicate short flag detection + with self.assertRaisesRegex( + flags.DuplicateFlagError, + r"The flag 'z' is defined twice\. .*First from.*, Second from"): + flags.DEFINE_boolean('zoom1', 0, 'runhelp z1', short_name='z') + flags.DEFINE_boolean('zoom2', 0, 'runhelp z2', short_name='z') + raise AssertionError('duplicate short flag detection failed') + + # Duplicate mixed flag detection + with self.assertRaisesRegex( + flags.DuplicateFlagError, + r"The flag 's' is defined twice\. .*First from.*, Second from"): + flags.DEFINE_boolean('short1', 0, 'runhelp s1', short_name='s') + flags.DEFINE_boolean('s', 0, 'runhelp s2') + + # Check that duplicate flag detection detects definition sites + # correctly. + flagnames = ['repeated'] + original_flags = flags.FlagValues() + flags.DEFINE_boolean( + flagnames[0], + False, + 'Flag about to be repeated.', + flag_values=original_flags) + duplicate_flags = module_foo.duplicate_flags(flagnames) + with self.assertRaisesRegex(flags.DuplicateFlagError, + 'flags_test.*module_foo'): + original_flags.append_flag_values(duplicate_flags) + + # Make sure allow_override works + try: + flags.DEFINE_boolean( + 'dup1', 0, 'runhelp d11', short_name='u', allow_override=0) + flag = FLAGS._flags()['dup1'] + self.assertEqual(flag.default, 0) + + flags.DEFINE_boolean( + 'dup1', 1, 'runhelp d12', short_name='u', allow_override=1) + flag = FLAGS._flags()['dup1'] + self.assertEqual(flag.default, 1) + except flags.DuplicateFlagError: + raise AssertionError('allow_override did not permit a flag duplication') + + # Make sure allow_override works + try: + flags.DEFINE_boolean( + 'dup2', 0, 'runhelp d21', short_name='u', allow_override=1) + flag = FLAGS._flags()['dup2'] + self.assertEqual(flag.default, 0) + + flags.DEFINE_boolean( + 'dup2', 1, 'runhelp d22', short_name='u', allow_override=0) + flag = FLAGS._flags()['dup2'] + self.assertEqual(flag.default, 1) + except flags.DuplicateFlagError: + raise AssertionError('allow_override did not permit a flag duplication') + + # Make sure that re-importing a module does not cause a DuplicateFlagError + # to be raised. + try: + sys.modules.pop('absl.flags.tests.module_baz') + import absl.flags.tests.module_baz + del absl + except flags.DuplicateFlagError: + raise AssertionError('Module reimport caused flag duplication error') + + # Make sure that when we override, the help string gets updated correctly + flags.DEFINE_boolean( + 'dup3', 0, 'runhelp d31', short_name='u', allow_override=1) + flags.DEFINE_boolean( + 'dup3', 1, 'runhelp d32', short_name='u', allow_override=1) + self.assertEqual(str(FLAGS).find('runhelp d31'), -1) + self.assertNotEqual(str(FLAGS).find('runhelp d32'), -1) + + # Make sure append_flag_values works + new_flags = flags.FlagValues() + flags.DEFINE_boolean('new1', 0, 'runhelp n1', flag_values=new_flags) + flags.DEFINE_boolean('new2', 0, 'runhelp n2', flag_values=new_flags) + self.assertEqual(len(new_flags._flags()), 2) + old_len = len(FLAGS._flags()) + FLAGS.append_flag_values(new_flags) + self.assertEqual(len(FLAGS._flags()) - old_len, 2) + self.assertEqual('new1' in FLAGS._flags(), True) + self.assertEqual('new2' in FLAGS._flags(), True) + + # Then test that removing those flags works + FLAGS.remove_flag_values(new_flags) + self.assertEqual(len(FLAGS._flags()), old_len) + self.assertFalse('new1' in FLAGS._flags()) + self.assertFalse('new2' in FLAGS._flags()) + + # Make sure append_flag_values works with flags with shortnames. + new_flags = flags.FlagValues() + flags.DEFINE_boolean('new3', 0, 'runhelp n3', flag_values=new_flags) + flags.DEFINE_boolean( + 'new4', 0, 'runhelp n4', flag_values=new_flags, short_name='n4') + self.assertEqual(len(new_flags._flags()), 3) + old_len = len(FLAGS._flags()) + FLAGS.append_flag_values(new_flags) + self.assertEqual(len(FLAGS._flags()) - old_len, 3) + self.assertIn('new3', FLAGS._flags()) + self.assertIn('new4', FLAGS._flags()) + self.assertIn('n4', FLAGS._flags()) + self.assertEqual(FLAGS._flags()['n4'], FLAGS._flags()['new4']) + + # Then test removing them + FLAGS.remove_flag_values(new_flags) + self.assertEqual(len(FLAGS._flags()), old_len) + self.assertFalse('new3' in FLAGS._flags()) + self.assertFalse('new4' in FLAGS._flags()) + self.assertFalse('n4' in FLAGS._flags()) + + # Make sure append_flag_values fails on duplicates + flags.DEFINE_boolean('dup4', 0, 'runhelp d41') + new_flags = flags.FlagValues() + flags.DEFINE_boolean('dup4', 0, 'runhelp d42', flag_values=new_flags) + with self.assertRaises(flags.DuplicateFlagError): + FLAGS.append_flag_values(new_flags) + + # Integer out of bounds + with self.assertRaises(flags.IllegalFlagValueError): + argv = ('./program', '--repeat=-4') + FLAGS(argv) + + # Non-integer + with self.assertRaises(flags.IllegalFlagValueError): + argv = ('./program', '--repeat=2.5') + FLAGS(argv) + + # Missing required argument + with self.assertRaises(flags.Error): + argv = ('./program', '--name') + FLAGS(argv) + + # Non-boolean arguments for boolean + with self.assertRaises(flags.IllegalFlagValueError): + argv = ('./program', '--debug=goofup') + FLAGS(argv) + + with self.assertRaises(flags.IllegalFlagValueError): + argv = ('./program', '--debug=42') + FLAGS(argv) + + # Non-numeric argument for integer flag --repeat + with self.assertRaises(flags.IllegalFlagValueError): + argv = ('./program', '--repeat', 'Bob', 'extra') + FLAGS(argv) + + # Aliases of existing flags + with self.assertRaises(flags.UnrecognizedFlagError): + flags.DEFINE_alias('alias_not_a_flag', 'not_a_flag') + + # Programmtically modify alias and aliased flag + flags.DEFINE_alias('alias_octal', 'octal') + FLAGS.octal = 0o2222 + self.assertEqual(0o2222, FLAGS.octal) + self.assertEqual(0o2222, FLAGS.alias_octal) + FLAGS.alias_octal = 0o4444 + self.assertEqual(0o4444, FLAGS.octal) + self.assertEqual(0o4444, FLAGS.alias_octal) + + # Setting alias preserves the default of the original + flags.DEFINE_alias('alias_name', 'name') + flags.DEFINE_alias('alias_debug', 'debug') + flags.DEFINE_alias('alias_decimal', 'decimal') + flags.DEFINE_alias('alias_float', 'float') + flags.DEFINE_alias('alias_letters', 'letters') + self.assertEqual(FLAGS['name'].default, FLAGS.alias_name) + self.assertEqual(FLAGS['debug'].default, FLAGS.alias_debug) + self.assertEqual(int(FLAGS['decimal'].default), FLAGS.alias_decimal) + self.assertEqual(float(FLAGS['float'].default), FLAGS.alias_float) + self.assertSameElements(FLAGS['letters'].default, FLAGS.alias_letters) + + # Original flags set on command line + argv = ('./program', '--name=Martin', '--debug=True', '--decimal=777', + '--letters=x,y,z') + FLAGS(argv) + self.assertEqual('Martin', FLAGS.name) + self.assertEqual('Martin', FLAGS.alias_name) + self.assertTrue(FLAGS.debug) + self.assertTrue(FLAGS.alias_debug) + self.assertEqual(777, FLAGS.decimal) + self.assertEqual(777, FLAGS.alias_decimal) + self.assertSameElements(['x', 'y', 'z'], FLAGS.letters) + self.assertSameElements(['x', 'y', 'z'], FLAGS.alias_letters) + + # Alias flags set on command line + argv = ('./program', '--alias_name=Auston', '--alias_debug=False', + '--alias_decimal=888', '--alias_letters=l,m,n') + FLAGS(argv) + self.assertEqual('Auston', FLAGS.name) + self.assertEqual('Auston', FLAGS.alias_name) + self.assertFalse(FLAGS.debug) + self.assertFalse(FLAGS.alias_debug) + self.assertEqual(888, FLAGS.decimal) + self.assertEqual(888, FLAGS.alias_decimal) + self.assertSameElements(['l', 'm', 'n'], FLAGS.letters) + self.assertSameElements(['l', 'm', 'n'], FLAGS.alias_letters) + + # Make sure importing a module does not change flag value parsed + # from commandline. + flags.DEFINE_integer( + 'dup5', 1, 'runhelp d51', short_name='u5', allow_override=0) + self.assertEqual(FLAGS.dup5, 1) + self.assertEqual(FLAGS.dup5, 1) + argv = ('./program', '--dup5=3') + FLAGS(argv) + self.assertEqual(FLAGS.dup5, 3) + flags.DEFINE_integer( + 'dup5', 2, 'runhelp d52', short_name='u5', allow_override=1) + self.assertEqual(FLAGS.dup5, 3) + + # Make sure importing a module does not change user defined flag value. + flags.DEFINE_integer( + 'dup6', 1, 'runhelp d61', short_name='u6', allow_override=0) + self.assertEqual(FLAGS.dup6, 1) + FLAGS.dup6 = 3 + self.assertEqual(FLAGS.dup6, 3) + flags.DEFINE_integer( + 'dup6', 2, 'runhelp d62', short_name='u6', allow_override=1) + self.assertEqual(FLAGS.dup6, 3) + + # Make sure importing a module does not change user defined flag value + # even if it is the 'default' value. + flags.DEFINE_integer( + 'dup7', 1, 'runhelp d71', short_name='u7', allow_override=0) + self.assertEqual(FLAGS.dup7, 1) + FLAGS.dup7 = 1 + self.assertEqual(FLAGS.dup7, 1) + flags.DEFINE_integer( + 'dup7', 2, 'runhelp d72', short_name='u7', allow_override=1) + self.assertEqual(FLAGS.dup7, 1) + + # Test module_help(). + helpstr = FLAGS.module_help(module_baz) + + expected_help = '\n' + module_baz.__name__ + ':' + """ + --[no]tmod_baz_x: Boolean flag. + (default: 'true')""" + + self.assertMultiLineEqual(expected_help, helpstr) + + # Test main_module_help(). This must be part of test_flags because + # it depends on dup1/2/3/etc being introduced first. + helpstr = FLAGS.main_module_help() + + expected_help = '\n' + sys.argv[0] + ':' + """ + --[no]alias_debug: Alias for --debug. + (default: 'false') + --alias_decimal: Alias for --decimal. + (default: '666') + (an integer) + --alias_float: Alias for --float. + (default: '3.14') + (a number) + --alias_letters: Alias for --letters. + (default: 'a,b,c') + (a comma separated list) + --alias_name: Alias for --name. + (default: 'Bob') + --alias_octal: Alias for --octal. + (default: '438') + (an integer) + --args: a list of arguments + (default: 'v=1,"vmodule=a=0,b=2"') + (a comma separated list) + --blah: <bla|Blah|BLAH|blah>: ? + --cases: <UPPER|lower|Initial|Ot_HeR>: ? + --[no]debug: debughelp + (default: 'false') + --decimal: using decimals + (default: '666') + (an integer) + -u,--[no]dup1: runhelp d12 + (default: 'true') + -u,--[no]dup2: runhelp d22 + (default: 'true') + -u,--[no]dup3: runhelp d32 + (default: 'true') + --[no]dup4: runhelp d41 + (default: 'false') + -u5,--dup5: runhelp d51 + (default: '1') + (an integer) + -u6,--dup6: runhelp d61 + (default: '1') + (an integer) + -u7,--dup7: runhelp d71 + (default: '1') + (an integer) + --float: using floats + (default: '3.14') + (a number) + --funny: <Joke|ha|ha|ha|ha>: ? + --hexadecimal: using hexadecimals + (default: '1638') + (an integer) + --kwery: <who|what|Why|where|when>: ? + --l: how long to be + (default: '9223372032559808512') + (an integer) + --letters: a list of letters + (default: 'a,b,c') + (a comma separated list) + -m,--m_str: string option that can occur multiple times; + repeat this option to specify a list of values + (default: "['def1', 'def2']") + --name: namehelp + (default: 'Bob') + --[no]noexec: boolean flag with no as prefix + (default: 'true') + --numbers: a list of numbers + (default: '1,2,3') + (a comma separated list) + --octal: using octals + (default: '438') + (an integer) + --only_once: test only sets this once + --[no]q: quiet mode + (default: 'true') + --[no]quack: superstring of 'q' + (default: 'false') + -r,--repeat: how many times to repeat (0-5) + (default: '4') + (a non-negative integer) + -s,--s_str: string option that can occur multiple times; + repeat this option to specify a list of values + (default: "['sing1']") + --sense: <Case|case|CASE>: ? + --[no]test0: test boolean parsing + --[no]test1: test boolean parsing + --testcomma_list: test comma list parsing + (default: '') + (a comma separated list) + --[no]testget1: test parsing with defaults + --[no]testget2: test parsing with defaults + --[no]testget3: test parsing with defaults + --testget4: test parsing with defaults + (an integer) + --[no]testnone: test boolean parsing + --testspace_list: tests space list parsing + (default: '') + (a whitespace separated list) + --testspace_or_comma_list: tests space list parsing with comma compatibility + (default: '') + (a whitespace or comma separated list) + --universe: test tries to set this three times + --x: how eXtreme to be + (default: '3') + (an integer) + -z,--[no]zoom1: runhelp z1 + (default: 'false')""" + + self.assertMultiLineEqual(expected_help, helpstr) + + def test_string_flag_with_wrong_type(self): + fv = flags.FlagValues() + with self.assertRaises(flags.IllegalFlagValueError): + flags.DEFINE_string('name', False, 'help', flag_values=fv) + with self.assertRaises(flags.IllegalFlagValueError): + flags.DEFINE_string('name2', 0, 'help', flag_values=fv) + + def test_integer_flag_with_wrong_type(self): + fv = flags.FlagValues() + with self.assertRaises(flags.IllegalFlagValueError): + flags.DEFINE_integer('name', 1e2, 'help', flag_values=fv) + with self.assertRaises(flags.IllegalFlagValueError): + flags.DEFINE_integer('name', [], 'help', flag_values=fv) + with self.assertRaises(flags.IllegalFlagValueError): + flags.DEFINE_integer('name', False, 'help', flag_values=fv) + + def test_float_flag_with_wrong_type(self): + fv = flags.FlagValues() + with self.assertRaises(flags.IllegalFlagValueError): + flags.DEFINE_float('name', False, 'help', flag_values=fv) + + def test_enum_flag_with_empty_values(self): + fv = flags.FlagValues() + with self.assertRaises(ValueError): + flags.DEFINE_enum('fruit', None, [], 'help', flag_values=fv) + + def test_define_enum_class_flag(self): + fv = flags.FlagValues() + flags.DEFINE_enum_class('fruit', None, Fruit, '?', flag_values=fv) + fv.mark_as_parsed() + + self.assertIsNone(fv.fruit) + + def test_parse_enum_class_flag(self): + fv = flags.FlagValues() + flags.DEFINE_enum_class('fruit', None, Fruit, '?', flag_values=fv) + + argv = ('./program', '--fruit=orange') + argv = fv(argv) + self.assertEqual(len(argv), 1, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + self.assertEqual(fv['fruit'].present, 1) + self.assertEqual(fv['fruit'].value, Fruit.ORANGE) + fv.unparse_flags() + argv = ('./program', '--fruit=APPLE') + argv = fv(argv) + self.assertEqual(len(argv), 1, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + self.assertEqual(fv['fruit'].present, 1) + self.assertEqual(fv['fruit'].value, Fruit.APPLE) + fv.unparse_flags() + + def test_enum_class_flag_help_message(self): + fv = flags.FlagValues() + flags.DEFINE_enum_class('fruit', None, Fruit, '?', flag_values=fv) + + helpstr = fv.main_module_help() + expected_help = '\n%s:\n --fruit: <apple|orange>: ?' % sys.argv[0] + + self.assertEqual(helpstr, expected_help) + + def test_enum_class_flag_with_wrong_default_value_type(self): + fv = flags.FlagValues() + with self.assertRaises(_exceptions.IllegalFlagValueError): + flags.DEFINE_enum_class('fruit', 1, Fruit, 'help', flag_values=fv) + + def test_enum_class_flag_requires_enum_class(self): + fv = flags.FlagValues() + with self.assertRaises(TypeError): + flags.DEFINE_enum_class( + 'fruit', None, ['apple', 'orange'], 'help', flag_values=fv) + + def test_enum_class_flag_requires_non_empty_enum_class(self): + fv = flags.FlagValues() + with self.assertRaises(ValueError): + flags.DEFINE_enum_class('empty', None, EmptyEnum, 'help', flag_values=fv) + + def test_required_flag(self): + fv = flags.FlagValues() + fl = flags.DEFINE_integer( + name='int_flag', + default=None, + help='help', + required=True, + flag_values=fv) + # Since the flag is required, the FlagHolder should ensure value returned + # is not None. + self.assertTrue(fl._ensure_non_none_value) + + def test_illegal_required_flag(self): + fv = flags.FlagValues() + with self.assertRaises(ValueError): + flags.DEFINE_integer( + name='int_flag', + default=3, + help='help', + required=True, + flag_values=fv) + + +class MultiNumericalFlagsTest(absltest.TestCase): + + def test_multi_numerical_flags(self): + """Test multi_int and multi_float flags.""" + fv = flags.FlagValues() + int_defaults = [77, 88] + flags.DEFINE_multi_integer( + 'm_int', + int_defaults, + 'integer option that can occur multiple times', + short_name='mi', + flag_values=fv) + self.assertListEqual(fv['m_int'].default, int_defaults) + argv = ('./program', '--m_int=-99', '--mi=101') + fv(argv) + self.assertListEqual(fv.get_flag_value('m_int', None), [-99, 101]) + + float_defaults = [2.2, 3] + flags.DEFINE_multi_float( + 'm_float', + float_defaults, + 'float option that can occur multiple times', + short_name='mf', + flag_values=fv) + for (expected, actual) in zip(float_defaults, + fv.get_flag_value('m_float', None)): + self.assertAlmostEqual(expected, actual) + argv = ('./program', '--m_float=-17', '--mf=2.78e9') + fv(argv) + expected_floats = [-17.0, 2.78e9] + for (expected, actual) in zip(expected_floats, + fv.get_flag_value('m_float', None)): + self.assertAlmostEqual(expected, actual) + + def test_multi_numerical_with_tuples(self): + """Verify multi_int/float accept tuples as default values.""" + flags.DEFINE_multi_integer( + 'm_int_tuple', (77, 88), + 'integer option that can occur multiple times', + short_name='mi_tuple') + self.assertListEqual(FLAGS.get_flag_value('m_int_tuple', None), [77, 88]) + + dict_with_float_keys = {2.2: 'hello', 3: 'happy'} + float_defaults = dict_with_float_keys.keys() + flags.DEFINE_multi_float( + 'm_float_tuple', + float_defaults, + 'float option that can occur multiple times', + short_name='mf_tuple') + for (expected, actual) in zip(float_defaults, + FLAGS.get_flag_value('m_float_tuple', None)): + self.assertAlmostEqual(expected, actual) + + def test_single_value_default(self): + """Test multi_int and multi_float flags with a single default value.""" + int_default = 77 + flags.DEFINE_multi_integer('m_int1', int_default, + 'integer option that can occur multiple times') + self.assertListEqual(FLAGS.get_flag_value('m_int1', None), [int_default]) + + float_default = 2.2 + flags.DEFINE_multi_float('m_float1', float_default, + 'float option that can occur multiple times') + actual = FLAGS.get_flag_value('m_float1', None) + self.assertEqual(1, len(actual)) + self.assertAlmostEqual(actual[0], float_default) + + def test_bad_multi_numerical_flags(self): + """Test multi_int and multi_float flags with non-parseable values.""" + + # Test non-parseable defaults. + self.assertRaisesRegex( + flags.IllegalFlagValueError, + r"flag --m_int2=abc: invalid literal for int\(\) with base 10: 'abc'", + flags.DEFINE_multi_integer, 'm_int2', ['abc'], 'desc') + + self.assertRaisesRegex( + flags.IllegalFlagValueError, r'flag --m_float2=abc: ' + r'(invalid literal for float\(\)|could not convert string to float): ' + r"'?abc'?", flags.DEFINE_multi_float, 'm_float2', ['abc'], 'desc') + + # Test non-parseable command line values. + fv = flags.FlagValues() + flags.DEFINE_multi_integer( + 'm_int2', + '77', + 'integer option that can occur multiple times', + flag_values=fv) + argv = ('./program', '--m_int2=def') + self.assertRaisesRegex( + flags.IllegalFlagValueError, + r"flag --m_int2=def: invalid literal for int\(\) with base 10: 'def'", + fv, argv) + + flags.DEFINE_multi_float( + 'm_float2', + 2.2, + 'float option that can occur multiple times', + flag_values=fv) + argv = ('./program', '--m_float2=def') + self.assertRaisesRegex( + flags.IllegalFlagValueError, r'flag --m_float2=def: ' + r'(invalid literal for float\(\)|could not convert string to float): ' + r"'?def'?", fv, argv) + + +class MultiEnumFlagsTest(absltest.TestCase): + + def test_multi_enum_flags(self): + """Test multi_enum flags.""" + fv = flags.FlagValues() + + enum_defaults = ['FOO', 'BAZ'] + flags.DEFINE_multi_enum( + 'm_enum', + enum_defaults, ['FOO', 'BAR', 'BAZ', 'WHOOSH'], + 'Enum option that can occur multiple times', + short_name='me', + flag_values=fv) + self.assertListEqual(fv['m_enum'].default, enum_defaults) + argv = ('./program', '--m_enum=WHOOSH', '--me=FOO') + fv(argv) + self.assertListEqual(fv.get_flag_value('m_enum', None), ['WHOOSH', 'FOO']) + + def test_help_text(self): + """Test multi_enum flag's help text.""" + fv = flags.FlagValues() + + flags.DEFINE_multi_enum( + 'm_enum', + None, ['FOO', 'BAR'], + 'Enum option that can occur multiple times', + flag_values=fv) + self.assertRegex( + fv['m_enum'].help, + r'<FOO\|BAR>: Enum option that can occur multiple times;\s+' + 'repeat this option to specify a list of values') + + def test_single_value_default(self): + """Test multi_enum flags with a single default value.""" + fv = flags.FlagValues() + enum_default = 'FOO' + flags.DEFINE_multi_enum( + 'm_enum1', + enum_default, ['FOO', 'BAR', 'BAZ', 'WHOOSH'], + 'enum option that can occur multiple times', + flag_values=fv) + self.assertListEqual(fv['m_enum1'].default, [enum_default]) + + def test_case_sensitivity(self): + """Test case sensitivity of multi_enum flag.""" + fv = flags.FlagValues() + # Test case insensitive enum. + flags.DEFINE_multi_enum( + 'm_enum2', ['whoosh'], ['FOO', 'BAR', 'BAZ', 'WHOOSH'], + 'Enum option that can occur multiple times', + short_name='me2', + case_sensitive=False, + flag_values=fv) + argv = ('./program', '--m_enum2=bar', '--me2=fOo') + fv(argv) + self.assertListEqual(fv.get_flag_value('m_enum2', None), ['BAR', 'FOO']) + + # Test case sensitive enum. + flags.DEFINE_multi_enum( + 'm_enum3', ['BAR'], ['FOO', 'BAR', 'BAZ', 'WHOOSH'], + 'Enum option that can occur multiple times', + short_name='me3', + case_sensitive=True, + flag_values=fv) + argv = ('./program', '--m_enum3=bar', '--me3=fOo') + self.assertRaisesRegex( + flags.IllegalFlagValueError, + r'flag --m_enum3=invalid: value should be one of <FOO|BAR|BAZ|WHOOSH>', + fv, argv) + + def test_bad_multi_enum_flags(self): + """Test multi_enum with invalid values.""" + + # Test defaults that are not in the permitted list of enums. + self.assertRaisesRegex( + flags.IllegalFlagValueError, + r'flag --m_enum=INVALID: value should be one of <FOO|BAR|BAZ>', + flags.DEFINE_multi_enum, 'm_enum', ['INVALID'], ['FOO', 'BAR', 'BAZ'], + 'desc') + + self.assertRaisesRegex( + flags.IllegalFlagValueError, + r'flag --m_enum=1234: value should be one of <FOO|BAR|BAZ>', + flags.DEFINE_multi_enum, 'm_enum2', [1234], ['FOO', 'BAR', 'BAZ'], + 'desc') + + # Test command-line values that are not in the permitted list of enums. + flags.DEFINE_multi_enum('m_enum4', 'FOO', ['FOO', 'BAR', 'BAZ'], + 'enum option that can occur multiple times') + argv = ('./program', '--m_enum4=INVALID') + self.assertRaisesRegex( + flags.IllegalFlagValueError, + r'flag --m_enum4=invalid: value should be one of <FOO|BAR|BAZ>', FLAGS, + argv) + + +class MultiEnumClassFlagsTest(absltest.TestCase): + + def test_define_results_in_registered_flag_with_none(self): + fv = flags.FlagValues() + enum_defaults = None + flags.DEFINE_multi_enum_class( + 'fruit', + enum_defaults, + Fruit, + 'Enum option that can occur multiple times', + flag_values=fv) + fv.mark_as_parsed() + + self.assertIsNone(fv.fruit) + + def test_help_text(self): + fv = flags.FlagValues() + enum_defaults = None + flags.DEFINE_multi_enum_class( + 'fruit', + enum_defaults, + Fruit, + 'Enum option that can occur multiple times', + flag_values=fv) + + self.assertRegex( + fv['fruit'].help, + r'<apple\|orange>: Enum option that can occur multiple times;\s+' + 'repeat this option to specify a list of values') + + def test_define_results_in_registered_flag_with_string(self): + fv = flags.FlagValues() + enum_defaults = 'apple' + flags.DEFINE_multi_enum_class( + 'fruit', + enum_defaults, + Fruit, + 'Enum option that can occur multiple times', + flag_values=fv) + fv.mark_as_parsed() + + self.assertListEqual(fv.fruit, [Fruit.APPLE]) + + def test_define_results_in_registered_flag_with_enum(self): + fv = flags.FlagValues() + enum_defaults = Fruit.APPLE + flags.DEFINE_multi_enum_class( + 'fruit', + enum_defaults, + Fruit, + 'Enum option that can occur multiple times', + flag_values=fv) + fv.mark_as_parsed() + + self.assertListEqual(fv.fruit, [Fruit.APPLE]) + + def test_define_results_in_registered_flag_with_string_list(self): + fv = flags.FlagValues() + enum_defaults = ['apple', 'APPLE'] + flags.DEFINE_multi_enum_class( + 'fruit', + enum_defaults, + CaseSensitiveFruit, + 'Enum option that can occur multiple times', + flag_values=fv, + case_sensitive=True) + fv.mark_as_parsed() + + self.assertListEqual(fv.fruit, + [CaseSensitiveFruit.apple, CaseSensitiveFruit.APPLE]) + + def test_define_results_in_registered_flag_with_enum_list(self): + fv = flags.FlagValues() + enum_defaults = [Fruit.APPLE, Fruit.ORANGE] + flags.DEFINE_multi_enum_class( + 'fruit', + enum_defaults, + Fruit, + 'Enum option that can occur multiple times', + flag_values=fv) + fv.mark_as_parsed() + + self.assertListEqual(fv.fruit, [Fruit.APPLE, Fruit.ORANGE]) + + def test_from_command_line_returns_multiple(self): + fv = flags.FlagValues() + enum_defaults = [Fruit.APPLE] + flags.DEFINE_multi_enum_class( + 'fruit', + enum_defaults, + Fruit, + 'Enum option that can occur multiple times', + flag_values=fv) + argv = ('./program', '--fruit=Apple', '--fruit=orange') + fv(argv) + self.assertListEqual(fv.fruit, [Fruit.APPLE, Fruit.ORANGE]) + + def test_bad_multi_enum_class_flags_from_definition(self): + with self.assertRaisesRegex( + flags.IllegalFlagValueError, + 'flag --fruit=INVALID: value should be one of <apple|orange|APPLE>'): + flags.DEFINE_multi_enum_class('fruit', ['INVALID'], Fruit, 'desc') + + def test_bad_multi_enum_class_flags_from_commandline(self): + fv = flags.FlagValues() + enum_defaults = [Fruit.APPLE] + flags.DEFINE_multi_enum_class( + 'fruit', enum_defaults, Fruit, 'desc', flag_values=fv) + argv = ('./program', '--fruit=INVALID') + with self.assertRaisesRegex( + flags.IllegalFlagValueError, + 'flag --fruit=INVALID: value should be one of <apple|orange|APPLE>'): + fv(argv) + + +class UnicodeFlagsTest(absltest.TestCase): + """Testing proper unicode support for flags.""" + + def test_unicode_default_and_helpstring(self): + fv = flags.FlagValues() + flags.DEFINE_string( + 'unicode_str', + b'\xC3\x80\xC3\xBD'.decode('utf-8'), + b'help:\xC3\xAA'.decode('utf-8'), + flag_values=fv) + argv = ('./program',) + fv(argv) # should not raise any exceptions + + argv = ('./program', '--unicode_str=foo') + fv(argv) # should not raise any exceptions + + def test_unicode_in_list(self): + fv = flags.FlagValues() + flags.DEFINE_list( + 'unicode_list', + ['abc', b'\xC3\x80'.decode('utf-8'), b'\xC3\xBD'.decode('utf-8')], + b'help:\xC3\xAB'.decode('utf-8'), + flag_values=fv) + argv = ('./program',) + fv(argv) # should not raise any exceptions + + argv = ('./program', '--unicode_list=hello,there') + fv(argv) # should not raise any exceptions + + def test_xmloutput(self): + fv = flags.FlagValues() + flags.DEFINE_string( + 'unicode1', + b'\xC3\x80\xC3\xBD'.decode('utf-8'), + b'help:\xC3\xAC'.decode('utf-8'), + flag_values=fv) + flags.DEFINE_list( + 'unicode2', + ['abc', b'\xC3\x80'.decode('utf-8'), b'\xC3\xBD'.decode('utf-8')], + b'help:\xC3\xAD'.decode('utf-8'), + flag_values=fv) + flags.DEFINE_list( + 'non_unicode', ['abc', 'def', 'ghi'], + b'help:\xC3\xAD'.decode('utf-8'), + flag_values=fv) + + outfile = io.StringIO() + fv.write_help_in_xml_format(outfile) + actual_output = outfile.getvalue() + + # The xml output is large, so we just check parts of it. + self.assertIn( + b'<name>unicode1</name>\n' + b' <meaning>help:\xc3\xac</meaning>\n' + b' <default>\xc3\x80\xc3\xbd</default>\n' + b' <current>\xc3\x80\xc3\xbd</current>'.decode('utf-8'), + actual_output) + self.assertIn( + b'<name>unicode2</name>\n' + b' <meaning>help:\xc3\xad</meaning>\n' + b' <default>abc,\xc3\x80,\xc3\xbd</default>\n' + b" <current>['abc', '\xc3\x80', '\xc3\xbd']" + b'</current>'.decode('utf-8'), actual_output) + self.assertIn( + b'<name>non_unicode</name>\n' + b' <meaning>help:\xc3\xad</meaning>\n' + b' <default>abc,def,ghi</default>\n' + b" <current>['abc', 'def', 'ghi']" + b'</current>'.decode('utf-8'), actual_output) + + +class LoadFromFlagFileTest(absltest.TestCase): + """Testing loading flags from a file and parsing them.""" + + def setUp(self): + self.flag_values = flags.FlagValues() + flags.DEFINE_string( + 'unittest_message1', + 'Foo!', + 'You Add Here.', + flag_values=self.flag_values) + flags.DEFINE_string( + 'unittest_message2', + 'Bar!', + 'Hello, Sailor!', + flag_values=self.flag_values) + flags.DEFINE_boolean( + 'unittest_boolflag', + 0, + 'Some Boolean thing', + flag_values=self.flag_values) + flags.DEFINE_integer( + 'unittest_number', + 12345, + 'Some integer', + lower_bound=0, + flag_values=self.flag_values) + flags.DEFINE_list( + 'UnitTestList', '1,2,3', 'Some list', flag_values=self.flag_values) + self.tmp_path = None + self.flag_values.mark_as_parsed() + + def tearDown(self): + self._remove_test_files() + + def _setup_test_files(self): + """Creates and sets up some dummy flagfile files with bogus flags.""" + + # Figure out where to create temporary files + self.assertFalse(self.tmp_path) + self.tmp_path = tempfile.mkdtemp() + + tmp_flag_file_1 = open(self.tmp_path + '/UnitTestFile1.tst', 'w') + tmp_flag_file_2 = open(self.tmp_path + '/UnitTestFile2.tst', 'w') + tmp_flag_file_3 = open(self.tmp_path + '/UnitTestFile3.tst', 'w') + tmp_flag_file_4 = open(self.tmp_path + '/UnitTestFile4.tst', 'w') + + # put some dummy flags in our test files + tmp_flag_file_1.write('#A Fake Comment\n') + tmp_flag_file_1.write('--unittest_message1=tempFile1!\n') + tmp_flag_file_1.write('\n') + tmp_flag_file_1.write('--unittest_number=54321\n') + tmp_flag_file_1.write('--nounittest_boolflag\n') + file_list = [tmp_flag_file_1.name] + # this one includes test file 1 + tmp_flag_file_2.write('//A Different Fake Comment\n') + tmp_flag_file_2.write('--flagfile=%s\n' % tmp_flag_file_1.name) + tmp_flag_file_2.write('--unittest_message2=setFromTempFile2\n') + tmp_flag_file_2.write('\t\t\n') + tmp_flag_file_2.write('--unittest_number=6789a\n') + file_list.append(tmp_flag_file_2.name) + # this file points to itself + tmp_flag_file_3.write('--flagfile=%s\n' % tmp_flag_file_3.name) + tmp_flag_file_3.write('--unittest_message1=setFromTempFile3\n') + tmp_flag_file_3.write('#YAFC\n') + tmp_flag_file_3.write('--unittest_boolflag\n') + file_list.append(tmp_flag_file_3.name) + # this file is unreadable + tmp_flag_file_4.write('--flagfile=%s\n' % tmp_flag_file_3.name) + tmp_flag_file_4.write('--unittest_message1=setFromTempFile4\n') + tmp_flag_file_4.write('--unittest_message1=setFromTempFile4\n') + os.chmod(self.tmp_path + '/UnitTestFile4.tst', 0) + file_list.append(tmp_flag_file_4.name) + + tmp_flag_file_1.close() + tmp_flag_file_2.close() + tmp_flag_file_3.close() + tmp_flag_file_4.close() + + return file_list # these are just the file names + + def _remove_test_files(self): + """Removes the files we just created.""" + if self.tmp_path: + shutil.rmtree(self.tmp_path, ignore_errors=True) + self.tmp_path = None + + def _read_flags_from_files(self, argv, force_gnu): + return argv[:1] + self.flag_values.read_flags_from_files( + argv[1:], force_gnu=force_gnu) + + #### Flagfile Unit Tests #### + def test_method_flagfiles_1(self): + """Test trivial case with no flagfile based options.""" + fake_cmd_line = 'fooScript --unittest_boolflag' + fake_argv = fake_cmd_line.split(' ') + self.flag_values(fake_argv) + self.assertEqual(self.flag_values.unittest_boolflag, 1) + self.assertListEqual(fake_argv, + self._read_flags_from_files(fake_argv, False)) + + def test_method_flagfiles_2(self): + """Tests parsing one file + arguments off simulated argv.""" + tmp_files = self._setup_test_files() + # specify our temp file on the fake cmd line + fake_cmd_line = 'fooScript --q --flagfile=%s' % tmp_files[0] + fake_argv = fake_cmd_line.split(' ') + + # We should see the original cmd line with the file's contents spliced in. + # Flags from the file will appear in the order order they are specified + # in the file, in the same position as the flagfile argument. + expected_results = [ + 'fooScript', '--q', '--unittest_message1=tempFile1!', + '--unittest_number=54321', '--nounittest_boolflag' + ] + test_results = self._read_flags_from_files(fake_argv, False) + self.assertListEqual(expected_results, test_results) + + # end testTwo def + + def test_method_flagfiles_3(self): + """Tests parsing nested files + arguments of simulated argv.""" + tmp_files = self._setup_test_files() + # specify our temp file on the fake cmd line + fake_cmd_line = ('fooScript --unittest_number=77 --flagfile=%s' % + tmp_files[1]) + fake_argv = fake_cmd_line.split(' ') + + expected_results = [ + 'fooScript', '--unittest_number=77', '--unittest_message1=tempFile1!', + '--unittest_number=54321', '--nounittest_boolflag', + '--unittest_message2=setFromTempFile2', '--unittest_number=6789a' + ] + test_results = self._read_flags_from_files(fake_argv, False) + self.assertListEqual(expected_results, test_results) + + # end testThree def + + def test_method_flagfiles_3_spaces(self): + """Tests parsing nested files + arguments of simulated argv. + + The arguments include a pair that is actually an arg with a value, so it + doesn't stop processing. + """ + tmp_files = self._setup_test_files() + # specify our temp file on the fake cmd line + fake_cmd_line = ('fooScript --unittest_number 77 --flagfile=%s' % + tmp_files[1]) + fake_argv = fake_cmd_line.split(' ') + + expected_results = [ + 'fooScript', '--unittest_number', '77', + '--unittest_message1=tempFile1!', '--unittest_number=54321', + '--nounittest_boolflag', '--unittest_message2=setFromTempFile2', + '--unittest_number=6789a' + ] + test_results = self._read_flags_from_files(fake_argv, False) + self.assertListEqual(expected_results, test_results) + + def test_method_flagfiles_3_spaces_boolean(self): + """Tests parsing nested files + arguments of simulated argv. + + The arguments include a pair that looks like a --x y arg with value, but + since the flag is a boolean it's actually not. + """ + tmp_files = self._setup_test_files() + # specify our temp file on the fake cmd line + fake_cmd_line = ('fooScript --unittest_boolflag 77 --flagfile=%s' % + tmp_files[1]) + fake_argv = fake_cmd_line.split(' ') + + expected_results = [ + 'fooScript', '--unittest_boolflag', '77', + '--flagfile=%s' % tmp_files[1] + ] + with _use_gnu_getopt(self.flag_values, False): + test_results = self._read_flags_from_files(fake_argv, False) + self.assertListEqual(expected_results, test_results) + + def test_method_flagfiles_4(self): + """Tests parsing self-referential files + arguments of simulated argv. + + This test should print a warning to stderr of some sort. + """ + tmp_files = self._setup_test_files() + # specify our temp file on the fake cmd line + fake_cmd_line = ('fooScript --flagfile=%s --nounittest_boolflag' % + tmp_files[2]) + fake_argv = fake_cmd_line.split(' ') + expected_results = [ + 'fooScript', '--unittest_message1=setFromTempFile3', + '--unittest_boolflag', '--nounittest_boolflag' + ] + + test_results = self._read_flags_from_files(fake_argv, False) + self.assertListEqual(expected_results, test_results) + + def test_method_flagfiles_5(self): + """Test that --flagfile parsing respects the '--' end-of-options marker.""" + tmp_files = self._setup_test_files() + # specify our temp file on the fake cmd line + fake_cmd_line = 'fooScript --some_flag -- --flagfile=%s' % tmp_files[0] + fake_argv = fake_cmd_line.split(' ') + expected_results = [ + 'fooScript', '--some_flag', '--', + '--flagfile=%s' % tmp_files[0] + ] + + test_results = self._read_flags_from_files(fake_argv, False) + self.assertListEqual(expected_results, test_results) + + def test_method_flagfiles_6(self): + """Test that --flagfile parsing stops at non-options (non-GNU behavior).""" + tmp_files = self._setup_test_files() + # specify our temp file on the fake cmd line + fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' % + tmp_files[0]) + fake_argv = fake_cmd_line.split(' ') + expected_results = [ + 'fooScript', '--some_flag', 'some_arg', + '--flagfile=%s' % tmp_files[0] + ] + + with _use_gnu_getopt(self.flag_values, False): + test_results = self._read_flags_from_files(fake_argv, False) + self.assertListEqual(expected_results, test_results) + + def test_method_flagfiles_7(self): + """Test that --flagfile parsing skips over a non-option (GNU behavior).""" + self.flag_values.set_gnu_getopt() + tmp_files = self._setup_test_files() + # specify our temp file on the fake cmd line + fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' % + tmp_files[0]) + fake_argv = fake_cmd_line.split(' ') + expected_results = [ + 'fooScript', '--some_flag', 'some_arg', + '--unittest_message1=tempFile1!', '--unittest_number=54321', + '--nounittest_boolflag' + ] + + test_results = self._read_flags_from_files(fake_argv, False) + self.assertListEqual(expected_results, test_results) + + def test_method_flagfiles_8(self): + """Test that --flagfile parsing respects force_gnu=True.""" + tmp_files = self._setup_test_files() + # specify our temp file on the fake cmd line + fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' % + tmp_files[0]) + fake_argv = fake_cmd_line.split(' ') + expected_results = [ + 'fooScript', '--some_flag', 'some_arg', + '--unittest_message1=tempFile1!', '--unittest_number=54321', + '--nounittest_boolflag' + ] + + test_results = self._read_flags_from_files(fake_argv, True) + self.assertListEqual(expected_results, test_results) + + def test_method_flagfiles_repeated_non_circular(self): + """Tests that parsing repeated non-circular flagfiles works.""" + tmp_files = self._setup_test_files() + # specify our temp files on the fake cmd line + fake_cmd_line = ('fooScript --flagfile=%s --flagfile=%s' % + (tmp_files[1], tmp_files[0])) + fake_argv = fake_cmd_line.split(' ') + expected_results = [ + 'fooScript', '--unittest_message1=tempFile1!', + '--unittest_number=54321', '--nounittest_boolflag', + '--unittest_message2=setFromTempFile2', '--unittest_number=6789a', + '--unittest_message1=tempFile1!', '--unittest_number=54321', + '--nounittest_boolflag' + ] + + test_results = self._read_flags_from_files(fake_argv, False) + self.assertListEqual(expected_results, test_results) + + @unittest.skipIf( + os.name == 'nt', + 'There is no good way to create an unreadable file on Windows.') + def test_method_flagfiles_no_permissions(self): + """Test that --flagfile raises except on file that is unreadable.""" + tmp_files = self._setup_test_files() + # specify our temp file on the fake cmd line + fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' % + tmp_files[3]) + fake_argv = fake_cmd_line.split(' ') + self.assertRaises(flags.CantOpenFlagFileError, self._read_flags_from_files, + fake_argv, True) + + def test_method_flagfiles_not_found(self): + """Test that --flagfile raises except on file that does not exist.""" + tmp_files = self._setup_test_files() + # specify our temp file on the fake cmd line + fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%sNOTEXIST' % + tmp_files[3]) + fake_argv = fake_cmd_line.split(' ') + self.assertRaises(flags.CantOpenFlagFileError, self._read_flags_from_files, + fake_argv, True) + + def test_flagfiles_user_path_expansion(self): + """Test that user directory referenced paths are correctly expanded. + + Test paths like ~/foo. This test depends on whatever account's running + the unit test to have read/write access to their own home directory, + otherwise it'll FAIL. + """ + fake_flagfile_item_style_1 = '--flagfile=~/foo.file' + fake_flagfile_item_style_2 = '-flagfile=~/foo.file' + + expected_results = os.path.expanduser('~/foo.file') + + test_results = self.flag_values._extract_filename( + fake_flagfile_item_style_1) + self.assertEqual(expected_results, test_results) + + test_results = self.flag_values._extract_filename( + fake_flagfile_item_style_2) + self.assertEqual(expected_results, test_results) + + def test_no_touchy_non_flags(self): + """Test that the flags parser does not mutilate arguments. + + The arguments are not supposed to be flags + """ + fake_argv = [ + 'fooScript', '--unittest_boolflag', 'command', '--command_arg1', + '--UnitTestBoom', '--UnitTestB' + ] + with _use_gnu_getopt(self.flag_values, False): + argv = self.flag_values(fake_argv) + self.assertListEqual(argv, fake_argv[:1] + fake_argv[2:]) + + def test_parse_flags_after_args_if_using_gnugetopt(self): + """Test that flags given after arguments are parsed if using gnu_getopt.""" + self.flag_values.set_gnu_getopt() + fake_argv = [ + 'fooScript', '--unittest_boolflag', 'command', '--unittest_number=54321' + ] + argv = self.flag_values(fake_argv) + self.assertListEqual(argv, ['fooScript', 'command']) + + def test_set_default(self): + """Test changing flag defaults.""" + # Test that set_default changes both the default and the value, + # and that the value is changed when one is given as an option. + self.flag_values.set_default('unittest_message1', 'New value') + self.assertEqual(self.flag_values.unittest_message1, 'New value') + self.assertEqual(self.flag_values['unittest_message1'].default_as_str, + "'New value'") + self.flag_values(['dummyscript', '--unittest_message1=Newer value']) + self.assertEqual(self.flag_values.unittest_message1, 'Newer value') + + # Test that setting the default to None works correctly. + self.flag_values.set_default('unittest_number', None) + self.assertEqual(self.flag_values.unittest_number, None) + self.assertEqual(self.flag_values['unittest_number'].default_as_str, None) + self.flag_values(['dummyscript', '--unittest_number=56']) + self.assertEqual(self.flag_values.unittest_number, 56) + + # Test that setting the default to zero works correctly. + self.flag_values.set_default('unittest_number', 0) + self.assertEqual(self.flag_values['unittest_number'].default, 0) + self.assertEqual(self.flag_values.unittest_number, 56) + self.assertEqual(self.flag_values['unittest_number'].default_as_str, "'0'") + self.flag_values(['dummyscript', '--unittest_number=56']) + self.assertEqual(self.flag_values.unittest_number, 56) + + # Test that setting the default to '' works correctly. + self.flag_values.set_default('unittest_message1', '') + self.assertEqual(self.flag_values['unittest_message1'].default, '') + self.assertEqual(self.flag_values.unittest_message1, 'Newer value') + self.assertEqual(self.flag_values['unittest_message1'].default_as_str, "''") + self.flag_values(['dummyscript', '--unittest_message1=fifty-six']) + self.assertEqual(self.flag_values.unittest_message1, 'fifty-six') + + # Test that setting the default to false works correctly. + self.flag_values.set_default('unittest_boolflag', False) + self.assertEqual(self.flag_values.unittest_boolflag, False) + self.assertEqual(self.flag_values['unittest_boolflag'].default_as_str, + "'false'") + self.flag_values(['dummyscript', '--unittest_boolflag=true']) + self.assertEqual(self.flag_values.unittest_boolflag, True) + + # Test that setting a list default works correctly. + self.flag_values.set_default('UnitTestList', '4,5,6') + self.assertListEqual(self.flag_values.UnitTestList, ['4', '5', '6']) + self.assertEqual(self.flag_values['UnitTestList'].default_as_str, "'4,5,6'") + self.flag_values(['dummyscript', '--UnitTestList=7,8,9']) + self.assertListEqual(self.flag_values.UnitTestList, ['7', '8', '9']) + + # Test that setting invalid defaults raises exceptions + with self.assertRaises(flags.IllegalFlagValueError): + self.flag_values.set_default('unittest_number', 'oops') + with self.assertRaises(flags.IllegalFlagValueError): + self.flag_values.set_default('unittest_number', -1) + + +class FlagsParsingTest(absltest.TestCase): + """Testing different aspects of parsing: '-f' vs '--flag', etc.""" + + def setUp(self): + self.flag_values = flags.FlagValues() + + def test_two_dash_arg_first(self): + flags.DEFINE_string( + 'twodash_name', 'Bob', 'namehelp', flag_values=self.flag_values) + flags.DEFINE_string( + 'twodash_blame', 'Rob', 'blamehelp', flag_values=self.flag_values) + argv = ('./program', '--', '--twodash_name=Harry') + argv = self.flag_values(argv) + self.assertEqual('Bob', self.flag_values.twodash_name) + self.assertEqual(argv[1], '--twodash_name=Harry') + + def test_two_dash_arg_middle(self): + flags.DEFINE_string( + 'twodash2_name', 'Bob', 'namehelp', flag_values=self.flag_values) + flags.DEFINE_string( + 'twodash2_blame', 'Rob', 'blamehelp', flag_values=self.flag_values) + argv = ('./program', '--twodash2_blame=Larry', '--', + '--twodash2_name=Harry') + argv = self.flag_values(argv) + self.assertEqual('Bob', self.flag_values.twodash2_name) + self.assertEqual('Larry', self.flag_values.twodash2_blame) + self.assertEqual(argv[1], '--twodash2_name=Harry') + + def test_one_dash_arg_first(self): + flags.DEFINE_string( + 'onedash_name', 'Bob', 'namehelp', flag_values=self.flag_values) + flags.DEFINE_string( + 'onedash_blame', 'Rob', 'blamehelp', flag_values=self.flag_values) + argv = ('./program', '-', '--onedash_name=Harry') + with _use_gnu_getopt(self.flag_values, False): + argv = self.flag_values(argv) + self.assertEqual(len(argv), 3) + self.assertEqual(argv[1], '-') + self.assertEqual(argv[2], '--onedash_name=Harry') + + def test_required_flag_not_specified(self): + flags.DEFINE_string( + 'str_flag', + default=None, + help='help', + required=True, + flag_values=self.flag_values) + argv = ('./program',) + with _use_gnu_getopt(self.flag_values, False): + with self.assertRaises(flags.IllegalFlagValueError): + self.flag_values(argv) + + def test_required_arg_works_with_other_validators(self): + flags.DEFINE_integer( + 'int_flag', + default=None, + help='help', + required=True, + lower_bound=4, + flag_values=self.flag_values) + argv = ('./program', '--int_flag=2') + with _use_gnu_getopt(self.flag_values, False): + with self.assertRaises(flags.IllegalFlagValueError): + self.flag_values(argv) + + def test_unrecognized_flags(self): + flags.DEFINE_string('name', 'Bob', 'namehelp', flag_values=self.flag_values) + # Unknown flag --nosuchflag + try: + argv = ('./program', '--nosuchflag', '--name=Bob', 'extra') + self.flag_values(argv) + raise AssertionError('Unknown flag exception not raised') + except flags.UnrecognizedFlagError as e: + self.assertEqual(e.flagname, 'nosuchflag') + self.assertEqual(e.flagvalue, '--nosuchflag') + + # Unknown flag -w (short option) + try: + argv = ('./program', '-w', '--name=Bob', 'extra') + self.flag_values(argv) + raise AssertionError('Unknown flag exception not raised') + except flags.UnrecognizedFlagError as e: + self.assertEqual(e.flagname, 'w') + self.assertEqual(e.flagvalue, '-w') + + # Unknown flag --nosuchflagwithparam=foo + try: + argv = ('./program', '--nosuchflagwithparam=foo', '--name=Bob', 'extra') + self.flag_values(argv) + raise AssertionError('Unknown flag exception not raised') + except flags.UnrecognizedFlagError as e: + self.assertEqual(e.flagname, 'nosuchflagwithparam') + self.assertEqual(e.flagvalue, '--nosuchflagwithparam=foo') + + # Allow unknown flag --nosuchflag if specified with undefok + argv = ('./program', '--nosuchflag', '--name=Bob', '--undefok=nosuchflag', + 'extra') + argv = self.flag_values(argv) + self.assertEqual(len(argv), 2, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + self.assertEqual(argv[1], 'extra', 'extra argument not preserved') + + # Allow unknown flag --noboolflag if undefok=boolflag is specified + argv = ('./program', '--noboolflag', '--name=Bob', '--undefok=boolflag', + 'extra') + argv = self.flag_values(argv) + self.assertEqual(len(argv), 2, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + self.assertEqual(argv[1], 'extra', 'extra argument not preserved') + + # But not if the flagname is misspelled: + try: + argv = ('./program', '--nosuchflag', '--name=Bob', '--undefok=nosuchfla', + 'extra') + self.flag_values(argv) + raise AssertionError('Unknown flag exception not raised') + except flags.UnrecognizedFlagError as e: + self.assertEqual(e.flagname, 'nosuchflag') + + try: + argv = ('./program', '--nosuchflag', '--name=Bob', + '--undefok=nosuchflagg', 'extra') + self.flag_values(argv) + raise AssertionError('Unknown flag exception not raised') + except flags.UnrecognizedFlagError as e: + self.assertEqual(e.flagname, 'nosuchflag') + + # Allow unknown short flag -w if specified with undefok + argv = ('./program', '-w', '--name=Bob', '--undefok=w', 'extra') + argv = self.flag_values(argv) + self.assertEqual(len(argv), 2, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + self.assertEqual(argv[1], 'extra', 'extra argument not preserved') + + # Allow unknown flag --nosuchflagwithparam=foo if specified + # with undefok + argv = ('./program', '--nosuchflagwithparam=foo', '--name=Bob', + '--undefok=nosuchflagwithparam', 'extra') + argv = self.flag_values(argv) + self.assertEqual(len(argv), 2, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + self.assertEqual(argv[1], 'extra', 'extra argument not preserved') + + # Even if undefok specifies multiple flags + argv = ('./program', '--nosuchflag', '-w', '--nosuchflagwithparam=foo', + '--name=Bob', '--undefok=nosuchflag,w,nosuchflagwithparam', 'extra') + argv = self.flag_values(argv) + self.assertEqual(len(argv), 2, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + self.assertEqual(argv[1], 'extra', 'extra argument not preserved') + + # However, not if undefok doesn't specify the flag + try: + argv = ('./program', '--nosuchflag', '--name=Bob', + '--undefok=another_such', 'extra') + self.flag_values(argv) + raise AssertionError('Unknown flag exception not raised') + except flags.UnrecognizedFlagError as e: + self.assertEqual(e.flagname, 'nosuchflag') + + # Make sure --undefok doesn't mask other option errors. + try: + # Provide an option requiring a parameter but not giving it one. + argv = ('./program', '--undefok=name', '--name') + self.flag_values(argv) + raise AssertionError('Missing option parameter exception not raised') + except flags.UnrecognizedFlagError: + raise AssertionError('Wrong kind of error exception raised') + except flags.Error: + pass + + # Test --undefok <list> + argv = ('./program', '--nosuchflag', '-w', '--nosuchflagwithparam=foo', + '--name=Bob', '--undefok', 'nosuchflag,w,nosuchflagwithparam', + 'extra') + argv = self.flag_values(argv) + self.assertEqual(len(argv), 2, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + self.assertEqual(argv[1], 'extra', 'extra argument not preserved') + + # Test incorrect --undefok with no value. + argv = ('./program', '--name=Bob', '--undefok') + with self.assertRaises(flags.Error): + self.flag_values(argv) + + +class NonGlobalFlagsTest(absltest.TestCase): + + def test_nonglobal_flags(self): + """Test use of non-global FlagValues.""" + nonglobal_flags = flags.FlagValues() + flags.DEFINE_string('nonglobal_flag', 'Bob', 'flaghelp', nonglobal_flags) + argv = ('./program', '--nonglobal_flag=Mary', 'extra') + argv = nonglobal_flags(argv) + self.assertEqual(len(argv), 2, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + self.assertEqual(argv[1], 'extra', 'extra argument not preserved') + self.assertEqual(nonglobal_flags['nonglobal_flag'].value, 'Mary') + + def test_unrecognized_nonglobal_flags(self): + """Test unrecognized non-global flags.""" + nonglobal_flags = flags.FlagValues() + argv = ('./program', '--nosuchflag') + try: + argv = nonglobal_flags(argv) + raise AssertionError('Unknown flag exception not raised') + except flags.UnrecognizedFlagError as e: + self.assertEqual(e.flagname, 'nosuchflag') + + argv = ('./program', '--nosuchflag', '--undefok=nosuchflag') + + argv = nonglobal_flags(argv) + self.assertEqual(len(argv), 1, 'wrong number of arguments pulled') + self.assertEqual(argv[0], './program', 'program name not preserved') + + def test_create_flag_errors(self): + # Since the exception classes are exposed, nothing stops users + # from creating their own instances. This test makes sure that + # people modifying the flags module understand that the external + # mechanisms for creating the exceptions should continue to work. + _ = flags.Error() + _ = flags.Error('message') + _ = flags.DuplicateFlagError() + _ = flags.DuplicateFlagError('message') + _ = flags.IllegalFlagValueError() + _ = flags.IllegalFlagValueError('message') + + def test_flag_values_del_attr(self): + """Checks that del self.flag_values.flag_id works.""" + default_value = 'default value for test_flag_values_del_attr' + # 1. Declare and delete a flag with no short name. + flag_values = flags.FlagValues() + flags.DEFINE_string( + 'delattr_foo', default_value, 'A simple flag.', flag_values=flag_values) + + flag_values.mark_as_parsed() + self.assertEqual(flag_values.delattr_foo, default_value) + flag_obj = flag_values['delattr_foo'] + # We also check that _FlagIsRegistered works as expected :) + self.assertTrue(flag_values._flag_is_registered(flag_obj)) + del flag_values.delattr_foo + self.assertFalse('delattr_foo' in flag_values._flags()) + self.assertFalse(flag_values._flag_is_registered(flag_obj)) + # If the previous del FLAGS.delattr_foo did not work properly, the + # next definition will trigger a redefinition error. + flags.DEFINE_integer( + 'delattr_foo', 3, 'A simple flag.', flag_values=flag_values) + del flag_values.delattr_foo + + self.assertFalse('delattr_foo' in flag_values) + + # 2. Declare and delete a flag with a short name. + flags.DEFINE_string( + 'delattr_bar', + default_value, + 'flag with short name', + short_name='x5', + flag_values=flag_values) + flag_obj = flag_values['delattr_bar'] + self.assertTrue(flag_values._flag_is_registered(flag_obj)) + del flag_values.x5 + self.assertTrue(flag_values._flag_is_registered(flag_obj)) + del flag_values.delattr_bar + self.assertFalse(flag_values._flag_is_registered(flag_obj)) + + # 3. Just like 2, but del flag_values.name last + flags.DEFINE_string( + 'delattr_bar', + default_value, + 'flag with short name', + short_name='x5', + flag_values=flag_values) + flag_obj = flag_values['delattr_bar'] + self.assertTrue(flag_values._flag_is_registered(flag_obj)) + del flag_values.delattr_bar + self.assertTrue(flag_values._flag_is_registered(flag_obj)) + del flag_values.x5 + self.assertFalse(flag_values._flag_is_registered(flag_obj)) + + self.assertFalse('delattr_bar' in flag_values) + self.assertFalse('x5' in flag_values) + + def test_list_flag_format(self): + """Tests for correctly-formatted list flags.""" + fv = flags.FlagValues() + flags.DEFINE_list('listflag', '', 'A list of arguments', flag_values=fv) + + def _check_parsing(listval): + """Parse a particular value for our test flag, --listflag.""" + argv = fv(['./program', '--listflag=' + listval, 'plain-arg']) + self.assertEqual(['./program', 'plain-arg'], argv) + return fv.listflag + + # Basic success case + self.assertEqual(_check_parsing('foo,bar'), ['foo', 'bar']) + # Success case: newline in argument is quoted. + self.assertEqual(_check_parsing('"foo","bar\nbar"'), ['foo', 'bar\nbar']) + # Failure case: newline in argument is unquoted. + self.assertRaises(flags.IllegalFlagValueError, _check_parsing, + '"foo",bar\nbar') + # Failure case: unmatched ". + self.assertRaises(flags.IllegalFlagValueError, _check_parsing, + '"foo,barbar') + + def test_flag_definition_via_setitem(self): + with self.assertRaises(flags.IllegalFlagValueError): + flag_values = flags.FlagValues() + flag_values['flag_name'] = 'flag_value' + + +class KeyFlagsTest(absltest.TestCase): + + def setUp(self): + self.flag_values = flags.FlagValues() + + def _get_names_of_defined_flags(self, module, flag_values): + """Returns the list of names of flags defined by a module. + + Auxiliary for the test_key_flags* methods. + + Args: + module: A module object or a string module name. + flag_values: A FlagValues object. + + Returns: + A list of strings. + """ + return [f.name for f in flag_values.get_flags_for_module(module)] + + def _get_names_of_key_flags(self, module, flag_values): + """Returns the list of names of key flags for a module. + + Auxiliary for the test_key_flags* methods. + + Args: + module: A module object or a string module name. + flag_values: A FlagValues object. + + Returns: + A list of strings. + """ + return [f.name for f in flag_values.get_key_flags_for_module(module)] + + def _assert_lists_have_same_elements(self, list_1, list_2): + # Checks that two lists have the same elements with the same + # multiplicity, in possibly different order. + list_1 = list(list_1) + list_1.sort() + list_2 = list(list_2) + list_2.sort() + self.assertListEqual(list_1, list_2) + + def test_key_flags(self): + flag_values = flags.FlagValues() + # Before starting any testing, make sure no flags are already + # defined for module_foo and module_bar. + self.assertListEqual( + self._get_names_of_key_flags(module_foo, flag_values), []) + self.assertListEqual( + self._get_names_of_key_flags(module_bar, flag_values), []) + self.assertListEqual( + self._get_names_of_defined_flags(module_foo, flag_values), []) + self.assertListEqual( + self._get_names_of_defined_flags(module_bar, flag_values), []) + + # Defines a few flags in module_foo and module_bar. + module_foo.define_flags(flag_values=flag_values) + + try: + # Part 1. Check that all flags defined by module_foo are key for + # that module, and similarly for module_bar. + for module in [module_foo, module_bar]: + self._assert_lists_have_same_elements( + flag_values.get_flags_for_module(module), + flag_values.get_key_flags_for_module(module)) + # Also check that each module defined the expected flags. + self._assert_lists_have_same_elements( + self._get_names_of_defined_flags(module, flag_values), + module.names_of_defined_flags()) + + # Part 2. Check that flags.declare_key_flag works fine. + # Declare that some flags from module_bar are key for + # module_foo. + module_foo.declare_key_flags(flag_values=flag_values) + + # Check that module_foo has the expected list of defined flags. + self._assert_lists_have_same_elements( + self._get_names_of_defined_flags(module_foo, flag_values), + module_foo.names_of_defined_flags()) + + # Check that module_foo has the expected list of key flags. + self._assert_lists_have_same_elements( + self._get_names_of_key_flags(module_foo, flag_values), + module_foo.names_of_declared_key_flags()) + + # Part 3. Check that flags.adopt_module_key_flags works fine. + # Trigger a call to flags.adopt_module_key_flags(module_bar) + # inside module_foo. This should declare a few more key + # flags in module_foo. + module_foo.declare_extra_key_flags(flag_values=flag_values) + + # Check that module_foo has the expected list of key flags. + self._assert_lists_have_same_elements( + self._get_names_of_key_flags(module_foo, flag_values), + module_foo.names_of_declared_key_flags() + + module_foo.names_of_declared_extra_key_flags()) + finally: + module_foo.remove_flags(flag_values=flag_values) + + def test_key_flags_with_non_default_flag_values_object(self): + # Check that key flags work even when we use a FlagValues object + # that is not the default flags.self.flag_values object. Otherwise, this + # test is similar to test_key_flags, but it uses only module_bar. + # The other test module (module_foo) uses only the default values + # for the flag_values keyword arguments. This way, test_key_flags + # and this method test both the default FlagValues, the explicitly + # specified one, and a mixed usage of the two. + + # A brand-new FlagValues object, to use instead of flags.self.flag_values. + fv = flags.FlagValues() + + # Before starting any testing, make sure no flags are already + # defined for module_foo and module_bar. + self.assertListEqual(self._get_names_of_key_flags(module_bar, fv), []) + self.assertListEqual(self._get_names_of_defined_flags(module_bar, fv), []) + + module_bar.define_flags(flag_values=fv) + + # Check that all flags defined by module_bar are key for that + # module, and that module_bar defined the expected flags. + self._assert_lists_have_same_elements( + fv.get_flags_for_module(module_bar), + fv.get_key_flags_for_module(module_bar)) + self._assert_lists_have_same_elements( + self._get_names_of_defined_flags(module_bar, fv), + module_bar.names_of_defined_flags()) + + # Pick two flags from module_bar, declare them as key for the + # current (i.e., main) module (via flags.declare_key_flag), and + # check that we get the expected effect. The important thing is + # that we always use flags_values=fv (instead of the default + # self.flag_values). + main_module = sys.argv[0] + names_of_flags_defined_by_bar = module_bar.names_of_defined_flags() + flag_name_0 = names_of_flags_defined_by_bar[0] + flag_name_2 = names_of_flags_defined_by_bar[2] + + flags.declare_key_flag(flag_name_0, flag_values=fv) + self._assert_lists_have_same_elements( + self._get_names_of_key_flags(main_module, fv), [flag_name_0]) + + flags.declare_key_flag(flag_name_2, flag_values=fv) + self._assert_lists_have_same_elements( + self._get_names_of_key_flags(main_module, fv), + [flag_name_0, flag_name_2]) + + # Try with a special (not user-defined) flag too: + flags.declare_key_flag('undefok', flag_values=fv) + self._assert_lists_have_same_elements( + self._get_names_of_key_flags(main_module, fv), + [flag_name_0, flag_name_2, 'undefok']) + + flags.adopt_module_key_flags(module_bar, fv) + self._assert_lists_have_same_elements( + self._get_names_of_key_flags(main_module, fv), + names_of_flags_defined_by_bar + ['undefok']) + + # Adopt key flags from the flags module itself. + flags.adopt_module_key_flags(flags, flag_values=fv) + self._assert_lists_have_same_elements( + self._get_names_of_key_flags(main_module, fv), + names_of_flags_defined_by_bar + ['flagfile', 'undefok']) + + def test_main_module_help_with_key_flags(self): + # Similar to test_main_module_help, but this time we make sure to + # declare some key flags. + + # Safety check that the main module does not declare any flags + # at the beginning of this test. + expected_help = '' + self.assertMultiLineEqual(expected_help, + self.flag_values.main_module_help()) + + # Define one flag in this main module and some flags in modules + # a and b. Also declare one flag from module a and one flag + # from module b as key flags for the main module. + flags.DEFINE_integer( + 'main_module_int_fg', + 1, + 'Integer flag in the main module.', + flag_values=self.flag_values) + + try: + main_module_int_fg_help = ( + ' --main_module_int_fg: Integer flag in the main module.\n' + " (default: '1')\n" + ' (an integer)') + + expected_help += '\n%s:\n%s' % (sys.argv[0], main_module_int_fg_help) + self.assertMultiLineEqual(expected_help, + self.flag_values.main_module_help()) + + # The following call should be a no-op: any flag declared by a + # module is automatically key for that module. + flags.declare_key_flag('main_module_int_fg', flag_values=self.flag_values) + self.assertMultiLineEqual(expected_help, + self.flag_values.main_module_help()) + + # The definition of a few flags in an imported module should not + # change the main module help. + module_foo.define_flags(flag_values=self.flag_values) + self.assertMultiLineEqual(expected_help, + self.flag_values.main_module_help()) + + flags.declare_key_flag('tmod_foo_bool', flag_values=self.flag_values) + tmod_foo_bool_help = ( + ' --[no]tmod_foo_bool: Boolean flag from module foo.\n' + " (default: 'true')") + expected_help += '\n' + tmod_foo_bool_help + self.assertMultiLineEqual(expected_help, + self.flag_values.main_module_help()) + + flags.declare_key_flag('tmod_bar_z', flag_values=self.flag_values) + tmod_bar_z_help = ( + ' --[no]tmod_bar_z: Another boolean flag from module bar.\n' + " (default: 'false')") + # Unfortunately, there is some flag sorting inside + # main_module_help, so we can't keep incrementally extending + # the expected_help string ... + expected_help = ('\n%s:\n%s\n%s\n%s' % + (sys.argv[0], main_module_int_fg_help, tmod_bar_z_help, + tmod_foo_bool_help)) + self.assertMultiLineEqual(self.flag_values.main_module_help(), + expected_help) + + finally: + # At the end, delete all the flag information we created. + self.flag_values.__delattr__('main_module_int_fg') + module_foo.remove_flags(flag_values=self.flag_values) + + def test_adoptmodule_key_flags(self): + # Check that adopt_module_key_flags raises an exception when + # called with a module name (as opposed to a module object). + self.assertRaises(flags.Error, flags.adopt_module_key_flags, 'pyglib.app') + + def test_disclaimkey_flags(self): + original_disclaim_module_ids = _helpers.disclaim_module_ids + _helpers.disclaim_module_ids = set(_helpers.disclaim_module_ids) + try: + module_bar.disclaim_key_flags() + module_foo.define_bar_flags(flag_values=self.flag_values) + module_name = self.flag_values.find_module_defining_flag('tmod_bar_x') + self.assertEqual(module_foo.__name__, module_name) + finally: + _helpers.disclaim_module_ids = original_disclaim_module_ids + + +class FindModuleTest(absltest.TestCase): + """Testing methods that find a module that defines a given flag.""" + + def test_find_module_defining_flag(self): + self.assertEqual( + 'default', + FLAGS.find_module_defining_flag('__NON_EXISTENT_FLAG__', 'default')) + self.assertEqual(module_baz.__name__, + FLAGS.find_module_defining_flag('tmod_baz_x')) + + def test_find_module_id_defining_flag(self): + self.assertEqual( + 'default', + FLAGS.find_module_id_defining_flag('__NON_EXISTENT_FLAG__', 'default')) + self.assertEqual( + id(module_baz), FLAGS.find_module_id_defining_flag('tmod_baz_x')) + + def test_find_module_defining_flag_passing_module_name(self): + my_flags = flags.FlagValues() + module_name = sys.__name__ # Must use an existing module. + flags.DEFINE_boolean( + 'flag_name', + True, + 'Flag with a different module name.', + flag_values=my_flags, + module_name=module_name) + self.assertEqual(module_name, + my_flags.find_module_defining_flag('flag_name')) + + def test_find_module_id_defining_flag_passing_module_name(self): + my_flags = flags.FlagValues() + module_name = sys.__name__ # Must use an existing module. + flags.DEFINE_boolean( + 'flag_name', + True, + 'Flag with a different module name.', + flag_values=my_flags, + module_name=module_name) + self.assertEqual( + id(sys), my_flags.find_module_id_defining_flag('flag_name')) + + +class FlagsErrorMessagesTest(absltest.TestCase): + """Testing special cases for integer and float flags error messages.""" + + def setUp(self): + self.flag_values = flags.FlagValues() + + def test_integer_error_text(self): + # Make sure we get proper error text + flags.DEFINE_integer( + 'positive', + 4, + 'non-negative flag', + lower_bound=1, + flag_values=self.flag_values) + flags.DEFINE_integer( + 'non_negative', + 4, + 'positive flag', + lower_bound=0, + flag_values=self.flag_values) + flags.DEFINE_integer( + 'negative', + -4, + 'negative flag', + upper_bound=-1, + flag_values=self.flag_values) + flags.DEFINE_integer( + 'non_positive', + -4, + 'non-positive flag', + upper_bound=0, + flag_values=self.flag_values) + flags.DEFINE_integer( + 'greater', + 19, + 'greater-than flag', + lower_bound=4, + flag_values=self.flag_values) + flags.DEFINE_integer( + 'smaller', + -19, + 'smaller-than flag', + upper_bound=4, + flag_values=self.flag_values) + flags.DEFINE_integer( + 'usual', + 4, + 'usual flag', + lower_bound=0, + upper_bound=10000, + flag_values=self.flag_values) + flags.DEFINE_integer( + 'another_usual', + 0, + 'usual flag', + lower_bound=-1, + upper_bound=1, + flag_values=self.flag_values) + + self._check_error_message('positive', -4, 'a positive integer') + self._check_error_message('non_negative', -4, 'a non-negative integer') + self._check_error_message('negative', 0, 'a negative integer') + self._check_error_message('non_positive', 4, 'a non-positive integer') + self._check_error_message('usual', -4, 'an integer in the range [0, 10000]') + self._check_error_message('another_usual', 4, + 'an integer in the range [-1, 1]') + self._check_error_message('greater', -5, 'integer >= 4') + self._check_error_message('smaller', 5, 'integer <= 4') + + def test_float_error_text(self): + flags.DEFINE_float( + 'positive', + 4, + 'non-negative flag', + lower_bound=1, + flag_values=self.flag_values) + flags.DEFINE_float( + 'non_negative', + 4, + 'positive flag', + lower_bound=0, + flag_values=self.flag_values) + flags.DEFINE_float( + 'negative', + -4, + 'negative flag', + upper_bound=-1, + flag_values=self.flag_values) + flags.DEFINE_float( + 'non_positive', + -4, + 'non-positive flag', + upper_bound=0, + flag_values=self.flag_values) + flags.DEFINE_float( + 'greater', + 19, + 'greater-than flag', + lower_bound=4, + flag_values=self.flag_values) + flags.DEFINE_float( + 'smaller', + -19, + 'smaller-than flag', + upper_bound=4, + flag_values=self.flag_values) + flags.DEFINE_float( + 'usual', + 4, + 'usual flag', + lower_bound=0, + upper_bound=10000, + flag_values=self.flag_values) + flags.DEFINE_float( + 'another_usual', + 0, + 'usual flag', + lower_bound=-1, + upper_bound=1, + flag_values=self.flag_values) + + self._check_error_message('positive', 0.5, 'number >= 1') + self._check_error_message('non_negative', -4.0, 'a non-negative number') + self._check_error_message('negative', 0.5, 'number <= -1') + self._check_error_message('non_positive', 4.0, 'a non-positive number') + self._check_error_message('usual', -4.0, 'a number in the range [0, 10000]') + self._check_error_message('another_usual', 4.0, + 'a number in the range [-1, 1]') + self._check_error_message('smaller', 5.0, 'number <= 4') + + def _check_error_message(self, flag_name, flag_value, + expected_message_suffix): + """Set a flag to a given value and make sure we get expected message.""" + + try: + self.flag_values.__setattr__(flag_name, flag_value) + raise AssertionError('Bounds exception not raised!') + except flags.IllegalFlagValueError as e: + expected = ('flag --%(name)s=%(value)s: %(value)s is not %(suffix)s' % { + 'name': flag_name, + 'value': flag_value, + 'suffix': expected_message_suffix + }) + self.assertEqual(str(e), expected) + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/flags/tests/flags_unicode_literals_test.py b/absl/flags/tests/flags_unicode_literals_test.py new file mode 100644 index 0000000..e8ed5bf --- /dev/null +++ b/absl/flags/tests/flags_unicode_literals_test.py @@ -0,0 +1,42 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the use of flags when from __future__ import unicode_literals is on.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from absl import flags +from absl.testing import absltest + + +flags.DEFINE_string('seen_in_crittenden', 'alleged mountain lion', + 'This tests if unicode input to these functions works.') + + +class FlagsUnicodeLiteralsTest(absltest.TestCase): + + def testUnicodeFlagNameAndValueAreGood(self): + alleged_mountain_lion = flags.FLAGS.seen_in_crittenden + self.assertTrue( + isinstance(alleged_mountain_lion, type(u'')), + msg='expected flag value to be a {} not {}'.format( + type(u''), type(alleged_mountain_lion))) + self.assertEqual(alleged_mountain_lion, u'alleged mountain lion') + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/flags/tests/module_bar.py b/absl/flags/tests/module_bar.py new file mode 100644 index 0000000..8714d2e --- /dev/null +++ b/absl/flags/tests/module_bar.py @@ -0,0 +1,121 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Auxiliary module for testing flags.py. + +The purpose of this module is to define a few flags. We want to make +sure the unit tests for flags.py involve more than one module. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import flags +from absl.flags import _helpers + +FLAGS = flags.FLAGS + + +def define_flags(flag_values=FLAGS): + """Defines some flags. + + Args: + flag_values: The FlagValues object we want to register the flags + with. + """ + # The 'tmod_bar_' prefix (short for 'test_module_bar') ensures there + # is no name clash with the existing flags. + flags.DEFINE_boolean('tmod_bar_x', True, 'Boolean flag.', + flag_values=flag_values) + flags.DEFINE_string('tmod_bar_y', 'default', 'String flag.', + flag_values=flag_values) + flags.DEFINE_boolean('tmod_bar_z', False, + 'Another boolean flag from module bar.', + flag_values=flag_values) + flags.DEFINE_integer('tmod_bar_t', 4, 'Sample int flag.', + flag_values=flag_values) + flags.DEFINE_integer('tmod_bar_u', 5, 'Sample int flag.', + flag_values=flag_values) + flags.DEFINE_integer('tmod_bar_v', 6, 'Sample int flag.', + flag_values=flag_values) + + +def remove_one_flag(flag_name, flag_values=FLAGS): + """Removes the definition of one flag from flags.FLAGS. + + Note: if the flag is not defined in flags.FLAGS, this function does + not do anything (in particular, it does not raise any exception). + + Motivation: We use this function for cleanup *after* a test: if + there was a failure during a test and not all flags were declared, + we do not want the cleanup code to crash. + + Args: + flag_name: A string, the name of the flag to delete. + flag_values: The FlagValues object we remove the flag from. + """ + if flag_name in flag_values: + flag_values.__delattr__(flag_name) + + +def names_of_defined_flags(): + """Returns: List of names of the flags declared in this module.""" + return ['tmod_bar_x', + 'tmod_bar_y', + 'tmod_bar_z', + 'tmod_bar_t', + 'tmod_bar_u', + 'tmod_bar_v'] + + +def remove_flags(flag_values=FLAGS): + """Deletes the flag definitions done by the above define_flags(). + + Args: + flag_values: The FlagValues object we remove the flags from. + """ + for flag_name in names_of_defined_flags(): + remove_one_flag(flag_name, flag_values=flag_values) + + +def get_module_name(): + """Uses get_calling_module() to return the name of this module. + + For checking that get_calling_module works as expected. + + Returns: + A string, the name of this module. + """ + return _helpers.get_calling_module() + + +def execute_code(code, global_dict): + """Executes some code in a given global environment. + + For testing of get_calling_module. + + Args: + code: A string, the code to be executed. + global_dict: A dictionary, the global environment that code should + be executed in. + """ + # Indeed, using exec generates a lint warning. But some user code + # actually uses exec, and we have to test for it ... + exec(code, global_dict) # pylint: disable=exec-used + + +def disclaim_key_flags(): + """Disclaims flags declared in this module.""" + flags.disclaim_key_flags() diff --git a/absl/flags/tests/module_baz.py b/absl/flags/tests/module_baz.py new file mode 100644 index 0000000..7199516 --- /dev/null +++ b/absl/flags/tests/module_baz.py @@ -0,0 +1,29 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Auxiliary module for testing flags.py. + +The purpose of this module is to test the behavior of flags that are defined +before main() executes. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import flags + +FLAGS = flags.FLAGS + +flags.DEFINE_boolean('tmod_baz_x', True, 'Boolean flag.') diff --git a/absl/flags/tests/module_foo.py b/absl/flags/tests/module_foo.py new file mode 100644 index 0000000..a1a2573 --- /dev/null +++ b/absl/flags/tests/module_foo.py @@ -0,0 +1,128 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Auxiliary module for testing flags.py. + +The purpose of this module is to define a few flags, and declare some +other flags as being important. We want to make sure the unit tests +for flags.py involve more than one module. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import flags +from absl.flags import _helpers +from absl.flags.tests import module_bar + +FLAGS = flags.FLAGS + + +DECLARED_KEY_FLAGS = ['tmod_bar_x', 'tmod_bar_z', 'tmod_bar_t', + # Special (not user-defined) flag: + 'flagfile'] + + +def define_flags(flag_values=FLAGS): + """Defines a few flags.""" + module_bar.define_flags(flag_values=flag_values) + # The 'tmod_foo_' prefix (short for 'test_module_foo') ensures that we + # have no name clash with existing flags. + flags.DEFINE_boolean('tmod_foo_bool', True, 'Boolean flag from module foo.', + flag_values=flag_values) + flags.DEFINE_string('tmod_foo_str', 'default', 'String flag.', + flag_values=flag_values) + flags.DEFINE_integer('tmod_foo_int', 3, 'Sample int flag.', + flag_values=flag_values) + + +def declare_key_flags(flag_values=FLAGS): + """Declares a few key flags.""" + for flag_name in DECLARED_KEY_FLAGS: + flags.declare_key_flag(flag_name, flag_values=flag_values) + + +def declare_extra_key_flags(flag_values=FLAGS): + """Declares some extra key flags.""" + flags.adopt_module_key_flags(module_bar, flag_values=flag_values) + + +def names_of_defined_flags(): + """Returns: list of names of flags defined by this module.""" + return ['tmod_foo_bool', 'tmod_foo_str', 'tmod_foo_int'] + + +def names_of_declared_key_flags(): + """Returns: list of names of key flags for this module.""" + return names_of_defined_flags() + DECLARED_KEY_FLAGS + + +def names_of_declared_extra_key_flags(): + """Returns the list of names of additional key flags for this module. + + These are the flags that became key for this module only as a result + of a call to declare_extra_key_flags() above. I.e., the flags declared + by module_bar, that were not already declared as key for this + module. + + Returns: + The list of names of additional key flags for this module. + """ + names_of_extra_key_flags = list(module_bar.names_of_defined_flags()) + for flag_name in names_of_declared_key_flags(): + while flag_name in names_of_extra_key_flags: + names_of_extra_key_flags.remove(flag_name) + return names_of_extra_key_flags + + +def remove_flags(flag_values=FLAGS): + """Deletes the flag definitions done by the above define_flags().""" + for flag_name in names_of_defined_flags(): + module_bar.remove_one_flag(flag_name, flag_values=flag_values) + module_bar.remove_flags(flag_values=flag_values) + + +def get_module_name(): + """Uses get_calling_module() to return the name of this module. + + For checking that _get_calling_module works as expected. + + Returns: + A string, the name of this module. + """ + return _helpers.get_calling_module() + + +def duplicate_flags(flagnames=None): + """Returns a new FlagValues object with the requested flagnames. + + Used to test DuplicateFlagError detection. + + Args: + flagnames: str, A list of flag names to create. + + Returns: + A FlagValues object with one boolean flag for each name in flagnames. + """ + flag_values = flags.FlagValues() + for name in flagnames: + flags.DEFINE_boolean(name, False, 'Flag named %s' % (name,), + flag_values=flag_values) + return flag_values + + +def define_bar_flags(flag_values=FLAGS): + """Defines flags from module_bar.""" + module_bar.define_flags(flag_values) diff --git a/absl/logging/BUILD b/absl/logging/BUILD new file mode 100644 index 0000000..6c0d1bc --- /dev/null +++ b/absl/logging/BUILD @@ -0,0 +1,100 @@ +licenses(["notice"]) + +py_library( + name = "logging", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":converter", + "//absl/flags", + ], +) + +py_library( + name = "converter", + srcs = ["converter.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], +) + +py_test( + name = "tests/converter_test", + size = "small", + srcs = ["tests/converter_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":converter", + ":logging", + "//absl/testing:absltest", + ], +) + +py_test( + name = "tests/logging_test", + size = "small", + srcs = ["tests/logging_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":logging", + "//absl/flags", + "//absl/testing:absltest", + "//absl/testing:flagsaver", + "//absl/testing:parameterized", + ], +) + +py_test( + name = "tests/log_before_import_test", + srcs = ["tests/log_before_import_test.py"], + main = "tests/log_before_import_test.py", + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":logging", + "//absl/testing:absltest", + ], +) + +py_test( + name = "tests/verbosity_flag_test", + srcs = ["tests/verbosity_flag_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":logging", + "//absl/flags", + "//absl/testing:absltest", + ], +) + +py_binary( + name = "tests/logging_functional_test_helper", + testonly = 1, + srcs = ["tests/logging_functional_test_helper.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":logging", + "//absl:app", + "//absl/flags", + ], +) + +py_test( + name = "tests/logging_functional_test", + size = "large", + srcs = ["tests/logging_functional_test.py"], + data = [":tests/logging_functional_test_helper"], + python_version = "PY3", + shard_count = 50, + srcs_version = "PY3", + deps = [ + ":logging", + "//absl/testing:_bazelize_command", + "//absl/testing:absltest", + "//absl/testing:parameterized", + ], +) diff --git a/absl/logging/__init__.py b/absl/logging/__init__.py new file mode 100644 index 0000000..8804490 --- /dev/null +++ b/absl/logging/__init__.py @@ -0,0 +1,1234 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Abseil Python logging module implemented on top of standard logging. + +Simple usage: + + from absl import logging + + logging.info('Interesting Stuff') + logging.info('Interesting Stuff with Arguments: %d', 42) + + logging.set_verbosity(logging.INFO) + logging.log(logging.DEBUG, 'This will *not* be printed') + logging.set_verbosity(logging.DEBUG) + logging.log(logging.DEBUG, 'This will be printed') + + logging.warning('Worrying Stuff') + logging.error('Alarming Stuff') + logging.fatal('AAAAHHHHH!!!!') # Process exits. + +Usage note: Do not pre-format the strings in your program code. +Instead, let the logging module perform argument interpolation. +This saves cycles because strings that don't need to be printed +are never formatted. Note that this module does not attempt to +interpolate arguments when no arguments are given. In other words + + logging.info('Interesting Stuff: %s') + +does not raise an exception because logging.info() has only one +argument, the message string. + +"Lazy" evaluation for debugging: + +If you do something like this: + logging.debug('Thing: %s', thing.ExpensiveOp()) +then the ExpensiveOp will be evaluated even if nothing +is printed to the log. To avoid this, use the level_debug() function: + if logging.level_debug(): + logging.debug('Thing: %s', thing.ExpensiveOp()) + +Per file level logging is supported by logging.vlog() and +logging.vlog_is_on(). For example: + + if logging.vlog_is_on(2): + logging.vlog(2, very_expensive_debug_message()) + +Notes on Unicode: + +The log output is encoded as UTF-8. Don't pass data in other encodings in +bytes() instances -- instead pass unicode string instances when you need to +(for both the format string and arguments). + +Note on critical and fatal: +Standard logging module defines fatal as an alias to critical, but it's not +documented, and it does NOT actually terminate the program. +This module only defines fatal but not critical, and it DOES terminate the +program. + +The differences in behavior are historical and unfortunate. +""" + +import collections +from collections import abc +import getpass +import io +import itertools +import logging +import os +import socket +import struct +import sys +import threading +import time +import timeit +import traceback +import types +import warnings + +from absl import flags +from absl.logging import converter + +try: + from typing import NoReturn +except ImportError: + pass + + +FLAGS = flags.FLAGS + + +# Logging levels. +FATAL = converter.ABSL_FATAL +ERROR = converter.ABSL_ERROR +WARNING = converter.ABSL_WARNING +WARN = converter.ABSL_WARNING # Deprecated name. +INFO = converter.ABSL_INFO +DEBUG = converter.ABSL_DEBUG + +# Regex to match/parse log line prefixes. +ABSL_LOGGING_PREFIX_REGEX = ( + r'^(?P<severity>[IWEF])' + r'(?P<month>\d\d)(?P<day>\d\d) ' + r'(?P<hour>\d\d):(?P<minute>\d\d):(?P<second>\d\d)' + r'\.(?P<microsecond>\d\d\d\d\d\d) +' + r'(?P<thread_id>-?\d+) ' + r'(?P<filename>[a-zA-Z<][\w._<>-]+):(?P<line>\d+)') + + +# Mask to convert integer thread ids to unsigned quantities for logging purposes +_THREAD_ID_MASK = 2 ** (struct.calcsize('L') * 8) - 1 + +# Extra property set on the LogRecord created by ABSLLogger when its level is +# CRITICAL/FATAL. +_ABSL_LOG_FATAL = '_absl_log_fatal' +# Extra prefix added to the log message when a non-absl logger logs a +# CRITICAL/FATAL message. +_CRITICAL_PREFIX = 'CRITICAL - ' + +# Used by findCaller to skip callers from */logging/__init__.py. +_LOGGING_FILE_PREFIX = os.path.join('logging', '__init__.') + +# The ABSL logger instance, initialized in _initialize(). +_absl_logger = None +# The ABSL handler instance, initialized in _initialize(). +_absl_handler = None + + +_CPP_NAME_TO_LEVELS = { + 'debug': '0', # Abseil C++ has no DEBUG level, mapping it to INFO here. + 'info': '0', + 'warning': '1', + 'warn': '1', + 'error': '2', + 'fatal': '3' +} + +_CPP_LEVEL_TO_NAMES = { + '0': 'info', + '1': 'warning', + '2': 'error', + '3': 'fatal', +} + + +class _VerbosityFlag(flags.Flag): + """Flag class for -v/--verbosity.""" + + def __init__(self, *args, **kwargs): + super(_VerbosityFlag, self).__init__( + flags.IntegerParser(), + flags.ArgumentSerializer(), + *args, **kwargs) + + @property + def value(self): + return self._value + + @value.setter + def value(self, v): + self._value = v + self._update_logging_levels() + + def _update_logging_levels(self): + """Updates absl logging levels to the current verbosity. + + Visibility: module-private + """ + if not _absl_logger: + return + + if self._value <= converter.ABSL_DEBUG: + standard_verbosity = converter.absl_to_standard(self._value) + else: + # --verbosity is set to higher than 1 for vlog. + standard_verbosity = logging.DEBUG - (self._value - 1) + + # Also update root level when absl_handler is used. + if _absl_handler in logging.root.handlers: + # Make absl logger inherit from the root logger. absl logger might have + # a non-NOTSET value if logging.set_verbosity() is called at import time. + _absl_logger.setLevel(logging.NOTSET) + logging.root.setLevel(standard_verbosity) + else: + _absl_logger.setLevel(standard_verbosity) + + +class _LoggerLevelsFlag(flags.Flag): + """Flag class for --logger_levels.""" + + def __init__(self, *args, **kwargs): + super(_LoggerLevelsFlag, self).__init__( + _LoggerLevelsParser(), + _LoggerLevelsSerializer(), + *args, **kwargs) + + @property + def value(self): + # For lack of an immutable type, be defensive and return a copy. + # Modifications to the dict aren't supported and won't have any affect. + # While Py3 could use MappingProxyType, that isn't deepcopy friendly, so + # just return a copy. + return self._value.copy() + + @value.setter + def value(self, v): + self._value = {} if v is None else v + self._update_logger_levels() + + def _update_logger_levels(self): + # Visibility: module-private. + # This is called by absl.app.run() during initialization. + for name, level in self._value.items(): + logging.getLogger(name).setLevel(level) + + +class _LoggerLevelsParser(flags.ArgumentParser): + """Parser for --logger_levels flag.""" + + def parse(self, value): + if isinstance(value, abc.Mapping): + return value + + pairs = [pair.strip() for pair in value.split(',') if pair.strip()] + + # Preserve the order so that serialization is deterministic. + levels = collections.OrderedDict() + for name_level in pairs: + name, level = name_level.split(':', 1) + name = name.strip() + level = level.strip() + levels[name] = level + return levels + + +class _LoggerLevelsSerializer(object): + """Serializer for --logger_levels flag.""" + + def serialize(self, value): + if isinstance(value, str): + return value + return ','.join( + '{}:{}'.format(name, level) for name, level in value.items()) + + +class _StderrthresholdFlag(flags.Flag): + """Flag class for --stderrthreshold.""" + + def __init__(self, *args, **kwargs): + super(_StderrthresholdFlag, self).__init__( + flags.ArgumentParser(), + flags.ArgumentSerializer(), + *args, **kwargs) + + @property + def value(self): + return self._value + + @value.setter + def value(self, v): + if v in _CPP_LEVEL_TO_NAMES: + # --stderrthreshold also accepts numeric strings whose values are + # Abseil C++ log levels. + cpp_value = int(v) + v = _CPP_LEVEL_TO_NAMES[v] # Normalize to strings. + elif v.lower() in _CPP_NAME_TO_LEVELS: + v = v.lower() + if v == 'warn': + v = 'warning' # Use 'warning' as the canonical name. + cpp_value = int(_CPP_NAME_TO_LEVELS[v]) + else: + raise ValueError( + '--stderrthreshold must be one of (case-insensitive) ' + "'debug', 'info', 'warning', 'error', 'fatal', " + "or '0', '1', '2', '3', not '%s'" % v) + + self._value = v + + +flags.DEFINE_boolean('logtostderr', + False, + 'Should only log to stderr?', allow_override_cpp=True) +flags.DEFINE_boolean('alsologtostderr', + False, + 'also log to stderr?', allow_override_cpp=True) +flags.DEFINE_string('log_dir', + os.getenv('TEST_TMPDIR', ''), + 'directory to write logfiles into', + allow_override_cpp=True) +flags.DEFINE_flag(_VerbosityFlag( + 'verbosity', -1, + 'Logging verbosity level. Messages logged at this level or lower will ' + 'be included. Set to 1 for debug logging. If the flag was not set or ' + 'supplied, the value will be changed from the default of -1 (warning) to ' + '0 (info) after flags are parsed.', + short_name='v', allow_hide_cpp=True)) +flags.DEFINE_flag( + _LoggerLevelsFlag( + 'logger_levels', {}, + 'Specify log level of loggers. The format is a CSV list of ' + '`name:level`. Where `name` is the logger name used with ' + '`logging.getLogger()`, and `level` is a level name (INFO, DEBUG, ' + 'etc). e.g. `myapp.foo:INFO,other.logger:DEBUG`')) +flags.DEFINE_flag(_StderrthresholdFlag( + 'stderrthreshold', 'fatal', + 'log messages at this level, or more severe, to stderr in ' + 'addition to the logfile. Possible values are ' + "'debug', 'info', 'warning', 'error', and 'fatal'. " + 'Obsoletes --alsologtostderr. Using --alsologtostderr ' + 'cancels the effect of this flag. Please also note that ' + 'this flag is subject to --verbosity and requires logfile ' + 'not be stderr.', allow_hide_cpp=True)) +flags.DEFINE_boolean('showprefixforinfo', True, + 'If False, do not prepend prefix to info messages ' + 'when it\'s logged to stderr, ' + '--verbosity is set to INFO level, ' + 'and python logging is used.') + + +def get_verbosity(): + """Returns the logging verbosity.""" + return FLAGS['verbosity'].value + + +def set_verbosity(v): + """Sets the logging verbosity. + + Causes all messages of level <= v to be logged, + and all messages of level > v to be silently discarded. + + Args: + v: int|str, the verbosity level as an integer or string. Legal string values + are those that can be coerced to an integer as well as case-insensitive + 'debug', 'info', 'warning', 'error', and 'fatal'. + """ + try: + new_level = int(v) + except ValueError: + new_level = converter.ABSL_NAMES[v.upper()] + FLAGS.verbosity = new_level + + +def set_stderrthreshold(s): + """Sets the stderr threshold to the value passed in. + + Args: + s: str|int, valid strings values are case-insensitive 'debug', + 'info', 'warning', 'error', and 'fatal'; valid integer values are + logging.DEBUG|INFO|WARNING|ERROR|FATAL. + + Raises: + ValueError: Raised when s is an invalid value. + """ + if s in converter.ABSL_LEVELS: + FLAGS.stderrthreshold = converter.ABSL_LEVELS[s] + elif isinstance(s, str) and s.upper() in converter.ABSL_NAMES: + FLAGS.stderrthreshold = s + else: + raise ValueError( + 'set_stderrthreshold only accepts integer absl logging level ' + 'from -3 to 1, or case-insensitive string values ' + "'debug', 'info', 'warning', 'error', and 'fatal'. " + 'But found "{}" ({}).'.format(s, type(s))) + + +def fatal(msg, *args, **kwargs): + # type: (Any, Any, Any) -> NoReturn + """Logs a fatal message.""" + log(FATAL, msg, *args, **kwargs) + + +def error(msg, *args, **kwargs): + """Logs an error message.""" + log(ERROR, msg, *args, **kwargs) + + +def warning(msg, *args, **kwargs): + """Logs a warning message.""" + log(WARNING, msg, *args, **kwargs) + + +def warn(msg, *args, **kwargs): + """Deprecated, use 'warning' instead.""" + warnings.warn("The 'warn' function is deprecated, use 'warning' instead", + DeprecationWarning, 2) + log(WARNING, msg, *args, **kwargs) + + +def info(msg, *args, **kwargs): + """Logs an info message.""" + log(INFO, msg, *args, **kwargs) + + +def debug(msg, *args, **kwargs): + """Logs a debug message.""" + log(DEBUG, msg, *args, **kwargs) + + +def exception(msg, *args): + """Logs an exception, with traceback and message.""" + error(msg, *args, exc_info=True) + + +# Counter to keep track of number of log entries per token. +_log_counter_per_token = {} + + +def _get_next_log_count_per_token(token): + """Wrapper for _log_counter_per_token. Thread-safe. + + Args: + token: The token for which to look up the count. + + Returns: + The number of times this function has been called with + *token* as an argument (starting at 0). + """ + # Can't use a defaultdict because defaultdict isn't atomic, whereas + # setdefault is. + return next(_log_counter_per_token.setdefault(token, itertools.count())) + + +def log_every_n(level, msg, n, *args): + """Logs 'msg % args' at level 'level' once per 'n' times. + + Logs the 1st call, (N+1)st call, (2N+1)st call, etc. + Not threadsafe. + + Args: + level: int, the absl logging level at which to log. + msg: str, the message to be logged. + n: int, the number of times this should be called before it is logged. + *args: The args to be substituted into the msg. + """ + count = _get_next_log_count_per_token(get_absl_logger().findCaller()) + log_if(level, msg, not (count % n), *args) + + +# Keeps track of the last log time of the given token. +# Note: must be a dict since set/get is atomic in CPython. +# Note: entries are never released as their number is expected to be low. +_log_timer_per_token = {} + + +def _seconds_have_elapsed(token, num_seconds): + """Tests if 'num_seconds' have passed since 'token' was requested. + + Not strictly thread-safe - may log with the wrong frequency if called + concurrently from multiple threads. Accuracy depends on resolution of + 'timeit.default_timer()'. + + Always returns True on the first call for a given 'token'. + + Args: + token: The token for which to look up the count. + num_seconds: The number of seconds to test for. + + Returns: + Whether it has been >= 'num_seconds' since 'token' was last requested. + """ + now = timeit.default_timer() + then = _log_timer_per_token.get(token, None) + if then is None or (now - then) >= num_seconds: + _log_timer_per_token[token] = now + return True + else: + return False + + +def log_every_n_seconds(level, msg, n_seconds, *args): + """Logs 'msg % args' at level 'level' iff 'n_seconds' elapsed since last call. + + Logs the first call, logs subsequent calls if 'n' seconds have elapsed since + the last logging call from the same call site (file + line). Not thread-safe. + + Args: + level: int, the absl logging level at which to log. + msg: str, the message to be logged. + n_seconds: float or int, seconds which should elapse before logging again. + *args: The args to be substituted into the msg. + """ + should_log = _seconds_have_elapsed(get_absl_logger().findCaller(), n_seconds) + log_if(level, msg, should_log, *args) + + +def log_first_n(level, msg, n, *args): + """Logs 'msg % args' at level 'level' only first 'n' times. + + Not threadsafe. + + Args: + level: int, the absl logging level at which to log. + msg: str, the message to be logged. + n: int, the maximal number of times the message is logged. + *args: The args to be substituted into the msg. + """ + count = _get_next_log_count_per_token(get_absl_logger().findCaller()) + log_if(level, msg, count < n, *args) + + +def log_if(level, msg, condition, *args): + """Logs 'msg % args' at level 'level' only if condition is fulfilled.""" + if condition: + log(level, msg, *args) + + +def log(level, msg, *args, **kwargs): + """Logs 'msg % args' at absl logging level 'level'. + + If no args are given just print msg, ignoring any interpolation specifiers. + + Args: + level: int, the absl logging level at which to log the message + (logging.DEBUG|INFO|WARNING|ERROR|FATAL). While some C++ verbose logging + level constants are also supported, callers should prefer explicit + logging.vlog() calls for such purpose. + + msg: str, the message to be logged. + *args: The args to be substituted into the msg. + **kwargs: May contain exc_info to add exception traceback to message. + """ + if level > converter.ABSL_DEBUG: + # Even though this function supports level that is greater than 1, users + # should use logging.vlog instead for such cases. + # Treat this as vlog, 1 is equivalent to DEBUG. + standard_level = converter.STANDARD_DEBUG - (level - 1) + else: + if level < converter.ABSL_FATAL: + level = converter.ABSL_FATAL + standard_level = converter.absl_to_standard(level) + + # Match standard logging's behavior. Before use_absl_handler() and + # logging is configured, there is no handler attached on _absl_logger nor + # logging.root. So logs go no where. + if not logging.root.handlers: + logging.basicConfig() + + _absl_logger.log(standard_level, msg, *args, **kwargs) + + +def vlog(level, msg, *args, **kwargs): + """Log 'msg % args' at C++ vlog level 'level'. + + Args: + level: int, the C++ verbose logging level at which to log the message, + e.g. 1, 2, 3, 4... While absl level constants are also supported, + callers should prefer logging.log|debug|info|... calls for such purpose. + msg: str, the message to be logged. + *args: The args to be substituted into the msg. + **kwargs: May contain exc_info to add exception traceback to message. + """ + log(level, msg, *args, **kwargs) + + +def vlog_is_on(level): + """Checks if vlog is enabled for the given level in caller's source file. + + Args: + level: int, the C++ verbose logging level at which to log the message, + e.g. 1, 2, 3, 4... While absl level constants are also supported, + callers should prefer level_debug|level_info|... calls for + checking those. + + Returns: + True if logging is turned on for that level. + """ + + if level > converter.ABSL_DEBUG: + # Even though this function supports level that is greater than 1, users + # should use logging.vlog instead for such cases. + # Treat this as vlog, 1 is equivalent to DEBUG. + standard_level = converter.STANDARD_DEBUG - (level - 1) + else: + if level < converter.ABSL_FATAL: + level = converter.ABSL_FATAL + standard_level = converter.absl_to_standard(level) + return _absl_logger.isEnabledFor(standard_level) + + +def flush(): + """Flushes all log files.""" + get_absl_handler().flush() + + +def level_debug(): + """Returns True if debug logging is turned on.""" + return get_verbosity() >= DEBUG + + +def level_info(): + """Returns True if info logging is turned on.""" + return get_verbosity() >= INFO + + +def level_warning(): + """Returns True if warning logging is turned on.""" + return get_verbosity() >= WARNING + + +level_warn = level_warning # Deprecated function. + + +def level_error(): + """Returns True if error logging is turned on.""" + return get_verbosity() >= ERROR + + +def get_log_file_name(level=INFO): + """Returns the name of the log file. + + For Python logging, only one file is used and level is ignored. And it returns + empty string if it logs to stderr/stdout or the log stream has no `name` + attribute. + + Args: + level: int, the absl.logging level. + + Raises: + ValueError: Raised when `level` has an invalid value. + """ + if level not in converter.ABSL_LEVELS: + raise ValueError('Invalid absl.logging level {}'.format(level)) + stream = get_absl_handler().python_handler.stream + if (stream == sys.stderr or stream == sys.stdout or + not hasattr(stream, 'name')): + return '' + else: + return stream.name + + +def find_log_dir_and_names(program_name=None, log_dir=None): + """Computes the directory and filename prefix for log file. + + Args: + program_name: str|None, the filename part of the path to the program that + is running without its extension. e.g: if your program is called + 'usr/bin/foobar.py' this method should probably be called with + program_name='foobar' However, this is just a convention, you can + pass in any string you want, and it will be used as part of the + log filename. If you don't pass in anything, the default behavior + is as described in the example. In python standard logging mode, + the program_name will be prepended with py_ if it is the program_name + argument is omitted. + log_dir: str|None, the desired log directory. + + Returns: + (log_dir, file_prefix, symlink_prefix) + + Raises: + FileNotFoundError: raised in Python 3 when it cannot find a log directory. + OSError: raised in Python 2 when it cannot find a log directory. + """ + if not program_name: + # Strip the extension (foobar.par becomes foobar, and + # fubar.py becomes fubar). We do this so that the log + # file names are similar to C++ log file names. + program_name = os.path.splitext(os.path.basename(sys.argv[0]))[0] + + # Prepend py_ to files so that python code gets a unique file, and + # so that C++ libraries do not try to write to the same log files as us. + program_name = 'py_%s' % program_name + + actual_log_dir = find_log_dir(log_dir=log_dir) + + try: + username = getpass.getuser() + except KeyError: + # This can happen, e.g. when running under docker w/o passwd file. + if hasattr(os, 'getuid'): + # Windows doesn't have os.getuid + username = str(os.getuid()) + else: + username = 'unknown' + hostname = socket.gethostname() + file_prefix = '%s.%s.%s.log' % (program_name, hostname, username) + + return actual_log_dir, file_prefix, program_name + + +def find_log_dir(log_dir=None): + """Returns the most suitable directory to put log files into. + + Args: + log_dir: str|None, if specified, the logfile(s) will be created in that + directory. Otherwise if the --log_dir command-line flag is provided, + the logfile will be created in that directory. Otherwise the logfile + will be created in a standard location. + + Raises: + FileNotFoundError: raised in Python 3 when it cannot find a log directory. + OSError: raised in Python 2 when it cannot find a log directory. + """ + # Get a list of possible log dirs (will try to use them in order). + if log_dir: + # log_dir was explicitly specified as an arg, so use it and it alone. + dirs = [log_dir] + elif FLAGS['log_dir'].value: + # log_dir flag was provided, so use it and it alone (this mimics the + # behavior of the same flag in logging.cc). + dirs = [FLAGS['log_dir'].value] + else: + dirs = ['/tmp/', './'] + + # Find the first usable log dir. + for d in dirs: + if os.path.isdir(d) and os.access(d, os.W_OK): + return d + raise FileNotFoundError( + "Can't find a writable directory for logs, tried %s" % dirs) + + +def get_absl_log_prefix(record): + """Returns the absl log prefix for the log record. + + Args: + record: logging.LogRecord, the record to get prefix for. + """ + created_tuple = time.localtime(record.created) + created_microsecond = int(record.created % 1.0 * 1e6) + + critical_prefix = '' + level = record.levelno + if _is_non_absl_fatal_record(record): + # When the level is FATAL, but not logged from absl, lower the level so + # it's treated as ERROR. + level = logging.ERROR + critical_prefix = _CRITICAL_PREFIX + severity = converter.get_initial_for_level(level) + + return '%c%02d%02d %02d:%02d:%02d.%06d %5d %s:%d] %s' % ( + severity, + created_tuple.tm_mon, + created_tuple.tm_mday, + created_tuple.tm_hour, + created_tuple.tm_min, + created_tuple.tm_sec, + created_microsecond, + _get_thread_id(), + record.filename, + record.lineno, + critical_prefix) + + +def skip_log_prefix(func): + """Skips reporting the prefix of a given function or name by ABSLLogger. + + This is a convenience wrapper function / decorator for + `ABSLLogger.register_frame_to_skip`. + + If a callable function is provided, only that function will be skipped. + If a function name is provided, all functions with the same name in the + file that this is called in will be skipped. + + This can be used as a decorator of the intended function to be skipped. + + Args: + func: Callable function or its name as a string. + + Returns: + func (the input, unchanged). + + Raises: + ValueError: The input is callable but does not have a function code object. + TypeError: The input is neither callable nor a string. + """ + if callable(func): + func_code = getattr(func, '__code__', None) + if func_code is None: + raise ValueError('Input callable does not have a function code object.') + file_name = func_code.co_filename + func_name = func_code.co_name + func_lineno = func_code.co_firstlineno + elif isinstance(func, str): + file_name = get_absl_logger().findCaller()[0] + func_name = func + func_lineno = None + else: + raise TypeError('Input is neither callable nor a string.') + ABSLLogger.register_frame_to_skip(file_name, func_name, func_lineno) + return func + + +def _is_non_absl_fatal_record(log_record): + return (log_record.levelno >= logging.FATAL and + not log_record.__dict__.get(_ABSL_LOG_FATAL, False)) + + +def _is_absl_fatal_record(log_record): + return (log_record.levelno >= logging.FATAL and + log_record.__dict__.get(_ABSL_LOG_FATAL, False)) + + +# Indicates if we still need to warn about pre-init logs going to stderr. +_warn_preinit_stderr = True + + +class PythonHandler(logging.StreamHandler): + """The handler class used by Abseil Python logging implementation.""" + + def __init__(self, stream=None, formatter=None): + super(PythonHandler, self).__init__(stream) + self.setFormatter(formatter or PythonFormatter()) + + def start_logging_to_file(self, program_name=None, log_dir=None): + """Starts logging messages to files instead of standard error.""" + FLAGS.logtostderr = False + + actual_log_dir, file_prefix, symlink_prefix = find_log_dir_and_names( + program_name=program_name, log_dir=log_dir) + + basename = '%s.INFO.%s.%d' % ( + file_prefix, + time.strftime('%Y%m%d-%H%M%S', time.localtime(time.time())), + os.getpid()) + filename = os.path.join(actual_log_dir, basename) + + self.stream = open(filename, 'a', encoding='utf-8') + + # os.symlink is not available on Windows Python 2. + if getattr(os, 'symlink', None): + # Create a symlink to the log file with a canonical name. + symlink = os.path.join(actual_log_dir, symlink_prefix + '.INFO') + try: + if os.path.islink(symlink): + os.unlink(symlink) + os.symlink(os.path.basename(filename), symlink) + except EnvironmentError: + # If it fails, we're sad but it's no error. Commonly, this + # fails because the symlink was created by another user and so + # we can't modify it + pass + + def use_absl_log_file(self, program_name=None, log_dir=None): + """Conditionally logs to files, based on --logtostderr.""" + if FLAGS['logtostderr'].value: + self.stream = sys.stderr + else: + self.start_logging_to_file(program_name=program_name, log_dir=log_dir) + + def flush(self): + """Flushes all log files.""" + self.acquire() + try: + self.stream.flush() + except (EnvironmentError, ValueError): + # A ValueError is thrown if we try to flush a closed file. + pass + finally: + self.release() + + def _log_to_stderr(self, record): + """Emits the record to stderr. + + This temporarily sets the handler stream to stderr, calls + StreamHandler.emit, then reverts the stream back. + + Args: + record: logging.LogRecord, the record to log. + """ + # emit() is protected by a lock in logging.Handler, so we don't need to + # protect here again. + old_stream = self.stream + self.stream = sys.stderr + try: + super(PythonHandler, self).emit(record) + finally: + self.stream = old_stream + + def emit(self, record): + """Prints a record out to some streams. + + If FLAGS.logtostderr is set, it will print to sys.stderr ONLY. + If FLAGS.alsologtostderr is set, it will print to sys.stderr. + If FLAGS.logtostderr is not set, it will log to the stream + associated with the current thread. + + Args: + record: logging.LogRecord, the record to emit. + """ + # People occasionally call logging functions at import time before + # our flags may have even been defined yet, let alone even parsed, as we + # rely on the C++ side to define some flags for us and app init to + # deal with parsing. Match the C++ library behavior of notify and emit + # such messages to stderr. It encourages people to clean-up and does + # not hide the message. + level = record.levelno + if not FLAGS.is_parsed(): # Also implies "before flag has been defined". + global _warn_preinit_stderr + if _warn_preinit_stderr: + sys.stderr.write( + 'WARNING: Logging before flag parsing goes to stderr.\n') + _warn_preinit_stderr = False + self._log_to_stderr(record) + elif FLAGS['logtostderr'].value: + self._log_to_stderr(record) + else: + super(PythonHandler, self).emit(record) + stderr_threshold = converter.string_to_standard( + FLAGS['stderrthreshold'].value) + if ((FLAGS['alsologtostderr'].value or level >= stderr_threshold) and + self.stream != sys.stderr): + self._log_to_stderr(record) + # Die when the record is created from ABSLLogger and level is FATAL. + if _is_absl_fatal_record(record): + self.flush() # Flush the log before dying. + + # In threaded python, sys.exit() from a non-main thread only + # exits the thread in question. + os.abort() + + def close(self): + """Closes the stream to which we are writing.""" + self.acquire() + try: + self.flush() + try: + # Do not close the stream if it's sys.stderr|stdout. They may be + # redirected or overridden to files, which should be managed by users + # explicitly. + user_managed = sys.stderr, sys.stdout, sys.__stderr__, sys.__stdout__ + if self.stream not in user_managed and ( + not hasattr(self.stream, 'isatty') or not self.stream.isatty()): + self.stream.close() + except ValueError: + # A ValueError is thrown if we try to run isatty() on a closed file. + pass + super(PythonHandler, self).close() + finally: + self.release() + + +class ABSLHandler(logging.Handler): + """Abseil Python logging module's log handler.""" + + def __init__(self, python_logging_formatter): + super(ABSLHandler, self).__init__() + + self._python_handler = PythonHandler(formatter=python_logging_formatter) + self.activate_python_handler() + + def format(self, record): + return self._current_handler.format(record) + + def setFormatter(self, fmt): + self._current_handler.setFormatter(fmt) + + def emit(self, record): + self._current_handler.emit(record) + + def flush(self): + self._current_handler.flush() + + def close(self): + super(ABSLHandler, self).close() + self._current_handler.close() + + def handle(self, record): + rv = self.filter(record) + if rv: + return self._current_handler.handle(record) + return rv + + @property + def python_handler(self): + return self._python_handler + + def activate_python_handler(self): + """Uses the Python logging handler as the current logging handler.""" + self._current_handler = self._python_handler + + def use_absl_log_file(self, program_name=None, log_dir=None): + self._current_handler.use_absl_log_file(program_name, log_dir) + + def start_logging_to_file(self, program_name=None, log_dir=None): + self._current_handler.start_logging_to_file(program_name, log_dir) + + +class PythonFormatter(logging.Formatter): + """Formatter class used by PythonHandler.""" + + def format(self, record): + """Appends the message from the record to the results of the prefix. + + Args: + record: logging.LogRecord, the record to be formatted. + + Returns: + The formatted string representing the record. + """ + if (not FLAGS['showprefixforinfo'].value and + FLAGS['verbosity'].value == converter.ABSL_INFO and + record.levelno == logging.INFO and + _absl_handler.python_handler.stream == sys.stderr): + prefix = '' + else: + prefix = get_absl_log_prefix(record) + return prefix + super(PythonFormatter, self).format(record) + + +class ABSLLogger(logging.getLoggerClass()): + """A logger that will create LogRecords while skipping some stack frames. + + This class maintains an internal list of filenames and method names + for use when determining who called the currently executing stack + frame. Any method names from specific source files are skipped when + walking backwards through the stack. + + Client code should use the register_frame_to_skip method to let the + ABSLLogger know which method from which file should be + excluded from the walk backwards through the stack. + """ + _frames_to_skip = set() + + def findCaller(self, stack_info=False, stacklevel=1): + """Finds the frame of the calling method on the stack. + + This method skips any frames registered with the + ABSLLogger and any methods from this file, and whatever + method is currently being used to generate the prefix for the log + line. Then it returns the file name, line number, and method name + of the calling method. An optional fourth item may be returned, + callers who only need things from the first three are advised to + always slice or index the result rather than using direct unpacking + assignment. + + Args: + stack_info: bool, when True, include the stack trace as a fourth item + returned. On Python 3 there are always four items returned - the + fourth will be None when this is False. On Python 2 the stdlib + base class API only returns three items. We do the same when this + new parameter is unspecified or False for compatibility. + + Returns: + (filename, lineno, methodname[, sinfo]) of the calling method. + """ + f_to_skip = ABSLLogger._frames_to_skip + # Use sys._getframe(2) instead of logging.currentframe(), it's slightly + # faster because there is one less frame to traverse. + frame = sys._getframe(2) # pylint: disable=protected-access + + while frame: + code = frame.f_code + if (_LOGGING_FILE_PREFIX not in code.co_filename and + (code.co_filename, code.co_name, + code.co_firstlineno) not in f_to_skip and + (code.co_filename, code.co_name) not in f_to_skip): + sinfo = None + if stack_info: + out = io.StringIO() + out.write(u'Stack (most recent call last):\n') + traceback.print_stack(frame, file=out) + sinfo = out.getvalue().rstrip(u'\n') + return (code.co_filename, frame.f_lineno, code.co_name, sinfo) + frame = frame.f_back + + def critical(self, msg, *args, **kwargs): + """Logs 'msg % args' with severity 'CRITICAL'.""" + self.log(logging.CRITICAL, msg, *args, **kwargs) + + def fatal(self, msg, *args, **kwargs): + """Logs 'msg % args' with severity 'FATAL'.""" + self.log(logging.FATAL, msg, *args, **kwargs) + + def error(self, msg, *args, **kwargs): + """Logs 'msg % args' with severity 'ERROR'.""" + self.log(logging.ERROR, msg, *args, **kwargs) + + def warn(self, msg, *args, **kwargs): + """Logs 'msg % args' with severity 'WARN'.""" + warnings.warn("The 'warn' method is deprecated, use 'warning' instead", + DeprecationWarning, 2) + self.log(logging.WARN, msg, *args, **kwargs) + + def warning(self, msg, *args, **kwargs): + """Logs 'msg % args' with severity 'WARNING'.""" + self.log(logging.WARNING, msg, *args, **kwargs) + + def info(self, msg, *args, **kwargs): + """Logs 'msg % args' with severity 'INFO'.""" + self.log(logging.INFO, msg, *args, **kwargs) + + def debug(self, msg, *args, **kwargs): + """Logs 'msg % args' with severity 'DEBUG'.""" + self.log(logging.DEBUG, msg, *args, **kwargs) + + def log(self, level, msg, *args, **kwargs): + """Logs a message at a cetain level substituting in the supplied arguments. + + This method behaves differently in python and c++ modes. + + Args: + level: int, the standard logging level at which to log the message. + msg: str, the text of the message to log. + *args: The arguments to substitute in the message. + **kwargs: The keyword arguments to substitute in the message. + """ + if level >= logging.FATAL: + # Add property to the LogRecord created by this logger. + # This will be used by the ABSLHandler to determine whether it should + # treat CRITICAL/FATAL logs as really FATAL. + extra = kwargs.setdefault('extra', {}) + extra[_ABSL_LOG_FATAL] = True + super(ABSLLogger, self).log(level, msg, *args, **kwargs) + + def handle(self, record): + """Calls handlers without checking Logger.disabled. + + Non-root loggers are set to disabled after setup with logging.config if + it's not explicitly specified. Historically, absl logging will not be + disabled by that. To maintaining this behavior, this function skips + checking the Logger.disabled bit. + + This logger can still be disabled by adding a filter that filters out + everything. + + Args: + record: logging.LogRecord, the record to handle. + """ + if self.filter(record): + self.callHandlers(record) + + @classmethod + def register_frame_to_skip(cls, file_name, function_name, line_number=None): + """Registers a function name to skip when walking the stack. + + The ABSLLogger sometimes skips method calls on the stack + to make the log messages meaningful in their appropriate context. + This method registers a function from a particular file as one + which should be skipped. + + Args: + file_name: str, the name of the file that contains the function. + function_name: str, the name of the function to skip. + line_number: int, if provided, only the function with this starting line + number will be skipped. Otherwise, all functions with the same name + in the file will be skipped. + """ + if line_number is not None: + cls._frames_to_skip.add((file_name, function_name, line_number)) + else: + cls._frames_to_skip.add((file_name, function_name)) + + +def _get_thread_id(): + """Gets id of current thread, suitable for logging as an unsigned quantity. + + If pywrapbase is linked, returns GetTID() for the thread ID to be + consistent with C++ logging. Otherwise, returns the numeric thread id. + The quantities are made unsigned by masking with 2*sys.maxint + 1. + + Returns: + Thread ID unique to this process (unsigned) + """ + thread_id = threading.get_ident() + return thread_id & _THREAD_ID_MASK + + +def get_absl_logger(): + """Returns the absl logger instance.""" + return _absl_logger + + +def get_absl_handler(): + """Returns the absl handler instance.""" + return _absl_handler + + +def use_python_logging(quiet=False): + """Uses the python implementation of the logging code. + + Args: + quiet: No logging message about switching logging type. + """ + get_absl_handler().activate_python_handler() + if not quiet: + info('Restoring pure python logging') + + +_attempted_to_remove_stderr_stream_handlers = False + + +def use_absl_handler(): + """Uses the ABSL logging handler for logging. + + This method is called in app.run() so the absl handler is used in absl apps. + """ + global _attempted_to_remove_stderr_stream_handlers + if not _attempted_to_remove_stderr_stream_handlers: + # The absl handler logs to stderr by default. To prevent double logging to + # stderr, the following code tries its best to remove other handlers that + # emit to stderr. Those handlers are most commonly added when + # logging.info/debug is called before calling use_absl_handler(). + handlers = [ + h for h in logging.root.handlers + if isinstance(h, logging.StreamHandler) and h.stream == sys.stderr] + for h in handlers: + logging.root.removeHandler(h) + _attempted_to_remove_stderr_stream_handlers = True + + absl_handler = get_absl_handler() + if absl_handler not in logging.root.handlers: + logging.root.addHandler(absl_handler) + FLAGS['verbosity']._update_logging_levels() # pylint: disable=protected-access + FLAGS['logger_levels']._update_logger_levels() # pylint: disable=protected-access + + +def _initialize(): + """Initializes loggers and handlers.""" + global _absl_logger, _absl_handler + + if _absl_logger: + return + + original_logger_class = logging.getLoggerClass() + logging.setLoggerClass(ABSLLogger) + _absl_logger = logging.getLogger('absl') + logging.setLoggerClass(original_logger_class) + + python_logging_formatter = PythonFormatter() + _absl_handler = ABSLHandler(python_logging_formatter) + + +_initialize() diff --git a/absl/logging/converter.py b/absl/logging/converter.py new file mode 100644 index 0000000..53dd46d --- /dev/null +++ b/absl/logging/converter.py @@ -0,0 +1,211 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module to convert log levels between Abseil Python, C++, and Python standard. + +This converter has to convert (best effort) between three different +logging level schemes: + cpp = The C++ logging level scheme used in Abseil C++. + absl = The absl.logging level scheme used in Abseil Python. + standard = The python standard library logging level scheme. + +Here is a handy ascii chart for easy mental mapping. + + LEVEL | cpp | absl | standard | + ---------+-----+--------+----------+ + DEBUG | 0 | 1 | 10 | + INFO | 0 | 0 | 20 | + WARNING | 1 | -1 | 30 | + ERROR | 2 | -2 | 40 | + CRITICAL | 3 | -3 | 50 | + FATAL | 3 | -3 | 50 | + +Note: standard logging CRITICAL is mapped to absl/cpp FATAL. +However, only CRITICAL logs from the absl logger (or absl.logging.fatal) will +terminate the program. CRITICAL logs from non-absl loggers are treated as +error logs with a message prefix "CRITICAL - ". + +Converting from standard to absl or cpp is a lossy conversion. +Converting back to standard will lose granularity. For this reason, +users should always try to convert to standard, the richest +representation, before manipulating the levels, and then only to cpp +or absl if those level schemes are absolutely necessary. +""" + +import logging + +STANDARD_CRITICAL = logging.CRITICAL +STANDARD_ERROR = logging.ERROR +STANDARD_WARNING = logging.WARNING +STANDARD_INFO = logging.INFO +STANDARD_DEBUG = logging.DEBUG + +# These levels are also used to define the constants +# FATAL, ERROR, WARNING, INFO, and DEBUG in the +# absl.logging module. +ABSL_FATAL = -3 +ABSL_ERROR = -2 +ABSL_WARNING = -1 +ABSL_WARN = -1 # Deprecated name. +ABSL_INFO = 0 +ABSL_DEBUG = 1 + +ABSL_LEVELS = {ABSL_FATAL: 'FATAL', + ABSL_ERROR: 'ERROR', + ABSL_WARNING: 'WARNING', + ABSL_INFO: 'INFO', + ABSL_DEBUG: 'DEBUG'} + +# Inverts the ABSL_LEVELS dictionary +ABSL_NAMES = {'FATAL': ABSL_FATAL, + 'ERROR': ABSL_ERROR, + 'WARNING': ABSL_WARNING, + 'WARN': ABSL_WARNING, # Deprecated name. + 'INFO': ABSL_INFO, + 'DEBUG': ABSL_DEBUG} + +ABSL_TO_STANDARD = {ABSL_FATAL: STANDARD_CRITICAL, + ABSL_ERROR: STANDARD_ERROR, + ABSL_WARNING: STANDARD_WARNING, + ABSL_INFO: STANDARD_INFO, + ABSL_DEBUG: STANDARD_DEBUG} + +# Inverts the ABSL_TO_STANDARD +STANDARD_TO_ABSL = dict((v, k) for (k, v) in ABSL_TO_STANDARD.items()) + + +def get_initial_for_level(level): + """Gets the initial that should start the log line for the given level. + + It returns: + - 'I' when: level < STANDARD_WARNING. + - 'W' when: STANDARD_WARNING <= level < STANDARD_ERROR. + - 'E' when: STANDARD_ERROR <= level < STANDARD_CRITICAL. + - 'F' when: level >= STANDARD_CRITICAL. + + Args: + level: int, a Python standard logging level. + + Returns: + The first initial as it would be logged by the C++ logging module. + """ + if level < STANDARD_WARNING: + return 'I' + elif level < STANDARD_ERROR: + return 'W' + elif level < STANDARD_CRITICAL: + return 'E' + else: + return 'F' + + +def absl_to_cpp(level): + """Converts an absl log level to a cpp log level. + + Args: + level: int, an absl.logging level. + + Raises: + TypeError: Raised when level is not an integer. + + Returns: + The corresponding integer level for use in Abseil C++. + """ + if not isinstance(level, int): + raise TypeError('Expect an int level, found {}'.format(type(level))) + if level >= 0: + # C++ log levels must be >= 0 + return 0 + else: + return -level + + +def absl_to_standard(level): + """Converts an integer level from the absl value to the standard value. + + Args: + level: int, an absl.logging level. + + Raises: + TypeError: Raised when level is not an integer. + + Returns: + The corresponding integer level for use in standard logging. + """ + if not isinstance(level, int): + raise TypeError('Expect an int level, found {}'.format(type(level))) + if level < ABSL_FATAL: + level = ABSL_FATAL + if level <= ABSL_DEBUG: + return ABSL_TO_STANDARD[level] + # Maps to vlog levels. + return STANDARD_DEBUG - level + 1 + + +def string_to_standard(level): + """Converts a string level to standard logging level value. + + Args: + level: str, case-insensitive 'debug', 'info', 'warning', 'error', 'fatal'. + + Returns: + The corresponding integer level for use in standard logging. + """ + return absl_to_standard(ABSL_NAMES.get(level.upper())) + + +def standard_to_absl(level): + """Converts an integer level from the standard value to the absl value. + + Args: + level: int, a Python standard logging level. + + Raises: + TypeError: Raised when level is not an integer. + + Returns: + The corresponding integer level for use in absl logging. + """ + if not isinstance(level, int): + raise TypeError('Expect an int level, found {}'.format(type(level))) + if level < 0: + level = 0 + if level < STANDARD_DEBUG: + # Maps to vlog levels. + return STANDARD_DEBUG - level + 1 + elif level < STANDARD_INFO: + return ABSL_DEBUG + elif level < STANDARD_WARNING: + return ABSL_INFO + elif level < STANDARD_ERROR: + return ABSL_WARNING + elif level < STANDARD_CRITICAL: + return ABSL_ERROR + else: + return ABSL_FATAL + + +def standard_to_cpp(level): + """Converts an integer level from the standard value to the cpp value. + + Args: + level: int, a Python standard logging level. + + Raises: + TypeError: Raised when level is not an integer. + + Returns: + The corresponding integer level for use in cpp logging. + """ + return absl_to_cpp(standard_to_absl(level)) diff --git a/absl/logging/tests/__init__.py b/absl/logging/tests/__init__.py new file mode 100644 index 0000000..a3bd1cd --- /dev/null +++ b/absl/logging/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/absl/logging/tests/converter_test.py b/absl/logging/tests/converter_test.py new file mode 100644 index 0000000..bdc893a --- /dev/null +++ b/absl/logging/tests/converter_test.py @@ -0,0 +1,135 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for converter.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging + +from absl import logging as absl_logging +from absl.logging import converter +from absl.testing import absltest + + +class ConverterTest(absltest.TestCase): + """Tests the converter module.""" + + def test_absl_to_cpp(self): + self.assertEqual(0, converter.absl_to_cpp(absl_logging.DEBUG)) + self.assertEqual(0, converter.absl_to_cpp(absl_logging.INFO)) + self.assertEqual(1, converter.absl_to_cpp(absl_logging.WARN)) + self.assertEqual(2, converter.absl_to_cpp(absl_logging.ERROR)) + self.assertEqual(3, converter.absl_to_cpp(absl_logging.FATAL)) + + with self.assertRaises(TypeError): + converter.absl_to_cpp('') + + def test_absl_to_standard(self): + self.assertEqual( + logging.DEBUG, converter.absl_to_standard(absl_logging.DEBUG)) + self.assertEqual( + logging.INFO, converter.absl_to_standard(absl_logging.INFO)) + self.assertEqual( + logging.WARNING, converter.absl_to_standard(absl_logging.WARN)) + self.assertEqual( + logging.WARN, converter.absl_to_standard(absl_logging.WARN)) + self.assertEqual( + logging.ERROR, converter.absl_to_standard(absl_logging.ERROR)) + self.assertEqual( + logging.FATAL, converter.absl_to_standard(absl_logging.FATAL)) + self.assertEqual( + logging.CRITICAL, converter.absl_to_standard(absl_logging.FATAL)) + # vlog levels. + self.assertEqual(9, converter.absl_to_standard(2)) + self.assertEqual(8, converter.absl_to_standard(3)) + + with self.assertRaises(TypeError): + converter.absl_to_standard('') + + def test_standard_to_absl(self): + self.assertEqual( + absl_logging.DEBUG, converter.standard_to_absl(logging.DEBUG)) + self.assertEqual( + absl_logging.INFO, converter.standard_to_absl(logging.INFO)) + self.assertEqual( + absl_logging.WARN, converter.standard_to_absl(logging.WARN)) + self.assertEqual( + absl_logging.WARN, converter.standard_to_absl(logging.WARNING)) + self.assertEqual( + absl_logging.ERROR, converter.standard_to_absl(logging.ERROR)) + self.assertEqual( + absl_logging.FATAL, converter.standard_to_absl(logging.FATAL)) + self.assertEqual( + absl_logging.FATAL, converter.standard_to_absl(logging.CRITICAL)) + # vlog levels. + self.assertEqual(2, converter.standard_to_absl(logging.DEBUG - 1)) + self.assertEqual(3, converter.standard_to_absl(logging.DEBUG - 2)) + + with self.assertRaises(TypeError): + converter.standard_to_absl('') + + def test_standard_to_cpp(self): + self.assertEqual(0, converter.standard_to_cpp(logging.DEBUG)) + self.assertEqual(0, converter.standard_to_cpp(logging.INFO)) + self.assertEqual(1, converter.standard_to_cpp(logging.WARN)) + self.assertEqual(1, converter.standard_to_cpp(logging.WARNING)) + self.assertEqual(2, converter.standard_to_cpp(logging.ERROR)) + self.assertEqual(3, converter.standard_to_cpp(logging.FATAL)) + self.assertEqual(3, converter.standard_to_cpp(logging.CRITICAL)) + + with self.assertRaises(TypeError): + converter.standard_to_cpp('') + + def test_get_initial_for_level(self): + self.assertEqual('F', converter.get_initial_for_level(logging.CRITICAL)) + self.assertEqual('E', converter.get_initial_for_level(logging.ERROR)) + self.assertEqual('W', converter.get_initial_for_level(logging.WARNING)) + self.assertEqual('I', converter.get_initial_for_level(logging.INFO)) + self.assertEqual('I', converter.get_initial_for_level(logging.DEBUG)) + self.assertEqual('I', converter.get_initial_for_level(logging.NOTSET)) + + self.assertEqual('F', converter.get_initial_for_level(51)) + self.assertEqual('E', converter.get_initial_for_level(49)) + self.assertEqual('E', converter.get_initial_for_level(41)) + self.assertEqual('W', converter.get_initial_for_level(39)) + self.assertEqual('W', converter.get_initial_for_level(31)) + self.assertEqual('I', converter.get_initial_for_level(29)) + self.assertEqual('I', converter.get_initial_for_level(21)) + self.assertEqual('I', converter.get_initial_for_level(19)) + self.assertEqual('I', converter.get_initial_for_level(11)) + self.assertEqual('I', converter.get_initial_for_level(9)) + self.assertEqual('I', converter.get_initial_for_level(1)) + self.assertEqual('I', converter.get_initial_for_level(-1)) + + def test_string_to_standard(self): + self.assertEqual(logging.DEBUG, converter.string_to_standard('debug')) + self.assertEqual(logging.INFO, converter.string_to_standard('info')) + self.assertEqual(logging.WARNING, converter.string_to_standard('warn')) + self.assertEqual(logging.WARNING, converter.string_to_standard('warning')) + self.assertEqual(logging.ERROR, converter.string_to_standard('error')) + self.assertEqual(logging.CRITICAL, converter.string_to_standard('fatal')) + + self.assertEqual(logging.DEBUG, converter.string_to_standard('DEBUG')) + self.assertEqual(logging.INFO, converter.string_to_standard('INFO')) + self.assertEqual(logging.WARNING, converter.string_to_standard('WARN')) + self.assertEqual(logging.WARNING, converter.string_to_standard('WARNING')) + self.assertEqual(logging.ERROR, converter.string_to_standard('ERROR')) + self.assertEqual(logging.CRITICAL, converter.string_to_standard('FATAL')) + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/logging/tests/log_before_import_test.py b/absl/logging/tests/log_before_import_test.py new file mode 100644 index 0000000..903af16 --- /dev/null +++ b/absl/logging/tests/log_before_import_test.py @@ -0,0 +1,127 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test of logging behavior before app.run(), aka flag and logging init().""" + +import contextlib +import io +import os +import re +import sys +import tempfile +from unittest import mock + +from absl import logging +from absl.testing import absltest + +logging.get_verbosity() # Access --verbosity before flag parsing. +# Access --logtostderr before flag parsing. +logging.get_absl_handler().use_absl_log_file() + + +class Error(Exception): + pass + + +@contextlib.contextmanager +def captured_stderr_filename(): + """Captures stderr and writes them to a temporary file. + + This uses os.dup/os.dup2 to redirect the stderr fd for capturing standard + error of logging at import-time. We cannot mock sys.stderr because on the + first log call, a default log handler writing to the mock sys.stderr is + registered, and it will never be removed and subsequent logs go to the mock + in addition to the real stder. + + Yields: + The filename of captured stderr. + """ + stderr_capture_file_fd, stderr_capture_file_name = tempfile.mkstemp() + original_stderr_fd = os.dup(sys.stderr.fileno()) + os.dup2(stderr_capture_file_fd, sys.stderr.fileno()) + try: + yield stderr_capture_file_name + finally: + os.close(stderr_capture_file_fd) + os.dup2(original_stderr_fd, sys.stderr.fileno()) + + +# Pre-initialization (aka "import" / __main__ time) test. +with captured_stderr_filename() as before_set_verbosity_filename: + # Warnings and above go to stderr. + logging.debug('Debug message at parse time.') + logging.info('Info message at parse time.') + logging.error('Error message at parse time.') + logging.warning('Warning message at parse time.') + try: + raise Error('Exception reason.') + except Error: + logging.exception('Exception message at parse time.') + + +logging.set_verbosity(logging.ERROR) +with captured_stderr_filename() as after_set_verbosity_filename: + # Verbosity is set to ERROR, errors and above go to stderr. + logging.debug('Debug message at parse time.') + logging.info('Info message at parse time.') + logging.warning('Warning message at parse time.') + logging.error('Error message at parse time.') + + +class LoggingInitWarningTest(absltest.TestCase): + + def test_captured_pre_init_warnings(self): + with open(before_set_verbosity_filename) as stderr_capture_file: + captured_stderr = stderr_capture_file.read() + self.assertNotIn('Debug message at parse time.', captured_stderr) + self.assertNotIn('Info message at parse time.', captured_stderr) + + traceback_re = re.compile( + r'\nTraceback \(most recent call last\):.*?Error: Exception reason.', + re.MULTILINE | re.DOTALL) + if not traceback_re.search(captured_stderr): + self.fail( + 'Cannot find traceback message from logging.exception ' + 'in stderr:\n{}'.format(captured_stderr)) + # Remove the traceback so the rest of the stderr is deterministic. + captured_stderr = traceback_re.sub('', captured_stderr) + captured_stderr_lines = captured_stderr.splitlines() + self.assertLen(captured_stderr_lines, 3) + self.assertIn('Error message at parse time.', captured_stderr_lines[0]) + self.assertIn('Warning message at parse time.', captured_stderr_lines[1]) + self.assertIn('Exception message at parse time.', captured_stderr_lines[2]) + + def test_set_verbosity_pre_init(self): + with open(after_set_verbosity_filename) as stderr_capture_file: + captured_stderr = stderr_capture_file.read() + captured_stderr_lines = captured_stderr.splitlines() + + self.assertNotIn('Debug message at parse time.', captured_stderr) + self.assertNotIn('Info message at parse time.', captured_stderr) + self.assertNotIn('Warning message at parse time.', captured_stderr) + self.assertLen(captured_stderr_lines, 1) + self.assertIn('Error message at parse time.', captured_stderr_lines[0]) + + def test_no_more_warnings(self): + fake_stderr_type = io.BytesIO if bytes is str else io.StringIO + with mock.patch('sys.stderr', new=fake_stderr_type()) as mock_stderr: + self.assertMultiLineEqual('', mock_stderr.getvalue()) + logging.warning('Hello. hello. hello. Is there anybody out there?') + self.assertNotIn('Logging before flag parsing goes to stderr', + mock_stderr.getvalue()) + logging.info('A major purpose of this executable is merely not to crash.') + + +if __name__ == '__main__': + absltest.main() # This calls the app.run() init equivalent. diff --git a/absl/logging/tests/logging_functional_test.py b/absl/logging/tests/logging_functional_test.py new file mode 100644 index 0000000..98a2fab --- /dev/null +++ b/absl/logging/tests/logging_functional_test.py @@ -0,0 +1,732 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functional tests for absl.logging.""" + +import fnmatch +import os +import re +import shutil +import subprocess +import sys +import tempfile + +from absl import logging +from absl.testing import _bazelize_command +from absl.testing import absltest +from absl.testing import parameterized + + +_PY_VLOG3_LOG_MESSAGE = """\ +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:62] This line is VLOG level 3 +""" + +_PY_VLOG2_LOG_MESSAGE = """\ +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:64] This line is VLOG level 2 +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:64] This line is log level 2 +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:64] VLOG level 1, but only if VLOG level 2 is active +""" + +# VLOG1 is the same as DEBUG logs. +_PY_DEBUG_LOG_MESSAGE = """\ +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is VLOG level 1 +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is log level 1 +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:66] This line is DEBUG +""" + +_PY_INFO_LOG_MESSAGE = """\ +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is VLOG level 0 +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is log level 0 +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:70] Interesting Stuff\0 +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:71] Interesting Stuff with Arguments: 42 +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:73] Interesting Stuff with Dictionary +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:123] This should appear 5 times. +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:123] This should appear 5 times. +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:123] This should appear 5 times. +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:123] This should appear 5 times. +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:123] This should appear 5 times. +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:76] Info first 1 of 2 +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:77] Info 1 (every 3) +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:76] Info first 2 of 2 +I1231 23:59:59.000000 12345 logging_functional_test_helper.py:77] Info 4 (every 3) +""" + +_PY_INFO_LOG_MESSAGE_NOPREFIX = """\ +This line is VLOG level 0 +This line is log level 0 +Interesting Stuff\0 +Interesting Stuff with Arguments: 42 +Interesting Stuff with Dictionary +This should appear 5 times. +This should appear 5 times. +This should appear 5 times. +This should appear 5 times. +This should appear 5 times. +Info first 1 of 2 +Info 1 (every 3) +Info first 2 of 2 +Info 4 (every 3) +""" + +_PY_WARNING_LOG_MESSAGE = """\ +W1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is VLOG level -1 +W1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is log level -1 +W1231 23:59:59.000000 12345 logging_functional_test_helper.py:79] Worrying Stuff +W0000 23:59:59.000000 12345 logging_functional_test_helper.py:81] Warn first 1 of 2 +W0000 23:59:59.000000 12345 logging_functional_test_helper.py:82] Warn 1 (every 3) +W0000 23:59:59.000000 12345 logging_functional_test_helper.py:81] Warn first 2 of 2 +W0000 23:59:59.000000 12345 logging_functional_test_helper.py:82] Warn 4 (every 3) +""" + +if sys.version_info[0:2] == (3, 4): + _FAKE_ERROR_EXTRA_MESSAGE = """\ +Traceback (most recent call last): + File "logging_functional_test_helper.py", line 456, in _test_do_logging + raise OSError('Fake Error') +""" +else: + _FAKE_ERROR_EXTRA_MESSAGE = '' + +_PY_ERROR_LOG_MESSAGE = """\ +E1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is VLOG level -2 +E1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is log level -2 +E1231 23:59:59.000000 12345 logging_functional_test_helper.py:87] An Exception %s +Traceback (most recent call last): + File "logging_functional_test_helper.py", line 456, in _test_do_logging + raise OSError('Fake Error') +OSError: Fake Error +E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] Once more, just because +Traceback (most recent call last): + File "./logging_functional_test_helper.py", line 78, in _test_do_logging + raise OSError('Fake Error') +OSError: Fake Error +E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] Exception 2 %s +Traceback (most recent call last): + File "logging_functional_test_helper.py", line 456, in _test_do_logging + raise OSError('Fake Error') +OSError: Fake Error +E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] Non-exception +E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] Exception 3 +Traceback (most recent call last): + File "logging_functional_test_helper.py", line 456, in _test_do_logging + raise OSError('Fake Error') +OSError: Fake Error +E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] No traceback +{fake_error_extra}OSError: Fake Error +E1231 23:59:59.000000 12345 logging_functional_test_helper.py:90] Alarming Stuff +E0000 23:59:59.000000 12345 logging_functional_test_helper.py:92] Error first 1 of 2 +E0000 23:59:59.000000 12345 logging_functional_test_helper.py:93] Error 1 (every 3) +E0000 23:59:59.000000 12345 logging_functional_test_helper.py:92] Error first 2 of 2 +E0000 23:59:59.000000 12345 logging_functional_test_helper.py:93] Error 4 (every 3) +""".format(fake_error_extra=_FAKE_ERROR_EXTRA_MESSAGE) + + +_CRITICAL_DOWNGRADE_TO_ERROR_MESSAGE = """\ +E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] CRITICAL - A critical message +""" + + +_VERBOSITY_FLAG_TEST_PARAMETERS = ( + ('fatal', logging.FATAL), + ('error', logging.ERROR), + ('warning', logging.WARN), + ('info', logging.INFO), + ('debug', logging.DEBUG), + ('vlog1', 1), + ('vlog2', 2), + ('vlog3', 3)) + + +def _get_fatal_log_expectation(testcase, message, include_stacktrace): + """Returns the expectation for fatal logging tests. + + Args: + testcase: The TestCase instance. + message: The extra fatal logging message. + include_stacktrace: Whether or not to include stacktrace. + + Returns: + A callable, the expectation for fatal logging tests. It will be passed to + FunctionalTest._exec_test as third items in the expected_logs list. + See _exec_test's docstring for more information. + """ + def assert_logs(logs): + if os.name == 'nt': + # On Windows, it also dumps extra information at the end, something like: + # This application has requested the Runtime to terminate it in an + # unusual way. Please contact the application's support team for more + # information. + logs = '\n'.join(logs.split('\n')[:-3]) + format_string = ( + 'F1231 23:59:59.000000 12345 logging_functional_test_helper.py:175] ' + '%s message\n') + expected_logs = format_string % message + if include_stacktrace: + expected_logs += 'Stack trace:\n' + faulthandler_start = 'Fatal Python error: Aborted' + testcase.assertIn(faulthandler_start, logs) + log_message = logs.split(faulthandler_start)[0] + testcase.assertEqual(_munge_log(expected_logs), _munge_log(log_message)) + + return assert_logs + + +def _munge_log(buf): + """Remove timestamps, thread ids, filenames and line numbers from logs.""" + + # Remove all messages produced before the output to be tested. + buf = re.sub(r'(?:.|\n)*START OF TEST HELPER LOGS: IGNORE PREVIOUS.\n', + r'', + buf) + + # Greeting + buf = re.sub(r'(?m)^Log file created at: .*\n', + '', + buf) + buf = re.sub(r'(?m)^Running on machine: .*\n', + '', + buf) + buf = re.sub(r'(?m)^Binary: .*\n', + '', + buf) + buf = re.sub(r'(?m)^Log line format: .*\n', + '', + buf) + + # Verify thread id is logged as a non-negative quantity. + matched = re.match(r'(?m)^(\w)(\d\d\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d) ' + r'([ ]*-?[0-9a-fA-f]+ )?([a-zA-Z<][\w._<>-]+):(\d+)', + buf) + if matched: + threadid = matched.group(3) + if int(threadid) < 0: + raise AssertionError("Negative threadid '%s' in '%s'" % (threadid, buf)) + + # Timestamp + buf = re.sub(r'(?m)' + logging.ABSL_LOGGING_PREFIX_REGEX, + r'\g<severity>0000 00:00:00.000000 12345 \g<filename>:123', + buf) + + # Traceback + buf = re.sub(r'(?m)^ File "(.*/)?([^"/]+)", line (\d+),', + r' File "\g<2>", line 456,', + buf) + + # Stack trace is too complicated for re, just assume it extends to end of + # output + buf = re.sub(r'(?sm)^Stack trace:\n.*', + r'Stack trace:\n', + buf) + buf = re.sub(r'(?sm)^\*\*\* Signal 6 received by PID.*\n.*', + r'Stack trace:\n', + buf) + buf = re.sub((r'(?sm)^\*\*\* ([A-Z]+) received by PID (\d+) ' + r'\(TID 0x([0-9a-f]+)\)' + r'( from PID \d+)?; stack trace: \*\*\*\n.*'), + r'Stack trace:\n', + buf) + buf = re.sub(r'(?sm)^\*\*\* Check failure stack trace: \*\*\*\n.*', + r'Stack trace:\n', + buf) + + if os.name == 'nt': + # On windows, we calls Python interpreter explicitly, so the file names + # include the full path. Strip them. + buf = re.sub(r'( File ").*(logging_functional_test_helper\.py", line )', + r'\1\2', + buf) + + return buf + + +def _verify_status(expected, actual, output): + if expected != actual: + raise AssertionError( + 'Test exited with unexpected status code %d (expected %d). ' + 'Output was:\n%s' % (actual, expected, output)) + + +def _verify_ok(status, output): + """Check that helper exited with no errors.""" + _verify_status(0, status, output) + + +def _verify_fatal(status, output): + """Check that helper died as expected.""" + # os.abort generates a SIGABRT signal (-6). On Windows, the process + # immediately returns an exit code of 3. + # See https://docs.python.org/3.6/library/os.html#os.abort. + expected_exit_code = 3 if os.name == 'nt' else -6 + _verify_status(expected_exit_code, status, output) + + +def _verify_assert(status, output): + """.Check that helper failed with assertion.""" + _verify_status(1, status, output) + + +class FunctionalTest(parameterized.TestCase): + """Functional tests using the logging_functional_test_helper script.""" + + def _get_helper(self): + helper_name = 'absl/logging/tests/logging_functional_test_helper' + return _bazelize_command.get_executable_path(helper_name) + + def _get_logs(self, + verbosity, + include_info_prefix=True): + logs = [] + if verbosity >= 3: + logs.append(_PY_VLOG3_LOG_MESSAGE) + if verbosity >= 2: + logs.append(_PY_VLOG2_LOG_MESSAGE) + if verbosity >= logging.DEBUG: + logs.append(_PY_DEBUG_LOG_MESSAGE) + + if verbosity >= logging.INFO: + if include_info_prefix: + logs.append(_PY_INFO_LOG_MESSAGE) + else: + logs.append(_PY_INFO_LOG_MESSAGE_NOPREFIX) + if verbosity >= logging.WARN: + logs.append(_PY_WARNING_LOG_MESSAGE) + if verbosity >= logging.ERROR: + logs.append(_PY_ERROR_LOG_MESSAGE) + + expected_logs = ''.join(logs) + expected_logs = expected_logs.replace( + "<type 'exceptions.OSError'>", "<class 'OSError'>") + return expected_logs + + def setUp(self): + super(FunctionalTest, self).setUp() + self._log_dir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value) + + def tearDown(self): + shutil.rmtree(self._log_dir) + super(FunctionalTest, self).tearDown() + + def _exec_test(self, + verify_exit_fn, + expected_logs, + test_name='do_logging', + pass_logtostderr=False, + use_absl_log_file=False, + show_info_prefix=1, + call_dict_config=False, + extra_args=()): + """Execute the helper script and verify its output. + + Args: + verify_exit_fn: A function taking (status, output). + expected_logs: List of tuples, or None if output shouldn't be checked. + Tuple is (log prefix, log type, expected contents): + - log prefix: A program name, or 'stderr'. + - log type: 'INFO', 'ERROR', etc. + - expected: Can be the following: + - A string + - A callable, called with the logs as a single argument + - None, means don't check contents of log file + test_name: Name to pass to helper. + pass_logtostderr: Pass --logtostderr to the helper script if True. + use_absl_log_file: If True, call + logging.get_absl_handler().use_absl_log_file() before test_fn in + logging_functional_test_helper. + show_info_prefix: --showprefixforinfo value passed to the helper script. + call_dict_config: True if helper script should call + logging.config.dictConfig. + extra_args: Iterable of str (optional, defaults to ()) - extra arguments + to pass to the helper script. + + Raises: + AssertionError: Assertion error when test fails. + """ + args = ['--log_dir=%s' % self._log_dir] + if pass_logtostderr: + args.append('--logtostderr') + if not show_info_prefix: + args.append('--noshowprefixforinfo') + args += extra_args + + # Execute helper in subprocess. + env = os.environ.copy() + env.update({ + 'TEST_NAME': test_name, + 'USE_ABSL_LOG_FILE': '%d' % (use_absl_log_file,), + 'CALL_DICT_CONFIG': '%d' % (call_dict_config,), + }) + cmd = [self._get_helper()] + args + + print('env: %s' % env, file=sys.stderr) + print('cmd: %s' % cmd, file=sys.stderr) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env, + universal_newlines=True) + output, _ = process.communicate() + status = process.returncode + + # Verify exit status. + verify_exit_fn(status, output) + + # Check outputs? + if expected_logs is None: + return + + # Get list of log files. + logs = os.listdir(self._log_dir) + logs = fnmatch.filter(logs, '*.log.*') + logs.append('stderr') + + # Look for a log matching each expected pattern. + matched = [] + unmatched = [] + unexpected = logs[:] + for log_prefix, log_type, expected in expected_logs: + # What pattern? + if log_prefix == 'stderr': + assert log_type is None + pattern = 'stderr' + else: + pattern = r'%s[.].*[.]log[.]%s[.][\d.-]*$' % (log_prefix, log_type) + + # Is it there + for basename in logs: + if re.match(pattern, basename): + matched.append([expected, basename]) + unexpected.remove(basename) + break + else: + unmatched.append(pattern) + + # Mismatch? + errors = '' + if unmatched: + errors += 'The following log files were expected but not found: %s' % ( + '\n '.join(unmatched)) + if unexpected: + if errors: + errors += '\n' + errors += 'The following log files were not expected: %s' % ( + '\n '.join(unexpected)) + if errors: + raise AssertionError(errors) + + # Compare contents of matches. + for (expected, basename) in matched: + if expected is None: + continue + + if basename == 'stderr': + actual = output + else: + path = os.path.join(self._log_dir, basename) + with open(path, encoding='utf-8') as f: + actual = f.read() + + if callable(expected): + try: + expected(actual) + except AssertionError: + print('expected_logs assertion failed, actual {} log:\n{}'.format( + basename, actual), file=sys.stderr) + raise + elif isinstance(expected, str): + self.assertMultiLineEqual(_munge_log(expected), _munge_log(actual), + '%s differs' % basename) + else: + self.fail( + 'Invalid value found for expected logs: {}, type: {}'.format( + expected, type(expected))) + + @parameterized.named_parameters( + ('', False), + ('logtostderr', True)) + def test_py_logging(self, logtostderr): + # Python logging by default logs to stderr. + self._exec_test( + _verify_ok, + [['stderr', None, self._get_logs(logging.INFO)]], + pass_logtostderr=logtostderr) + + def test_py_logging_use_absl_log_file(self): + # Python logging calling use_absl_log_file causes also log to files. + self._exec_test( + _verify_ok, + [['stderr', None, ''], + ['absl_log_file', 'INFO', self._get_logs(logging.INFO)]], + use_absl_log_file=True) + + def test_py_logging_use_absl_log_file_logtostderr(self): + # Python logging asked to log to stderr even though use_absl_log_file + # is called. + self._exec_test( + _verify_ok, + [['stderr', None, self._get_logs(logging.INFO)]], + pass_logtostderr=True, + use_absl_log_file=True) + + @parameterized.named_parameters( + ('', False), + ('logtostderr', True)) + def test_py_logging_noshowprefixforinfo(self, logtostderr): + self._exec_test( + _verify_ok, + [['stderr', None, self._get_logs(logging.INFO, + include_info_prefix=False)]], + pass_logtostderr=logtostderr, + show_info_prefix=0) + + def test_py_logging_noshowprefixforinfo_use_absl_log_file(self): + self._exec_test( + _verify_ok, + [['stderr', None, ''], + ['absl_log_file', 'INFO', self._get_logs(logging.INFO)]], + show_info_prefix=0, + use_absl_log_file=True) + + def test_py_logging_noshowprefixforinfo_use_absl_log_file_logtostderr(self): + self._exec_test( + _verify_ok, + [['stderr', None, self._get_logs(logging.INFO, + include_info_prefix=False)]], + pass_logtostderr=True, + show_info_prefix=0, + use_absl_log_file=True) + + def test_py_logging_noshowprefixforinfo_verbosity(self): + self._exec_test( + _verify_ok, + [['stderr', None, self._get_logs(logging.DEBUG)]], + pass_logtostderr=True, + show_info_prefix=0, + use_absl_log_file=True, + extra_args=['-v=1']) + + def test_py_logging_fatal_main_thread_only(self): + self._exec_test( + _verify_fatal, + [['stderr', None, _get_fatal_log_expectation( + self, 'fatal_main_thread_only', False)]], + test_name='fatal_main_thread_only') + + def test_py_logging_fatal_with_other_threads(self): + self._exec_test( + _verify_fatal, + [['stderr', None, _get_fatal_log_expectation( + self, 'fatal_with_other_threads', False)]], + test_name='fatal_with_other_threads') + + def test_py_logging_fatal_non_main_thread(self): + self._exec_test( + _verify_fatal, + [['stderr', None, _get_fatal_log_expectation( + self, 'fatal_non_main_thread', False)]], + test_name='fatal_non_main_thread') + + def test_py_logging_critical_non_absl(self): + self._exec_test( + _verify_ok, + [['stderr', None, _CRITICAL_DOWNGRADE_TO_ERROR_MESSAGE]], + test_name='critical_from_non_absl_logger') + + def test_py_logging_skip_log_prefix(self): + self._exec_test( + _verify_ok, + [['stderr', None, '']], + test_name='register_frame_to_skip') + + def test_py_logging_flush(self): + self._exec_test( + _verify_ok, + [['stderr', None, '']], + test_name='flush') + + @parameterized.named_parameters(*_VERBOSITY_FLAG_TEST_PARAMETERS) + def test_py_logging_verbosity_stderr(self, verbosity): + """Tests -v/--verbosity flag with python logging to stderr.""" + v_flag = '-v=%d' % verbosity + self._exec_test( + _verify_ok, + [['stderr', None, self._get_logs(verbosity)]], + extra_args=[v_flag]) + + @parameterized.named_parameters(*_VERBOSITY_FLAG_TEST_PARAMETERS) + def test_py_logging_verbosity_file(self, verbosity): + """Tests -v/--verbosity flag with Python logging to stderr.""" + v_flag = '-v=%d' % verbosity + self._exec_test( + _verify_ok, + [['stderr', None, ''], + # When using python logging, it only creates a file named INFO, + # unlike C++ it also creates WARNING and ERROR files. + ['absl_log_file', 'INFO', self._get_logs(verbosity)]], + use_absl_log_file=True, + extra_args=[v_flag]) + + def test_stderrthreshold_py_logging(self): + """Tests --stderrthreshold.""" + + stderr_logs = '''\ +I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=debug, debug log +I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=debug, info log +W0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=debug, warning log +E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=debug, error log +I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=info, info log +W0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=info, warning log +E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=info, error log +W0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=warning, warning log +E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=warning, error log +E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=error, error log +''' + + expected_logs = [ + ['stderr', None, stderr_logs], + ['absl_log_file', 'INFO', None], + ] + # Set verbosity to debug to test stderrthreshold == debug. + extra_args = ['-v=1'] + + self._exec_test( + _verify_ok, + expected_logs, + test_name='stderrthreshold', + extra_args=extra_args, + use_absl_log_file=True) + + def test_std_logging_py_logging(self): + """Tests logs from std logging.""" + stderr_logs = '''\ +I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] std debug log +I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] std info log +W0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] std warning log +E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] std error log +''' + expected_logs = [['stderr', None, stderr_logs]] + + extra_args = ['-v=1', '--logtostderr'] + self._exec_test( + _verify_ok, + expected_logs, + test_name='std_logging', + extra_args=extra_args) + + def test_bad_exc_info_py_logging(self): + + def assert_stderr(stderr): + # The exact message differs among different Python versions. So it just + # asserts some certain information is there. + self.assertIn('Traceback (most recent call last):', stderr) + self.assertIn('IndexError', stderr) + + expected_logs = [ + ['stderr', None, assert_stderr], + ['absl_log_file', 'INFO', '']] + + self._exec_test( + _verify_ok, + expected_logs, + test_name='bad_exc_info', + use_absl_log_file=True) + + def test_verbosity_logger_levels_flag_ordering(self): + """Make sure last-specified flag wins.""" + + def assert_error_level_logged(stderr): + lines = stderr.splitlines() + for line in lines: + self.assertIn('std error log', line) + + self._exec_test( + _verify_ok, + test_name='std_logging', + expected_logs=[('stderr', None, assert_error_level_logged)], + extra_args=['-v=1', '--logger_levels=:ERROR']) + + def assert_debug_level_logged(stderr): + lines = stderr.splitlines() + for line in lines: + self.assertRegex(line, 'std (debug|info|warning|error) log') + + self._exec_test( + _verify_ok, + test_name='std_logging', + expected_logs=[('stderr', None, assert_debug_level_logged)], + extra_args=['--logger_levels=:ERROR', '-v=1']) + + def test_none_exc_info_py_logging(self): + + expected_stderr = '' + expected_info = '''\ +I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] None exc_info +''' + expected_info += 'NoneType: None\n' + + expected_logs = [ + ['stderr', None, expected_stderr], + ['absl_log_file', 'INFO', expected_info]] + + self._exec_test( + _verify_ok, + expected_logs, + test_name='none_exc_info', + use_absl_log_file=True) + + def test_unicode_py_logging(self): + + def get_stderr_message(stderr, name): + match = re.search( + '-- begin {} --\n(.*)-- end {} --'.format(name, name), + stderr, re.MULTILINE | re.DOTALL) + self.assertTrue( + match, 'Cannot find stderr message for test {}'.format(name)) + return match.group(1) + + def assert_stderr(stderr): + """Verifies that it writes correct information to stderr for Python 3. + + There are no unicode errors in Python 3. + + Args: + stderr: the message from stderr. + """ + # Successful logs: + for name in ( + 'unicode', 'unicode % unicode', 'bytes % bytes', 'unicode % bytes', + 'bytes % unicode', 'unicode % iso8859-15', 'str % exception', + 'str % exception'): + logging.info('name = %s', name) + self.assertEqual('', get_stderr_message(stderr, name)) + + expected_logs = [['stderr', None, assert_stderr]] + + info_log = u'''\ +I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] G\u00eete: Ch\u00e2tonnaye +I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] G\u00eete: Ch\u00e2tonnaye +I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] b'G\\xc3\\xaete: b'Ch\\xc3\\xa2tonnaye'' +I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] G\u00eete: b'Ch\\xc3\\xa2tonnaye' +I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] b'G\\xc3\\xaete: Ch\u00e2tonnaye' +I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] G\u00eete: b'Ch\\xe2tonnaye' +I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] exception: Ch\u00e2tonnaye +''' + expected_logs.append(['absl_log_file', 'INFO', info_log]) + + self._exec_test( + _verify_ok, + expected_logs, + test_name='unicode', + use_absl_log_file=True) + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/logging/tests/logging_functional_test_helper.py b/absl/logging/tests/logging_functional_test_helper.py new file mode 100644 index 0000000..b95647b --- /dev/null +++ b/absl/logging/tests/logging_functional_test_helper.py @@ -0,0 +1,312 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper script for logging_functional_test.""" + +import logging as std_logging +import logging.config as std_logging_config +import os +import sys +import threading +import time +import timeit +from unittest import mock + +from absl import app +from absl import flags +from absl import logging + +FLAGS = flags.FLAGS + + +class VerboseDel(object): + """Dummy class to test __del__ running.""" + + def __init__(self, msg): + self._msg = msg + + def __del__(self): + sys.stderr.write(self._msg) + sys.stderr.flush() + + +def _test_do_logging(): + """Do some log operations.""" + logging.vlog(3, 'This line is VLOG level 3') + logging.vlog(2, 'This line is VLOG level 2') + logging.log(2, 'This line is log level 2') + if logging.vlog_is_on(2): + logging.log(1, 'VLOG level 1, but only if VLOG level 2 is active') + + logging.vlog(1, 'This line is VLOG level 1') + logging.log(1, 'This line is log level 1') + logging.debug('This line is DEBUG') + + logging.vlog(0, 'This line is VLOG level 0') + logging.log(0, 'This line is log level 0') + logging.info('Interesting Stuff\0') + logging.info('Interesting Stuff with Arguments: %d', 42) + logging.info('%(a)s Stuff with %(b)s', + {'a': 'Interesting', 'b': 'Dictionary'}) + + with mock.patch.object(timeit, 'default_timer') as mock_timer: + mock_timer.return_value = 0 + while timeit.default_timer() < 9: + logging.log_every_n_seconds(logging.INFO, 'This should appear 5 times.', + 2) + mock_timer.return_value = mock_timer() + .2 + + for i in range(1, 5): + logging.log_first_n(logging.INFO, 'Info first %d of %d', 2, i, 2) + logging.log_every_n(logging.INFO, 'Info %d (every %d)', 3, i, 3) + + logging.vlog(-1, 'This line is VLOG level -1') + logging.log(-1, 'This line is log level -1') + logging.warning('Worrying Stuff') + for i in range(1, 5): + logging.log_first_n(logging.WARNING, 'Warn first %d of %d', 2, i, 2) + logging.log_every_n(logging.WARNING, 'Warn %d (every %d)', 3, i, 3) + + logging.vlog(-2, 'This line is VLOG level -2') + logging.log(-2, 'This line is log level -2') + try: + raise OSError('Fake Error') + except OSError: + saved_exc_info = sys.exc_info() + logging.exception('An Exception %s') + logging.exception('Once more, %(reason)s', {'reason': 'just because'}) + logging.error('Exception 2 %s', exc_info=True) + logging.error('Non-exception', exc_info=False) + + try: + sys.exc_clear() + except AttributeError: + # No sys.exc_clear() in Python 3, but this will clear sys.exc_info() too. + pass + + logging.error('Exception %s', '3', exc_info=saved_exc_info) + logging.error('No traceback', exc_info=saved_exc_info[:2] + (None,)) + + logging.error('Alarming Stuff') + for i in range(1, 5): + logging.log_first_n(logging.ERROR, 'Error first %d of %d', 2, i, 2) + logging.log_every_n(logging.ERROR, 'Error %d (every %d)', 3, i, 3) + logging.flush() + + +def _test_fatal_main_thread_only(): + """Test logging.fatal from main thread, no other threads running.""" + v = VerboseDel('fatal_main_thread_only main del called\n') + try: + logging.fatal('fatal_main_thread_only message') + finally: + del v + + +def _test_fatal_with_other_threads(): + """Test logging.fatal from main thread, other threads running.""" + + lock = threading.Lock() + lock.acquire() + + def sleep_forever(lock=lock): + v = VerboseDel('fatal_with_other_threads non-main del called\n') + try: + lock.release() + while True: + time.sleep(10000) + finally: + del v + + v = VerboseDel('fatal_with_other_threads main del called\n') + try: + # Start new thread + t = threading.Thread(target=sleep_forever) + t.start() + + # Wait for other thread + lock.acquire() + lock.release() + + # Die + logging.fatal('fatal_with_other_threads message') + while True: + time.sleep(10000) + finally: + del v + + +def _test_fatal_non_main_thread(): + """Test logging.fatal from non main thread.""" + + lock = threading.Lock() + lock.acquire() + + def die_soon(lock=lock): + v = VerboseDel('fatal_non_main_thread non-main del called\n') + try: + # Wait for signal from other thread + lock.acquire() + lock.release() + logging.fatal('fatal_non_main_thread message') + while True: + time.sleep(10000) + finally: + del v + + v = VerboseDel('fatal_non_main_thread main del called\n') + try: + # Start new thread + t = threading.Thread(target=die_soon) + t.start() + + # Signal other thread + lock.release() + + # Wait for it to die + while True: + time.sleep(10000) + finally: + del v + + +def _test_critical_from_non_absl_logger(): + """Test CRITICAL logs from non-absl loggers.""" + + std_logging.critical('A critical message') + + +def _test_register_frame_to_skip(): + """Test skipping frames for line number reporting.""" + + def _getline(): + + def _getline_inner(): + return logging.get_absl_logger().findCaller()[1] + + return _getline_inner() + + # Check register_frame_to_skip function to see if log frame skipping works. + line1 = _getline() + line2 = _getline() + logging.get_absl_logger().register_frame_to_skip(__file__, '_getline') + line3 = _getline() + # Both should be line number of the _getline_inner() call. + assert (line1 == line2), (line1, line2) + # line3 should be a line number in this function. + assert (line2 != line3), (line2, line3) + + +def _test_flush(): + """Test flush in various difficult cases.""" + # Flush, but one of the logfiles is closed + log_filename = os.path.join(FLAGS.log_dir, 'a_thread_with_logfile.txt') + with open(log_filename, 'w') as log_file: + logging.get_absl_handler().python_handler.stream = log_file + logging.flush() + + +def _test_stderrthreshold(): + """Tests modifying --stderrthreshold after flag parsing will work.""" + + def log_things(): + logging.debug('FLAGS.stderrthreshold=%s, debug log', FLAGS.stderrthreshold) + logging.info('FLAGS.stderrthreshold=%s, info log', FLAGS.stderrthreshold) + logging.warning('FLAGS.stderrthreshold=%s, warning log', + FLAGS.stderrthreshold) + logging.error('FLAGS.stderrthreshold=%s, error log', FLAGS.stderrthreshold) + + FLAGS.stderrthreshold = 'debug' + log_things() + FLAGS.stderrthreshold = 'info' + log_things() + FLAGS.stderrthreshold = 'warning' + log_things() + FLAGS.stderrthreshold = 'error' + log_things() + + +def _test_std_logging(): + """Tests logs from std logging.""" + std_logging.debug('std debug log') + std_logging.info('std info log') + std_logging.warning('std warning log') + std_logging.error('std error log') + + +def _test_bad_exc_info(): + """Tests when a bad exc_info valud is provided.""" + logging.info('Bad exc_info', exc_info=(None, None)) + + +def _test_none_exc_info(): + """Tests when exc_info is requested but not available.""" + # Clear exc_info first. + try: + sys.exc_clear() + except AttributeError: + # No sys.exc_clear() in Python 3, but this will clear sys.exc_info() too. + pass + logging.info('None exc_info', exc_info=True) + + +def _test_unicode(): + """Tests unicode handling.""" + + test_names = [] + + def log(name, msg, *args): + """Logs the message, and ensures the same name is not logged again.""" + assert name not in test_names, ('test_unicode expects unique names to work,' + ' found existing name {}').format(name) + test_names.append(name) + + # Add line separators so that tests can verify the output for each log + # message. + sys.stderr.write('-- begin {} --\n'.format(name)) + logging.info(msg, *args) + sys.stderr.write('-- end {} --\n'.format(name)) + + log('unicode', u'G\u00eete: Ch\u00e2tonnaye') + log('unicode % unicode', u'G\u00eete: %s', u'Ch\u00e2tonnaye') + log('bytes % bytes', u'G\u00eete: %s'.encode('utf-8'), + u'Ch\u00e2tonnaye'.encode('utf-8')) + log('unicode % bytes', u'G\u00eete: %s', u'Ch\u00e2tonnaye'.encode('utf-8')) + log('bytes % unicode', u'G\u00eete: %s'.encode('utf-8'), u'Ch\u00e2tonnaye') + log('unicode % iso8859-15', u'G\u00eete: %s', + u'Ch\u00e2tonnaye'.encode('iso-8859-15')) + log('str % exception', 'exception: %s', Exception(u'Ch\u00e2tonnaye')) + + +def main(argv): + del argv # Unused. + + test_name = os.environ.get('TEST_NAME', None) + test_fn = globals().get('_test_%s' % test_name) + if test_fn is None: + raise AssertionError('TEST_NAME must be set to a valid value') + # Flush so previous messages are written to file before we switch to a new + # file with use_absl_log_file. + logging.flush() + if os.environ.get('USE_ABSL_LOG_FILE') == '1': + logging.get_absl_handler().use_absl_log_file('absl_log_file', FLAGS.log_dir) + + test_fn() + + +if __name__ == '__main__': + sys.argv[0] = 'py_argv_0' + if os.environ.get('CALL_DICT_CONFIG') == '1': + std_logging_config.dictConfig({'version': 1}) + app.run(main) diff --git a/absl/logging/tests/logging_test.py b/absl/logging/tests/logging_test.py new file mode 100644 index 0000000..e5c4fcc --- /dev/null +++ b/absl/logging/tests/logging_test.py @@ -0,0 +1,1002 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for absl.logging.""" + +import contextlib +import functools +import getpass +import io +import logging as std_logging +import os +import re +import socket +import sys +import tempfile +import threading +import time +import traceback +import unittest +from unittest import mock + +from absl import flags +from absl import logging +from absl.testing import absltest +from absl.testing import flagsaver +from absl.testing import parameterized + +FLAGS = flags.FLAGS + + +class ConfigurationTest(absltest.TestCase): + """Tests the initial logging configuration.""" + + def test_logger_and_handler(self): + absl_logger = std_logging.getLogger('absl') + self.assertIs(absl_logger, logging.get_absl_logger()) + self.assertIsInstance(absl_logger, logging.ABSLLogger) + self.assertIsInstance( + logging.get_absl_handler().python_handler.formatter, + logging.PythonFormatter) + + +class LoggerLevelsTest(parameterized.TestCase): + + def setUp(self): + super(LoggerLevelsTest, self).setUp() + # Since these tests muck with the flag, always save/restore in case the + # tests forget to clean up properly. + # enter_context() is py3-only, but manually enter/exit should suffice. + cm = self.set_logger_levels({}) + cm.__enter__() + self.addCleanup(lambda: cm.__exit__(None, None, None)) + + @contextlib.contextmanager + def set_logger_levels(self, levels): + original_levels = { + name: std_logging.getLogger(name).level for name in levels + } + + try: + with flagsaver.flagsaver(logger_levels=levels): + yield + finally: + for name, level in original_levels.items(): + std_logging.getLogger(name).setLevel(level) + + def assert_logger_level(self, name, expected_level): + logger = std_logging.getLogger(name) + self.assertEqual(logger.level, expected_level) + + def assert_logged(self, logger_name, expected_msgs): + logger = std_logging.getLogger(logger_name) + # NOTE: assertLogs() sets the logger to INFO if not specified. + with self.assertLogs(logger, logger.level) as cm: + logger.debug('debug') + logger.info('info') + logger.warning('warning') + logger.error('error') + logger.critical('critical') + + actual = {r.getMessage() for r in cm.records} + self.assertEqual(set(expected_msgs), actual) + + def test_setting_levels(self): + # Other tests change the root logging level, so we can't + # assume it's the default. + orig_root_level = std_logging.root.getEffectiveLevel() + with self.set_logger_levels({'foo': 'ERROR', 'bar': 'DEBUG'}): + + self.assert_logger_level('foo', std_logging.ERROR) + self.assert_logger_level('bar', std_logging.DEBUG) + self.assert_logger_level('', orig_root_level) + + self.assert_logged('foo', {'error', 'critical'}) + self.assert_logged('bar', + {'debug', 'info', 'warning', 'error', 'critical'}) + + @parameterized.named_parameters( + ('empty', ''), + ('one_value', 'one:INFO'), + ('two_values', 'one.a:INFO,two.b:ERROR'), + ('whitespace_ignored', ' one : DEBUG , two : INFO'), + ) + def test_serialize_parse(self, levels_str): + fl = FLAGS['logger_levels'] + fl.parse(levels_str) + expected = levels_str.replace(' ', '') + actual = fl.serialize() + self.assertEqual('--logger_levels={}'.format(expected), actual) + + def test_invalid_value(self): + with self.assertRaisesRegex(ValueError, 'Unknown level.*10'): + FLAGS['logger_levels'].parse('foo:10') + + +class PythonHandlerTest(absltest.TestCase): + """Tests the PythonHandler class.""" + + def setUp(self): + super().setUp() + (year, month, day, hour, minute, sec, + dunno, dayofyear, dst_flag) = (1979, 10, 21, 18, 17, 16, 3, 15, 0) + self.now_tuple = (year, month, day, hour, minute, sec, + dunno, dayofyear, dst_flag) + self.python_handler = logging.PythonHandler() + + def tearDown(self): + mock.patch.stopall() + super().tearDown() + + @flagsaver.flagsaver(logtostderr=False) + def test_set_google_log_file_no_log_to_stderr(self): + with mock.patch.object(self.python_handler, 'start_logging_to_file'): + self.python_handler.use_absl_log_file() + self.python_handler.start_logging_to_file.assert_called_once_with( + program_name=None, log_dir=None) + + @flagsaver.flagsaver(logtostderr=True) + def test_set_google_log_file_with_log_to_stderr(self): + self.python_handler.stream = None + self.python_handler.use_absl_log_file() + self.assertEqual(sys.stderr, self.python_handler.stream) + + @mock.patch.object(logging, 'find_log_dir_and_names') + @mock.patch.object(logging.time, 'localtime') + @mock.patch.object(logging.time, 'time') + @mock.patch.object(os.path, 'islink') + @mock.patch.object(os, 'unlink') + @mock.patch.object(os, 'getpid') + def test_start_logging_to_file( + self, mock_getpid, mock_unlink, mock_islink, mock_time, + mock_localtime, mock_find_log_dir_and_names): + mock_find_log_dir_and_names.return_value = ('here', 'prog1', 'prog1') + mock_time.return_value = '12345' + mock_localtime.return_value = self.now_tuple + mock_getpid.return_value = 4321 + symlink = os.path.join('here', 'prog1.INFO') + mock_islink.return_value = True + with mock.patch.object( + logging, 'open', return_value=sys.stdout, create=True): + if getattr(os, 'symlink', None): + with mock.patch.object(os, 'symlink'): + self.python_handler.start_logging_to_file() + mock_unlink.assert_called_once_with(symlink) + os.symlink.assert_called_once_with( + 'prog1.INFO.19791021-181716.4321', symlink) + else: + self.python_handler.start_logging_to_file() + + def test_log_file(self): + handler = logging.PythonHandler() + self.assertEqual(sys.stderr, handler.stream) + + stream = mock.Mock() + handler = logging.PythonHandler(stream) + self.assertEqual(stream, handler.stream) + + def test_flush(self): + stream = mock.Mock() + handler = logging.PythonHandler(stream) + handler.flush() + stream.flush.assert_called_once() + + def test_flush_with_value_error(self): + stream = mock.Mock() + stream.flush.side_effect = ValueError + handler = logging.PythonHandler(stream) + handler.flush() + stream.flush.assert_called_once() + + def test_flush_with_environment_error(self): + stream = mock.Mock() + stream.flush.side_effect = EnvironmentError + handler = logging.PythonHandler(stream) + handler.flush() + stream.flush.assert_called_once() + + def test_flush_with_assertion_error(self): + stream = mock.Mock() + stream.flush.side_effect = AssertionError + handler = logging.PythonHandler(stream) + with self.assertRaises(AssertionError): + handler.flush() + + def test_log_to_std_err(self): + record = std_logging.LogRecord( + 'name', std_logging.INFO, 'path', 12, 'logging_msg', [], False) + with mock.patch.object(std_logging.StreamHandler, 'emit'): + self.python_handler._log_to_stderr(record) + std_logging.StreamHandler.emit.assert_called_once_with(record) + + @flagsaver.flagsaver(logtostderr=True) + def test_emit_log_to_stderr(self): + record = std_logging.LogRecord( + 'name', std_logging.INFO, 'path', 12, 'logging_msg', [], False) + with mock.patch.object(self.python_handler, '_log_to_stderr'): + self.python_handler.emit(record) + self.python_handler._log_to_stderr.assert_called_once_with(record) + + def test_emit(self): + stream = io.StringIO() + handler = logging.PythonHandler(stream) + handler.stderr_threshold = std_logging.FATAL + record = std_logging.LogRecord( + 'name', std_logging.INFO, 'path', 12, 'logging_msg', [], False) + handler.emit(record) + self.assertEqual(1, stream.getvalue().count('logging_msg')) + + @flagsaver.flagsaver(stderrthreshold='debug') + def test_emit_and_stderr_threshold(self): + mock_stderr = io.StringIO() + stream = io.StringIO() + handler = logging.PythonHandler(stream) + record = std_logging.LogRecord( + 'name', std_logging.INFO, 'path', 12, 'logging_msg', [], False) + with mock.patch.object(sys, 'stderr', new=mock_stderr) as mock_stderr: + handler.emit(record) + self.assertEqual(1, stream.getvalue().count('logging_msg')) + self.assertEqual(1, mock_stderr.getvalue().count('logging_msg')) + + @flagsaver.flagsaver(alsologtostderr=True) + def test_emit_also_log_to_stderr(self): + mock_stderr = io.StringIO() + stream = io.StringIO() + handler = logging.PythonHandler(stream) + handler.stderr_threshold = std_logging.FATAL + record = std_logging.LogRecord( + 'name', std_logging.INFO, 'path', 12, 'logging_msg', [], False) + with mock.patch.object(sys, 'stderr', new=mock_stderr) as mock_stderr: + handler.emit(record) + self.assertEqual(1, stream.getvalue().count('logging_msg')) + self.assertEqual(1, mock_stderr.getvalue().count('logging_msg')) + + def test_emit_on_stderr(self): + mock_stderr = io.StringIO() + with mock.patch.object(sys, 'stderr', new=mock_stderr) as mock_stderr: + handler = logging.PythonHandler() + handler.stderr_threshold = std_logging.INFO + record = std_logging.LogRecord( + 'name', std_logging.INFO, 'path', 12, 'logging_msg', [], False) + handler.emit(record) + self.assertEqual(1, mock_stderr.getvalue().count('logging_msg')) + + def test_emit_fatal_absl(self): + stream = io.StringIO() + handler = logging.PythonHandler(stream) + record = std_logging.LogRecord( + 'name', std_logging.FATAL, 'path', 12, 'logging_msg', [], False) + record.__dict__[logging._ABSL_LOG_FATAL] = True + with mock.patch.object(handler, 'flush') as mock_flush: + with mock.patch.object(os, 'abort') as mock_abort: + handler.emit(record) + mock_abort.assert_called_once() + mock_flush.assert_called() # flush is also called by super class. + + def test_emit_fatal_non_absl(self): + stream = io.StringIO() + handler = logging.PythonHandler(stream) + record = std_logging.LogRecord( + 'name', std_logging.FATAL, 'path', 12, 'logging_msg', [], False) + with mock.patch.object(os, 'abort') as mock_abort: + handler.emit(record) + mock_abort.assert_not_called() + + def test_close(self): + stream = mock.Mock() + stream.isatty.return_value = True + handler = logging.PythonHandler(stream) + with mock.patch.object(handler, 'flush') as mock_flush: + with mock.patch.object(std_logging.StreamHandler, 'close') as super_close: + handler.close() + mock_flush.assert_called_once() + super_close.assert_called_once() + stream.close.assert_not_called() + + def test_close_afile(self): + stream = mock.Mock() + stream.isatty.return_value = False + stream.close.side_effect = ValueError + handler = logging.PythonHandler(stream) + with mock.patch.object(handler, 'flush') as mock_flush: + with mock.patch.object(std_logging.StreamHandler, 'close') as super_close: + handler.close() + mock_flush.assert_called_once() + super_close.assert_called_once() + + def test_close_stderr(self): + with mock.patch.object(sys, 'stderr') as mock_stderr: + mock_stderr.isatty.return_value = False + handler = logging.PythonHandler(sys.stderr) + handler.close() + mock_stderr.close.assert_not_called() + + def test_close_stdout(self): + with mock.patch.object(sys, 'stdout') as mock_stdout: + mock_stdout.isatty.return_value = False + handler = logging.PythonHandler(sys.stdout) + handler.close() + mock_stdout.close.assert_not_called() + + def test_close_original_stderr(self): + with mock.patch.object(sys, '__stderr__') as mock_original_stderr: + mock_original_stderr.isatty.return_value = False + handler = logging.PythonHandler(sys.__stderr__) + handler.close() + mock_original_stderr.close.assert_not_called() + + def test_close_original_stdout(self): + with mock.patch.object(sys, '__stdout__') as mock_original_stdout: + mock_original_stdout.isatty.return_value = False + handler = logging.PythonHandler(sys.__stdout__) + handler.close() + mock_original_stdout.close.assert_not_called() + + def test_close_fake_file(self): + + class FakeFile(object): + """A file-like object that does not implement "isatty".""" + + def __init__(self): + self.closed = False + + def close(self): + self.closed = True + + def flush(self): + pass + + fake_file = FakeFile() + handler = logging.PythonHandler(fake_file) + handler.close() + self.assertTrue(fake_file.closed) + + +class ABSLHandlerTest(absltest.TestCase): + + def setUp(self): + super().setUp() + formatter = logging.PythonFormatter() + self.absl_handler = logging.ABSLHandler(formatter) + + def test_activate_python_handler(self): + self.absl_handler.activate_python_handler() + self.assertEqual( + self.absl_handler._current_handler, self.absl_handler.python_handler) + + +class ABSLLoggerTest(absltest.TestCase): + """Tests the ABSLLogger class.""" + + def set_up_mock_frames(self): + """Sets up mock frames for use with the testFindCaller methods.""" + logging_file = os.path.join('absl', 'logging', '__init__.py') + + # Set up mock frame 0 + mock_frame_0 = mock.Mock() + mock_code_0 = mock.Mock() + mock_code_0.co_filename = logging_file + mock_code_0.co_name = 'LoggingLog' + mock_code_0.co_firstlineno = 124 + mock_frame_0.f_code = mock_code_0 + mock_frame_0.f_lineno = 125 + + # Set up mock frame 1 + mock_frame_1 = mock.Mock() + mock_code_1 = mock.Mock() + mock_code_1.co_filename = 'myfile.py' + mock_code_1.co_name = 'Method1' + mock_code_1.co_firstlineno = 124 + mock_frame_1.f_code = mock_code_1 + mock_frame_1.f_lineno = 125 + + # Set up mock frame 2 + mock_frame_2 = mock.Mock() + mock_code_2 = mock.Mock() + mock_code_2.co_filename = 'myfile.py' + mock_code_2.co_name = 'Method2' + mock_code_2.co_firstlineno = 124 + mock_frame_2.f_code = mock_code_2 + mock_frame_2.f_lineno = 125 + + # Set up mock frame 3 + mock_frame_3 = mock.Mock() + mock_code_3 = mock.Mock() + mock_code_3.co_filename = 'myfile.py' + mock_code_3.co_name = 'Method3' + mock_code_3.co_firstlineno = 124 + mock_frame_3.f_code = mock_code_3 + mock_frame_3.f_lineno = 125 + + # Set up mock frame 4 that has the same function name as frame 2. + mock_frame_4 = mock.Mock() + mock_code_4 = mock.Mock() + mock_code_4.co_filename = 'myfile.py' + mock_code_4.co_name = 'Method2' + mock_code_4.co_firstlineno = 248 + mock_frame_4.f_code = mock_code_4 + mock_frame_4.f_lineno = 249 + + # Tie them together. + mock_frame_4.f_back = None + mock_frame_3.f_back = mock_frame_4 + mock_frame_2.f_back = mock_frame_3 + mock_frame_1.f_back = mock_frame_2 + mock_frame_0.f_back = mock_frame_1 + + mock.patch.object(sys, '_getframe').start() + sys._getframe.return_value = mock_frame_0 + + def setUp(self): + super().setUp() + self.message = 'Hello Nurse' + self.logger = logging.ABSLLogger('') + + def tearDown(self): + mock.patch.stopall() + self.logger._frames_to_skip.clear() + super().tearDown() + + def test_constructor_without_level(self): + self.logger = logging.ABSLLogger('') + self.assertEqual(std_logging.NOTSET, self.logger.getEffectiveLevel()) + + def test_constructor_with_level(self): + self.logger = logging.ABSLLogger('', std_logging.DEBUG) + self.assertEqual(std_logging.DEBUG, self.logger.getEffectiveLevel()) + + def test_find_caller_normal(self): + self.set_up_mock_frames() + expected_name = 'Method1' + self.assertEqual(expected_name, self.logger.findCaller()[2]) + + def test_find_caller_skip_method1(self): + self.set_up_mock_frames() + self.logger.register_frame_to_skip('myfile.py', 'Method1') + expected_name = 'Method2' + self.assertEqual(expected_name, self.logger.findCaller()[2]) + + def test_find_caller_skip_method1_and_method2(self): + self.set_up_mock_frames() + self.logger.register_frame_to_skip('myfile.py', 'Method1') + self.logger.register_frame_to_skip('myfile.py', 'Method2') + expected_name = 'Method3' + self.assertEqual(expected_name, self.logger.findCaller()[2]) + + def test_find_caller_skip_method1_and_method3(self): + self.set_up_mock_frames() + self.logger.register_frame_to_skip('myfile.py', 'Method1') + # Skipping Method3 should change nothing since Method2 should be hit. + self.logger.register_frame_to_skip('myfile.py', 'Method3') + expected_name = 'Method2' + self.assertEqual(expected_name, self.logger.findCaller()[2]) + + def test_find_caller_skip_method1_and_method4(self): + self.set_up_mock_frames() + self.logger.register_frame_to_skip('myfile.py', 'Method1') + # Skipping frame 4's Method2 should change nothing for frame 2's Method2. + self.logger.register_frame_to_skip('myfile.py', 'Method2', 248) + expected_name = 'Method2' + expected_frame_lineno = 125 + self.assertEqual(expected_name, self.logger.findCaller()[2]) + self.assertEqual(expected_frame_lineno, self.logger.findCaller()[1]) + + def test_find_caller_skip_method1_method2_and_method3(self): + self.set_up_mock_frames() + self.logger.register_frame_to_skip('myfile.py', 'Method1') + self.logger.register_frame_to_skip('myfile.py', 'Method2', 124) + self.logger.register_frame_to_skip('myfile.py', 'Method3') + expected_name = 'Method2' + expected_frame_lineno = 249 + self.assertEqual(expected_name, self.logger.findCaller()[2]) + self.assertEqual(expected_frame_lineno, self.logger.findCaller()[1]) + + def test_find_caller_stack_info(self): + self.set_up_mock_frames() + self.logger.register_frame_to_skip('myfile.py', 'Method1') + with mock.patch.object(traceback, 'print_stack') as print_stack: + self.assertEqual( + ('myfile.py', 125, 'Method2', 'Stack (most recent call last):'), + self.logger.findCaller(stack_info=True)) + print_stack.assert_called_once() + + def test_critical(self): + with mock.patch.object(self.logger, 'log'): + self.logger.critical(self.message) + self.logger.log.assert_called_once_with( + std_logging.CRITICAL, self.message) + + def test_fatal(self): + with mock.patch.object(self.logger, 'log'): + self.logger.fatal(self.message) + self.logger.log.assert_called_once_with(std_logging.FATAL, self.message) + + def test_error(self): + with mock.patch.object(self.logger, 'log'): + self.logger.error(self.message) + self.logger.log.assert_called_once_with(std_logging.ERROR, self.message) + + def test_warn(self): + with mock.patch.object(self.logger, 'log'): + self.logger.warn(self.message) + self.logger.log.assert_called_once_with(std_logging.WARN, self.message) + + def test_warning(self): + with mock.patch.object(self.logger, 'log'): + self.logger.warning(self.message) + self.logger.log.assert_called_once_with(std_logging.WARNING, self.message) + + def test_info(self): + with mock.patch.object(self.logger, 'log'): + self.logger.info(self.message) + self.logger.log.assert_called_once_with(std_logging.INFO, self.message) + + def test_debug(self): + with mock.patch.object(self.logger, 'log'): + self.logger.debug(self.message) + self.logger.log.assert_called_once_with(std_logging.DEBUG, self.message) + + def test_log_debug_with_python(self): + with mock.patch.object(self.logger, 'log'): + FLAGS.verbosity = 1 + self.logger.debug(self.message) + self.logger.log.assert_called_once_with(std_logging.DEBUG, self.message) + + def test_log_fatal_with_python(self): + with mock.patch.object(self.logger, 'log'): + self.logger.fatal(self.message) + self.logger.log.assert_called_once_with(std_logging.FATAL, self.message) + + def test_register_frame_to_skip(self): + # This is basically just making sure that if I put something in a + # list, it actually appears in that list. + frame_tuple = ('file', 'method') + self.logger.register_frame_to_skip(*frame_tuple) + self.assertIn(frame_tuple, self.logger._frames_to_skip) + + def test_register_frame_to_skip_with_lineno(self): + frame_tuple = ('file', 'method', 123) + self.logger.register_frame_to_skip(*frame_tuple) + self.assertIn(frame_tuple, self.logger._frames_to_skip) + + def test_logger_cannot_be_disabled(self): + self.logger.disabled = True + record = self.logger.makeRecord( + 'name', std_logging.INFO, 'fn', 20, 'msg', [], False) + with mock.patch.object(self.logger, 'callHandlers') as mock_call_handlers: + self.logger.handle(record) + mock_call_handlers.assert_called_once() + + +class ABSLLogPrefixTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.record = std_logging.LogRecord( + 'name', std_logging.INFO, 'path/to/source.py', 13, 'log message', + None, None) + + @parameterized.named_parameters( + ('debug', std_logging.DEBUG, 'I'), + ('info', std_logging.INFO, 'I'), + ('warning', std_logging.WARNING, 'W'), + ('error', std_logging.ERROR, 'E'), + ) + def test_default_prefixes(self, levelno, level_prefix): + self.record.levelno = levelno + self.record.created = 1494293880.378885 + thread_id = '{: >5}'.format(logging._get_thread_id()) + # Use UTC so the test passes regardless of the local time zone. + with mock.patch.object(time, 'localtime', side_effect=time.gmtime): + self.assertEqual( + '{}0509 01:38:00.378885 {} source.py:13] '.format( + level_prefix, thread_id), + logging.get_absl_log_prefix(self.record)) + time.localtime.assert_called_once_with(self.record.created) + + def test_absl_prefix_regex(self): + self.record.created = 1226888258.0521369 + # Use UTC so the test passes regardless of the local time zone. + with mock.patch.object(time, 'localtime', side_effect=time.gmtime): + prefix = logging.get_absl_log_prefix(self.record) + + match = re.search(logging.ABSL_LOGGING_PREFIX_REGEX, prefix) + self.assertTrue(match) + + expect = {'severity': 'I', + 'month': '11', + 'day': '17', + 'hour': '02', + 'minute': '17', + 'second': '38', + 'microsecond': '052136', + 'thread_id': str(logging._get_thread_id()), + 'filename': 'source.py', + 'line': '13', + } + actual = {name: match.group(name) for name in expect} + self.assertEqual(expect, actual) + + def test_critical_absl(self): + self.record.levelno = std_logging.CRITICAL + self.record.created = 1494293880.378885 + self.record._absl_log_fatal = True + thread_id = '{: >5}'.format(logging._get_thread_id()) + # Use UTC so the test passes regardless of the local time zone. + with mock.patch.object(time, 'localtime', side_effect=time.gmtime): + self.assertEqual( + 'F0509 01:38:00.378885 {} source.py:13] '.format(thread_id), + logging.get_absl_log_prefix(self.record)) + time.localtime.assert_called_once_with(self.record.created) + + def test_critical_non_absl(self): + self.record.levelno = std_logging.CRITICAL + self.record.created = 1494293880.378885 + thread_id = '{: >5}'.format(logging._get_thread_id()) + # Use UTC so the test passes regardless of the local time zone. + with mock.patch.object(time, 'localtime', side_effect=time.gmtime): + self.assertEqual( + 'E0509 01:38:00.378885 {} source.py:13] CRITICAL - '.format( + thread_id), + logging.get_absl_log_prefix(self.record)) + time.localtime.assert_called_once_with(self.record.created) + + +class LogCountTest(absltest.TestCase): + + def test_counter_threadsafe(self): + threads_start = threading.Event() + counts = set() + k = object() + + def t(): + threads_start.wait() + counts.add(logging._get_next_log_count_per_token(k)) + + threads = [threading.Thread(target=t) for _ in range(100)] + for thread in threads: + thread.start() + threads_start.set() + for thread in threads: + thread.join() + self.assertEqual(counts, {i for i in range(100)}) + + +class LoggingTest(absltest.TestCase): + + def test_fatal(self): + with mock.patch.object(os, 'abort') as mock_abort: + logging.fatal('Die!') + mock_abort.assert_called_once() + + def test_find_log_dir_with_arg(self): + with mock.patch.object(os, 'access'), \ + mock.patch.object(os.path, 'isdir'): + os.path.isdir.return_value = True + os.access.return_value = True + log_dir = logging.find_log_dir(log_dir='./') + self.assertEqual('./', log_dir) + + @flagsaver.flagsaver(log_dir='./') + def test_find_log_dir_with_flag(self): + with mock.patch.object(os, 'access'), \ + mock.patch.object(os.path, 'isdir'): + os.path.isdir.return_value = True + os.access.return_value = True + log_dir = logging.find_log_dir() + self.assertEqual('./', log_dir) + + @flagsaver.flagsaver(log_dir='') + def test_find_log_dir_with_hda_tmp(self): + with mock.patch.object(os, 'access'), \ + mock.patch.object(os.path, 'exists'), \ + mock.patch.object(os.path, 'isdir'): + os.path.exists.return_value = True + os.path.isdir.return_value = True + os.access.return_value = True + log_dir = logging.find_log_dir() + self.assertEqual('/tmp/', log_dir) + + @flagsaver.flagsaver(log_dir='') + def test_find_log_dir_with_tmp(self): + with mock.patch.object(os, 'access'), \ + mock.patch.object(os.path, 'exists'), \ + mock.patch.object(os.path, 'isdir'): + os.path.exists.return_value = False + os.path.isdir.side_effect = lambda path: path == '/tmp/' + os.access.return_value = True + log_dir = logging.find_log_dir() + self.assertEqual('/tmp/', log_dir) + + def test_find_log_dir_with_nothing(self): + with mock.patch.object(os.path, 'exists'), \ + mock.patch.object(os.path, 'isdir'): + os.path.exists.return_value = False + os.path.isdir.return_value = False + with self.assertRaises(FileNotFoundError): + logging.find_log_dir() + + def test_find_log_dir_and_names_with_args(self): + user = 'test_user' + host = 'test_host' + log_dir = 'here' + program_name = 'prog1' + with mock.patch.object(getpass, 'getuser'), \ + mock.patch.object(logging, 'find_log_dir') as mock_find_log_dir, \ + mock.patch.object(socket, 'gethostname') as mock_gethostname: + getpass.getuser.return_value = user + mock_gethostname.return_value = host + mock_find_log_dir.return_value = log_dir + + prefix = '%s.%s.%s.log' % (program_name, host, user) + self.assertEqual((log_dir, prefix, program_name), + logging.find_log_dir_and_names( + program_name=program_name, log_dir=log_dir)) + + def test_find_log_dir_and_names_without_args(self): + user = 'test_user' + host = 'test_host' + log_dir = 'here' + py_program_name = 'py_prog1' + sys.argv[0] = 'path/to/prog1' + with mock.patch.object(getpass, 'getuser'), \ + mock.patch.object(logging, 'find_log_dir') as mock_find_log_dir, \ + mock.patch.object(socket, 'gethostname'): + getpass.getuser.return_value = user + socket.gethostname.return_value = host + mock_find_log_dir.return_value = log_dir + prefix = '%s.%s.%s.log' % (py_program_name, host, user) + self.assertEqual((log_dir, prefix, py_program_name), + logging.find_log_dir_and_names()) + + def test_find_log_dir_and_names_wo_username(self): + # Windows doesn't have os.getuid at all + if hasattr(os, 'getuid'): + mock_getuid = mock.patch.object(os, 'getuid') + uid = 100 + logged_uid = '100' + else: + # The function doesn't exist, but our test code still tries to mock + # it, so just use a fake thing. + mock_getuid = _mock_windows_os_getuid() + uid = -1 + logged_uid = 'unknown' + + host = 'test_host' + log_dir = 'here' + program_name = 'prog1' + with mock.patch.object(getpass, 'getuser'), \ + mock_getuid as getuid, \ + mock.patch.object(logging, 'find_log_dir') as mock_find_log_dir, \ + mock.patch.object(socket, 'gethostname') as mock_gethostname: + getpass.getuser.side_effect = KeyError() + getuid.return_value = uid + mock_gethostname.return_value = host + mock_find_log_dir.return_value = log_dir + + prefix = '%s.%s.%s.log' % (program_name, host, logged_uid) + self.assertEqual((log_dir, prefix, program_name), + logging.find_log_dir_and_names( + program_name=program_name, log_dir=log_dir)) + + def test_errors_in_logging(self): + with mock.patch.object(sys, 'stderr', new=io.StringIO()) as stderr: + logging.info('not enough args: %s %s', 'foo') # pylint: disable=logging-too-few-args + self.assertIn('Traceback (most recent call last):', stderr.getvalue()) + self.assertIn('TypeError', stderr.getvalue()) + + def test_dict_arg(self): + # Tests that passing a dictionary as a single argument does not crash. + logging.info('%(test)s', {'test': 'Hello world!'}) + + def test_exception_dict_format(self): + # Just verify that this doesn't raise a TypeError. + logging.exception('%(test)s', {'test': 'Hello world!'}) + + def test_logging_levels(self): + old_level = logging.get_verbosity() + + logging.set_verbosity(logging.DEBUG) + self.assertEqual(logging.get_verbosity(), logging.DEBUG) + self.assertTrue(logging.level_debug()) + self.assertTrue(logging.level_info()) + self.assertTrue(logging.level_warning()) + self.assertTrue(logging.level_error()) + + logging.set_verbosity(logging.INFO) + self.assertEqual(logging.get_verbosity(), logging.INFO) + self.assertFalse(logging.level_debug()) + self.assertTrue(logging.level_info()) + self.assertTrue(logging.level_warning()) + self.assertTrue(logging.level_error()) + + logging.set_verbosity(logging.WARNING) + self.assertEqual(logging.get_verbosity(), logging.WARNING) + self.assertFalse(logging.level_debug()) + self.assertFalse(logging.level_info()) + self.assertTrue(logging.level_warning()) + self.assertTrue(logging.level_error()) + + logging.set_verbosity(logging.ERROR) + self.assertEqual(logging.get_verbosity(), logging.ERROR) + self.assertFalse(logging.level_debug()) + self.assertFalse(logging.level_info()) + self.assertTrue(logging.level_error()) + + logging.set_verbosity(old_level) + + def test_set_verbosity_strings(self): + old_level = logging.get_verbosity() + + # Lowercase names. + logging.set_verbosity('debug') + self.assertEqual(logging.get_verbosity(), logging.DEBUG) + logging.set_verbosity('info') + self.assertEqual(logging.get_verbosity(), logging.INFO) + logging.set_verbosity('warning') + self.assertEqual(logging.get_verbosity(), logging.WARNING) + logging.set_verbosity('warn') + self.assertEqual(logging.get_verbosity(), logging.WARNING) + logging.set_verbosity('error') + self.assertEqual(logging.get_verbosity(), logging.ERROR) + logging.set_verbosity('fatal') + + # Uppercase names. + self.assertEqual(logging.get_verbosity(), logging.FATAL) + logging.set_verbosity('DEBUG') + self.assertEqual(logging.get_verbosity(), logging.DEBUG) + logging.set_verbosity('INFO') + self.assertEqual(logging.get_verbosity(), logging.INFO) + logging.set_verbosity('WARNING') + self.assertEqual(logging.get_verbosity(), logging.WARNING) + logging.set_verbosity('WARN') + self.assertEqual(logging.get_verbosity(), logging.WARNING) + logging.set_verbosity('ERROR') + self.assertEqual(logging.get_verbosity(), logging.ERROR) + logging.set_verbosity('FATAL') + self.assertEqual(logging.get_verbosity(), logging.FATAL) + + # Integers as strings. + logging.set_verbosity(str(logging.DEBUG)) + self.assertEqual(logging.get_verbosity(), logging.DEBUG) + logging.set_verbosity(str(logging.INFO)) + self.assertEqual(logging.get_verbosity(), logging.INFO) + logging.set_verbosity(str(logging.WARNING)) + self.assertEqual(logging.get_verbosity(), logging.WARNING) + logging.set_verbosity(str(logging.ERROR)) + self.assertEqual(logging.get_verbosity(), logging.ERROR) + logging.set_verbosity(str(logging.FATAL)) + self.assertEqual(logging.get_verbosity(), logging.FATAL) + + logging.set_verbosity(old_level) + + def test_key_flags(self): + key_flags = FLAGS.get_key_flags_for_module(logging) + key_flag_names = [flag.name for flag in key_flags] + self.assertIn('stderrthreshold', key_flag_names) + self.assertIn('verbosity', key_flag_names) + + def test_get_absl_logger(self): + self.assertIsInstance( + logging.get_absl_logger(), logging.ABSLLogger) + + def test_get_absl_handler(self): + self.assertIsInstance( + logging.get_absl_handler(), logging.ABSLHandler) + + +@mock.patch.object(logging.ABSLLogger, 'register_frame_to_skip') +class LogSkipPrefixTest(absltest.TestCase): + """Tests for logging.skip_log_prefix.""" + + def _log_some_info(self): + """Logging helper function for LogSkipPrefixTest.""" + logging.info('info') + + def _log_nested_outer(self): + """Nested logging helper functions for LogSkipPrefixTest.""" + def _log_nested_inner(): + logging.info('info nested') + return _log_nested_inner + + def test_skip_log_prefix_with_name(self, mock_skip_register): + retval = logging.skip_log_prefix('_log_some_info') + mock_skip_register.assert_called_once_with(__file__, '_log_some_info', None) + self.assertEqual(retval, '_log_some_info') + + def test_skip_log_prefix_with_func(self, mock_skip_register): + retval = logging.skip_log_prefix(self._log_some_info) + mock_skip_register.assert_called_once_with( + __file__, '_log_some_info', mock.ANY) + self.assertEqual(retval, self._log_some_info) + + def test_skip_log_prefix_with_functools_partial(self, mock_skip_register): + partial_input = functools.partial(self._log_some_info) + with self.assertRaises(ValueError): + _ = logging.skip_log_prefix(partial_input) + mock_skip_register.assert_not_called() + + def test_skip_log_prefix_with_lambda(self, mock_skip_register): + lambda_input = lambda _: self._log_some_info() + retval = logging.skip_log_prefix(lambda_input) + mock_skip_register.assert_called_once_with(__file__, '<lambda>', mock.ANY) + self.assertEqual(retval, lambda_input) + + def test_skip_log_prefix_with_bad_input(self, mock_skip_register): + dict_input = {1: 2, 2: 3} + with self.assertRaises(TypeError): + _ = logging.skip_log_prefix(dict_input) + mock_skip_register.assert_not_called() + + def test_skip_log_prefix_with_nested_func(self, mock_skip_register): + nested_input = self._log_nested_outer() + retval = logging.skip_log_prefix(nested_input) + mock_skip_register.assert_called_once_with( + __file__, '_log_nested_inner', mock.ANY) + self.assertEqual(retval, nested_input) + + def test_skip_log_prefix_decorator(self, mock_skip_register): + + @logging.skip_log_prefix + def _log_decorated(): + logging.info('decorated') + + del _log_decorated + mock_skip_register.assert_called_once_with( + __file__, '_log_decorated', mock.ANY) + + +@contextlib.contextmanager +def override_python_handler_stream(stream): + handler = logging.get_absl_handler().python_handler + old_stream = handler.stream + handler.stream = stream + try: + yield + finally: + handler.stream = old_stream + + +class GetLogFileNameTest(parameterized.TestCase): + + @parameterized.named_parameters( + ('err', sys.stderr), + ('out', sys.stdout), + ) + def test_get_log_file_name_py_std(self, stream): + with override_python_handler_stream(stream): + self.assertEqual('', logging.get_log_file_name()) + + def test_get_log_file_name_py_no_name(self): + + class FakeFile(object): + pass + + with override_python_handler_stream(FakeFile()): + self.assertEqual('', logging.get_log_file_name()) + + def test_get_log_file_name_py_file(self): + _, filename = tempfile.mkstemp(dir=absltest.TEST_TMPDIR.value) + with open(filename, 'a') as stream: + with override_python_handler_stream(stream): + self.assertEqual(filename, logging.get_log_file_name()) + + +@contextlib.contextmanager +def _mock_windows_os_getuid(): + yield mock.MagicMock() + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/logging/tests/verbosity_flag_test.py b/absl/logging/tests/verbosity_flag_test.py new file mode 100644 index 0000000..4609e64 --- /dev/null +++ b/absl/logging/tests/verbosity_flag_test.py @@ -0,0 +1,56 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests -v/--verbosity flag and logging.root level's sync behavior.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging + +assert logging.root.getEffectiveLevel() == logging.WARN, ( + 'default logging.root level should be WARN, but found {}'.format( + logging.root.getEffectiveLevel())) + +# This is here to test importing logging won't change the level. +logging.root.setLevel(logging.ERROR) + +assert logging.root.getEffectiveLevel() == logging.ERROR, ( + 'logging.root level should be changed to ERROR, but found {}'.format( + logging.root.getEffectiveLevel())) + +from absl import flags +from absl import logging as _ # pylint: disable=unused-import +from absl.testing import absltest + +FLAGS = flags.FLAGS + +assert FLAGS['verbosity'].value == -1, ( + '-v/--verbosity should be -1 before flags are parsed.') + +assert logging.root.getEffectiveLevel() == logging.ERROR, ( + 'logging.root level should be kept to ERROR, but found {}'.format( + logging.root.getEffectiveLevel())) + + +class VerbosityFlagTest(absltest.TestCase): + + def test_default_value_after_init(self): + self.assertEqual(0, FLAGS.verbosity) + self.assertEqual(logging.INFO, logging.root.getEffectiveLevel()) + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/testing/BUILD b/absl/testing/BUILD new file mode 100644 index 0000000..b608c8c --- /dev/null +++ b/absl/testing/BUILD @@ -0,0 +1,254 @@ +licenses(["notice"]) + +py_library( + name = "absltest", + srcs = ["absltest.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":_pretty_print_reporter", + ":xml_reporter", + "//absl:app", + "//absl/flags", + "//absl/logging", + ], +) + +py_library( + name = "flagsaver", + srcs = ["flagsaver.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//absl/flags", + ], +) + +py_library( + name = "parameterized", + srcs = [ + "parameterized.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":absltest", + ], +) + +py_library( + name = "xml_reporter", + srcs = ["xml_reporter.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":_pretty_print_reporter", + ], +) + +py_library( + name = "_bazelize_command", + testonly = 1, + srcs = ["_bazelize_command.py"], + srcs_version = "PY2AND3", + visibility = ["//:__subpackages__"], + deps = [ + "//absl/flags", + ], +) + +py_library( + name = "_pretty_print_reporter", + srcs = ["_pretty_print_reporter.py"], + srcs_version = "PY2AND3", +) + +py_library( + name = "tests/absltest_env", + testonly = True, + srcs = ["tests/absltest_env.py"], +) + +py_test( + name = "tests/absltest_filtering_test", + size = "medium", + srcs = ["tests/absltest_filtering_test.py"], + data = [":tests/absltest_filtering_test_helper"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":_bazelize_command", + ":absltest", + ":parameterized", + ":tests/absltest_env", + "//absl/logging", + ], +) + +py_binary( + name = "tests/absltest_filtering_test_helper", + testonly = 1, + srcs = ["tests/absltest_filtering_test_helper.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":absltest", + ":parameterized", + "//absl:app", + ], +) + +py_test( + name = "tests/absltest_fail_fast_test", + size = "small", + srcs = ["tests/absltest_fail_fast_test.py"], + data = [":tests/absltest_fail_fast_test_helper"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":_bazelize_command", + ":absltest", + ":parameterized", + ":tests/absltest_env", + "//absl/logging", + ], +) + +py_binary( + name = "tests/absltest_fail_fast_test_helper", + testonly = 1, + srcs = ["tests/absltest_fail_fast_test_helper.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":absltest", + "//absl:app", + ], +) + +py_test( + name = "tests/absltest_randomization_test", + size = "medium", + srcs = ["tests/absltest_randomization_test.py"], + data = [":tests/absltest_randomization_testcase"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":_bazelize_command", + ":absltest", + ":parameterized", + ":tests/absltest_env", + "//absl/flags", + ], +) + +py_binary( + name = "tests/absltest_randomization_testcase", + testonly = 1, + srcs = ["tests/absltest_randomization_testcase.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":absltest", + ], +) + +py_test( + name = "tests/absltest_sharding_test", + size = "small", + srcs = ["tests/absltest_sharding_test.py"], + data = [":tests/absltest_sharding_test_helper"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":_bazelize_command", + ":absltest", + ":tests/absltest_env", + ], +) + +py_binary( + name = "tests/absltest_sharding_test_helper", + testonly = 1, + srcs = ["tests/absltest_sharding_test_helper.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [":absltest"], +) + +py_test( + name = "tests/absltest_test", + size = "small", + srcs = ["tests/absltest_test.py"], + data = [":tests/absltest_test_helper"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":_bazelize_command", + ":absltest", + ":parameterized", + ":tests/absltest_env", + ], +) + +py_binary( + name = "tests/absltest_test_helper", + testonly = 1, + srcs = ["tests/absltest_test_helper.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":absltest", + "//absl/flags", + ], +) + +py_test( + name = "tests/flagsaver_test", + srcs = ["tests/flagsaver_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":absltest", + ":flagsaver", + "//absl/flags", + ], +) + +py_test( + name = "tests/parameterized_test", + srcs = ["tests/parameterized_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":absltest", + ":parameterized", + ], +) + +py_test( + name = "tests/xml_reporter_test", + srcs = ["tests/xml_reporter_test.py"], + data = [":tests/xml_reporter_helper_test"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":_bazelize_command", + ":absltest", + ":parameterized", + ":xml_reporter", + "//absl/logging", + ], +) + +py_binary( + name = "tests/xml_reporter_helper_test", + testonly = 1, + srcs = ["tests/xml_reporter_helper_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":absltest", + "//absl/flags", + ], +) diff --git a/absl/testing/__init__.py b/absl/testing/__init__.py new file mode 100644 index 0000000..a3bd1cd --- /dev/null +++ b/absl/testing/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/absl/testing/_bazelize_command.py b/absl/testing/_bazelize_command.py new file mode 100644 index 0000000..fdf6eb6 --- /dev/null +++ b/absl/testing/_bazelize_command.py @@ -0,0 +1,72 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal helper for running tests on Windows Bazel.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from absl import flags + +FLAGS = flags.FLAGS + + +def get_executable_path(py_binary_name): + """Returns the executable path of a py_binary. + + This returns the executable path of a py_binary that is in another Bazel + target's data dependencies. + + On Linux/macOS, the path and __file__ has the same root directory. + On Windows, bazel builds an .exe file and we need to use the MANIFEST file + the location the actual binary. + + Args: + py_binary_name: string, the name of a py_binary that is in another Bazel + target's data dependencies. + + Raises: + RuntimeError: Raised when it cannot locate the executable path. + """ + + if os.name == 'nt': + py_binary_name += '.exe' + manifest_file = os.path.join(FLAGS.test_srcdir, 'MANIFEST') + workspace_name = os.environ['TEST_WORKSPACE'] + manifest_entry = '{}/{}'.format(workspace_name, py_binary_name) + with open(manifest_file, 'r') as manifest_fd: + for line in manifest_fd: + tokens = line.strip().split(' ') + if len(tokens) != 2: + continue + if manifest_entry == tokens[0]: + return tokens[1] + raise RuntimeError( + 'Cannot locate executable path for {}, MANIFEST file: {}.'.format( + py_binary_name, manifest_file)) + else: + # NOTE: __file__ may be .py or .pyc, depending on how the module was + # loaded and executed. + path = __file__ + + # Use the package name to find the root directory: every dot is + # a directory, plus one for ourselves. + for _ in range(__name__.count('.') + 1): + path = os.path.dirname(path) + + root_directory = path + return os.path.join(root_directory, py_binary_name) diff --git a/absl/testing/_pretty_print_reporter.py b/absl/testing/_pretty_print_reporter.py new file mode 100644 index 0000000..ef03934 --- /dev/null +++ b/absl/testing/_pretty_print_reporter.py @@ -0,0 +1,95 @@ +# Copyright 2018 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TestResult implementing default output for test execution status.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + + +class TextTestResult(unittest.TextTestResult): + """TestResult class that provides the default text result formatting.""" + + def __init__(self, stream, descriptions, verbosity): + # Disable the verbose per-test output from the superclass, since it would + # conflict with our customized output. + super(TextTestResult, self).__init__(stream, descriptions, 0) + self._per_test_output = verbosity > 0 + + def _print_status(self, tag, test): + if self._per_test_output: + test_id = test.id() + if test_id.startswith('__main__.'): + test_id = test_id[len('__main__.'):] + print('[%s] %s' % (tag, test_id), file=self.stream) + self.stream.flush() + + def startTest(self, test): + super(TextTestResult, self).startTest(test) + self._print_status(' RUN ', test) + + def addSuccess(self, test): + super(TextTestResult, self).addSuccess(test) + self._print_status(' OK ', test) + + def addError(self, test, err): + super(TextTestResult, self).addError(test, err) + self._print_status(' FAILED ', test) + + def addFailure(self, test, err): + super(TextTestResult, self).addFailure(test, err) + self._print_status(' FAILED ', test) + + def addSkip(self, test, reason): + super(TextTestResult, self).addSkip(test, reason) + self._print_status(' SKIPPED ', test) + + def addExpectedFailure(self, test, err): + super(TextTestResult, self).addExpectedFailure(test, err) + self._print_status(' OK ', test) + + def addUnexpectedSuccess(self, test): + super(TextTestResult, self).addUnexpectedSuccess(test) + self._print_status(' FAILED ', test) + + +class TextTestRunner(unittest.TextTestRunner): + """A test runner that produces formatted text results.""" + + _TEST_RESULT_CLASS = TextTestResult + + # Set this to true at the class or instance level to run tests using a + # debug-friendly method (e.g, one that doesn't catch exceptions and interacts + # better with debuggers). + # Usually this is set using --pdb_post_mortem. + run_for_debugging = False + + def run(self, test): + # type: (TestCase) -> TestResult + if self.run_for_debugging: + return self._run_debug(test) + else: + return super(TextTestRunner, self).run(test) + + def _run_debug(self, test): + # type: (TestCase) -> TestResult + test.debug() + # Return an empty result to indicate success. + return self._makeResult() + + def _makeResult(self): + return TextTestResult(self.stream, self.descriptions, self.verbosity) diff --git a/absl/testing/absltest.py b/absl/testing/absltest.py new file mode 100644 index 0000000..cebb1ca --- /dev/null +++ b/absl/testing/absltest.py @@ -0,0 +1,2554 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base functionality for Abseil Python tests. + +This module contains base classes and high-level functions for Abseil-style +tests. +""" + +from collections import abc +import contextlib +import difflib +import enum +import errno +import getpass +import inspect +import io +import itertools +import json +import os +import random +import re +import shlex +import shutil +import signal +import stat +import subprocess +import sys +import tempfile +import textwrap +import unittest +from unittest import mock # pylint: disable=unused-import Allow absltest.mock. +from urllib import parse + +try: + # The faulthandler module isn't always available, and pytype doesn't + # understand that we're catching ImportError, so suppress the error. + # pytype: disable=import-error + import faulthandler + # pytype: enable=import-error +except ImportError: + # We use faulthandler if it is available. + faulthandler = None + +from absl import app +from absl import flags +from absl import logging +from absl.testing import _pretty_print_reporter +from absl.testing import xml_reporter + +# Make typing an optional import to avoid it being a required dependency +# in Python 2. Type checkers will still understand the imports. +try: + # pylint: disable=unused-import + import typing + from typing import Any, AnyStr, BinaryIO, Callable, ContextManager, IO, Iterator, List, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Text, TextIO, Tuple, Type, Union + # pylint: enable=unused-import +except ImportError: + pass +else: + # Use an if-type-checking block to prevent leakage of type-checking only + # symbols. We don't want people relying on these at runtime. + if typing.TYPE_CHECKING: + # Unbounded TypeVar for general usage + _T = typing.TypeVar('_T') + + import unittest.case + _OutcomeType = unittest.case._Outcome # pytype: disable=module-attr + + + +# Re-export a bunch of unittest functions we support so that people don't +# have to import unittest to get them +# pylint: disable=invalid-name +skip = unittest.skip +skipIf = unittest.skipIf +skipUnless = unittest.skipUnless +SkipTest = unittest.SkipTest +expectedFailure = unittest.expectedFailure +# pylint: enable=invalid-name + +# End unittest re-exports + +FLAGS = flags.FLAGS + +_TEXT_OR_BINARY_TYPES = (str, bytes) + +# Suppress surplus entries in AssertionError stack traces. +__unittest = True # pylint: disable=invalid-name + + +def expectedFailureIf(condition, reason): # pylint: disable=invalid-name + """Expects the test to fail if the run condition is True. + + Example usage: + @expectedFailureIf(sys.version.major == 2, "Not yet working in py2") + def test_foo(self): + ... + + Args: + condition: bool, whether to expect failure or not. + reason: Text, the reason to expect failure. + Returns: + Decorator function + """ + del reason # Unused + if condition: + return unittest.expectedFailure + else: + return lambda f: f + + +class TempFileCleanup(enum.Enum): + # Always cleanup temp files when the test completes. + ALWAYS = 'always' + # Only cleanup temp file if the test passes. This allows easier inspection + # of tempfile contents on test failure. absltest.TEST_TMPDIR.value determines + # where tempfiles are created. + SUCCESS = 'success' + # Never cleanup temp files. + OFF = 'never' + + +# Many of the methods in this module have names like assertSameElements. +# This kind of name does not comply with PEP8 style, +# but it is consistent with the naming of methods in unittest.py. +# pylint: disable=invalid-name + + +def _get_default_test_random_seed(): + # type: () -> int + random_seed = 301 + value = os.environ.get('TEST_RANDOM_SEED', '') + try: + random_seed = int(value) + except ValueError: + pass + return random_seed + + +def get_default_test_srcdir(): + # type: () -> Text + """Returns default test source dir.""" + return os.environ.get('TEST_SRCDIR', '') + + +def get_default_test_tmpdir(): + # type: () -> Text + """Returns default test temp dir.""" + tmpdir = os.environ.get('TEST_TMPDIR', '') + if not tmpdir: + tmpdir = os.path.join(tempfile.gettempdir(), 'absl_testing') + + return tmpdir + + +def _get_default_randomize_ordering_seed(): + # type: () -> int + """Returns default seed to use for randomizing test order. + + This function first checks the --test_randomize_ordering_seed flag, and then + the TEST_RANDOMIZE_ORDERING_SEED environment variable. If the first value + we find is: + * (not set): disable test randomization + * 0: disable test randomization + * 'random': choose a random seed in [1, 4294967295] for test order + randomization + * positive integer: use this seed for test order randomization + + (The values used are patterned after + https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED). + + In principle, it would be simpler to return None if no override is provided; + however, the python random module has no `get_seed()`, only `getstate()`, + which returns far more data than we want to pass via an environment variable + or flag. + + Returns: + A default value for test case randomization (int). 0 means do not randomize. + + Raises: + ValueError: Raised when the flag or env value is not one of the options + above. + """ + if FLAGS['test_randomize_ordering_seed'].present: + randomize = FLAGS.test_randomize_ordering_seed + elif 'TEST_RANDOMIZE_ORDERING_SEED' in os.environ: + randomize = os.environ['TEST_RANDOMIZE_ORDERING_SEED'] + else: + randomize = '' + if not randomize: + return 0 + if randomize == 'random': + return random.Random().randint(1, 4294967295) + if randomize == '0': + return 0 + try: + seed = int(randomize) + if seed > 0: + return seed + except ValueError: + pass + raise ValueError( + 'Unknown test randomization seed value: {}'.format(randomize)) + + +TEST_SRCDIR = flags.DEFINE_string( + 'test_srcdir', + get_default_test_srcdir(), + 'Root of directory tree where source files live', + allow_override_cpp=True) +TEST_TMPDIR = flags.DEFINE_string( + 'test_tmpdir', + get_default_test_tmpdir(), + 'Directory for temporary testing files', + allow_override_cpp=True) + +flags.DEFINE_integer( + 'test_random_seed', + _get_default_test_random_seed(), + 'Random seed for testing. Some test frameworks may ' + 'change the default value of this flag between runs, so ' + 'it is not appropriate for seeding probabilistic tests.', + allow_override_cpp=True) +flags.DEFINE_string( + 'test_randomize_ordering_seed', + '', + 'If positive, use this as a seed to randomize the ' + 'execution order for test cases. If "random", pick a ' + 'random seed to use. If 0 or not set, do not randomize ' + 'test case execution order. This flag also overrides ' + 'the TEST_RANDOMIZE_ORDERING_SEED environment variable.', + allow_override_cpp=True) +flags.DEFINE_string('xml_output_file', '', 'File to store XML test results') + + +# We might need to monkey-patch TestResult so that it stops considering an +# unexpected pass as a as a "successful result". For details, see +# http://bugs.python.org/issue20165 +def _monkey_patch_test_result_for_unexpected_passes(): + # type: () -> None + """Workaround for <http://bugs.python.org/issue20165>.""" + + def wasSuccessful(self): + # type: () -> bool + """Tells whether or not this result was a success. + + Any unexpected pass is to be counted as a non-success. + + Args: + self: The TestResult instance. + + Returns: + Whether or not this result was a success. + """ + return (len(self.failures) == len(self.errors) == + len(self.unexpectedSuccesses) == 0) + + test_result = unittest.TestResult() + test_result.addUnexpectedSuccess(unittest.FunctionTestCase(lambda: None)) + if test_result.wasSuccessful(): # The bug is present. + unittest.TestResult.wasSuccessful = wasSuccessful + if test_result.wasSuccessful(): # Warn the user if our hot-fix failed. + sys.stderr.write('unittest.result.TestResult monkey patch to report' + ' unexpected passes as failures did not work.\n') + + +_monkey_patch_test_result_for_unexpected_passes() + + +def _open(filepath, mode, _open_func=open): + # type: (Text, Text, Callable[..., IO]) -> IO + """Opens a file. + + Like open(), but ensure that we can open real files even if tests stub out + open(). + + Args: + filepath: A filepath. + mode: A mode. + _open_func: A built-in open() function. + + Returns: + The opened file object. + """ + return _open_func(filepath, mode, encoding='utf-8') + + +class _TempDir(object): + """Represents a temporary directory for tests. + + Creation of this class is internal. Using its public methods is OK. + + This class implements the `os.PathLike` interface (specifically, + `os.PathLike[str]`). This means, in Python 3, it can be directly passed + to e.g. `os.path.join()`. + """ + + def __init__(self, path): + # type: (Text) -> None + """Module-private: do not instantiate outside module.""" + self._path = path + + @property + def full_path(self): + # type: () -> Text + """Returns the path, as a string, for the directory. + + TIP: Instead of e.g. `os.path.join(temp_dir.full_path)`, you can simply + do `os.path.join(temp_dir)` because `__fspath__()` is implemented. + """ + return self._path + + def __fspath__(self): + # type: () -> Text + """See os.PathLike.""" + return self.full_path + + def create_file(self, file_path=None, content=None, mode='w', encoding='utf8', + errors='strict'): + # type: (Optional[Text], Optional[AnyStr], Text, Text, Text) -> _TempFile + """Create a file in the directory. + + NOTE: If the file already exists, it will be made writable and overwritten. + + Args: + file_path: Optional file path for the temp file. If not given, a unique + file name will be generated and used. Slashes are allowed in the name; + any missing intermediate directories will be created. NOTE: This path + is the path that will be cleaned up, including any directories in the + path, e.g., 'foo/bar/baz.txt' will `rm -r foo` + content: Optional string or bytes to initially write to the file. If not + specified, then an empty file is created. + mode: Mode string to use when writing content. Only used if `content` is + non-empty. + encoding: Encoding to use when writing string content. Only used if + `content` is text. + errors: How to handle text to bytes encoding errors. Only used if + `content` is text. + + Returns: + A _TempFile representing the created file. + """ + tf, _ = _TempFile._create(self._path, file_path, content, mode, encoding, + errors) + return tf + + def mkdir(self, dir_path=None): + # type: (Optional[Text]) -> _TempDir + """Create a directory in the directory. + + Args: + dir_path: Optional path to the directory to create. If not given, + a unique name will be generated and used. + + Returns: + A _TempDir representing the created directory. + """ + if dir_path: + path = os.path.join(self._path, dir_path) + else: + path = tempfile.mkdtemp(dir=self._path) + + # Note: there's no need to clear the directory since the containing + # dir was cleared by the tempdir() function. + os.makedirs(path, exist_ok=True) + return _TempDir(path) + + +class _TempFile(object): + """Represents a tempfile for tests. + + Creation of this class is internal. Using its public methods is OK. + + This class implements the `os.PathLike` interface (specifically, + `os.PathLike[str]`). This means, in Python 3, it can be directly passed + to e.g. `os.path.join()`. + """ + + def __init__(self, path): + # type: (Text) -> None + """Private: use _create instead.""" + self._path = path + + # pylint: disable=line-too-long + @classmethod + def _create(cls, base_path, file_path, content, mode, encoding, errors): + # type: (Text, Optional[Text], AnyStr, Text, Text, Text) -> Tuple[_TempFile, Text] + # pylint: enable=line-too-long + """Module-private: create a tempfile instance.""" + if file_path: + cleanup_path = os.path.join(base_path, _get_first_part(file_path)) + path = os.path.join(base_path, file_path) + os.makedirs(os.path.dirname(path), exist_ok=True) + # The file may already exist, in which case, ensure it's writable so that + # it can be truncated. + if os.path.exists(path) and not os.access(path, os.W_OK): + stat_info = os.stat(path) + os.chmod(path, stat_info.st_mode | stat.S_IWUSR) + else: + os.makedirs(base_path, exist_ok=True) + fd, path = tempfile.mkstemp(dir=str(base_path)) + os.close(fd) + cleanup_path = path + + tf = cls(path) + + if content: + if isinstance(content, str): + tf.write_text(content, mode=mode, encoding=encoding, errors=errors) + else: + tf.write_bytes(content, mode) + + else: + tf.write_bytes(b'') + + return tf, cleanup_path + + @property + def full_path(self): + # type: () -> Text + """Returns the path, as a string, for the file. + + TIP: Instead of e.g. `os.path.join(temp_file.full_path)`, you can simply + do `os.path.join(temp_file)` because `__fspath__()` is implemented. + """ + return self._path + + def __fspath__(self): + # type: () -> Text + """See os.PathLike.""" + return self.full_path + + def read_text(self, encoding='utf8', errors='strict'): + # type: (Text, Text) -> Text + """Return the contents of the file as text.""" + with self.open_text(encoding=encoding, errors=errors) as fp: + return fp.read() + + def read_bytes(self): + # type: () -> bytes + """Return the content of the file as bytes.""" + with self.open_bytes() as fp: + return fp.read() + + def write_text(self, text, mode='w', encoding='utf8', errors='strict'): + # type: (Text, Text, Text, Text) -> None + """Write text to the file. + + Args: + text: Text to write. In Python 2, it can be bytes, which will be + decoded using the `encoding` arg (this is as an aid for code that + is 2 and 3 compatible). + mode: The mode to open the file for writing. + encoding: The encoding to use when writing the text to the file. + errors: The error handling strategy to use when converting text to bytes. + """ + with self.open_text(mode, encoding=encoding, errors=errors) as fp: + fp.write(text) + + def write_bytes(self, data, mode='wb'): + # type: (bytes, Text) -> None + """Write bytes to the file. + + Args: + data: bytes to write. + mode: Mode to open the file for writing. The "b" flag is implicit if + not already present. It must not have the "t" flag. + """ + with self.open_bytes(mode) as fp: + fp.write(data) + + def open_text(self, mode='rt', encoding='utf8', errors='strict'): + # type: (Text, Text, Text) -> ContextManager[TextIO] + """Return a context manager for opening the file in text mode. + + Args: + mode: The mode to open the file in. The "t" flag is implicit if not + already present. It must not have the "b" flag. + encoding: The encoding to use when opening the file. + errors: How to handle decoding errors. + + Returns: + Context manager that yields an open file. + + Raises: + ValueError: if invalid inputs are provided. + """ + if 'b' in mode: + raise ValueError('Invalid mode {!r}: "b" flag not allowed when opening ' + 'file in text mode'.format(mode)) + if 't' not in mode: + mode += 't' + cm = self._open(mode, encoding, errors) + return cm + + def open_bytes(self, mode='rb'): + # type: (Text) -> ContextManager[BinaryIO] + """Return a context manager for opening the file in binary mode. + + Args: + mode: The mode to open the file in. The "b" mode is implicit if not + already present. It must not have the "t" flag. + + Returns: + Context manager that yields an open file. + + Raises: + ValueError: if invalid inputs are provided. + """ + if 't' in mode: + raise ValueError('Invalid mode {!r}: "t" flag not allowed when opening ' + 'file in binary mode'.format(mode)) + if 'b' not in mode: + mode += 'b' + cm = self._open(mode, encoding=None, errors=None) + return cm + + # TODO(b/123775699): Once pytype supports typing.Literal, use overload and + # Literal to express more precise return types. The contained type is + # currently `Any` to avoid [bad-return-type] errors in the open_* methods. + @contextlib.contextmanager + def _open(self, mode, encoding='utf8', errors='strict'): + # type: (Text, Text, Text) -> Iterator[Any] + with io.open( + self.full_path, mode=mode, encoding=encoding, errors=errors) as fp: + yield fp + + +class _method(object): + """A decorator that supports both instance and classmethod invocations. + + Using similar semantics to the @property builtin, this decorator can augment + an instance method to support conditional logic when invoked on a class + object. This breaks support for invoking an instance method via the class + (e.g. Cls.method(self, ...)) but is still situationally useful. + """ + + def __init__(self, finstancemethod): + # type: (Callable[..., Any]) -> None + self._finstancemethod = finstancemethod + self._fclassmethod = None + + def classmethod(self, fclassmethod): + # type: (Callable[..., Any]) -> _method + self._fclassmethod = classmethod(fclassmethod) + return self + + def __doc__(self): + # type: () -> str + if getattr(self._finstancemethod, '__doc__'): + return self._finstancemethod.__doc__ + elif getattr(self._fclassmethod, '__doc__'): + return self._fclassmethod.__doc__ + return '' + + def __get__(self, obj, type_): + # type: (Optional[Any], Optional[Type[Any]]) -> Callable[..., Any] + func = self._fclassmethod if obj is None else self._finstancemethod + return func.__get__(obj, type_) # pytype: disable=attribute-error + + +class TestCase(unittest.TestCase): + """Extension of unittest.TestCase providing more power.""" + + # When to cleanup files/directories created by our `create_tempfile()` and + # `create_tempdir()` methods after each test case completes. This does *not* + # affect e.g., files created outside of those methods, e.g., using the stdlib + # tempfile module. This can be overridden at the class level, instance level, + # or with the `cleanup` arg of `create_tempfile()` and `create_tempdir()`. See + # `TempFileCleanup` for details on the different values. + # TODO(b/70517332): Remove the type comment and the disable once pytype has + # better support for enums. + tempfile_cleanup = TempFileCleanup.ALWAYS # type: TempFileCleanup # pytype: disable=annotation-type-mismatch + + maxDiff = 80 * 20 + longMessage = True + + # Exit stacks for per-test and per-class scopes. + _exit_stack = None + _cls_exit_stack = None + + def __init__(self, *args, **kwargs): + super(TestCase, self).__init__(*args, **kwargs) + # This is to work around missing type stubs in unittest.pyi + self._outcome = getattr(self, '_outcome') # type: Optional[_OutcomeType] + + def setUp(self): + super(TestCase, self).setUp() + # NOTE: Only Python 3 contextlib has ExitStack + if hasattr(contextlib, 'ExitStack'): + self._exit_stack = contextlib.ExitStack() + self.addCleanup(self._exit_stack.close) + + @classmethod + def setUpClass(cls): + super(TestCase, cls).setUpClass() + # NOTE: Only Python 3 contextlib has ExitStack and only Python 3.8+ has + # addClassCleanup. + if hasattr(contextlib, 'ExitStack') and hasattr(cls, 'addClassCleanup'): + cls._cls_exit_stack = contextlib.ExitStack() + cls.addClassCleanup(cls._cls_exit_stack.close) + + def create_tempdir(self, name=None, cleanup=None): + # type: (Optional[Text], Optional[TempFileCleanup]) -> _TempDir + """Create a temporary directory specific to the test. + + NOTE: The directory and its contents will be recursively cleared before + creation. This ensures that there is no pre-existing state. + + This creates a named directory on disk that is isolated to this test, and + will be properly cleaned up by the test. This avoids several pitfalls of + creating temporary directories for test purposes, as well as makes it easier + to setup directories and verify their contents. For example: + + def test_foo(self): + out_dir = self.create_tempdir() + out_log = out_dir.create_file('output.log') + expected_outputs = [ + os.path.join(out_dir, 'data-0.txt'), + os.path.join(out_dir, 'data-1.txt'), + ] + code_under_test(out_dir) + self.assertTrue(os.path.exists(expected_paths[0])) + self.assertTrue(os.path.exists(expected_paths[1])) + self.assertEqual('foo', out_log.read_text()) + + See also: `create_tempfile()` for creating temporary files. + + Args: + name: Optional name of the directory. If not given, a unique + name will be generated and used. + cleanup: Optional cleanup policy on when/if to remove the directory (and + all its contents) at the end of the test. If None, then uses + `self.tempfile_cleanup`. + + Returns: + A _TempDir representing the created directory; see _TempDir class docs + for usage. + """ + test_path = self._get_tempdir_path_test() + + if name: + path = os.path.join(test_path, name) + cleanup_path = os.path.join(test_path, _get_first_part(name)) + else: + os.makedirs(test_path, exist_ok=True) + path = tempfile.mkdtemp(dir=test_path) + cleanup_path = path + + _rmtree_ignore_errors(cleanup_path) + os.makedirs(path, exist_ok=True) + + self._maybe_add_temp_path_cleanup(cleanup_path, cleanup) + + return _TempDir(path) + + # pylint: disable=line-too-long + def create_tempfile(self, file_path=None, content=None, mode='w', + encoding='utf8', errors='strict', cleanup=None): + # type: (Optional[Text], Optional[AnyStr], Text, Text, Text, Optional[TempFileCleanup]) -> _TempFile + # pylint: enable=line-too-long + """Create a temporary file specific to the test. + + This creates a named file on disk that is isolated to this test, and will + be properly cleaned up by the test. This avoids several pitfalls of + creating temporary files for test purposes, as well as makes it easier + to setup files, their data, read them back, and inspect them when + a test fails. For example: + + def test_foo(self): + output = self.create_tempfile() + code_under_test(output) + self.assertGreater(os.path.getsize(output), 0) + self.assertEqual('foo', output.read_text()) + + NOTE: This will zero-out the file. This ensures there is no pre-existing + state. + NOTE: If the file already exists, it will be made writable and overwritten. + + See also: `create_tempdir()` for creating temporary directories, and + `_TempDir.create_file` for creating files within a temporary directory. + + Args: + file_path: Optional file path for the temp file. If not given, a unique + file name will be generated and used. Slashes are allowed in the name; + any missing intermediate directories will be created. NOTE: This path is + the path that will be cleaned up, including any directories in the path, + e.g., 'foo/bar/baz.txt' will `rm -r foo`. + content: Optional string or + bytes to initially write to the file. If not + specified, then an empty file is created. + mode: Mode string to use when writing content. Only used if `content` is + non-empty. + encoding: Encoding to use when writing string content. Only used if + `content` is text. + errors: How to handle text to bytes encoding errors. Only used if + `content` is text. + cleanup: Optional cleanup policy on when/if to remove the directory (and + all its contents) at the end of the test. If None, then uses + `self.tempfile_cleanup`. + + Returns: + A _TempFile representing the created file; see _TempFile class docs for + usage. + """ + test_path = self._get_tempdir_path_test() + tf, cleanup_path = _TempFile._create(test_path, file_path, content=content, + mode=mode, encoding=encoding, + errors=errors) + self._maybe_add_temp_path_cleanup(cleanup_path, cleanup) + return tf + + @_method + def enter_context(self, manager): + # type: (ContextManager[_T]) -> _T + """Returns the CM's value after registering it with the exit stack. + + Entering a context pushes it onto a stack of contexts. When `enter_context` + is called on the test instance (e.g. `self.enter_context`), the context is + exited after the test case's tearDown call. When called on the test class + (e.g. `TestCase.enter_context`), the context is exited after the test + class's tearDownClass call. + + Contexts are are exited in the reverse order of entering. They will always + be exited, regardless of test failure/success. + + This is useful to eliminate per-test boilerplate when context managers + are used. For example, instead of decorating every test with `@mock.patch`, + simply do `self.foo = self.enter_context(mock.patch(...))' in `setUp()`. + + NOTE: The context managers will always be exited without any error + information. This is an unfortunate implementation detail due to some + internals of how unittest runs tests. + + Args: + manager: The context manager to enter. + """ + if not self._exit_stack: + raise AssertionError( + 'self._exit_stack is not set: enter_context is Py3-only; also make ' + 'sure that AbslTest.setUp() is called.') + return self._exit_stack.enter_context(manager) + + @enter_context.classmethod + def enter_context(cls, manager): # pylint: disable=no-self-argument + # type: (ContextManager[_T]) -> _T + if not cls._cls_exit_stack: + raise AssertionError( + 'cls._cls_exit_stack is not set: cls.enter_context requires ' + 'Python 3.8+; also make sure that AbslTest.setUpClass() is called.') + return cls._cls_exit_stack.enter_context(manager) + + @classmethod + def _get_tempdir_path_cls(cls): + # type: () -> Text + return os.path.join(TEST_TMPDIR.value, + cls.__qualname__.replace('__main__.', '')) + + def _get_tempdir_path_test(self): + # type: () -> Text + return os.path.join(self._get_tempdir_path_cls(), self._testMethodName) + + def _get_tempfile_cleanup(self, override): + # type: (Optional[TempFileCleanup]) -> TempFileCleanup + if override is not None: + return override + return self.tempfile_cleanup + + def _maybe_add_temp_path_cleanup(self, path, cleanup): + # type: (Text, Optional[TempFileCleanup]) -> None + cleanup = self._get_tempfile_cleanup(cleanup) + if cleanup == TempFileCleanup.OFF: + return + elif cleanup == TempFileCleanup.ALWAYS: + self.addCleanup(_rmtree_ignore_errors, path) + elif cleanup == TempFileCleanup.SUCCESS: + self._internal_cleanup_on_success(_rmtree_ignore_errors, path) + else: + raise AssertionError('Unexpected cleanup value: {}'.format(cleanup)) + + def _internal_cleanup_on_success(self, function, *args, **kwargs): + # type: (Callable[..., object], Any, Any) -> None + def _call_cleaner_on_success(*args, **kwargs): + if not self._ran_and_passed(): + return + function(*args, **kwargs) + self.addCleanup(_call_cleaner_on_success, *args, **kwargs) + + def _ran_and_passed(self): + # type: () -> bool + outcome = self._outcome + result = self.defaultTestResult() + self._feedErrorsToResult(result, outcome.errors) # pytype: disable=attribute-error + return result.wasSuccessful() + + def shortDescription(self): + # type: () -> Text + """Formats both the test method name and the first line of its docstring. + + If no docstring is given, only returns the method name. + + This method overrides unittest.TestCase.shortDescription(), which + only returns the first line of the docstring, obscuring the name + of the test upon failure. + + Returns: + desc: A short description of a test method. + """ + desc = self.id() + + # Omit the main name so that test name can be directly copy/pasted to + # the command line. + if desc.startswith('__main__.'): + desc = desc[len('__main__.'):] + + # NOTE: super() is used here instead of directly invoking + # unittest.TestCase.shortDescription(self), because of the + # following line that occurs later on: + # unittest.TestCase = TestCase + # Because of this, direct invocation of what we think is the + # superclass will actually cause infinite recursion. + doc_first_line = super(TestCase, self).shortDescription() + if doc_first_line is not None: + desc = '\n'.join((desc, doc_first_line)) + return desc + + def assertStartsWith(self, actual, expected_start, msg=None): + """Asserts that actual.startswith(expected_start) is True. + + Args: + actual: str + expected_start: str + msg: Optional message to report on failure. + """ + if not actual.startswith(expected_start): + self.fail('%r does not start with %r' % (actual, expected_start), msg) + + def assertNotStartsWith(self, actual, unexpected_start, msg=None): + """Asserts that actual.startswith(unexpected_start) is False. + + Args: + actual: str + unexpected_start: str + msg: Optional message to report on failure. + """ + if actual.startswith(unexpected_start): + self.fail('%r does start with %r' % (actual, unexpected_start), msg) + + def assertEndsWith(self, actual, expected_end, msg=None): + """Asserts that actual.endswith(expected_end) is True. + + Args: + actual: str + expected_end: str + msg: Optional message to report on failure. + """ + if not actual.endswith(expected_end): + self.fail('%r does not end with %r' % (actual, expected_end), msg) + + def assertNotEndsWith(self, actual, unexpected_end, msg=None): + """Asserts that actual.endswith(unexpected_end) is False. + + Args: + actual: str + unexpected_end: str + msg: Optional message to report on failure. + """ + if actual.endswith(unexpected_end): + self.fail('%r does end with %r' % (actual, unexpected_end), msg) + + def assertSequenceStartsWith(self, prefix, whole, msg=None): + """An equality assertion for the beginning of ordered sequences. + + If prefix is an empty sequence, it will raise an error unless whole is also + an empty sequence. + + If prefix is not a sequence, it will raise an error if the first element of + whole does not match. + + Args: + prefix: A sequence expected at the beginning of the whole parameter. + whole: The sequence in which to look for prefix. + msg: Optional message to report on failure. + """ + try: + prefix_len = len(prefix) + except (TypeError, NotImplementedError): + prefix = [prefix] + prefix_len = 1 + + try: + whole_len = len(whole) + except (TypeError, NotImplementedError): + self.fail('For whole: len(%s) is not supported, it appears to be type: ' + '%s' % (whole, type(whole)), msg) + + assert prefix_len <= whole_len, self._formatMessage( + msg, + 'Prefix length (%d) is longer than whole length (%d).' % + (prefix_len, whole_len) + ) + + if not prefix_len and whole_len: + self.fail('Prefix length is 0 but whole length is %d: %s' % + (len(whole), whole), msg) + + try: + self.assertSequenceEqual(prefix, whole[:prefix_len], msg) + except AssertionError: + self.fail('prefix: %s not found at start of whole: %s.' % + (prefix, whole), msg) + + def assertEmpty(self, container, msg=None): + """Asserts that an object has zero length. + + Args: + container: Anything that implements the collections.abc.Sized interface. + msg: Optional message to report on failure. + """ + if not isinstance(container, abc.Sized): + self.fail('Expected a Sized object, got: ' + '{!r}'.format(type(container).__name__), msg) + + # explicitly check the length since some Sized objects (e.g. numpy.ndarray) + # have strange __nonzero__/__bool__ behavior. + if len(container): # pylint: disable=g-explicit-length-test + self.fail('{!r} has length of {}.'.format(container, len(container)), msg) + + def assertNotEmpty(self, container, msg=None): + """Asserts that an object has non-zero length. + + Args: + container: Anything that implements the collections.abc.Sized interface. + msg: Optional message to report on failure. + """ + if not isinstance(container, abc.Sized): + self.fail('Expected a Sized object, got: ' + '{!r}'.format(type(container).__name__), msg) + + # explicitly check the length since some Sized objects (e.g. numpy.ndarray) + # have strange __nonzero__/__bool__ behavior. + if not len(container): # pylint: disable=g-explicit-length-test + self.fail('{!r} has length of 0.'.format(container), msg) + + def assertLen(self, container, expected_len, msg=None): + """Asserts that an object has the expected length. + + Args: + container: Anything that implements the collections.abc.Sized interface. + expected_len: The expected length of the container. + msg: Optional message to report on failure. + """ + if not isinstance(container, abc.Sized): + self.fail('Expected a Sized object, got: ' + '{!r}'.format(type(container).__name__), msg) + if len(container) != expected_len: + container_repr = unittest.util.safe_repr(container) # pytype: disable=module-attr + self.fail('{} has length of {}, expected {}.'.format( + container_repr, len(container), expected_len), msg) + + def assertSequenceAlmostEqual(self, expected_seq, actual_seq, places=None, + msg=None, delta=None): + """An approximate equality assertion for ordered sequences. + + Fail if the two sequences are unequal as determined by their value + differences rounded to the given number of decimal places (default 7) and + comparing to zero, or by comparing that the difference between each value + in the two sequences is more than the given delta. + + Note that decimal places (from zero) are usually not the same as significant + digits (measured from the most significant digit). + + If the two sequences compare equal then they will automatically compare + almost equal. + + Args: + expected_seq: A sequence containing elements we are expecting. + actual_seq: The sequence that we are testing. + places: The number of decimal places to compare. + msg: The message to be printed if the test fails. + delta: The OK difference between compared values. + """ + if len(expected_seq) != len(actual_seq): + self.fail('Sequence size mismatch: {} vs {}'.format( + len(expected_seq), len(actual_seq)), msg) + + err_list = [] + for idx, (exp_elem, act_elem) in enumerate(zip(expected_seq, actual_seq)): + try: + # assertAlmostEqual should be called with at most one of `places` and + # `delta`. However, it's okay for assertSequenceAlmostEqual to pass + # both because we want the latter to fail if the former does. + # pytype: disable=wrong-keyword-args + self.assertAlmostEqual(exp_elem, act_elem, places=places, msg=msg, + delta=delta) + # pytype: enable=wrong-keyword-args + except self.failureException as err: + err_list.append('At index {}: {}'.format(idx, err)) + + if err_list: + if len(err_list) > 30: + err_list = err_list[:30] + ['...'] + msg = self._formatMessage(msg, '\n'.join(err_list)) + self.fail(msg) + + def assertContainsSubset(self, expected_subset, actual_set, msg=None): + """Checks whether actual iterable is a superset of expected iterable.""" + missing = set(expected_subset) - set(actual_set) + if not missing: + return + + self.fail('Missing elements %s\nExpected: %s\nActual: %s' % ( + missing, expected_subset, actual_set), msg) + + def assertNoCommonElements(self, expected_seq, actual_seq, msg=None): + """Checks whether actual iterable and expected iterable are disjoint.""" + common = set(expected_seq) & set(actual_seq) + if not common: + return + + self.fail('Common elements %s\nExpected: %s\nActual: %s' % ( + common, expected_seq, actual_seq), msg) + + def assertItemsEqual(self, expected_seq, actual_seq, msg=None): + """Deprecated, please use assertCountEqual instead. + + This is equivalent to assertCountEqual. + + Args: + expected_seq: A sequence containing elements we are expecting. + actual_seq: The sequence that we are testing. + msg: The message to be printed if the test fails. + """ + super().assertCountEqual(expected_seq, actual_seq, msg) + + def assertSameElements(self, expected_seq, actual_seq, msg=None): + """Asserts that two sequences have the same elements (in any order). + + This method, unlike assertCountEqual, doesn't care about any + duplicates in the expected and actual sequences. + + >> assertSameElements([1, 1, 1, 0, 0, 0], [0, 1]) + # Doesn't raise an AssertionError + + If possible, you should use assertCountEqual instead of + assertSameElements. + + Args: + expected_seq: A sequence containing elements we are expecting. + actual_seq: The sequence that we are testing. + msg: The message to be printed if the test fails. + """ + # `unittest2.TestCase` used to have assertSameElements, but it was + # removed in favor of assertItemsEqual. As there's a unit test + # that explicitly checks this behavior, I am leaving this method + # alone. + # Fail on strings: empirically, passing strings to this test method + # is almost always a bug. If comparing the character sets of two strings + # is desired, cast the inputs to sets or lists explicitly. + if (isinstance(expected_seq, _TEXT_OR_BINARY_TYPES) or + isinstance(actual_seq, _TEXT_OR_BINARY_TYPES)): + self.fail('Passing string/bytes to assertSameElements is usually a bug. ' + 'Did you mean to use assertEqual?\n' + 'Expected: %s\nActual: %s' % (expected_seq, actual_seq)) + try: + expected = dict([(element, None) for element in expected_seq]) + actual = dict([(element, None) for element in actual_seq]) + missing = [element for element in expected if element not in actual] + unexpected = [element for element in actual if element not in expected] + missing.sort() + unexpected.sort() + except TypeError: + # Fall back to slower list-compare if any of the objects are + # not hashable. + expected = list(expected_seq) + actual = list(actual_seq) + expected.sort() + actual.sort() + missing, unexpected = _sorted_list_difference(expected, actual) + errors = [] + if msg: + errors.extend((msg, ':\n')) + if missing: + errors.append('Expected, but missing:\n %r\n' % missing) + if unexpected: + errors.append('Unexpected, but present:\n %r\n' % unexpected) + if missing or unexpected: + self.fail(''.join(errors)) + + # unittest.TestCase.assertMultiLineEqual works very similarly, but it + # has a different error format. However, I find this slightly more readable. + def assertMultiLineEqual(self, first, second, msg=None, **kwargs): + """Asserts that two multi-line strings are equal.""" + assert isinstance(first, + str), ('First argument is not a string: %r' % (first,)) + assert isinstance(second, + str), ('Second argument is not a string: %r' % (second,)) + line_limit = kwargs.pop('line_limit', 0) + if kwargs: + raise TypeError('Unexpected keyword args {}'.format(tuple(kwargs))) + + if first == second: + return + if msg: + failure_message = [msg + ':\n'] + else: + failure_message = ['\n'] + if line_limit: + line_limit += len(failure_message) + for line in difflib.ndiff(first.splitlines(True), second.splitlines(True)): + failure_message.append(line) + if not line.endswith('\n'): + failure_message.append('\n') + if line_limit and len(failure_message) > line_limit: + n_omitted = len(failure_message) - line_limit + failure_message = failure_message[:line_limit] + failure_message.append( + '(... and {} more delta lines omitted for brevity.)\n'.format( + n_omitted)) + + raise self.failureException(''.join(failure_message)) + + def assertBetween(self, value, minv, maxv, msg=None): + """Asserts that value is between minv and maxv (inclusive).""" + msg = self._formatMessage(msg, + '"%r" unexpectedly not between "%r" and "%r"' % + (value, minv, maxv)) + self.assertTrue(minv <= value, msg) + self.assertTrue(maxv >= value, msg) + + def assertRegexMatch(self, actual_str, regexes, message=None): + r"""Asserts that at least one regex in regexes matches str. + + If possible you should use `assertRegex`, which is a simpler + version of this method. `assertRegex` takes a single regular + expression (a string or re compiled object) instead of a list. + + Notes: + 1. This function uses substring matching, i.e. the matching + succeeds if *any* substring of the error message matches *any* + regex in the list. This is more convenient for the user than + full-string matching. + + 2. If regexes is the empty list, the matching will always fail. + + 3. Use regexes=[''] for a regex that will always pass. + + 4. '.' matches any single character *except* the newline. To + match any character, use '(.|\n)'. + + 5. '^' matches the beginning of each line, not just the beginning + of the string. Similarly, '$' matches the end of each line. + + 6. An exception will be thrown if regexes contains an invalid + regex. + + Args: + actual_str: The string we try to match with the items in regexes. + regexes: The regular expressions we want to match against str. + See "Notes" above for detailed notes on how this is interpreted. + message: The message to be printed if the test fails. + """ + if isinstance(regexes, _TEXT_OR_BINARY_TYPES): + self.fail('regexes is string or bytes; use assertRegex instead.', + message) + if not regexes: + self.fail('No regexes specified.', message) + + regex_type = type(regexes[0]) + for regex in regexes[1:]: + if type(regex) is not regex_type: # pylint: disable=unidiomatic-typecheck + self.fail('regexes list must all be the same type.', message) + + if regex_type is bytes and isinstance(actual_str, str): + regexes = [regex.decode('utf-8') for regex in regexes] + regex_type = str + elif regex_type is str and isinstance(actual_str, bytes): + regexes = [regex.encode('utf-8') for regex in regexes] + regex_type = bytes + + if regex_type is str: + regex = u'(?:%s)' % u')|(?:'.join(regexes) + elif regex_type is bytes: + regex = b'(?:' + (b')|(?:'.join(regexes)) + b')' + else: + self.fail('Only know how to deal with unicode str or bytes regexes.', + message) + + if not re.search(regex, actual_str, re.MULTILINE): + self.fail('"%s" does not contain any of these regexes: %s.' % + (actual_str, regexes), message) + + def assertCommandSucceeds(self, command, regexes=(b'',), env=None, + close_fds=True, msg=None): + """Asserts that a shell command succeeds (i.e. exits with code 0). + + Args: + command: List or string representing the command to run. + regexes: List of regular expression byte strings that match success. + env: Dictionary of environment variable settings. If None, no environment + variables will be set for the child process. This is to make tests + more hermetic. NOTE: this behavior is different than the standard + subprocess module. + close_fds: Whether or not to close all open fd's in the child after + forking. + msg: Optional message to report on failure. + """ + (ret_code, err) = get_command_stderr(command, env, close_fds) + + # We need bytes regexes here because `err` is bytes. + # Accommodate code which listed their output regexes w/o the b'' prefix by + # converting them to bytes for the user. + if isinstance(regexes[0], str): + regexes = [regex.encode('utf-8') for regex in regexes] + + command_string = get_command_string(command) + self.assertEqual( + ret_code, 0, + self._formatMessage(msg, + 'Running command\n' + '%s failed with error code %s and message\n' + '%s' % (_quote_long_string(command_string), + ret_code, + _quote_long_string(err))) + ) + self.assertRegexMatch( + err, + regexes, + message=self._formatMessage( + msg, + 'Running command\n' + '%s failed with error code %s and message\n' + '%s which matches no regex in %s' % ( + _quote_long_string(command_string), + ret_code, + _quote_long_string(err), + regexes))) + + def assertCommandFails(self, command, regexes, env=None, close_fds=True, + msg=None): + """Asserts a shell command fails and the error matches a regex in a list. + + Args: + command: List or string representing the command to run. + regexes: the list of regular expression strings. + env: Dictionary of environment variable settings. If None, no environment + variables will be set for the child process. This is to make tests + more hermetic. NOTE: this behavior is different than the standard + subprocess module. + close_fds: Whether or not to close all open fd's in the child after + forking. + msg: Optional message to report on failure. + """ + (ret_code, err) = get_command_stderr(command, env, close_fds) + + # We need bytes regexes here because `err` is bytes. + # Accommodate code which listed their output regexes w/o the b'' prefix by + # converting them to bytes for the user. + if isinstance(regexes[0], str): + regexes = [regex.encode('utf-8') for regex in regexes] + + command_string = get_command_string(command) + self.assertNotEqual( + ret_code, 0, + self._formatMessage(msg, 'The following command succeeded ' + 'while expected to fail:\n%s' % + _quote_long_string(command_string))) + self.assertRegexMatch( + err, + regexes, + message=self._formatMessage( + msg, + 'Running command\n' + '%s failed with error code %s and message\n' + '%s which matches no regex in %s' % ( + _quote_long_string(command_string), + ret_code, + _quote_long_string(err), + regexes))) + + class _AssertRaisesContext(object): + + def __init__(self, expected_exception, test_case, test_func, msg=None): + self.expected_exception = expected_exception + self.test_case = test_case + self.test_func = test_func + self.msg = msg + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, tb): + if exc_type is None: + self.test_case.fail(self.expected_exception.__name__ + ' not raised', + self.msg) + if not issubclass(exc_type, self.expected_exception): + return False + self.test_func(exc_value) + return True + + @typing.overload + def assertRaisesWithPredicateMatch( + self, expected_exception, predicate) -> _AssertRaisesContext: + # The purpose of this return statement is to work around + # https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored. + return self._AssertRaisesContext(None, None, None) + + @typing.overload + def assertRaisesWithPredicateMatch( + self, expected_exception, predicate, callable_obj: Callable[..., Any], + *args, **kwargs) -> None: + # The purpose of this return statement is to work around + # https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored. + return self._AssertRaisesContext(None, None, None) + + def assertRaisesWithPredicateMatch(self, expected_exception, predicate, + callable_obj=None, *args, **kwargs): + """Asserts that exception is thrown and predicate(exception) is true. + + Args: + expected_exception: Exception class expected to be raised. + predicate: Function of one argument that inspects the passed-in exception + and returns True (success) or False (please fail the test). + callable_obj: Function to be called. + *args: Extra args. + **kwargs: Extra keyword args. + + Returns: + A context manager if callable_obj is None. Otherwise, None. + + Raises: + self.failureException if callable_obj does not raise a matching exception. + """ + def Check(err): + self.assertTrue(predicate(err), + '%r does not match predicate %r' % (err, predicate)) + + context = self._AssertRaisesContext(expected_exception, self, Check) + if callable_obj is None: + return context + with context: + callable_obj(*args, **kwargs) + + @typing.overload + def assertRaisesWithLiteralMatch( + self, expected_exception, expected_exception_message + ) -> _AssertRaisesContext: + # The purpose of this return statement is to work around + # https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored. + return self._AssertRaisesContext(None, None, None) + + @typing.overload + def assertRaisesWithLiteralMatch( + self, expected_exception, expected_exception_message, + callable_obj: Callable[..., Any], *args, **kwargs) -> None: + # The purpose of this return statement is to work around + # https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored. + return self._AssertRaisesContext(None, None, None) + + def assertRaisesWithLiteralMatch(self, expected_exception, + expected_exception_message, + callable_obj=None, *args, **kwargs): + """Asserts that the message in a raised exception equals the given string. + + Unlike assertRaisesRegex, this method takes a literal string, not + a regular expression. + + with self.assertRaisesWithLiteralMatch(ExType, 'message'): + DoSomething() + + Args: + expected_exception: Exception class expected to be raised. + expected_exception_message: String message expected in the raised + exception. For a raise exception e, expected_exception_message must + equal str(e). + callable_obj: Function to be called, or None to return a context. + *args: Extra args. + **kwargs: Extra kwargs. + + Returns: + A context manager if callable_obj is None. Otherwise, None. + + Raises: + self.failureException if callable_obj does not raise a matching exception. + """ + def Check(err): + actual_exception_message = str(err) + self.assertTrue(expected_exception_message == actual_exception_message, + 'Exception message does not match.\n' + 'Expected: %r\n' + 'Actual: %r' % (expected_exception_message, + actual_exception_message)) + + context = self._AssertRaisesContext(expected_exception, self, Check) + if callable_obj is None: + return context + with context: + callable_obj(*args, **kwargs) + + def assertContainsInOrder(self, strings, target, msg=None): + """Asserts that the strings provided are found in the target in order. + + This may be useful for checking HTML output. + + Args: + strings: A list of strings, such as [ 'fox', 'dog' ] + target: A target string in which to look for the strings, such as + 'The quick brown fox jumped over the lazy dog'. + msg: Optional message to report on failure. + """ + if isinstance(strings, (bytes, unicode if str is bytes else str)): + strings = (strings,) + + current_index = 0 + last_string = None + for string in strings: + index = target.find(str(string), current_index) + if index == -1 and current_index == 0: + self.fail("Did not find '%s' in '%s'" % + (string, target), msg) + elif index == -1: + self.fail("Did not find '%s' after '%s' in '%s'" % + (string, last_string, target), msg) + last_string = string + current_index = index + + def assertContainsSubsequence(self, container, subsequence, msg=None): + """Asserts that "container" contains "subsequence" as a subsequence. + + Asserts that "container" contains all the elements of "subsequence", in + order, but possibly with other elements interspersed. For example, [1, 2, 3] + is a subsequence of [0, 0, 1, 2, 0, 3, 0] but not of [0, 0, 1, 3, 0, 2, 0]. + + Args: + container: the list we're testing for subsequence inclusion. + subsequence: the list we hope will be a subsequence of container. + msg: Optional message to report on failure. + """ + first_nonmatching = None + reversed_container = list(reversed(container)) + subsequence = list(subsequence) + + for e in subsequence: + if e not in reversed_container: + first_nonmatching = e + break + while e != reversed_container.pop(): + pass + + if first_nonmatching is not None: + self.fail('%s not a subsequence of %s. First non-matching element: %s' % + (subsequence, container, first_nonmatching), msg) + + def assertContainsExactSubsequence(self, container, subsequence, msg=None): + """Asserts that "container" contains "subsequence" as an exact subsequence. + + Asserts that "container" contains all the elements of "subsequence", in + order, and without other elements interspersed. For example, [1, 2, 3] is an + exact subsequence of [0, 0, 1, 2, 3, 0] but not of [0, 0, 1, 2, 0, 3, 0]. + + Args: + container: the list we're testing for subsequence inclusion. + subsequence: the list we hope will be an exact subsequence of container. + msg: Optional message to report on failure. + """ + container = list(container) + subsequence = list(subsequence) + longest_match = 0 + + for start in range(1 + len(container) - len(subsequence)): + if longest_match == len(subsequence): + break + index = 0 + while (index < len(subsequence) and + subsequence[index] == container[start + index]): + index += 1 + longest_match = max(longest_match, index) + + if longest_match < len(subsequence): + self.fail('%s not an exact subsequence of %s. ' + 'Longest matching prefix: %s' % + (subsequence, container, subsequence[:longest_match]), msg) + + def assertTotallyOrdered(self, *groups, **kwargs): + """Asserts that total ordering has been implemented correctly. + + For example, say you have a class A that compares only on its attribute x. + Comparators other than __lt__ are omitted for brevity. + + class A(object): + def __init__(self, x, y): + self.x = x + self.y = y + + def __hash__(self): + return hash(self.x) + + def __lt__(self, other): + try: + return self.x < other.x + except AttributeError: + return NotImplemented + + assertTotallyOrdered will check that instances can be ordered correctly. + For example, + + self.assertTotallyOrdered( + [None], # None should come before everything else. + [1], # Integers sort earlier. + [A(1, 'a')], + [A(2, 'b')], # 2 is after 1. + [A(3, 'c'), A(3, 'd')], # The second argument is irrelevant. + [A(4, 'z')], + ['foo']) # Strings sort last. + + Args: + *groups: A list of groups of elements. Each group of elements is a list + of objects that are equal. The elements in each group must be less + than the elements in the group after it. For example, these groups are + totally ordered: [None], [1], [2, 2], [3]. + **kwargs: optional msg keyword argument can be passed. + """ + + def CheckOrder(small, big): + """Ensures small is ordered before big.""" + self.assertFalse(small == big, + self._formatMessage(msg, '%r unexpectedly equals %r' % + (small, big))) + self.assertTrue(small != big, + self._formatMessage(msg, '%r unexpectedly equals %r' % + (small, big))) + self.assertLess(small, big, msg) + self.assertFalse(big < small, + self._formatMessage(msg, + '%r unexpectedly less than %r' % + (big, small))) + self.assertLessEqual(small, big, msg) + self.assertFalse(big <= small, self._formatMessage( + '%r unexpectedly less than or equal to %r' % (big, small), msg + )) + self.assertGreater(big, small, msg) + self.assertFalse(small > big, + self._formatMessage(msg, + '%r unexpectedly greater than %r' % + (small, big))) + self.assertGreaterEqual(big, small) + self.assertFalse(small >= big, self._formatMessage( + msg, + '%r unexpectedly greater than or equal to %r' % (small, big))) + + def CheckEqual(a, b): + """Ensures that a and b are equal.""" + self.assertEqual(a, b, msg) + self.assertFalse(a != b, + self._formatMessage(msg, '%r unexpectedly unequals %r' % + (a, b))) + + # Objects that compare equal must hash to the same value, but this only + # applies if both objects are hashable. + if (isinstance(a, abc.Hashable) and + isinstance(b, abc.Hashable)): + self.assertEqual( + hash(a), hash(b), + self._formatMessage( + msg, 'hash %d of %r unexpectedly not equal to hash %d of %r' % + (hash(a), a, hash(b), b))) + + self.assertFalse(a < b, + self._formatMessage(msg, + '%r unexpectedly less than %r' % + (a, b))) + self.assertFalse(b < a, + self._formatMessage(msg, + '%r unexpectedly less than %r' % + (b, a))) + self.assertLessEqual(a, b, msg) + self.assertLessEqual(b, a, msg) # pylint: disable=arguments-out-of-order + self.assertFalse(a > b, + self._formatMessage(msg, + '%r unexpectedly greater than %r' % + (a, b))) + self.assertFalse(b > a, + self._formatMessage(msg, + '%r unexpectedly greater than %r' % + (b, a))) + self.assertGreaterEqual(a, b, msg) + self.assertGreaterEqual(b, a, msg) # pylint: disable=arguments-out-of-order + + msg = kwargs.get('msg') + + # For every combination of elements, check the order of every pair of + # elements. + for elements in itertools.product(*groups): + elements = list(elements) + for index, small in enumerate(elements[:-1]): + for big in elements[index + 1:]: + CheckOrder(small, big) + + # Check that every element in each group is equal. + for group in groups: + for a in group: + CheckEqual(a, a) + for a, b in itertools.product(group, group): + CheckEqual(a, b) + + def assertDictEqual(self, a, b, msg=None): + """Raises AssertionError if a and b are not equal dictionaries. + + Args: + a: A dict, the expected value. + b: A dict, the actual value. + msg: An optional str, the associated message. + + Raises: + AssertionError: if the dictionaries are not equal. + """ + self.assertIsInstance(a, dict, self._formatMessage( + msg, + 'First argument is not a dictionary' + )) + self.assertIsInstance(b, dict, self._formatMessage( + msg, + 'Second argument is not a dictionary' + )) + + def Sorted(list_of_items): + try: + return sorted(list_of_items) # In 3.3, unordered are possible. + except TypeError: + return list_of_items + + if a == b: + return + a_items = Sorted(list(a.items())) + b_items = Sorted(list(b.items())) + + unexpected = [] + missing = [] + different = [] + + safe_repr = unittest.util.safe_repr # pytype: disable=module-attr + + def Repr(dikt): + """Deterministic repr for dict.""" + # Sort the entries based on their repr, not based on their sort order, + # which will be non-deterministic across executions, for many types. + entries = sorted((safe_repr(k), safe_repr(v)) for k, v in dikt.items()) + return '{%s}' % (', '.join('%s: %s' % pair for pair in entries)) + + message = ['%s != %s%s' % (Repr(a), Repr(b), ' (%s)' % msg if msg else '')] + + # The standard library default output confounds lexical difference with + # value difference; treat them separately. + for a_key, a_value in a_items: + if a_key not in b: + missing.append((a_key, a_value)) + elif a_value != b[a_key]: + different.append((a_key, a_value, b[a_key])) + + for b_key, b_value in b_items: + if b_key not in a: + unexpected.append((b_key, b_value)) + + if unexpected: + message.append( + 'Unexpected, but present entries:\n%s' % ''.join( + '%s: %s\n' % (safe_repr(k), safe_repr(v)) for k, v in unexpected)) + + if different: + message.append( + 'repr() of differing entries:\n%s' % ''.join( + '%s: %s != %s\n' % (safe_repr(k), safe_repr(a_value), + safe_repr(b_value)) + for k, a_value, b_value in different)) + + if missing: + message.append( + 'Missing entries:\n%s' % ''.join( + ('%s: %s\n' % (safe_repr(k), safe_repr(v)) for k, v in missing))) + + raise self.failureException('\n'.join(message)) + + def assertUrlEqual(self, a, b, msg=None): + """Asserts that urls are equal, ignoring ordering of query params.""" + parsed_a = parse.urlparse(a) + parsed_b = parse.urlparse(b) + self.assertEqual(parsed_a.scheme, parsed_b.scheme, msg) + self.assertEqual(parsed_a.netloc, parsed_b.netloc, msg) + self.assertEqual(parsed_a.path, parsed_b.path, msg) + self.assertEqual(parsed_a.fragment, parsed_b.fragment, msg) + self.assertEqual(sorted(parsed_a.params.split(';')), + sorted(parsed_b.params.split(';')), msg) + self.assertDictEqual( + parse.parse_qs(parsed_a.query, keep_blank_values=True), + parse.parse_qs(parsed_b.query, keep_blank_values=True), msg) + + def assertSameStructure(self, a, b, aname='a', bname='b', msg=None): + """Asserts that two values contain the same structural content. + + The two arguments should be data trees consisting of trees of dicts and + lists. They will be deeply compared by walking into the contents of dicts + and lists; other items will be compared using the == operator. + If the two structures differ in content, the failure message will indicate + the location within the structures where the first difference is found. + This may be helpful when comparing large structures. + + Mixed Sequence and Set types are supported. Mixed Mapping types are + supported, but the order of the keys will not be considered in the + comparison. + + Args: + a: The first structure to compare. + b: The second structure to compare. + aname: Variable name to use for the first structure in assertion messages. + bname: Variable name to use for the second structure. + msg: Additional text to include in the failure message. + """ + + # Accumulate all the problems found so we can report all of them at once + # rather than just stopping at the first + problems = [] + + _walk_structure_for_problems(a, b, aname, bname, problems) + + # Avoid spamming the user toooo much + if self.maxDiff is not None: + max_problems_to_show = self.maxDiff // 80 + if len(problems) > max_problems_to_show: + problems = problems[0:max_problems_to_show-1] + ['...'] + + if problems: + self.fail('; '.join(problems), msg) + + def assertJsonEqual(self, first, second, msg=None): + """Asserts that the JSON objects defined in two strings are equal. + + A summary of the differences will be included in the failure message + using assertSameStructure. + + Args: + first: A string containing JSON to decode and compare to second. + second: A string containing JSON to decode and compare to first. + msg: Additional text to include in the failure message. + """ + try: + first_structured = json.loads(first) + except ValueError as e: + raise ValueError(self._formatMessage( + msg, + 'could not decode first JSON value %s: %s' % (first, e))) + + try: + second_structured = json.loads(second) + except ValueError as e: + raise ValueError(self._formatMessage( + msg, + 'could not decode second JSON value %s: %s' % (second, e))) + + self.assertSameStructure(first_structured, second_structured, + aname='first', bname='second', msg=msg) + + def _getAssertEqualityFunc(self, first, second): + # type: (Any, Any) -> Callable[..., None] + try: + return super(TestCase, self)._getAssertEqualityFunc(first, second) + except AttributeError: + # This is a workaround if unittest.TestCase.__init__ was never run. + # It usually means that somebody created a subclass just for the + # assertions and has overridden __init__. "assertTrue" is a safe + # value that will not make __init__ raise a ValueError. + test_method = getattr(self, '_testMethodName', 'assertTrue') + super(TestCase, self).__init__(test_method) + + return super(TestCase, self)._getAssertEqualityFunc(first, second) + + def fail(self, msg=None, prefix=None): + """Fail immediately with the given message, optionally prefixed.""" + return super(TestCase, self).fail(self._formatMessage(prefix, msg)) + + +def _sorted_list_difference(expected, actual): + # type: (List[_T], List[_T]) -> Tuple[List[_T], List[_T]] + """Finds elements in only one or the other of two, sorted input lists. + + Returns a two-element tuple of lists. The first list contains those + elements in the "expected" list but not in the "actual" list, and the + second contains those elements in the "actual" list but not in the + "expected" list. Duplicate elements in either input list are ignored. + + Args: + expected: The list we expected. + actual: The list we actually got. + Returns: + (missing, unexpected) + missing: items in expected that are not in actual. + unexpected: items in actual that are not in expected. + """ + i = j = 0 + missing = [] + unexpected = [] + while True: + try: + e = expected[i] + a = actual[j] + if e < a: + missing.append(e) + i += 1 + while expected[i] == e: + i += 1 + elif e > a: + unexpected.append(a) + j += 1 + while actual[j] == a: + j += 1 + else: + i += 1 + try: + while expected[i] == e: + i += 1 + finally: + j += 1 + while actual[j] == a: + j += 1 + except IndexError: + missing.extend(expected[i:]) + unexpected.extend(actual[j:]) + break + return missing, unexpected + + +def _are_both_of_integer_type(a, b): + # type: (object, object) -> bool + return isinstance(a, int) and isinstance(b, int) + + +def _are_both_of_sequence_type(a, b): + # type: (object, object) -> bool + return isinstance(a, abc.Sequence) and isinstance( + b, abc.Sequence) and not isinstance( + a, _TEXT_OR_BINARY_TYPES) and not isinstance(b, _TEXT_OR_BINARY_TYPES) + + +def _are_both_of_set_type(a, b): + # type: (object, object) -> bool + return isinstance(a, abc.Set) and isinstance(b, abc.Set) + + +def _are_both_of_mapping_type(a, b): + # type: (object, object) -> bool + return isinstance(a, abc.Mapping) and isinstance( + b, abc.Mapping) + + +def _walk_structure_for_problems(a, b, aname, bname, problem_list): + """The recursive comparison behind assertSameStructure.""" + if type(a) != type(b) and not ( # pylint: disable=unidiomatic-typecheck + _are_both_of_integer_type(a, b) or _are_both_of_sequence_type(a, b) or + _are_both_of_set_type(a, b) or _are_both_of_mapping_type(a, b)): + # We do not distinguish between int and long types as 99.99% of Python 2 + # code should never care. They collapse into a single type in Python 3. + problem_list.append('%s is a %r but %s is a %r' % + (aname, type(a), bname, type(b))) + # If they have different types there's no point continuing + return + + if isinstance(a, abc.Set): + for k in a: + if k not in b: + problem_list.append( + '%s has %r but %s does not' % (aname, k, bname)) + for k in b: + if k not in a: + problem_list.append('%s lacks %r but %s has it' % (aname, k, bname)) + + # NOTE: a or b could be a defaultdict, so we must take care that the traversal + # doesn't modify the data. + elif isinstance(a, abc.Mapping): + for k in a: + if k in b: + _walk_structure_for_problems( + a[k], b[k], '%s[%r]' % (aname, k), '%s[%r]' % (bname, k), + problem_list) + else: + problem_list.append( + "%s has [%r] with value %r but it's missing in %s" % + (aname, k, a[k], bname)) + for k in b: + if k not in a: + problem_list.append( + '%s lacks [%r] but %s has it with value %r' % + (aname, k, bname, b[k])) + + # Strings/bytes are Sequences but we'll just do those with regular != + elif (isinstance(a, abc.Sequence) and + not isinstance(a, _TEXT_OR_BINARY_TYPES)): + minlen = min(len(a), len(b)) + for i in range(minlen): + _walk_structure_for_problems( + a[i], b[i], '%s[%d]' % (aname, i), '%s[%d]' % (bname, i), + problem_list) + for i in range(minlen, len(a)): + problem_list.append('%s has [%i] with value %r but %s does not' % + (aname, i, a[i], bname)) + for i in range(minlen, len(b)): + problem_list.append('%s lacks [%i] but %s has it with value %r' % + (aname, i, bname, b[i])) + + else: + if a != b: + problem_list.append('%s is %r but %s is %r' % (aname, a, bname, b)) + + +def get_command_string(command): + """Returns an escaped string that can be used as a shell command. + + Args: + command: List or string representing the command to run. + Returns: + A string suitable for use as a shell command. + """ + if isinstance(command, str): + return command + else: + if os.name == 'nt': + return ' '.join(command) + else: + # The following is identical to Python 3's shlex.quote function. + command_string = '' + for word in command: + # Single quote word, and replace each ' in word with '"'"' + command_string += "'" + word.replace("'", "'\"'\"'") + "' " + return command_string[:-1] + + +def get_command_stderr(command, env=None, close_fds=True): + """Runs the given shell command and returns a tuple. + + Args: + command: List or string representing the command to run. + env: Dictionary of environment variable settings. If None, no environment + variables will be set for the child process. This is to make tests + more hermetic. NOTE: this behavior is different than the standard + subprocess module. + close_fds: Whether or not to close all open fd's in the child after forking. + On Windows, this is ignored and close_fds is always False. + + Returns: + Tuple of (exit status, text printed to stdout and stderr by the command). + """ + if env is None: env = {} + if os.name == 'nt': + # Windows does not support setting close_fds to True while also redirecting + # standard handles. + close_fds = False + + use_shell = isinstance(command, str) + process = subprocess.Popen( + command, + close_fds=close_fds, + env=env, + shell=use_shell, + stderr=subprocess.STDOUT, + stdout=subprocess.PIPE) + output = process.communicate()[0] + exit_status = process.wait() + return (exit_status, output) + + +def _quote_long_string(s): + # type: (Union[Text, bytes, bytearray]) -> Text + """Quotes a potentially multi-line string to make the start and end obvious. + + Args: + s: A string. + + Returns: + The quoted string. + """ + if isinstance(s, (bytes, bytearray)): + try: + s = s.decode('utf-8') + except UnicodeDecodeError: + s = str(s) + return ('8<-----------\n' + + s + '\n' + + '----------->8\n') + + +def print_python_version(): + # type: () -> None + # Having this in the test output logs by default helps debugging when all + # you've got is the log and no other idea of which Python was used. + sys.stderr.write('Running tests under Python {0[0]}.{0[1]}.{0[2]}: ' + '{1}\n'.format( + sys.version_info, + sys.executable if sys.executable else 'embedded.')) + + +def main(*args, **kwargs): + # type: (Text, Any) -> None + """Executes a set of Python unit tests. + + Usually this function is called without arguments, so the + unittest.TestProgram instance will get created with the default settings, + so it will run all test methods of all TestCase classes in the __main__ + module. + + Args: + *args: Positional arguments passed through to unittest.TestProgram.__init__. + **kwargs: Keyword arguments passed through to unittest.TestProgram.__init__. + """ + print_python_version() + _run_in_app(run_tests, args, kwargs) + + +def _is_in_app_main(): + # type: () -> bool + """Returns True iff app.run is active.""" + f = sys._getframe().f_back # pylint: disable=protected-access + while f: + if f.f_code == app.run.__code__: + return True + f = f.f_back + return False + + +class _SavedFlag(object): + """Helper class for saving and restoring a flag value.""" + + def __init__(self, flag): + self.flag = flag + self.value = flag.value + self.present = flag.present + + def restore_flag(self): + self.flag.value = self.value + self.flag.present = self.present + + +def _register_sigterm_with_faulthandler(): + # type: () -> None + """Have faulthandler dump stacks on SIGTERM. Useful to diagnose timeouts.""" + if faulthandler and getattr(faulthandler, 'register', None): + # faulthandler.register is not available on Windows. + # faulthandler.enable() is already called by app.run. + try: + faulthandler.register(signal.SIGTERM, chain=True) # pytype: disable=module-attr + except Exception as e: # pylint: disable=broad-except + sys.stderr.write('faulthandler.register(SIGTERM) failed ' + '%r; ignoring.\n' % e) + + +def _run_in_app(function, args, kwargs): + # type: (Callable[..., None], Sequence[Text], Mapping[Text, Any]) -> None + """Executes a set of Python unit tests, ensuring app.run. + + This is a private function, users should call absltest.main(). + + _run_in_app calculates argv to be the command-line arguments of this program + (without the flags), sets the default of FLAGS.alsologtostderr to True, + then it calls function(argv, args, kwargs), making sure that `function' + will get called within app.run(). _run_in_app does this by checking whether + it is called by app.run(), or by calling app.run() explicitly. + + The reason why app.run has to be ensured is to make sure that + flags are parsed and stripped properly, and other initializations done by + the app module are also carried out, no matter if absltest.run() is called + from within or outside app.run(). + + If _run_in_app is called from within app.run(), then it will reparse + sys.argv and pass the result without command-line flags into the argv + argument of `function'. The reason why this parsing is needed is that + __main__.main() calls absltest.main() without passing its argv. So the + only way _run_in_app could get to know the argv without the flags is that + it reparses sys.argv. + + _run_in_app changes the default of FLAGS.alsologtostderr to True so that the + test program's stderr will contain all the log messages unless otherwise + specified on the command-line. This overrides any explicit assignment to + FLAGS.alsologtostderr by the test program prior to the call to _run_in_app() + (e.g. in __main__.main). + + Please note that _run_in_app (and the function it calls) is allowed to make + changes to kwargs. + + Args: + function: absltest.run_tests or a similar function. It will be called as + function(argv, args, kwargs) where argv is a list containing the + elements of sys.argv without the command-line flags. + args: Positional arguments passed through to unittest.TestProgram.__init__. + kwargs: Keyword arguments passed through to unittest.TestProgram.__init__. + """ + if _is_in_app_main(): + _register_sigterm_with_faulthandler() + + # Save command-line flags so the side effects of FLAGS(sys.argv) can be + # undone. + flag_objects = (FLAGS[name] for name in FLAGS) + saved_flags = dict((f.name, _SavedFlag(f)) for f in flag_objects) + + # Change the default of alsologtostderr from False to True, so the test + # programs's stderr will contain all the log messages. + # If --alsologtostderr=false is specified in the command-line, or user + # has called FLAGS.alsologtostderr = False before, then the value is kept + # False. + FLAGS.set_default('alsologtostderr', True) + # Remove it from saved flags so it doesn't get restored later. + del saved_flags['alsologtostderr'] + + # The call FLAGS(sys.argv) parses sys.argv, returns the arguments + # without the flags, and -- as a side effect -- modifies flag values in + # FLAGS. We don't want the side effect, because we don't want to + # override flag changes the program did (e.g. in __main__.main) + # after the command-line has been parsed. So we have the for loop below + # to change back flags to their old values. + argv = FLAGS(sys.argv) + for saved_flag in saved_flags.values(): + saved_flag.restore_flag() + + function(argv, args, kwargs) + else: + # Send logging to stderr. Use --alsologtostderr instead of --logtostderr + # in case tests are reading their own logs. + FLAGS.set_default('alsologtostderr', True) + + def main_function(argv): + _register_sigterm_with_faulthandler() + function(argv, args, kwargs) + + app.run(main=main_function) + + +def _is_suspicious_attribute(testCaseClass, name): + # type: (Type, Text) -> bool + """Returns True if an attribute is a method named like a test method.""" + if name.startswith('Test') and len(name) > 4 and name[4].isupper(): + attr = getattr(testCaseClass, name) + if inspect.isfunction(attr) or inspect.ismethod(attr): + args = inspect.getfullargspec(attr) + return (len(args.args) == 1 and args.args[0] == 'self' and + args.varargs is None and args.varkw is None and + not args.kwonlyargs) + return False + + +def skipThisClass(reason): + # type: (Text) -> Callable[[_T], _T] + """Skip tests in the decorated TestCase, but not any of its subclasses. + + This decorator indicates that this class should skip all its tests, but not + any of its subclasses. Useful for if you want to share testMethod or setUp + implementations between a number of concrete testcase classes. + + Example usage, showing how you can share some common test methods between + subclasses. In this example, only 'BaseTest' will be marked as skipped, and + not RealTest or SecondRealTest: + + @absltest.skipThisClass("Shared functionality") + class BaseTest(absltest.TestCase): + def test_simple_functionality(self): + self.assertEqual(self.system_under_test.method(), 1) + + class RealTest(BaseTest): + def setUp(self): + super().setUp() + self.system_under_test = MakeSystem(argument) + + def test_specific_behavior(self): + ... + + class SecondRealTest(BaseTest): + def setUp(self): + super().setUp() + self.system_under_test = MakeSystem(other_arguments) + + def test_other_behavior(self): + ... + + Args: + reason: The reason we have a skip in place. For instance: 'shared test + methods' or 'shared assertion methods'. + + Returns: + Decorator function that will cause a class to be skipped. + """ + if isinstance(reason, type): + raise TypeError('Got {!r}, expected reason as string'.format(reason)) + + def _skip_class(test_case_class): + if not issubclass(test_case_class, unittest.TestCase): + raise TypeError( + 'Decorating {!r}, expected TestCase subclass'.format(test_case_class)) + + # Only shadow the setUpClass method if it is directly defined. If it is + # in the parent class we invoke it via a super() call instead of holding + # a reference to it. + shadowed_setupclass = test_case_class.__dict__.get('setUpClass', None) + + @classmethod + def replacement_setupclass(cls, *args, **kwargs): + # Skip this class if it is the one that was decorated with @skipThisClass + if cls is test_case_class: + raise SkipTest(reason) + if shadowed_setupclass: + # Pass along `cls` so the MRO chain doesn't break. + # The original method is a `classmethod` descriptor, which can't + # be directly called, but `__func__` has the underlying function. + return shadowed_setupclass.__func__(cls, *args, **kwargs) + else: + # Because there's no setUpClass() defined directly on test_case_class, + # we call super() ourselves to continue execution of the inheritance + # chain. + return super(test_case_class, cls).setUpClass(*args, **kwargs) + + test_case_class.setUpClass = replacement_setupclass + return test_case_class + + return _skip_class + + +class TestLoader(unittest.TestLoader): + """A test loader which supports common test features. + + Supported features include: + * Banning untested methods with test-like names: methods attached to this + testCase with names starting with `Test` are ignored by the test runner, + and often represent mistakenly-omitted test cases. This loader will raise + a TypeError when attempting to load a TestCase with such methods. + * Randomization of test case execution order (optional). + """ + + _ERROR_MSG = textwrap.dedent("""Method '%s' is named like a test case but + is not one. This is often a bug. If you want it to be a test method, + name it with 'test' in lowercase. If not, rename the method to not begin + with 'Test'.""") + + def __init__(self, *args, **kwds): + super(TestLoader, self).__init__(*args, **kwds) + seed = _get_default_randomize_ordering_seed() + if seed: + self._randomize_ordering_seed = seed + self._random = random.Random(self._randomize_ordering_seed) + else: + self._randomize_ordering_seed = None + self._random = None + + def getTestCaseNames(self, testCaseClass): # pylint:disable=invalid-name + """Validates and returns a (possibly randomized) list of test case names.""" + for name in dir(testCaseClass): + if _is_suspicious_attribute(testCaseClass, name): + raise TypeError(TestLoader._ERROR_MSG % name) + names = super(TestLoader, self).getTestCaseNames(testCaseClass) + if self._randomize_ordering_seed is not None: + logging.info( + 'Randomizing test order with seed: %d', self._randomize_ordering_seed) + logging.info( + 'To reproduce this order, re-run with ' + '--test_randomize_ordering_seed=%d', self._randomize_ordering_seed) + self._random.shuffle(names) + return names + + +def get_default_xml_output_filename(): + # type: () -> Optional[Text] + if os.environ.get('XML_OUTPUT_FILE'): + return os.environ['XML_OUTPUT_FILE'] + elif os.environ.get('RUNNING_UNDER_TEST_DAEMON'): + return os.path.join(os.path.dirname(TEST_TMPDIR.value), 'test_detail.xml') + elif os.environ.get('TEST_XMLOUTPUTDIR'): + return os.path.join( + os.environ['TEST_XMLOUTPUTDIR'], + os.path.splitext(os.path.basename(sys.argv[0]))[0] + '.xml') + + +def _setup_filtering(argv): + # type: (MutableSequence[Text]) -> None + """Implements the bazel test filtering protocol. + + The following environment variable is used in this method: + + TESTBRIDGE_TEST_ONLY: string, if set, is forwarded to the unittest + framework to use as a test filter. Its value is split with shlex, then: + 1. On Python 3.6 and before, split values are passed as positional + arguments on argv. + 2. On Python 3.7+, split values are passed to unittest's `-k` flag. Tests + are matched by glob patterns or substring. See + https://docs.python.org/3/library/unittest.html#cmdoption-unittest-k + + Args: + argv: the argv to mutate in-place. + """ + test_filter = os.environ.get('TESTBRIDGE_TEST_ONLY') + if argv is None or not test_filter: + return + + filters = shlex.split(test_filter) + if sys.version_info[:2] >= (3, 7): + filters = ['-k=' + test_filter for test_filter in filters] + + argv[1:1] = filters + + +def _setup_test_runner_fail_fast(argv): + # type: (MutableSequence[Text]) -> None + """Implements the bazel test fail fast protocol. + + The following environment variable is used in this method: + + TESTBRIDGE_TEST_RUNNER_FAIL_FAST=<1|0> + + If set to 1, --failfast is passed to the unittest framework to return upon + first failure. + + Args: + argv: the argv to mutate in-place. + """ + + if argv is None: + return + + if os.environ.get('TESTBRIDGE_TEST_RUNNER_FAIL_FAST') != '1': + return + + argv[1:1] = ['--failfast'] + + +def _setup_sharding(custom_loader=None): + # type: (Optional[unittest.TestLoader]) -> unittest.TestLoader + """Implements the bazel sharding protocol. + + The following environment variables are used in this method: + + TEST_SHARD_STATUS_FILE: string, if set, points to a file. We write a blank + file to tell the test runner that this test implements the test sharding + protocol. + + TEST_TOTAL_SHARDS: int, if set, sharding is requested. + + TEST_SHARD_INDEX: int, must be set if TEST_TOTAL_SHARDS is set. Specifies + the shard index for this instance of the test process. Must satisfy: + 0 <= TEST_SHARD_INDEX < TEST_TOTAL_SHARDS. + + Args: + custom_loader: A TestLoader to be made sharded. + + Returns: + The test loader for shard-filtering or the standard test loader, depending + on the sharding environment variables. + """ + + # It may be useful to write the shard file even if the other sharding + # environment variables are not set. Test runners may use this functionality + # to query whether a test binary implements the test sharding protocol. + if 'TEST_SHARD_STATUS_FILE' in os.environ: + try: + with open(os.environ['TEST_SHARD_STATUS_FILE'], 'w') as f: + f.write('') + except IOError: + sys.stderr.write('Error opening TEST_SHARD_STATUS_FILE (%s). Exiting.' + % os.environ['TEST_SHARD_STATUS_FILE']) + sys.exit(1) + + base_loader = custom_loader or TestLoader() + if 'TEST_TOTAL_SHARDS' not in os.environ: + # Not using sharding, use the expected test loader. + return base_loader + + total_shards = int(os.environ['TEST_TOTAL_SHARDS']) + shard_index = int(os.environ['TEST_SHARD_INDEX']) + + if shard_index < 0 or shard_index >= total_shards: + sys.stderr.write('ERROR: Bad sharding values. index=%d, total=%d\n' % + (shard_index, total_shards)) + sys.exit(1) + + # Replace the original getTestCaseNames with one that returns + # the test case names for this shard. + delegate_get_names = base_loader.getTestCaseNames + + bucket_iterator = itertools.cycle(range(total_shards)) + + def getShardedTestCaseNames(testCaseClass): + filtered_names = [] + # We need to sort the list of tests in order to determine which tests this + # shard is responsible for; however, it's important to preserve the order + # returned by the base loader, e.g. in the case of randomized test ordering. + ordered_names = delegate_get_names(testCaseClass) + for testcase in sorted(ordered_names): + bucket = next(bucket_iterator) + if bucket == shard_index: + filtered_names.append(testcase) + return [x for x in ordered_names if x in filtered_names] + + base_loader.getTestCaseNames = getShardedTestCaseNames + return base_loader + + +# pylint: disable=line-too-long +def _run_and_get_tests_result(argv, args, kwargs, xml_test_runner_class): + # type: (MutableSequence[Text], Sequence[Any], MutableMapping[Text, Any], Type) -> unittest.TestResult + # pylint: enable=line-too-long + """Same as run_tests, except it returns the result instead of exiting.""" + + # The entry from kwargs overrides argv. + argv = kwargs.pop('argv', argv) + + # Set up test filtering if requested in environment. + _setup_filtering(argv) + # Set up --failfast as requested in environment + _setup_test_runner_fail_fast(argv) + + # Shard the (default or custom) loader if sharding is turned on. + kwargs['testLoader'] = _setup_sharding(kwargs.get('testLoader', None)) + + # XML file name is based upon (sorted by priority): + # --xml_output_file flag, XML_OUTPUT_FILE variable, + # TEST_XMLOUTPUTDIR variable or RUNNING_UNDER_TEST_DAEMON variable. + if not FLAGS.xml_output_file: + FLAGS.xml_output_file = get_default_xml_output_filename() + xml_output_file = FLAGS.xml_output_file + + xml_buffer = None + if xml_output_file: + xml_output_dir = os.path.dirname(xml_output_file) + if xml_output_dir and not os.path.isdir(xml_output_dir): + try: + os.makedirs(xml_output_dir) + except OSError as e: + # File exists error can occur with concurrent tests + if e.errno != errno.EEXIST: + raise + # Fail early if we can't write to the XML output file. This is so that we + # don't waste people's time running tests that will just fail anyways. + with _open(xml_output_file, 'w'): + pass + + # We can reuse testRunner if it supports XML output (e. g. by inheriting + # from xml_reporter.TextAndXMLTestRunner). Otherwise we need to use + # xml_reporter.TextAndXMLTestRunner. + if (kwargs.get('testRunner') is not None + and not hasattr(kwargs['testRunner'], 'set_default_xml_stream')): + sys.stderr.write('WARNING: XML_OUTPUT_FILE or --xml_output_file setting ' + 'overrides testRunner=%r setting (possibly from --pdb)' + % (kwargs['testRunner'])) + # Passing a class object here allows TestProgram to initialize + # instances based on its kwargs and/or parsed command-line args. + kwargs['testRunner'] = xml_test_runner_class + if kwargs.get('testRunner') is None: + kwargs['testRunner'] = xml_test_runner_class + # Use an in-memory buffer (not backed by the actual file) to store the XML + # report, because some tools modify the file (e.g., create a placeholder + # with partial information, in case the test process crashes). + xml_buffer = io.StringIO() + kwargs['testRunner'].set_default_xml_stream(xml_buffer) # pytype: disable=attribute-error + + # If we've used a seed to randomize test case ordering, we want to record it + # as a top-level attribute in the `testsuites` section of the XML output. + randomize_ordering_seed = getattr( + kwargs['testLoader'], '_randomize_ordering_seed', None) + setter = getattr(kwargs['testRunner'], 'set_testsuites_property', None) + if randomize_ordering_seed and setter: + setter('test_randomize_ordering_seed', randomize_ordering_seed) + elif kwargs.get('testRunner') is None: + kwargs['testRunner'] = _pretty_print_reporter.TextTestRunner + + if FLAGS.pdb_post_mortem: + runner = kwargs['testRunner'] + # testRunner can be a class or an instance, which must be tested for + # differently. + # Overriding testRunner isn't uncommon, so only enable the debugging + # integration if the runner claims it does; we don't want to accidentally + # clobber something on the runner. + if ((isinstance(runner, type) and + issubclass(runner, _pretty_print_reporter.TextTestRunner)) or + isinstance(runner, _pretty_print_reporter.TextTestRunner)): + runner.run_for_debugging = True + + # Make sure tmpdir exists. + if not os.path.isdir(TEST_TMPDIR.value): + try: + os.makedirs(TEST_TMPDIR.value) + except OSError as e: + # Concurrent test might have created the directory. + if e.errno != errno.EEXIST: + raise + + # Let unittest.TestProgram.__init__ do its own argv parsing, e.g. for '-v', + # on argv, which is sys.argv without the command-line flags. + kwargs['argv'] = argv + + try: + test_program = unittest.TestProgram(*args, **kwargs) + return test_program.result + finally: + if xml_buffer: + try: + with _open(xml_output_file, 'w') as f: + f.write(xml_buffer.getvalue()) + finally: + xml_buffer.close() + + +def run_tests(argv, args, kwargs): # pylint: disable=line-too-long + # type: (MutableSequence[Text], Sequence[Any], MutableMapping[Text, Any]) -> None + # pylint: enable=line-too-long + """Executes a set of Python unit tests. + + Most users should call absltest.main() instead of run_tests. + + Please note that run_tests should be called from app.run. + Calling absltest.main() would ensure that. + + Please note that run_tests is allowed to make changes to kwargs. + + Args: + argv: sys.argv with the command-line flags removed from the front, i.e. the + argv with which app.run() has called __main__.main. It is passed to + unittest.TestProgram.__init__(argv=), which does its own flag parsing. It + is ignored if kwargs contains an argv entry. + args: Positional arguments passed through to unittest.TestProgram.__init__. + kwargs: Keyword arguments passed through to unittest.TestProgram.__init__. + """ + result = _run_and_get_tests_result( + argv, args, kwargs, xml_reporter.TextAndXMLTestRunner) + sys.exit(not result.wasSuccessful()) + + +def _rmtree_ignore_errors(path): + # type: (Text) -> None + if os.path.isfile(path): + try: + os.unlink(path) + except OSError: + pass + else: + shutil.rmtree(path, ignore_errors=True) + + +def _get_first_part(path): + # type: (Text) -> Text + parts = path.split(os.sep, 1) + return parts[0] diff --git a/absl/testing/flagsaver.py b/absl/testing/flagsaver.py new file mode 100644 index 0000000..7fe95fe --- /dev/null +++ b/absl/testing/flagsaver.py @@ -0,0 +1,198 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Decorator and context manager for saving and restoring flag values. + +There are many ways to save and restore. Always use the most convenient method +for a given use case. + +Here are examples of each method. They all call do_stuff() while FLAGS.someflag +is temporarily set to 'foo'. + + from absl.testing import flagsaver + + # Use a decorator which can optionally override flags via arguments. + @flagsaver.flagsaver(someflag='foo') + def some_func(): + do_stuff() + + # Use a decorator which can optionally override flags with flagholders. + @flagsaver.flagsaver((module.FOO_FLAG, 'foo'), (other_mod.BAR_FLAG, 23)) + def some_func(): + do_stuff() + + # Use a decorator which does not override flags itself. + @flagsaver.flagsaver + def some_func(): + FLAGS.someflag = 'foo' + do_stuff() + + # Use a context manager which can optionally override flags via arguments. + with flagsaver.flagsaver(someflag='foo'): + do_stuff() + + # Save and restore the flag values yourself. + saved_flag_values = flagsaver.save_flag_values() + try: + FLAGS.someflag = 'foo' + do_stuff() + finally: + flagsaver.restore_flag_values(saved_flag_values) + +We save and restore a shallow copy of each Flag object's __dict__ attribute. +This preserves all attributes of the flag, such as whether or not it was +overridden from its default value. + +WARNING: Currently a flag that is saved and then deleted cannot be restored. An +exception will be raised. However if you *add* a flag after saving flag values, +and then restore flag values, the added flag will be deleted with no errors. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import inspect + +from absl import flags + +FLAGS = flags.FLAGS + + +def flagsaver(*args, **kwargs): + """The main flagsaver interface. See module doc for usage.""" + if not args: + return _FlagOverrider(**kwargs) + # args can be [func] if used as `@flagsaver` instead of `@flagsaver(...)` + if len(args) == 1 and callable(args[0]): + if kwargs: + raise ValueError( + "It's invalid to specify both positional and keyword parameters.") + func = args[0] + if inspect.isclass(func): + raise TypeError('@flagsaver.flagsaver cannot be applied to a class.') + return _wrap(func, {}) + # args can be a list of (FlagHolder, value) pairs. + # In which case they augment any specified kwargs. + for arg in args: + if not isinstance(arg, tuple) or len(arg) != 2: + raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,)) + holder, value = arg + if not isinstance(holder, flags.FlagHolder): + raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,)) + if holder.name in kwargs: + raise ValueError('Cannot set --%s multiple times' % holder.name) + kwargs[holder.name] = value + return _FlagOverrider(**kwargs) + + +def save_flag_values(flag_values=FLAGS): + """Returns copy of flag values as a dict. + + Args: + flag_values: FlagValues, the FlagValues instance with which the flag will + be saved. This should almost never need to be overridden. + Returns: + Dictionary mapping keys to values. Keys are flag names, values are + corresponding __dict__ members. E.g. {'key': value_dict, ...}. + """ + return {name: _copy_flag_dict(flag_values[name]) for name in flag_values} + + +def restore_flag_values(saved_flag_values, flag_values=FLAGS): + """Restores flag values based on the dictionary of flag values. + + Args: + saved_flag_values: {'flag_name': value_dict, ...} + flag_values: FlagValues, the FlagValues instance from which the flag will + be restored. This should almost never need to be overridden. + """ + new_flag_names = list(flag_values) + for name in new_flag_names: + saved = saved_flag_values.get(name) + if saved is None: + # If __dict__ was not saved delete "new" flag. + delattr(flag_values, name) + else: + if flag_values[name].value != saved['_value']: + flag_values[name].value = saved['_value'] # Ensure C++ value is set. + flag_values[name].__dict__ = saved + + +def _wrap(func, overrides): + """Creates a wrapper function that saves/restores flag values. + + Args: + func: function object - This will be called between saving flags and + restoring flags. + overrides: {str: object} - Flag names mapped to their values. These flags + will be set after saving the original flag state. + + Returns: + return value from func() + """ + @functools.wraps(func) + def _flagsaver_wrapper(*args, **kwargs): + """Wrapper function that saves and restores flags.""" + with _FlagOverrider(**overrides): + return func(*args, **kwargs) + return _flagsaver_wrapper + + +class _FlagOverrider(object): + """Overrides flags for the duration of the decorated function call. + + It also restores all original values of flags after decorated method + completes. + """ + + def __init__(self, **overrides): + self._overrides = overrides + self._saved_flag_values = None + + def __call__(self, func): + if inspect.isclass(func): + raise TypeError('flagsaver cannot be applied to a class.') + return _wrap(func, self._overrides) + + def __enter__(self): + self._saved_flag_values = save_flag_values(FLAGS) + try: + FLAGS._set_attributes(**self._overrides) + except: + # It may fail because of flag validators. + restore_flag_values(self._saved_flag_values, FLAGS) + raise + + def __exit__(self, exc_type, exc_value, traceback): + restore_flag_values(self._saved_flag_values, FLAGS) + + +def _copy_flag_dict(flag): + """Returns a copy of the flag object's __dict__. + + It's mostly a shallow copy of the __dict__, except it also does a shallow + copy of the validator list. + + Args: + flag: flags.Flag, the flag to copy. + + Returns: + A copy of the flag object's __dict__. + """ + copy = flag.__dict__.copy() + copy['_value'] = flag.value # Ensure correct restore for C++ flags. + copy['validators'] = list(flag.validators) + return copy diff --git a/absl/testing/parameterized.py b/absl/testing/parameterized.py new file mode 100644 index 0000000..ec6a529 --- /dev/null +++ b/absl/testing/parameterized.py @@ -0,0 +1,700 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Adds support for parameterized tests to Python's unittest TestCase class. + +A parameterized test is a method in a test case that is invoked with different +argument tuples. + +A simple example: + + class AdditionExample(parameterized.TestCase): + @parameterized.parameters( + (1, 2, 3), + (4, 5, 9), + (1, 1, 3)) + def testAddition(self, op1, op2, result): + self.assertEqual(result, op1 + op2) + + +Each invocation is a separate test case and properly isolated just +like a normal test method, with its own setUp/tearDown cycle. In the +example above, there are three separate testcases, one of which will +fail due to an assertion error (1 + 1 != 3). + +Parameters for individual test cases can be tuples (with positional parameters) +or dictionaries (with named parameters): + + class AdditionExample(parameterized.TestCase): + @parameterized.parameters( + {'op1': 1, 'op2': 2, 'result': 3}, + {'op1': 4, 'op2': 5, 'result': 9}, + ) + def testAddition(self, op1, op2, result): + self.assertEqual(result, op1 + op2) + +If a parameterized test fails, the error message will show the +original test name and the parameters for that test. + +The id method of the test, used internally by the unittest framework, is also +modified to show the arguments (but note that the name reported by `id()` +doesn't match the actual test name, see below). To make sure that test names +stay the same across several invocations, object representations like + + >>> class Foo(object): + ... pass + >>> repr(Foo()) + '<__main__.Foo object at 0x23d8610>' + +are turned into '<__main__.Foo>'. When selecting a subset of test cases to run +on the command-line, the test cases contain an index suffix for each argument +in the order they were passed to `parameters()` (eg. testAddition0, +testAddition1, etc.) This naming scheme is subject to change; for more reliable +and stable names, especially in test logs, use `named_parameters()` instead. + +Tests using `named_parameters()` are similar to `parameters()`, except only +tuples or dicts of args are supported. For tuples, the first parameter arg +has to be a string (or an object that returns an apt name when converted via +str()). For dicts, a value for the key 'testcase_name' must be present and must +be a string (or an object that returns an apt name when converted via str()): + + class NamedExample(parameterized.TestCase): + @parameterized.named_parameters( + ('Normal', 'aa', 'aaa', True), + ('EmptyPrefix', '', 'abc', True), + ('BothEmpty', '', '', True)) + def testStartsWith(self, prefix, string, result): + self.assertEqual(result, string.startswith(prefix)) + + class NamedExample(parameterized.TestCase): + @parameterized.named_parameters( + {'testcase_name': 'Normal', + 'result': True, 'string': 'aaa', 'prefix': 'aa'}, + {'testcase_name': 'EmptyPrefix', + 'result': True, 'string': 'abc', 'prefix': ''}, + {'testcase_name': 'BothEmpty', + 'result': True, 'string': '', 'prefix': ''}) + def testStartsWith(self, prefix, string, result): + self.assertEqual(result, string.startswith(prefix)) + +Named tests also have the benefit that they can be run individually +from the command line: + + $ testmodule.py NamedExample.testStartsWithNormal + . + -------------------------------------------------------------------- + Ran 1 test in 0.000s + + OK + +Parameterized Classes +===================== +If invocation arguments are shared across test methods in a single +TestCase class, instead of decorating all test methods +individually, the class itself can be decorated: + + @parameterized.parameters( + (1, 2, 3), + (4, 5, 9)) + class ArithmeticTest(parameterized.TestCase): + def testAdd(self, arg1, arg2, result): + self.assertEqual(arg1 + arg2, result) + + def testSubtract(self, arg1, arg2, result): + self.assertEqual(result - arg1, arg2) + +Inputs from Iterables +===================== +If parameters should be shared across several test cases, or are dynamically +created from other sources, a single non-tuple iterable can be passed into +the decorator. This iterable will be used to obtain the test cases: + + class AdditionExample(parameterized.TestCase): + @parameterized.parameters( + c.op1, c.op2, c.result for c in testcases + ) + def testAddition(self, op1, op2, result): + self.assertEqual(result, op1 + op2) + + +Single-Argument Test Methods +============================ +If a test method takes only one argument, the single arguments must not be +wrapped into a tuple: + + class NegativeNumberExample(parameterized.TestCase): + @parameterized.parameters( + -1, -3, -4, -5 + ) + def testIsNegative(self, arg): + self.assertTrue(IsNegative(arg)) + + +List/tuple as a Single Argument +=============================== +If a test method takes a single argument of a list/tuple, it must be wrapped +inside a tuple: + + class ZeroSumExample(parameterized.TestCase): + @parameterized.parameters( + ([-1, 0, 1], ), + ([-2, 0, 2], ), + ) + def testSumIsZero(self, arg): + self.assertEqual(0, sum(arg)) + + +Cartesian product of Parameter Values as Parametrized Test Cases +====================================================== +If required to test method over a cartesian product of parameters, +`parameterized.product` may be used to facilitate generation of parameters +test combinations: + + class TestModuloExample(parameterized.TestCase): + @parameterized.product( + num=[0, 20, 80], + modulo=[2, 4], + expected=[0] + ) + def testModuloResult(self, num, modulo, expected): + self.assertEqual(expected, num % modulo) + +This results in 6 test cases being created - one for each combination of the +parameters. It is also possible to supply sequences of keyword argument dicts +as elements of the cartesian product: + + @parameterized.product( + (dict(num=5, modulo=3, expected=2), + dict(num=7, modulo=4, expected=3)), + dtype=(int, float) + ) + def testModuloResult(self, num, modulo, expected, dtype): + self.assertEqual(expected, dtype(num) % modulo) + +This results in 4 test cases being created - for each of the two sets of test +data (supplied as kwarg dicts) and for each of the two data types (supplied as +a named parameter). Multiple keyword argument dicts may be supplied if required. + +Async Support +=============================== +If a test needs to call async functions, it can inherit from both +parameterized.TestCase and another TestCase that supports async calls, such +as [asynctest](https://github.com/Martiusweb/asynctest): + + import asynctest + + class AsyncExample(parameterized.TestCase, asynctest.TestCase): + @parameterized.parameters( + ('a', 1), + ('b', 2), + ) + async def testSomeAsyncFunction(self, arg, expected): + actual = await someAsyncFunction(arg) + self.assertEqual(actual, expected) +""" + +from collections import abc +import functools +import inspect +import itertools +import re +import types +import unittest + +from absl.testing import absltest + + +_ADDR_RE = re.compile(r'\<([a-zA-Z0-9_\-\.]+) object at 0x[a-fA-F0-9]+\>') +_NAMED = object() +_ARGUMENT_REPR = object() +_NAMED_DICT_KEY = 'testcase_name' + + +class NoTestsError(Exception): + """Raised when parameterized decorators do not generate any tests.""" + + +class DuplicateTestNameError(Exception): + """Raised when a parameterized test has the same test name multiple times.""" + + def __init__(self, test_class_name, new_test_name, original_test_name): + super(DuplicateTestNameError, self).__init__( + 'Duplicate parameterized test name in {}: generated test name {!r} ' + '(generated from {!r}) already exists. Consider using ' + 'named_parameters() to give your tests unique names and/or renaming ' + 'the conflicting test method.'.format( + test_class_name, new_test_name, original_test_name)) + + +def _clean_repr(obj): + return _ADDR_RE.sub(r'<\1>', repr(obj)) + + +def _non_string_or_bytes_iterable(obj): + return (isinstance(obj, abc.Iterable) and not isinstance(obj, str) and + not isinstance(obj, bytes)) + + +def _format_parameter_list(testcase_params): + if isinstance(testcase_params, abc.Mapping): + return ', '.join('%s=%s' % (argname, _clean_repr(value)) + for argname, value in testcase_params.items()) + elif _non_string_or_bytes_iterable(testcase_params): + return ', '.join(map(_clean_repr, testcase_params)) + else: + return _format_parameter_list((testcase_params,)) + + +def _async_wrapped(func): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + return await func(*args, **kwargs) + return wrapper + + +class _ParameterizedTestIter(object): + """Callable and iterable class for producing new test cases.""" + + def __init__(self, test_method, testcases, naming_type, original_name=None): + """Returns concrete test functions for a test and a list of parameters. + + The naming_type is used to determine the name of the concrete + functions as reported by the unittest framework. If naming_type is + _FIRST_ARG, the testcases must be tuples, and the first element must + have a string representation that is a valid Python identifier. + + Args: + test_method: The decorated test method. + testcases: (list of tuple/dict) A list of parameter tuples/dicts for + individual test invocations. + naming_type: The test naming type, either _NAMED or _ARGUMENT_REPR. + original_name: The original test method name. When decorated on a test + method, None is passed to __init__ and test_method.__name__ is used. + Note test_method.__name__ might be different than the original defined + test method because of the use of other decorators. A more accurate + value is set by TestGeneratorMetaclass.__new__ later. + """ + self._test_method = test_method + self.testcases = testcases + self._naming_type = naming_type + if original_name is None: + original_name = test_method.__name__ + self._original_name = original_name + self.__name__ = _ParameterizedTestIter.__name__ + + def __call__(self, *args, **kwargs): + raise RuntimeError('You appear to be running a parameterized test case ' + 'without having inherited from parameterized.' + 'TestCase. This is bad because none of ' + 'your test cases are actually being run. You may also ' + 'be using another decorator before the parameterized ' + 'one, in which case you should reverse the order.') + + def __iter__(self): + test_method = self._test_method + naming_type = self._naming_type + + def make_bound_param_test(testcase_params): + @functools.wraps(test_method) + def bound_param_test(self): + if isinstance(testcase_params, abc.Mapping): + return test_method(self, **testcase_params) + elif _non_string_or_bytes_iterable(testcase_params): + return test_method(self, *testcase_params) + else: + return test_method(self, testcase_params) + + if naming_type is _NAMED: + # Signal the metaclass that the name of the test function is unique + # and descriptive. + bound_param_test.__x_use_name__ = True + + testcase_name = None + if isinstance(testcase_params, abc.Mapping): + if _NAMED_DICT_KEY not in testcase_params: + raise RuntimeError( + 'Dict for named tests must contain key "%s"' % _NAMED_DICT_KEY) + # Create a new dict to avoid modifying the supplied testcase_params. + testcase_name = testcase_params[_NAMED_DICT_KEY] + testcase_params = { + k: v for k, v in testcase_params.items() if k != _NAMED_DICT_KEY + } + elif _non_string_or_bytes_iterable(testcase_params): + if not isinstance(testcase_params[0], str): + raise RuntimeError( + 'The first element of named test parameters is the test name ' + 'suffix and must be a string') + testcase_name = testcase_params[0] + testcase_params = testcase_params[1:] + else: + raise RuntimeError( + 'Named tests must be passed a dict or non-string iterable.') + + test_method_name = self._original_name + # Support PEP-8 underscore style for test naming if used. + if (test_method_name.startswith('test_') + and testcase_name + and not testcase_name.startswith('_')): + test_method_name += '_' + + bound_param_test.__name__ = test_method_name + str(testcase_name) + elif naming_type is _ARGUMENT_REPR: + # If it's a generator, convert it to a tuple and treat them as + # parameters. + if isinstance(testcase_params, types.GeneratorType): + testcase_params = tuple(testcase_params) + # The metaclass creates a unique, but non-descriptive method name for + # _ARGUMENT_REPR tests using an indexed suffix. + # To keep test names descriptive, only the original method name is used. + # To make sure test names are unique, we add a unique descriptive suffix + # __x_params_repr__ for every test. + params_repr = '(%s)' % (_format_parameter_list(testcase_params),) + bound_param_test.__x_params_repr__ = params_repr + else: + raise RuntimeError('%s is not a valid naming type.' % (naming_type,)) + + bound_param_test.__doc__ = '%s(%s)' % ( + bound_param_test.__name__, _format_parameter_list(testcase_params)) + if test_method.__doc__: + bound_param_test.__doc__ += '\n%s' % (test_method.__doc__,) + if inspect.iscoroutinefunction(test_method): + return _async_wrapped(bound_param_test) + return bound_param_test + + return (make_bound_param_test(c) for c in self.testcases) + + +def _modify_class(class_object, testcases, naming_type): + assert not getattr(class_object, '_test_params_reprs', None), ( + 'Cannot add parameters to %s. Either it already has parameterized ' + 'methods, or its super class is also a parameterized class.' % ( + class_object,)) + # NOTE: _test_params_repr is private to parameterized.TestCase and it's + # metaclass; do not use it outside of those classes. + class_object._test_params_reprs = test_params_reprs = {} + for name, obj in class_object.__dict__.copy().items(): + if (name.startswith(unittest.TestLoader.testMethodPrefix) + and isinstance(obj, types.FunctionType)): + delattr(class_object, name) + methods = {} + _update_class_dict_for_param_test_case( + class_object.__name__, methods, test_params_reprs, name, + _ParameterizedTestIter(obj, testcases, naming_type, name)) + for meth_name, meth in methods.items(): + setattr(class_object, meth_name, meth) + + +def _parameter_decorator(naming_type, testcases): + """Implementation of the parameterization decorators. + + Args: + naming_type: The naming type. + testcases: Testcase parameters. + + Raises: + NoTestsError: Raised when the decorator generates no tests. + + Returns: + A function for modifying the decorated object. + """ + def _apply(obj): + if isinstance(obj, type): + _modify_class(obj, testcases, naming_type) + return obj + else: + return _ParameterizedTestIter(obj, testcases, naming_type) + + if (len(testcases) == 1 and + not isinstance(testcases[0], tuple) and + not isinstance(testcases[0], abc.Mapping)): + # Support using a single non-tuple parameter as a list of test cases. + # Note that the single non-tuple parameter can't be Mapping either, which + # means a single dict parameter case. + assert _non_string_or_bytes_iterable(testcases[0]), ( + 'Single parameter argument must be a non-string non-Mapping iterable') + testcases = testcases[0] + + if not isinstance(testcases, abc.Sequence): + testcases = list(testcases) + if not testcases: + raise NoTestsError( + 'parameterized test decorators did not generate any tests. ' + 'Make sure you specify non-empty parameters, ' + 'and do not reuse generators more than once.') + + return _apply + + +def parameters(*testcases): + """A decorator for creating parameterized tests. + + See the module docstring for a usage example. + + Args: + *testcases: Parameters for the decorated method, either a single + iterable, or a list of tuples/dicts/objects (for tests with only one + argument). + + Raises: + NoTestsError: Raised when the decorator generates no tests. + + Returns: + A test generator to be handled by TestGeneratorMetaclass. + """ + return _parameter_decorator(_ARGUMENT_REPR, testcases) + + +def named_parameters(*testcases): + """A decorator for creating parameterized tests. + + See the module docstring for a usage example. For every parameter tuple + passed, the first element of the tuple should be a string and will be appended + to the name of the test method. Each parameter dict passed must have a value + for the key "testcase_name", the string representation of that value will be + appended to the name of the test method. + + Args: + *testcases: Parameters for the decorated method, either a single iterable, + or a list of tuples or dicts. + + Raises: + NoTestsError: Raised when the decorator generates no tests. + + Returns: + A test generator to be handled by TestGeneratorMetaclass. + """ + return _parameter_decorator(_NAMED, testcases) + + +def product(*kwargs_seqs, **testgrid): + """A decorator for running tests over cartesian product of parameters values. + + See the module docstring for a usage example. The test will be run for every + possible combination of the parameters. + + Args: + *kwargs_seqs: Each positional parameter is a sequence of keyword arg dicts; + every test case generated will include exactly one kwargs dict from each + positional parameter; these will then be merged to form an overall list + of arguments for the test case. + **testgrid: A mapping of parameter names and their possible values. Possible + values should given as either a list or a tuple. + + Raises: + NoTestsError: Raised when the decorator generates no tests. + + Returns: + A test generator to be handled by TestGeneratorMetaclass. + """ + + for name, values in testgrid.items(): + assert isinstance(values, (list, tuple)), ( + 'Values of {} must be given as list or tuple, found {}'.format( + name, type(values))) + + prior_arg_names = set() + for kwargs_seq in kwargs_seqs: + assert ((isinstance(kwargs_seq, (list, tuple))) and + all(isinstance(kwargs, dict) for kwargs in kwargs_seq)), ( + 'Positional parameters must be a sequence of keyword arg' + 'dicts, found {}' + .format(kwargs_seq)) + if kwargs_seq: + arg_names = set(kwargs_seq[0]) + assert all(set(kwargs) == arg_names for kwargs in kwargs_seq), ( + 'Keyword argument dicts within a single parameter must all have the ' + 'same keys, found {}'.format(kwargs_seq)) + assert not (arg_names & prior_arg_names), ( + 'Keyword argument dict sequences must all have distinct argument ' + 'names, found duplicate(s) {}' + .format(sorted(arg_names & prior_arg_names))) + prior_arg_names |= arg_names + + assert not (prior_arg_names & set(testgrid)), ( + 'Arguments supplied in kwargs dicts in positional parameters must not ' + 'overlap with arguments supplied as named parameters; found duplicate ' + 'argument(s) {}'.format(sorted(prior_arg_names & set(testgrid)))) + + # Convert testgrid into a sequence of sequences of kwargs dicts and combine + # with the positional parameters. + # So foo=[1,2], bar=[3,4] --> [[{foo: 1}, {foo: 2}], [{bar: 3, bar: 4}]] + testgrid = (tuple({k: v} for v in vs) for k, vs in testgrid.items()) + testgrid = tuple(kwargs_seqs) + tuple(testgrid) + + # Create all possible combinations of parameters as a cartesian product + # of parameter values. + testcases = [ + dict(itertools.chain.from_iterable(case.items() + for case in cases)) + for cases in itertools.product(*testgrid) + ] + return _parameter_decorator(_ARGUMENT_REPR, testcases) + + +class TestGeneratorMetaclass(type): + """Metaclass for adding tests generated by parameterized decorators.""" + + def __new__(cls, class_name, bases, dct): + # NOTE: _test_params_repr is private to parameterized.TestCase and it's + # metaclass; do not use it outside of those classes. + test_params_reprs = dct.setdefault('_test_params_reprs', {}) + for name, obj in dct.copy().items(): + if (name.startswith(unittest.TestLoader.testMethodPrefix) and + _non_string_or_bytes_iterable(obj)): + # NOTE: `obj` might not be a _ParameterizedTestIter in two cases: + # 1. a class-level iterable named test* that isn't a test, such as + # a list of something. Such attributes get deleted from the class. + # + # 2. If a decorator is applied to the parameterized test, e.g. + # @morestuff + # @parameterized.parameters(...) + # def test_foo(...): ... + # + # This is OK so long as the underlying parameterized function state + # is forwarded (e.g. using functool.wraps() and **without** + # accessing explicitly accessing the internal attributes. + if isinstance(obj, _ParameterizedTestIter): + # Update the original test method name so it's more accurate. + # The mismatch might happen when another decorator is used inside + # the parameterized decrators, and the inner decorator doesn't + # preserve its __name__. + obj._original_name = name + iterator = iter(obj) + dct.pop(name) + _update_class_dict_for_param_test_case( + class_name, dct, test_params_reprs, name, iterator) + # If the base class is a subclass of parameterized.TestCase, inherit its + # _test_params_reprs too. + for base in bases: + # Check if the base has _test_params_reprs first, then check if it's a + # subclass of parameterized.TestCase. Otherwise when this is called for + # the parameterized.TestCase definition itself, this raises because + # itself is not defined yet. This works as long as absltest.TestCase does + # not define _test_params_reprs. + base_test_params_reprs = getattr(base, '_test_params_reprs', None) + if base_test_params_reprs and issubclass(base, TestCase): + for test_method, test_method_id in base_test_params_reprs.items(): + # test_method may both exists in base and this class. + # This class's method overrides base class's. + # That's why it should only inherit it if it does not exist. + test_params_reprs.setdefault(test_method, test_method_id) + + return type.__new__(cls, class_name, bases, dct) + + +def _update_class_dict_for_param_test_case( + test_class_name, dct, test_params_reprs, name, iterator): + """Adds individual test cases to a dictionary. + + Args: + test_class_name: The name of the class tests are added to. + dct: The target dictionary. + test_params_reprs: The dictionary for mapping names to test IDs. + name: The original name of the test case. + iterator: The iterator generating the individual test cases. + + Raises: + DuplicateTestNameError: Raised when a test name occurs multiple times. + RuntimeError: If non-parameterized functions are generated. + """ + for idx, func in enumerate(iterator): + assert callable(func), 'Test generators must yield callables, got %r' % ( + func,) + if not (getattr(func, '__x_use_name__', None) or + getattr(func, '__x_params_repr__', None)): + raise RuntimeError( + '{}.{} generated a test function without using the parameterized ' + 'decorators. Only tests generated using the decorators are ' + 'supported.'.format(test_class_name, name)) + + if getattr(func, '__x_use_name__', False): + original_name = func.__name__ + new_name = original_name + else: + original_name = name + new_name = '%s%d' % (original_name, idx) + + if new_name in dct: + raise DuplicateTestNameError(test_class_name, new_name, original_name) + + dct[new_name] = func + test_params_reprs[new_name] = getattr(func, '__x_params_repr__', '') + + +class TestCase(absltest.TestCase, metaclass=TestGeneratorMetaclass): + """Base class for test cases using the parameters decorator.""" + + # visibility: private; do not call outside this class. + def _get_params_repr(self): + return self._test_params_reprs.get(self._testMethodName, '') + + def __str__(self): + params_repr = self._get_params_repr() + if params_repr: + params_repr = ' ' + params_repr + return '{}{} ({})'.format( + self._testMethodName, params_repr, + unittest.util.strclass(self.__class__)) + + def id(self): + """Returns the descriptive ID of the test. + + This is used internally by the unittesting framework to get a name + for the test to be used in reports. + + Returns: + The test id. + """ + base = super(TestCase, self).id() + params_repr = self._get_params_repr() + if params_repr: + # We include the params in the id so that, when reported in the + # test.xml file, the value is more informative than just "test_foo0". + # Use a space to separate them so that it's copy/paste friendly and + # easy to identify the actual test id. + return '{} {}'.format(base, params_repr) + else: + return base + + +# This function is kept CamelCase because it's used as a class's base class. +def CoopTestCase(other_base_class): # pylint: disable=invalid-name + """Returns a new base class with a cooperative metaclass base. + + This enables the TestCase to be used in combination + with other base classes that have custom metaclasses, such as + mox.MoxTestBase. + + Only works with metaclasses that do not override type.__new__. + + Example: + + from absl.testing import parameterized + + class ExampleTest(parameterized.CoopTestCase(OtherTestCase)): + ... + + Args: + other_base_class: (class) A test case base class. + + Returns: + A new class object. + """ + metaclass = type( + 'CoopMetaclass', + (other_base_class.__metaclass__, + TestGeneratorMetaclass), {}) + return metaclass( + 'CoopTestCase', + (other_base_class, TestCase), {}) diff --git a/absl/testing/tests/__init__.py b/absl/testing/tests/__init__.py new file mode 100644 index 0000000..a3bd1cd --- /dev/null +++ b/absl/testing/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/absl/testing/tests/absltest_env.py b/absl/testing/tests/absltest_env.py new file mode 100644 index 0000000..6c06d62 --- /dev/null +++ b/absl/testing/tests/absltest_env.py @@ -0,0 +1,30 @@ +"""Helper library to get environment variables for absltest helper binaries.""" + +import os + + +_INHERITED_ENV_KEYS = frozenset({ + # This is needed to correctly use the Python interpreter determined by + # bazel. + 'PATH', + # This is used by the random module on Windows to locate crypto + # libraries. + 'SYSTEMROOT', +}) + + +def inherited_env(): + """Returns the environment variables that should be inherited from parent. + + Reason why using an explicit list of environment variables instead of + inheriting all from parent: the absltest module itself interprets a list of + environment variables set by bazel, e.g. XML_OUTPUT_FILE, + TESTBRIDGE_TEST_ONLY. While testing absltest's own behavior, we should + remove them when invoking the helper subprocess. Using an explicit list is + safer. + """ + env = {} + for key in _INHERITED_ENV_KEYS: + if key in os.environ: + env[key] = os.environ[key] + return env diff --git a/absl/testing/tests/absltest_fail_fast_test.py b/absl/testing/tests/absltest_fail_fast_test.py new file mode 100644 index 0000000..dc967f9 --- /dev/null +++ b/absl/testing/tests/absltest_fail_fast_test.py @@ -0,0 +1,109 @@ +# Copyright 2020 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for test fail fast protocol.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import subprocess + +from absl import logging +from absl.testing import _bazelize_command +from absl.testing import absltest +from absl.testing import parameterized +from absl.testing.tests import absltest_env + + +@parameterized.named_parameters( + ('use_app_run', True), + ('no_argv', False), +) +class TestFailFastTest(parameterized.TestCase): + """Integration tests: Runs a test binary with fail fast. + + This is done by setting the fail fast environment variable + """ + + def setUp(self): + super().setUp() + self._test_name = 'absl/testing/tests/absltest_fail_fast_test_helper' + + def _run_fail_fast(self, fail_fast, use_app_run): + """Runs the py_test binary in a subprocess. + + Args: + fail_fast: string, the fail fast value. + use_app_run: bool, whether the test helper should call + `absltest.main(argv=)` inside `app.run`. + + Returns: + (stdout, exit_code) tuple of (string, int). + """ + env = absltest_env.inherited_env() + if fail_fast is not None: + env['TESTBRIDGE_TEST_RUNNER_FAIL_FAST'] = fail_fast + env['USE_APP_RUN'] = '1' if use_app_run else '0' + + proc = subprocess.Popen( + args=[_bazelize_command.get_executable_path(self._test_name)], + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True) + stdout = proc.communicate()[0] + + logging.info('output: %s', stdout) + return stdout, proc.wait() + + def test_no_fail_fast(self, use_app_run): + out, exit_code = self._run_fail_fast(None, use_app_run) + self.assertEqual(1, exit_code) + self.assertIn('class A test A', out) + self.assertIn('class A test B', out) + self.assertIn('class A test C', out) + self.assertIn('class A test D', out) + self.assertIn('class A test E', out) + + def test_empty_fail_fast(self, use_app_run): + out, exit_code = self._run_fail_fast('', use_app_run) + self.assertEqual(1, exit_code) + self.assertIn('class A test A', out) + self.assertIn('class A test B', out) + self.assertIn('class A test C', out) + self.assertIn('class A test D', out) + self.assertIn('class A test E', out) + + def test_fail_fast_1(self, use_app_run): + out, exit_code = self._run_fail_fast('1', use_app_run) + self.assertEqual(1, exit_code) + self.assertIn('class A test A', out) + self.assertIn('class A test B', out) + self.assertIn('class A test C', out) + self.assertNotIn('class A test D', out) + self.assertNotIn('class A test E', out) + + def test_fail_fast_0(self, use_app_run): + out, exit_code = self._run_fail_fast('0', use_app_run) + self.assertEqual(1, exit_code) + self.assertIn('class A test A', out) + self.assertIn('class A test B', out) + self.assertIn('class A test C', out) + self.assertIn('class A test D', out) + self.assertIn('class A test E', out) + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/testing/tests/absltest_fail_fast_test_helper.py b/absl/testing/tests/absltest_fail_fast_test_helper.py new file mode 100644 index 0000000..339a569 --- /dev/null +++ b/absl/testing/tests/absltest_fail_fast_test_helper.py @@ -0,0 +1,56 @@ +# Copyright 2020 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A helper test program for absltest_fail_fast_test.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +from absl import app +from absl.testing import absltest + + +class ClassA(absltest.TestCase): + """Helper test case A for absltest_fail_fast_test.""" + + def testA(self): + sys.stderr.write('\nclass A test A\n') + + def testB(self): + sys.stderr.write('\nclass A test B\n') + + def testC(self): + sys.stderr.write('\nclass A test C\n') + self.fail('Force failure') + + def testD(self): + sys.stderr.write('\nclass A test D\n') + + def testE(self): + sys.stderr.write('\nclass A test E\n') + + +def main(argv): + absltest.main(argv=argv) + + +if __name__ == '__main__': + if os.environ['USE_APP_RUN'] == '1': + app.run(main) + else: + absltest.main() diff --git a/absl/testing/tests/absltest_filtering_test.py b/absl/testing/tests/absltest_filtering_test.py new file mode 100644 index 0000000..30a81f6 --- /dev/null +++ b/absl/testing/tests/absltest_filtering_test.py @@ -0,0 +1,192 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for test filtering protocol.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import subprocess +import sys + +from absl import logging +from absl.testing import _bazelize_command +from absl.testing import absltest +from absl.testing import parameterized +from absl.testing.tests import absltest_env + + +@parameterized.named_parameters( + ('as_env_variable_use_app_run', True, True), + ('as_env_variable_no_argv', True, False), + ('as_commandline_args_use_app_run', False, True), + ('as_commandline_args_no_argv', False, False), +) +class TestFilteringTest(absltest.TestCase): + """Integration tests: Runs a test binary with filtering. + + This is done by either setting the filtering environment variable, or passing + the filters as command line arguments. + """ + + def setUp(self): + super().setUp() + self._test_name = 'absl/testing/tests/absltest_filtering_test_helper' + + def _run_filtered(self, test_filter, use_env_variable, use_app_run): + """Runs the py_test binary in a subprocess. + + Args: + test_filter: string, the filter argument to use. + use_env_variable: bool, pass the test filter as environment variable if + True, otherwise pass as command line arguments. + use_app_run: bool, whether the test helper should call + `absltest.main(argv=)` inside `app.run`. + + Returns: + (stdout, exit_code) tuple of (string, int). + """ + env = absltest_env.inherited_env() + env['USE_APP_RUN'] = '1' if use_app_run else '0' + additional_args = [] + if test_filter is not None: + if use_env_variable: + env['TESTBRIDGE_TEST_ONLY'] = test_filter + elif test_filter: + if sys.version_info[:2] >= (3, 7): + # The -k flags are passed as positional arguments to absl.flags. + additional_args.append('--') + additional_args.extend(['-k=' + f for f in test_filter.split(' ')]) + else: + additional_args.extend(test_filter.split(' ')) + + proc = subprocess.Popen( + args=([_bazelize_command.get_executable_path(self._test_name)] + + additional_args), + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True) + stdout = proc.communicate()[0] + + logging.info('output: %s', stdout) + return stdout, proc.wait() + + def test_no_filter(self, use_env_variable, use_app_run): + out, exit_code = self._run_filtered(None, use_env_variable, use_app_run) + self.assertEqual(1, exit_code) + self.assertIn('class B test E', out) + + def test_empty_filter(self, use_env_variable, use_app_run): + out, exit_code = self._run_filtered('', use_env_variable, use_app_run) + self.assertEqual(1, exit_code) + self.assertIn('class B test E', out) + + def test_class_filter(self, use_env_variable, use_app_run): + out, exit_code = self._run_filtered('ClassA', use_env_variable, use_app_run) + self.assertEqual(0, exit_code) + self.assertNotIn('class B', out) + + def test_method_filter(self, use_env_variable, use_app_run): + out, exit_code = self._run_filtered('ClassB.testA', use_env_variable, + use_app_run) + self.assertEqual(0, exit_code) + self.assertNotIn('class A', out) + self.assertNotIn('class B test B', out) + + out, exit_code = self._run_filtered('ClassB.testE', use_env_variable, + use_app_run) + self.assertEqual(1, exit_code) + self.assertNotIn('class A', out) + + def test_multiple_class_and_method_filter(self, use_env_variable, + use_app_run): + out, exit_code = self._run_filtered( + 'ClassA.testA ClassA.testB ClassB.testC', use_env_variable, use_app_run) + self.assertEqual(0, exit_code) + self.assertIn('class A test A', out) + self.assertIn('class A test B', out) + self.assertNotIn('class A test C', out) + self.assertIn('class B test C', out) + self.assertNotIn('class B test A', out) + + @absltest.skipIf( + sys.version_info[:2] < (3, 7), + 'Only Python 3.7+ does glob and substring matching.') + def test_substring(self, use_env_variable, use_app_run): + out, exit_code = self._run_filtered( + 'testA', use_env_variable, use_app_run) + self.assertEqual(0, exit_code) + self.assertIn('Ran 2 tests', out) + self.assertIn('ClassA.testA', out) + self.assertIn('ClassB.testA', out) + + @absltest.skipIf( + sys.version_info[:2] < (3, 7), + 'Only Python 3.7+ does glob and substring matching.') + def test_glob_pattern(self, use_env_variable, use_app_run): + out, exit_code = self._run_filtered( + '__main__.Class*.testA', use_env_variable, use_app_run) + self.assertEqual(0, exit_code) + self.assertIn('Ran 2 tests', out) + self.assertIn('ClassA.testA', out) + self.assertIn('ClassB.testA', out) + + @absltest.skipIf( + sys.version_info[:2] >= (3, 7), + "Python 3.7+ uses unittest's -k flag and doesn't fail if no tests match.") + def test_not_found_filters_py36(self, use_env_variable, use_app_run): + out, exit_code = self._run_filtered('NotExistedClass.not_existed_method', + use_env_variable, use_app_run) + self.assertEqual(1, exit_code) + self.assertIn("has no attribute 'NotExistedClass'", out) + + @absltest.skipIf( + sys.version_info[:2] < (3, 7), + 'Python 3.6 passes the filter as positional arguments and fails if no ' + 'tests match.' + ) + def test_not_found_filters_py37(self, use_env_variable, use_app_run): + out, exit_code = self._run_filtered('NotExistedClass.not_existed_method', + use_env_variable, use_app_run) + self.assertEqual(0, exit_code) + self.assertIn('Ran 0 tests', out) + + @absltest.skipIf( + sys.version_info[:2] < (3, 7), + 'Python 3.6 passes the filter as positional arguments and matches by name' + ) + def test_parameterized_unnamed(self, use_env_variable, use_app_run): + out, exit_code = self._run_filtered('ParameterizedTest.test_unnamed', + use_env_variable, use_app_run) + self.assertEqual(0, exit_code) + self.assertIn('Ran 2 tests', out) + self.assertIn('parameterized unnamed 1', out) + self.assertIn('parameterized unnamed 2', out) + + @absltest.skipIf( + sys.version_info[:2] < (3, 7), + 'Python 3.6 passes the filter as positional arguments and matches by name' + ) + def test_parameterized_named(self, use_env_variable, use_app_run): + out, exit_code = self._run_filtered('ParameterizedTest.test_named', + use_env_variable, use_app_run) + self.assertEqual(0, exit_code) + self.assertIn('Ran 2 tests', out) + self.assertIn('parameterized named 1', out) + self.assertIn('parameterized named 2', out) + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/testing/tests/absltest_filtering_test_helper.py b/absl/testing/tests/absltest_filtering_test_helper.py new file mode 100644 index 0000000..2b741ed --- /dev/null +++ b/absl/testing/tests/absltest_filtering_test_helper.py @@ -0,0 +1,85 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A helper test program for absltest_filtering_test.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +from absl import app +from absl.testing import absltest +from absl.testing import parameterized + + +class ClassA(absltest.TestCase): + """Helper test case A for absltest_filtering_test.""" + + def testA(self): + sys.stderr.write('\nclass A test A\n') + + def testB(self): + sys.stderr.write('\nclass A test B\n') + + def testC(self): + sys.stderr.write('\nclass A test C\n') + + +class ClassB(absltest.TestCase): + """Helper test case B for absltest_filtering_test.""" + + def testA(self): + sys.stderr.write('\nclass B test A\n') + + def testB(self): + sys.stderr.write('\nclass B test B\n') + + def testC(self): + sys.stderr.write('\nclass B test C\n') + + def testD(self): + sys.stderr.write('\nclass B test D\n') + + def testE(self): + sys.stderr.write('\nclass B test E\n') + self.fail('Force failure') + + +class ParameterizedTest(parameterized.TestCase): + """Helper parameterized test case for absltest_filtering_test.""" + + @parameterized.parameters([1, 2]) + def test_unnamed(self, value): + sys.stderr.write('\nparameterized unnamed %s' % value) + + @parameterized.named_parameters( + ('test1', 1), + ('test2', 2), + ) + def test_named(self, value): + sys.stderr.write('\nparameterized named %s' % value) + + +def main(argv): + absltest.main(argv=argv) + + +if __name__ == '__main__': + if os.environ['USE_APP_RUN'] == '1': + app.run(main) + else: + absltest.main() diff --git a/absl/testing/tests/absltest_py3_test.py b/absl/testing/tests/absltest_py3_test.py new file mode 100644 index 0000000..7c5f500 --- /dev/null +++ b/absl/testing/tests/absltest_py3_test.py @@ -0,0 +1,44 @@ +# Copyright 2020 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Python3-only Tests for absltest.""" + +from absl.testing import absltest + + +class GetTestCaseNamesPEP3102Test(absltest.TestCase): + """This test verifies absltest.TestLoader.GetTestCasesNames PEP3102 support. + + The test is Python3 only, as keyword only arguments are considered + syntax error in Python2. + + The rest of getTestCaseNames functionality is covered + by absltest_test.TestLoaderTest. + """ + + class Valid(absltest.TestCase): + + def testKeywordOnly(self, *, arg): + pass + + def setUp(self): + self.loader = absltest.TestLoader() + super(GetTestCaseNamesPEP3102Test, self).setUp() + + def test_PEP3102_get_test_case_names(self): + self.assertCountEqual( + self.loader.getTestCaseNames(GetTestCaseNamesPEP3102Test.Valid), + ["testKeywordOnly"]) + +if __name__ == "__main__": + absltest.main() diff --git a/absl/testing/tests/absltest_randomization_test.py b/absl/testing/tests/absltest_randomization_test.py new file mode 100644 index 0000000..75a3868 --- /dev/null +++ b/absl/testing/tests/absltest_randomization_test.py @@ -0,0 +1,154 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for test randomization.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import subprocess + +from absl import flags +from absl.testing import _bazelize_command +from absl.testing import absltest +from absl.testing import parameterized +from absl.testing.tests import absltest_env + +FLAGS = flags.FLAGS + + +class TestOrderRandomizationTest(parameterized.TestCase): + """Integration tests: Runs a py_test binary with randomization. + + This is done by setting flags and environment variables. + """ + + def setUp(self): + super(TestOrderRandomizationTest, self).setUp() + self._test_name = 'absl/testing/tests/absltest_randomization_testcase' + + def _run_test(self, extra_argv, extra_env): + """Runs the py_test binary in a subprocess, with the given args or env. + + Args: + extra_argv: extra args to pass to the test + extra_env: extra env vars to set when running the test + + Returns: + (stdout, test_cases, exit_code) tuple of (str, list of strs, int). + """ + env = absltest_env.inherited_env() + # If *this* test is being run with this flag, we don't want to + # automatically set it for all tests we run. + env.pop('TEST_RANDOMIZE_ORDERING_SEED', '') + if extra_env is not None: + env.update(extra_env) + + command = ( + [_bazelize_command.get_executable_path(self._test_name)] + extra_argv) + proc = subprocess.Popen( + args=command, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True) + + stdout, _ = proc.communicate() + + test_lines = [l for l in stdout.splitlines() if l.startswith('class ')] + return stdout, test_lines, proc.wait() + + def test_no_args(self): + output, tests, exit_code = self._run_test([], None) + self.assertEqual(0, exit_code, msg='command output: ' + output) + self.assertNotIn('Randomizing test order with seed:', output) + cases = ['class A test ' + t for t in ('A', 'B', 'C')] + self.assertEqual(cases, tests) + + @parameterized.parameters( + { + 'argv': ['--test_randomize_ordering_seed=random'], + 'env': None, + }, + { + 'argv': [], + 'env': { + 'TEST_RANDOMIZE_ORDERING_SEED': 'random', + }, + },) + def test_simple_randomization(self, argv, env): + output, tests, exit_code = self._run_test(argv, env) + self.assertEqual(0, exit_code, msg='command output: ' + output) + self.assertIn('Randomizing test order with seed: ', output) + cases = ['class A test ' + t for t in ('A', 'B', 'C')] + # This may come back in any order; we just know it'll be the same + # set of elements. + self.assertSameElements(cases, tests) + + @parameterized.parameters( + { + 'argv': ['--test_randomize_ordering_seed=1'], + 'env': None, + }, + { + 'argv': [], + 'env': { + 'TEST_RANDOMIZE_ORDERING_SEED': '1' + }, + }, + { + 'argv': [], + 'env': { + 'LATE_SET_TEST_RANDOMIZE_ORDERING_SEED': '1' + }, + }, + ) + def test_fixed_seed(self, argv, env): + output, tests, exit_code = self._run_test(argv, env) + self.assertEqual(0, exit_code, msg='command output: ' + output) + self.assertIn('Randomizing test order with seed: 1', output) + # Even though we know the seed, we need to shuffle the tests here, since + # this behaves differently in Python2 vs Python3. + shuffled_cases = ['A', 'B', 'C'] + random.Random(1).shuffle(shuffled_cases) + cases = ['class A test ' + t for t in shuffled_cases] + # We know what order this will come back for the random seed we've + # specified. + self.assertEqual(cases, tests) + + @parameterized.parameters( + { + 'argv': ['--test_randomize_ordering_seed=0'], + 'env': { + 'TEST_RANDOMIZE_ORDERING_SEED': 'random' + }, + }, + { + 'argv': [], + 'env': { + 'TEST_RANDOMIZE_ORDERING_SEED': '0' + }, + },) + def test_disabling_randomization(self, argv, env): + output, tests, exit_code = self._run_test(argv, env) + self.assertEqual(0, exit_code, msg='command output: ' + output) + self.assertNotIn('Randomizing test order with seed:', output) + cases = ['class A test ' + t for t in ('A', 'B', 'C')] + self.assertEqual(cases, tests) + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/testing/tests/absltest_randomization_testcase.py b/absl/testing/tests/absltest_randomization_testcase.py new file mode 100644 index 0000000..18b20ff --- /dev/null +++ b/absl/testing/tests/absltest_randomization_testcase.py @@ -0,0 +1,47 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Stub tests, only for use in absltest_randomization_test.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +from absl.testing import absltest + + +# This stanza exercises setting $TEST_RANDOMIZE_ORDERING_SEED *after* importing +# the absltest library. +if os.environ.get('LATE_SET_TEST_RANDOMIZE_ORDERING_SEED', ''): + os.environ['TEST_RANDOMIZE_ORDERING_SEED'] = os.environ[ + 'LATE_SET_TEST_RANDOMIZE_ORDERING_SEED'] + + +class ClassA(absltest.TestCase): + + def test_a(self): + sys.stderr.write('\nclass A test A\n') + + def test_b(self): + sys.stderr.write('\nclass A test B\n') + + def test_c(self): + sys.stderr.write('\nclass A test C\n') + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/testing/tests/absltest_sharding_test.py b/absl/testing/tests/absltest_sharding_test.py new file mode 100644 index 0000000..6411971 --- /dev/null +++ b/absl/testing/tests/absltest_sharding_test.py @@ -0,0 +1,165 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for test sharding protocol.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import subprocess + +from absl.testing import _bazelize_command +from absl.testing import absltest +from absl.testing.tests import absltest_env + + +NUM_TEST_METHODS = 8 # Hard-coded, based on absltest_sharding_test_helper.py + + +class TestShardingTest(absltest.TestCase): + """Integration tests: Runs a test binary with sharding. + + This is done by setting the sharding environment variables. + """ + + def setUp(self): + super().setUp() + self._test_name = 'absl/testing/tests/absltest_sharding_test_helper' + self._shard_file = None + + def tearDown(self): + super().tearDown() + if self._shard_file is not None and os.path.exists(self._shard_file): + os.unlink(self._shard_file) + + def _run_sharded(self, + total_shards, + shard_index, + shard_file=None, + additional_env=None): + """Runs the py_test binary in a subprocess. + + Args: + total_shards: int, the total number of shards. + shard_index: int, the shard index. + shard_file: string, if not 'None', the path to the shard file. + This method asserts it is properly created. + additional_env: Additional environment variables to be set for the py_test + binary. + + Returns: + (stdout, exit_code) tuple of (string, int). + """ + env = absltest_env.inherited_env() + if additional_env: + env.update(additional_env) + env.update({ + 'TEST_TOTAL_SHARDS': str(total_shards), + 'TEST_SHARD_INDEX': str(shard_index) + }) + if shard_file: + self._shard_file = shard_file + env['TEST_SHARD_STATUS_FILE'] = shard_file + if os.path.exists(shard_file): + os.unlink(shard_file) + + proc = subprocess.Popen( + args=[_bazelize_command.get_executable_path(self._test_name)], + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True) + stdout = proc.communicate()[0] + + if shard_file: + self.assertTrue(os.path.exists(shard_file)) + + return (stdout, proc.wait()) + + def _assert_sharding_correctness(self, total_shards): + """Assert the primary correctness and performance of sharding. + + 1. Completeness (all methods are run) + 2. Partition (each method run at most once) + 3. Balance (for performance) + + Args: + total_shards: int, total number of shards. + """ + + outerr_by_shard = [] # A list of lists of strings + combined_outerr = [] # A list of strings + exit_code_by_shard = [] # A list of ints + + for i in range(total_shards): + (out, exit_code) = self._run_sharded(total_shards, i) + method_list = [x for x in out.split('\n') if x.startswith('class')] + outerr_by_shard.append(method_list) + combined_outerr.extend(method_list) + exit_code_by_shard.append(exit_code) + + self.assertLen([x for x in exit_code_by_shard if x != 0], 1, + 'Expected exactly one failure') + + # Test completeness and partition properties. + self.assertLen(combined_outerr, NUM_TEST_METHODS, + 'Partition requirement not met') + self.assertLen(set(combined_outerr), NUM_TEST_METHODS, + 'Completeness requirement not met') + + # Test balance: + for i in range(len(outerr_by_shard)): + self.assertGreaterEqual(len(outerr_by_shard[i]), + (NUM_TEST_METHODS / total_shards) - 1, + 'Shard %d of %d out of balance' % + (i, len(outerr_by_shard))) + + def test_shard_file(self): + self._run_sharded(3, 1, os.path.join( + absltest.TEST_TMPDIR.value, 'shard_file')) + + def test_zero_shards(self): + out, exit_code = self._run_sharded(0, 0) + self.assertEqual(1, exit_code) + self.assertGreaterEqual(out.find('Bad sharding values. index=0, total=0'), + 0, 'Bad output: %s' % (out)) + + def test_with_four_shards(self): + self._assert_sharding_correctness(4) + + def test_with_one_shard(self): + self._assert_sharding_correctness(1) + + def test_with_ten_shards(self): + self._assert_sharding_correctness(10) + + def test_sharding_with_randomization(self): + # If we're both sharding *and* randomizing, we need to confirm that we + # randomize within the shard; we use two seeds to confirm we're seeing the + # same tests (sharding is consistent) in a different order. + tests_seen = [] + for seed in ('7', '17'): + out, exit_code = self._run_sharded( + 2, 0, additional_env={'TEST_RANDOMIZE_ORDERING_SEED': seed}) + self.assertEqual(0, exit_code) + tests_seen.append([x for x in out.splitlines() if x.startswith('class')]) + first_tests, second_tests = tests_seen # pylint: disable=unbalanced-tuple-unpacking + self.assertEqual(set(first_tests), set(second_tests)) + self.assertNotEqual(first_tests, second_tests) + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/testing/tests/absltest_sharding_test_helper.py b/absl/testing/tests/absltest_sharding_test_helper.py new file mode 100644 index 0000000..7b2f20e --- /dev/null +++ b/absl/testing/tests/absltest_sharding_test_helper.py @@ -0,0 +1,60 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A helper test program for absltest_sharding_test.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from absl.testing import absltest + + +class ClassA(absltest.TestCase): + """Helper test case A for absltest_sharding_test.""" + + def testA(self): + sys.stderr.write('\nclass A test A\n') + + def testB(self): + sys.stderr.write('\nclass A test B\n') + + def testC(self): + sys.stderr.write('\nclass A test C\n') + + +class ClassB(absltest.TestCase): + """Helper test case B for absltest_sharding_test.""" + + def testA(self): + sys.stderr.write('\nclass B test A\n') + + def testB(self): + sys.stderr.write('\nclass B test B\n') + + def testC(self): + sys.stderr.write('\nclass B test C\n') + + def testD(self): + sys.stderr.write('\nclass B test D\n') + + def testE(self): + sys.stderr.write('\nclass B test E\n') + self.fail('Force failure') + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/testing/tests/absltest_test.py b/absl/testing/tests/absltest_test.py new file mode 100644 index 0000000..48eeca8 --- /dev/null +++ b/absl/testing/tests/absltest_test.py @@ -0,0 +1,2374 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for absltest.""" + +import collections +import contextlib +import io +import os +import pathlib +import re +import stat +import string +import subprocess +import tempfile +import unittest + +from absl.testing import _bazelize_command +from absl.testing import absltest +from absl.testing import parameterized +from absl.testing.tests import absltest_env + + +class HelperMixin(object): + + def _get_helper_exec_path(self): + helper = 'absl/testing/tests/absltest_test_helper' + return _bazelize_command.get_executable_path(helper) + + def run_helper(self, test_id, args, env_overrides, expect_success): + env = absltest_env.inherited_env() + for key, value in env_overrides.items(): + if value is None: + if key in env: + del env[key] + else: + env[key] = value + + command = [self._get_helper_exec_path(), + '--test_id={}'.format(test_id)] + args + process = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env, + universal_newlines=True) + stdout, stderr = process.communicate() + if expect_success: + self.assertEqual( + 0, process.returncode, + 'Expected success, but failed with ' + 'stdout:\n{}\nstderr:\n{}\n'.format(stdout, stderr)) + else: + self.assertEqual( + 1, process.returncode, + 'Expected failure, but succeeded with ' + 'stdout:\n{}\nstderr:\n{}\n'.format(stdout, stderr)) + return stdout, stderr + + +class TestCaseTest(absltest.TestCase, HelperMixin): + longMessage = True + + def run_helper(self, test_id, args, env_overrides, expect_success): + return super(TestCaseTest, self).run_helper(test_id, args + ['HelperTest'], + env_overrides, expect_success) + + def test_flags_no_env_var_no_flags(self): + self.run_helper( + 1, + [], + {'TEST_RANDOM_SEED': None, + 'TEST_SRCDIR': None, + 'TEST_TMPDIR': None, + }, + expect_success=True) + + def test_flags_env_var_no_flags(self): + tmpdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value) + srcdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value) + self.run_helper( + 2, + [], + {'TEST_RANDOM_SEED': '321', + 'TEST_SRCDIR': srcdir, + 'TEST_TMPDIR': tmpdir, + 'ABSLTEST_TEST_HELPER_EXPECTED_TEST_SRCDIR': srcdir, + 'ABSLTEST_TEST_HELPER_EXPECTED_TEST_TMPDIR': tmpdir, + }, + expect_success=True) + + def test_flags_no_env_var_flags(self): + tmpdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value) + srcdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value) + self.run_helper( + 3, + ['--test_random_seed=123', '--test_srcdir={}'.format(srcdir), + '--test_tmpdir={}'.format(tmpdir)], + {'TEST_RANDOM_SEED': None, + 'TEST_SRCDIR': None, + 'TEST_TMPDIR': None, + 'ABSLTEST_TEST_HELPER_EXPECTED_TEST_SRCDIR': srcdir, + 'ABSLTEST_TEST_HELPER_EXPECTED_TEST_TMPDIR': tmpdir, + }, + expect_success=True) + + def test_flags_env_var_flags(self): + tmpdir_from_flag = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value) + srcdir_from_flag = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value) + tmpdir_from_env_var = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value) + srcdir_from_env_var = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value) + self.run_helper( + 4, + ['--test_random_seed=221', '--test_srcdir={}'.format(srcdir_from_flag), + '--test_tmpdir={}'.format(tmpdir_from_flag)], + {'TEST_RANDOM_SEED': '123', + 'TEST_SRCDIR': srcdir_from_env_var, + 'TEST_TMPDIR': tmpdir_from_env_var, + 'ABSLTEST_TEST_HELPER_EXPECTED_TEST_SRCDIR': srcdir_from_flag, + 'ABSLTEST_TEST_HELPER_EXPECTED_TEST_TMPDIR': tmpdir_from_flag, + }, + expect_success=True) + + def test_xml_output_file_from_xml_output_file_env(self): + xml_dir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value) + xml_output_file_env = os.path.join(xml_dir, 'xml_output_file.xml') + random_dir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value) + self.run_helper( + 6, + [], + {'XML_OUTPUT_FILE': xml_output_file_env, + 'RUNNING_UNDER_TEST_DAEMON': '1', + 'TEST_XMLOUTPUTDIR': random_dir, + 'ABSLTEST_TEST_HELPER_EXPECTED_XML_OUTPUT_FILE': xml_output_file_env, + }, + expect_success=True) + + def test_xml_output_file_from_daemon(self): + tmpdir = os.path.join(tempfile.mkdtemp( + dir=absltest.TEST_TMPDIR.value), 'sub_dir') + random_dir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value) + self.run_helper( + 6, + ['--test_tmpdir', tmpdir], + {'XML_OUTPUT_FILE': None, + 'RUNNING_UNDER_TEST_DAEMON': '1', + 'TEST_XMLOUTPUTDIR': random_dir, + 'ABSLTEST_TEST_HELPER_EXPECTED_XML_OUTPUT_FILE': os.path.join( + os.path.dirname(tmpdir), 'test_detail.xml'), + }, + expect_success=True) + + def test_xml_output_file_from_test_xmloutputdir_env(self): + xml_output_dir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value) + expected_xml_file = 'absltest_test_helper.xml' + self.run_helper( + 6, + [], + {'XML_OUTPUT_FILE': None, + 'RUNNING_UNDER_TEST_DAEMON': None, + 'TEST_XMLOUTPUTDIR': xml_output_dir, + 'ABSLTEST_TEST_HELPER_EXPECTED_XML_OUTPUT_FILE': os.path.join( + xml_output_dir, expected_xml_file), + }, + expect_success=True) + + def test_xml_output_file_from_flag(self): + random_dir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value) + flag_file = os.path.join( + tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value), 'output.xml') + self.run_helper( + 6, + ['--xml_output_file', flag_file], + {'XML_OUTPUT_FILE': os.path.join(random_dir, 'output.xml'), + 'RUNNING_UNDER_TEST_DAEMON': '1', + 'TEST_XMLOUTPUTDIR': random_dir, + 'ABSLTEST_TEST_HELPER_EXPECTED_XML_OUTPUT_FILE': flag_file, + }, + expect_success=True) + + def test_assert_in(self): + animals = {'monkey': 'banana', 'cow': 'grass', 'seal': 'fish'} + + self.assertIn('a', 'abc') + self.assertIn(2, [1, 2, 3]) + self.assertIn('monkey', animals) + + self.assertNotIn('d', 'abc') + self.assertNotIn(0, [1, 2, 3]) + self.assertNotIn('otter', animals) + + self.assertRaises(AssertionError, self.assertIn, 'x', 'abc') + self.assertRaises(AssertionError, self.assertIn, 4, [1, 2, 3]) + self.assertRaises(AssertionError, self.assertIn, 'elephant', animals) + + self.assertRaises(AssertionError, self.assertNotIn, 'c', 'abc') + self.assertRaises(AssertionError, self.assertNotIn, 1, [1, 2, 3]) + self.assertRaises(AssertionError, self.assertNotIn, 'cow', animals) + + @absltest.expectedFailure + def test_expected_failure(self): + self.assertEqual(1, 2) # the expected failure + + @absltest.expectedFailureIf(True, 'always true') + def test_expected_failure_if(self): + self.assertEqual(1, 2) # the expected failure + + def test_expected_failure_success(self): + _, stderr = self.run_helper(5, ['--', '-v'], {}, expect_success=False) + self.assertRegex(stderr, r'FAILED \(.*unexpected successes=1\)') + + def test_assert_equal(self): + self.assertListEqual([], []) + self.assertTupleEqual((), ()) + self.assertSequenceEqual([], ()) + + a = [0, 'a', []] + b = [] + self.assertRaises(absltest.TestCase.failureException, + self.assertListEqual, a, b) + self.assertRaises(absltest.TestCase.failureException, + self.assertListEqual, tuple(a), tuple(b)) + self.assertRaises(absltest.TestCase.failureException, + self.assertSequenceEqual, a, tuple(b)) + + b.extend(a) + self.assertListEqual(a, b) + self.assertTupleEqual(tuple(a), tuple(b)) + self.assertSequenceEqual(a, tuple(b)) + self.assertSequenceEqual(tuple(a), b) + + self.assertRaises(AssertionError, self.assertListEqual, a, tuple(b)) + self.assertRaises(AssertionError, self.assertTupleEqual, tuple(a), b) + self.assertRaises(AssertionError, self.assertListEqual, None, b) + self.assertRaises(AssertionError, self.assertTupleEqual, None, tuple(b)) + self.assertRaises(AssertionError, self.assertSequenceEqual, None, tuple(b)) + self.assertRaises(AssertionError, self.assertListEqual, 1, 1) + self.assertRaises(AssertionError, self.assertTupleEqual, 1, 1) + self.assertRaises(AssertionError, self.assertSequenceEqual, 1, 1) + + self.assertSameElements([1, 2, 3], [3, 2, 1]) + self.assertSameElements([1, 2] + [3] * 100, [1] * 100 + [2, 3]) + self.assertSameElements(['foo', 'bar', 'baz'], ['bar', 'baz', 'foo']) + self.assertRaises(AssertionError, self.assertSameElements, [10], [10, 11]) + self.assertRaises(AssertionError, self.assertSameElements, [10, 11], [10]) + + # Test that sequences of unhashable objects can be tested for sameness: + self.assertSameElements([[1, 2], [3, 4]], [[3, 4], [1, 2]]) + self.assertRaises(AssertionError, self.assertSameElements, [[1]], [[2]]) + + def test_assert_items_equal_hotfix(self): + """Confirm that http://bugs.python.org/issue14832 - b/10038517 is gone.""" + for assert_items_method in (self.assertItemsEqual, self.assertCountEqual): + with self.assertRaises(self.failureException) as error_context: + assert_items_method([4], [2]) + error_message = str(error_context.exception) + # Confirm that the bug is either no longer present in Python or that our + # assertItemsEqual patching version of the method in absltest.TestCase + # doesn't get used. + self.assertIn('First has 1, Second has 0: 4', error_message) + self.assertIn('First has 0, Second has 1: 2', error_message) + + def test_assert_dict_equal(self): + self.assertDictEqual({}, {}) + + c = {'x': 1} + d = {} + self.assertRaises(absltest.TestCase.failureException, + self.assertDictEqual, c, d) + + d.update(c) + self.assertDictEqual(c, d) + + d['x'] = 0 + self.assertRaises(absltest.TestCase.failureException, + self.assertDictEqual, c, d, 'These are unequal') + + self.assertRaises(AssertionError, self.assertDictEqual, None, d) + self.assertRaises(AssertionError, self.assertDictEqual, [], d) + self.assertRaises(AssertionError, self.assertDictEqual, 1, 1) + + try: + # Ensure we use equality as the sole measure of elements, not type, since + # that is consistent with dict equality. + self.assertDictEqual({1: 1.0, 2: 2}, {1: 1, 2: 3}) + except AssertionError as e: + self.assertMultiLineEqual('{1: 1.0, 2: 2} != {1: 1, 2: 3}\n' + 'repr() of differing entries:\n2: 2 != 3\n', + str(e)) + + try: + self.assertDictEqual({}, {'x': 1}) + except AssertionError as e: + self.assertMultiLineEqual("{} != {'x': 1}\n" + "Unexpected, but present entries:\n'x': 1\n", + str(e)) + else: + self.fail('Expecting AssertionError') + + try: + self.assertDictEqual({}, {'x': 1}, 'a message') + except AssertionError as e: + self.assertIn('a message', str(e)) + else: + self.fail('Expecting AssertionError') + + expected = {'a': 1, 'b': 2, 'c': 3} + seen = {'a': 2, 'c': 3, 'd': 4} + try: + self.assertDictEqual(expected, seen) + except AssertionError as e: + self.assertMultiLineEqual("""\ +{'a': 1, 'b': 2, 'c': 3} != {'a': 2, 'c': 3, 'd': 4} +Unexpected, but present entries: +'d': 4 + +repr() of differing entries: +'a': 1 != 2 + +Missing entries: +'b': 2 +""", str(e)) + else: + self.fail('Expecting AssertionError') + + self.assertRaises(AssertionError, self.assertDictEqual, (1, 2), {}) + self.assertRaises(AssertionError, self.assertDictEqual, {}, (1, 2)) + + # Ensure deterministic output of keys in dictionaries whose sort order + # doesn't match the lexical ordering of repr -- this is most Python objects, + # which are keyed by memory address. + class Obj(object): + + def __init__(self, name): + self.name = name + + def __repr__(self): + return self.name + + try: + self.assertDictEqual( + {'a': Obj('A'), Obj('b'): Obj('B'), Obj('c'): Obj('C')}, + {'a': Obj('A'), Obj('d'): Obj('D'), Obj('e'): Obj('E')}) + except AssertionError as e: + # Do as best we can not to be misleading when objects have the same repr + # but aren't equal. + err_str = str(e) + self.assertStartsWith(err_str, + "{'a': A, b: B, c: C} != {'a': A, d: D, e: E}\n") + self.assertRegex( + err_str, r'(?ms).*^Unexpected, but present entries:\s+' + r'^(d: D$\s+^e: E|e: E$\s+^d: D)$') + self.assertRegex( + err_str, r'(?ms).*^repr\(\) of differing entries:\s+' + r'^.a.: A != A$', err_str) + self.assertRegex( + err_str, r'(?ms).*^Missing entries:\s+' + r'^(b: B$\s+^c: C|c: C$\s+^b: B)$') + else: + self.fail('Expecting AssertionError') + + # Confirm that safe_repr, not repr, is being used. + class RaisesOnRepr(object): + + def __repr__(self): + return 1/0 # Intentionally broken __repr__ implementation. + + try: + self.assertDictEqual( + {RaisesOnRepr(): RaisesOnRepr()}, + {RaisesOnRepr(): RaisesOnRepr()} + ) + self.fail('Expected dicts not to match') + except AssertionError as e: + # Depending on the testing environment, the object may get a __main__ + # prefix or a absltest_test prefix, so strip that for comparison. + error_msg = re.sub( + r'( at 0x[^>]+)|__main__\.|absltest_test\.', '', str(e)) + self.assertRegex(error_msg, """(?m)\ +{<.*RaisesOnRepr object.*>: <.*RaisesOnRepr object.*>} != \ +{<.*RaisesOnRepr object.*>: <.*RaisesOnRepr object.*>} +Unexpected, but present entries: +<.*RaisesOnRepr object.*>: <.*RaisesOnRepr object.*> + +Missing entries: +<.*RaisesOnRepr object.*>: <.*RaisesOnRepr object.*> +""") + + # Confirm that safe_repr, not repr, is being used. + class RaisesOnLt(object): + + def __lt__(self, unused_other): + raise TypeError('Object is unordered.') + + def __repr__(self): + return '<RaisesOnLt object>' + + try: + self.assertDictEqual( + {RaisesOnLt(): RaisesOnLt()}, + {RaisesOnLt(): RaisesOnLt()}) + except AssertionError as e: + self.assertIn('Unexpected, but present entries:\n<RaisesOnLt', str(e)) + self.assertIn('Missing entries:\n<RaisesOnLt', str(e)) + + def test_assert_set_equal(self): + set1 = set() + set2 = set() + self.assertSetEqual(set1, set2) + + self.assertRaises(AssertionError, self.assertSetEqual, None, set2) + self.assertRaises(AssertionError, self.assertSetEqual, [], set2) + self.assertRaises(AssertionError, self.assertSetEqual, set1, None) + self.assertRaises(AssertionError, self.assertSetEqual, set1, []) + + set1 = set(['a']) + set2 = set() + self.assertRaises(AssertionError, self.assertSetEqual, set1, set2) + + set1 = set(['a']) + set2 = set(['a']) + self.assertSetEqual(set1, set2) + + set1 = set(['a']) + set2 = set(['a', 'b']) + self.assertRaises(AssertionError, self.assertSetEqual, set1, set2) + + set1 = set(['a']) + set2 = frozenset(['a', 'b']) + self.assertRaises(AssertionError, self.assertSetEqual, set1, set2) + + set1 = set(['a', 'b']) + set2 = frozenset(['a', 'b']) + self.assertSetEqual(set1, set2) + + set1 = set() + set2 = 'foo' + self.assertRaises(AssertionError, self.assertSetEqual, set1, set2) + self.assertRaises(AssertionError, self.assertSetEqual, set2, set1) + + # make sure any string formatting is tuple-safe + set1 = set([(0, 1), (2, 3)]) + set2 = set([(4, 5)]) + self.assertRaises(AssertionError, self.assertSetEqual, set1, set2) + + def test_assert_dict_contains_subset(self): + self.assertDictContainsSubset({}, {}) + + self.assertDictContainsSubset({}, {'a': 1}) + + self.assertDictContainsSubset({'a': 1}, {'a': 1}) + + self.assertDictContainsSubset({'a': 1}, {'a': 1, 'b': 2}) + + self.assertDictContainsSubset({'a': 1, 'b': 2}, {'a': 1, 'b': 2}) + + self.assertRaises(absltest.TestCase.failureException, + self.assertDictContainsSubset, {'a': 2}, {'a': 1}, + '.*Mismatched values:.*') + + self.assertRaises(absltest.TestCase.failureException, + self.assertDictContainsSubset, {'c': 1}, {'a': 1}, + '.*Missing:.*') + + self.assertRaises(absltest.TestCase.failureException, + self.assertDictContainsSubset, {'a': 1, 'c': 1}, {'a': 1}, + '.*Missing:.*') + + self.assertRaises(absltest.TestCase.failureException, + self.assertDictContainsSubset, {'a': 1, 'c': 1}, {'a': 1}, + '.*Missing:.*Mismatched values:.*') + + def test_assert_sequence_almost_equal(self): + actual = (1.1, 1.2, 1.4) + + # Test across sequence types. + self.assertSequenceAlmostEqual((1.1, 1.2, 1.4), actual) + self.assertSequenceAlmostEqual([1.1, 1.2, 1.4], actual) + + # Test sequence size mismatch. + with self.assertRaises(AssertionError): + self.assertSequenceAlmostEqual([1.1, 1.2], actual) + with self.assertRaises(AssertionError): + self.assertSequenceAlmostEqual([1.1, 1.2, 1.4, 1.5], actual) + + # Test delta. + with self.assertRaises(AssertionError): + self.assertSequenceAlmostEqual((1.15, 1.15, 1.4), actual) + self.assertSequenceAlmostEqual((1.15, 1.15, 1.4), actual, delta=0.1) + + # Test places. + with self.assertRaises(AssertionError): + self.assertSequenceAlmostEqual((1.1001, 1.2001, 1.3999), actual) + self.assertSequenceAlmostEqual((1.1001, 1.2001, 1.3999), actual, places=3) + + def test_assert_contains_subset(self): + # sets, lists, tuples, dicts all ok. Types of set and subset do not have to + # match. + actual = ('a', 'b', 'c') + self.assertContainsSubset({'a', 'b'}, actual) + self.assertContainsSubset(('b', 'c'), actual) + self.assertContainsSubset({'b': 1, 'c': 2}, list(actual)) + self.assertContainsSubset(['c', 'a'], set(actual)) + self.assertContainsSubset([], set()) + self.assertContainsSubset([], {'a': 1}) + + self.assertRaises(AssertionError, self.assertContainsSubset, ('d',), actual) + self.assertRaises(AssertionError, self.assertContainsSubset, ['d'], + set(actual)) + self.assertRaises(AssertionError, self.assertContainsSubset, {'a': 1}, []) + + with self.assertRaisesRegex(AssertionError, 'Missing elements'): + self.assertContainsSubset({1, 2, 3}, {1, 2}) + + with self.assertRaisesRegex( + AssertionError, + re.compile('Missing elements .* Custom message', re.DOTALL)): + self.assertContainsSubset({1, 2}, {1}, 'Custom message') + + def test_assert_no_common_elements(self): + actual = ('a', 'b', 'c') + self.assertNoCommonElements((), actual) + self.assertNoCommonElements(('d', 'e'), actual) + self.assertNoCommonElements({'d', 'e'}, actual) + + with self.assertRaisesRegex( + AssertionError, + re.compile('Common elements .* Custom message', re.DOTALL)): + self.assertNoCommonElements({1, 2}, {1}, 'Custom message') + + with self.assertRaises(AssertionError): + self.assertNoCommonElements(['a'], actual) + + with self.assertRaises(AssertionError): + self.assertNoCommonElements({'a', 'b', 'c'}, actual) + + with self.assertRaises(AssertionError): + self.assertNoCommonElements({'b', 'c'}, set(actual)) + + def test_assert_almost_equal(self): + self.assertAlmostEqual(1.00000001, 1.0) + self.assertNotAlmostEqual(1.0000001, 1.0) + + def test_assert_almost_equals_with_delta(self): + self.assertAlmostEqual(3.14, 3, delta=0.2) + self.assertAlmostEqual(2.81, 3.14, delta=1) + self.assertAlmostEqual(-1, 1, delta=3) + self.assertRaises(AssertionError, self.assertAlmostEqual, + 3.14, 2.81, delta=0.1) + self.assertRaises(AssertionError, self.assertAlmostEqual, + 1, 2, delta=0.5) + self.assertNotAlmostEqual(3.14, 2.81, delta=0.1) + + def test_assert_starts_with(self): + self.assertStartsWith('foobar', 'foo') + self.assertStartsWith('foobar', 'foobar') + msg = 'This is a useful message' + whole_msg = "'foobar' does not start with 'bar' : This is a useful message" + self.assertRaisesWithLiteralMatch(AssertionError, whole_msg, + self.assertStartsWith, + 'foobar', 'bar', msg) + self.assertRaises(AssertionError, self.assertStartsWith, 'foobar', 'blah') + + def test_assert_not_starts_with(self): + self.assertNotStartsWith('foobar', 'bar') + self.assertNotStartsWith('foobar', 'blah') + msg = 'This is a useful message' + whole_msg = "'foobar' does start with 'foo' : This is a useful message" + self.assertRaisesWithLiteralMatch(AssertionError, whole_msg, + self.assertNotStartsWith, + 'foobar', 'foo', msg) + self.assertRaises(AssertionError, self.assertNotStartsWith, 'foobar', + 'foobar') + + def test_assert_ends_with(self): + self.assertEndsWith('foobar', 'bar') + self.assertEndsWith('foobar', 'foobar') + msg = 'This is a useful message' + whole_msg = "'foobar' does not end with 'foo' : This is a useful message" + self.assertRaisesWithLiteralMatch(AssertionError, whole_msg, + self.assertEndsWith, + 'foobar', 'foo', msg) + self.assertRaises(AssertionError, self.assertEndsWith, 'foobar', 'blah') + + def test_assert_not_ends_with(self): + self.assertNotEndsWith('foobar', 'foo') + self.assertNotEndsWith('foobar', 'blah') + msg = 'This is a useful message' + whole_msg = "'foobar' does end with 'bar' : This is a useful message" + self.assertRaisesWithLiteralMatch(AssertionError, whole_msg, + self.assertNotEndsWith, + 'foobar', 'bar', msg) + self.assertRaises(AssertionError, self.assertNotEndsWith, 'foobar', + 'foobar') + + def test_assert_regex_backports(self): + self.assertRegex('regex', 'regex') + self.assertNotRegex('not-regex', 'no-match') + with self.assertRaisesRegex(ValueError, 'pattern'): + raise ValueError('pattern') + + def test_assert_regex_match_matches(self): + self.assertRegexMatch('str', ['str']) + + def test_assert_regex_match_matches_substring(self): + self.assertRegexMatch('pre-str-post', ['str']) + + def test_assert_regex_match_multiple_regex_matches(self): + self.assertRegexMatch('str', ['rts', 'str']) + + def test_assert_regex_match_empty_list_fails(self): + expected_re = re.compile(r'No regexes specified\.', re.MULTILINE) + + with self.assertRaisesRegex(AssertionError, expected_re): + self.assertRegexMatch('str', regexes=[]) + + def test_assert_regex_match_bad_arguments(self): + with self.assertRaisesRegex(AssertionError, + 'regexes is string or bytes;.*'): + self.assertRegexMatch('1.*2', '1 2') + + def test_assert_regex_match_unicode_vs_bytes(self): + """Ensure proper utf-8 encoding or decoding happens automatically.""" + self.assertRegexMatch(u'str', [b'str']) + self.assertRegexMatch(b'str', [u'str']) + + def test_assert_regex_match_unicode(self): + self.assertRegexMatch(u'foo str', [u'str']) + + def test_assert_regex_match_bytes(self): + self.assertRegexMatch(b'foo str', [b'str']) + + def test_assert_regex_match_all_the_same_type(self): + with self.assertRaisesRegex(AssertionError, 'regexes .* same type'): + self.assertRegexMatch('foo str', [b'str', u'foo']) + + def test_assert_command_fails_stderr(self): + tmpdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value) + self.assertCommandFails( + ['cat', os.path.join(tmpdir, 'file.txt')], + ['No such file or directory'], + env=_env_for_command_tests()) + + def test_assert_command_fails_with_list_of_string(self): + self.assertCommandFails( + ['false'], [''], env=_env_for_command_tests()) + + def test_assert_command_fails_with_list_of_unicode_string(self): + self.assertCommandFails( + [u'false'], [''], env=_env_for_command_tests()) + + def test_assert_command_fails_with_unicode_string(self): + self.assertCommandFails( + u'false', [u''], env=_env_for_command_tests()) + + def test_assert_command_fails_with_unicode_string_bytes_regex(self): + self.assertCommandFails( + u'false', [b''], env=_env_for_command_tests()) + + def test_assert_command_fails_with_message(self): + msg = 'This is a useful message' + expected_re = re.compile('The following command succeeded while expected to' + ' fail:.* This is a useful message', re.DOTALL) + + with self.assertRaisesRegex(AssertionError, expected_re): + self.assertCommandFails( + [u'true'], [''], msg=msg, env=_env_for_command_tests()) + + def test_assert_command_succeeds_stderr(self): + expected_re = re.compile('No such file or directory') + tmpdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value) + + with self.assertRaisesRegex(AssertionError, expected_re): + self.assertCommandSucceeds( + ['cat', os.path.join(tmpdir, 'file.txt')], + env=_env_for_command_tests()) + + def test_assert_command_succeeds_with_matching_unicode_regexes(self): + self.assertCommandSucceeds( + ['echo', 'SUCCESS'], regexes=[u'SUCCESS'], + env=_env_for_command_tests()) + + def test_assert_command_succeeds_with_matching_bytes_regexes(self): + self.assertCommandSucceeds( + ['echo', 'SUCCESS'], regexes=[b'SUCCESS'], + env=_env_for_command_tests()) + + def test_assert_command_succeeds_with_non_matching_regexes(self): + expected_re = re.compile('Running command.* This is a useful message', + re.DOTALL) + msg = 'This is a useful message' + + with self.assertRaisesRegex(AssertionError, expected_re): + self.assertCommandSucceeds( + ['echo', 'FAIL'], regexes=['SUCCESS'], msg=msg, + env=_env_for_command_tests()) + + def test_assert_command_succeeds_with_list_of_string(self): + self.assertCommandSucceeds( + ['true'], env=_env_for_command_tests()) + + def test_assert_command_succeeds_with_list_of_unicode_string(self): + self.assertCommandSucceeds( + [u'true'], env=_env_for_command_tests()) + + def test_assert_command_succeeds_with_unicode_string(self): + self.assertCommandSucceeds( + u'true', env=_env_for_command_tests()) + + def test_inequality(self): + # Try ints + self.assertGreater(2, 1) + self.assertGreaterEqual(2, 1) + self.assertGreaterEqual(1, 1) + self.assertLess(1, 2) + self.assertLessEqual(1, 2) + self.assertLessEqual(1, 1) + self.assertRaises(AssertionError, self.assertGreater, 1, 2) + self.assertRaises(AssertionError, self.assertGreater, 1, 1) + self.assertRaises(AssertionError, self.assertGreaterEqual, 1, 2) + self.assertRaises(AssertionError, self.assertLess, 2, 1) + self.assertRaises(AssertionError, self.assertLess, 1, 1) + self.assertRaises(AssertionError, self.assertLessEqual, 2, 1) + + # Try Floats + self.assertGreater(1.1, 1.0) + self.assertGreaterEqual(1.1, 1.0) + self.assertGreaterEqual(1.0, 1.0) + self.assertLess(1.0, 1.1) + self.assertLessEqual(1.0, 1.1) + self.assertLessEqual(1.0, 1.0) + self.assertRaises(AssertionError, self.assertGreater, 1.0, 1.1) + self.assertRaises(AssertionError, self.assertGreater, 1.0, 1.0) + self.assertRaises(AssertionError, self.assertGreaterEqual, 1.0, 1.1) + self.assertRaises(AssertionError, self.assertLess, 1.1, 1.0) + self.assertRaises(AssertionError, self.assertLess, 1.0, 1.0) + self.assertRaises(AssertionError, self.assertLessEqual, 1.1, 1.0) + + # Try Strings + self.assertGreater('bug', 'ant') + self.assertGreaterEqual('bug', 'ant') + self.assertGreaterEqual('ant', 'ant') + self.assertLess('ant', 'bug') + self.assertLessEqual('ant', 'bug') + self.assertLessEqual('ant', 'ant') + self.assertRaises(AssertionError, self.assertGreater, 'ant', 'bug') + self.assertRaises(AssertionError, self.assertGreater, 'ant', 'ant') + self.assertRaises(AssertionError, self.assertGreaterEqual, 'ant', 'bug') + self.assertRaises(AssertionError, self.assertLess, 'bug', 'ant') + self.assertRaises(AssertionError, self.assertLess, 'ant', 'ant') + self.assertRaises(AssertionError, self.assertLessEqual, 'bug', 'ant') + + # Try Unicode + self.assertGreater(u'bug', u'ant') + self.assertGreaterEqual(u'bug', u'ant') + self.assertGreaterEqual(u'ant', u'ant') + self.assertLess(u'ant', u'bug') + self.assertLessEqual(u'ant', u'bug') + self.assertLessEqual(u'ant', u'ant') + self.assertRaises(AssertionError, self.assertGreater, u'ant', u'bug') + self.assertRaises(AssertionError, self.assertGreater, u'ant', u'ant') + self.assertRaises(AssertionError, self.assertGreaterEqual, u'ant', u'bug') + self.assertRaises(AssertionError, self.assertLess, u'bug', u'ant') + self.assertRaises(AssertionError, self.assertLess, u'ant', u'ant') + self.assertRaises(AssertionError, self.assertLessEqual, u'bug', u'ant') + + # Try Mixed String/Unicode + self.assertGreater('bug', u'ant') + self.assertGreater(u'bug', 'ant') + self.assertGreaterEqual('bug', u'ant') + self.assertGreaterEqual(u'bug', 'ant') + self.assertGreaterEqual('ant', u'ant') + self.assertGreaterEqual(u'ant', 'ant') + self.assertLess('ant', u'bug') + self.assertLess(u'ant', 'bug') + self.assertLessEqual('ant', u'bug') + self.assertLessEqual(u'ant', 'bug') + self.assertLessEqual('ant', u'ant') + self.assertLessEqual(u'ant', 'ant') + self.assertRaises(AssertionError, self.assertGreater, 'ant', u'bug') + self.assertRaises(AssertionError, self.assertGreater, u'ant', 'bug') + self.assertRaises(AssertionError, self.assertGreater, 'ant', u'ant') + self.assertRaises(AssertionError, self.assertGreater, u'ant', 'ant') + self.assertRaises(AssertionError, self.assertGreaterEqual, 'ant', u'bug') + self.assertRaises(AssertionError, self.assertGreaterEqual, u'ant', 'bug') + self.assertRaises(AssertionError, self.assertLess, 'bug', u'ant') + self.assertRaises(AssertionError, self.assertLess, u'bug', 'ant') + self.assertRaises(AssertionError, self.assertLess, 'ant', u'ant') + self.assertRaises(AssertionError, self.assertLess, u'ant', 'ant') + self.assertRaises(AssertionError, self.assertLessEqual, 'bug', u'ant') + self.assertRaises(AssertionError, self.assertLessEqual, u'bug', 'ant') + + def test_assert_multi_line_equal(self): + sample_text = """\ +http://www.python.org/doc/2.3/lib/module-unittest.html +test case + A test case is the smallest unit of testing. [...] +""" + revised_sample_text = """\ +http://www.python.org/doc/2.4.1/lib/module-unittest.html +test case + A test case is the smallest unit of testing. [...] You may provide your + own implementation that does not subclass from TestCase, of course. +""" + sample_text_error = """ +- http://www.python.org/doc/2.3/lib/module-unittest.html +? ^ ++ http://www.python.org/doc/2.4.1/lib/module-unittest.html +? ^^^ + test case +- A test case is the smallest unit of testing. [...] ++ A test case is the smallest unit of testing. [...] You may provide your +? +++++++++++++++++++++ ++ own implementation that does not subclass from TestCase, of course. +""" + self.assertRaisesWithLiteralMatch(AssertionError, sample_text_error, + self.assertMultiLineEqual, + sample_text, + revised_sample_text) + + self.assertRaises(AssertionError, self.assertMultiLineEqual, (1, 2), 'str') + self.assertRaises(AssertionError, self.assertMultiLineEqual, 'str', (1, 2)) + + def test_assert_multi_line_equal_adds_newlines_if_needed(self): + self.assertRaisesWithLiteralMatch( + AssertionError, + '\n' + ' line1\n' + '- line2\n' + '? ^\n' + '+ line3\n' + '? ^\n', + self.assertMultiLineEqual, + 'line1\n' + 'line2', + 'line1\n' + 'line3') + + def test_assert_multi_line_equal_shows_missing_newlines(self): + self.assertRaisesWithLiteralMatch( + AssertionError, + '\n' + ' line1\n' + '- line2\n' + '? -\n' + '+ line2\n', + self.assertMultiLineEqual, + 'line1\n' + 'line2\n', + 'line1\n' + 'line2') + + def test_assert_multi_line_equal_shows_extra_newlines(self): + self.assertRaisesWithLiteralMatch( + AssertionError, + '\n' + ' line1\n' + '- line2\n' + '+ line2\n' + '? +\n', + self.assertMultiLineEqual, + 'line1\n' + 'line2', + 'line1\n' + 'line2\n') + + def test_assert_multi_line_equal_line_limit_limits(self): + self.assertRaisesWithLiteralMatch( + AssertionError, + '\n' + ' line1\n' + '(... and 4 more delta lines omitted for brevity.)\n', + self.assertMultiLineEqual, + 'line1\n' + 'line2\n', + 'line1\n' + 'line3\n', + line_limit=1) + + def test_assert_multi_line_equal_line_limit_limits_with_message(self): + self.assertRaisesWithLiteralMatch( + AssertionError, + 'Prefix:\n' + ' line1\n' + '(... and 4 more delta lines omitted for brevity.)\n', + self.assertMultiLineEqual, + 'line1\n' + 'line2\n', + 'line1\n' + 'line3\n', + 'Prefix', + line_limit=1) + + def test_assert_is_none(self): + self.assertIsNone(None) + self.assertRaises(AssertionError, self.assertIsNone, False) + self.assertIsNotNone('Google') + self.assertRaises(AssertionError, self.assertIsNotNone, None) + self.assertRaises(AssertionError, self.assertIsNone, (1, 2)) + + def test_assert_is(self): + self.assertIs(object, object) + self.assertRaises(AssertionError, self.assertIsNot, object, object) + self.assertIsNot(True, False) + self.assertRaises(AssertionError, self.assertIs, True, False) + + def test_assert_between(self): + self.assertBetween(3.14, 3.1, 3.141) + self.assertBetween(4, 4, 1e10000) + self.assertBetween(9.5, 9.4, 9.5) + self.assertBetween(-1e10, -1e10000, 0) + self.assertRaises(AssertionError, self.assertBetween, 9.4, 9.3, 9.3999) + self.assertRaises(AssertionError, self.assertBetween, -1e10000, -1e10, 0) + + def test_assert_raises_with_predicate_match_no_raise(self): + with self.assertRaisesRegex(AssertionError, '^Exception not raised$'): + self.assertRaisesWithPredicateMatch(Exception, + lambda e: True, + lambda: 1) # don't raise + + with self.assertRaisesRegex(AssertionError, '^Exception not raised$'): + with self.assertRaisesWithPredicateMatch(Exception, lambda e: True): + pass # don't raise + + def test_assert_raises_with_predicate_match_raises_wrong_exception(self): + def _raise_value_error(): + raise ValueError + + with self.assertRaises(ValueError): + self.assertRaisesWithPredicateMatch(IOError, + lambda e: True, + _raise_value_error) + + with self.assertRaises(ValueError): + with self.assertRaisesWithPredicateMatch(IOError, lambda e: True): + raise ValueError + + def test_assert_raises_with_predicate_match_predicate_fails(self): + def _raise_value_error(): + raise ValueError + with self.assertRaisesRegex(AssertionError, ' does not match predicate '): + self.assertRaisesWithPredicateMatch(ValueError, + lambda e: False, + _raise_value_error) + + with self.assertRaisesRegex(AssertionError, ' does not match predicate '): + with self.assertRaisesWithPredicateMatch(ValueError, lambda e: False): + raise ValueError + + def test_assert_raises_with_predicate_match_predicate_passes(self): + def _raise_value_error(): + raise ValueError + + self.assertRaisesWithPredicateMatch(ValueError, + lambda e: True, + _raise_value_error) + + with self.assertRaisesWithPredicateMatch(ValueError, lambda e: True): + raise ValueError + + def test_assert_contains_in_order(self): + # Valids + self.assertContainsInOrder( + ['fox', 'dog'], 'The quick brown fox jumped over the lazy dog.') + self.assertContainsInOrder( + ['quick', 'fox', 'dog'], + 'The quick brown fox jumped over the lazy dog.') + self.assertContainsInOrder( + ['The', 'fox', 'dog.'], 'The quick brown fox jumped over the lazy dog.') + self.assertContainsInOrder( + ['fox'], 'The quick brown fox jumped over the lazy dog.') + self.assertContainsInOrder( + 'fox', 'The quick brown fox jumped over the lazy dog.') + self.assertContainsInOrder( + ['fox', 'dog'], 'fox dog fox') + self.assertContainsInOrder( + [], 'The quick brown fox jumped over the lazy dog.') + self.assertContainsInOrder( + [], '') + + # Invalids + msg = 'This is a useful message' + whole_msg = ("Did not find 'fox' after 'dog' in 'The quick brown fox" + " jumped over the lazy dog' : This is a useful message") + self.assertRaisesWithLiteralMatch( + AssertionError, whole_msg, self.assertContainsInOrder, + ['dog', 'fox'], 'The quick brown fox jumped over the lazy dog', msg=msg) + self.assertRaises( + AssertionError, self.assertContainsInOrder, + ['The', 'dog', 'fox'], 'The quick brown fox jumped over the lazy dog') + self.assertRaises( + AssertionError, self.assertContainsInOrder, ['dog'], '') + + def test_assert_contains_subsequence_for_numbers(self): + self.assertContainsSubsequence([1, 2, 3], [1]) + self.assertContainsSubsequence([1, 2, 3], [1, 2]) + self.assertContainsSubsequence([1, 2, 3], [1, 3]) + + with self.assertRaises(AssertionError): + self.assertContainsSubsequence([1, 2, 3], [4]) + msg = 'This is a useful message' + whole_msg = ('[3, 1] not a subsequence of [1, 2, 3]. ' + 'First non-matching element: 1 : This is a useful message') + self.assertRaisesWithLiteralMatch(AssertionError, whole_msg, + self.assertContainsSubsequence, + [1, 2, 3], [3, 1], msg=msg) + + def test_assert_contains_subsequence_for_strings(self): + self.assertContainsSubsequence(['foo', 'bar', 'blorp'], ['foo', 'blorp']) + with self.assertRaises(AssertionError): + self.assertContainsSubsequence( + ['foo', 'bar', 'blorp'], ['blorp', 'foo']) + + def test_assert_contains_subsequence_with_empty_subsequence(self): + self.assertContainsSubsequence([1, 2, 3], []) + self.assertContainsSubsequence(['foo', 'bar', 'blorp'], []) + self.assertContainsSubsequence([], []) + + def test_assert_contains_subsequence_with_empty_container(self): + with self.assertRaises(AssertionError): + self.assertContainsSubsequence([], [1]) + with self.assertRaises(AssertionError): + self.assertContainsSubsequence([], ['foo']) + + def test_assert_contains_exact_subsequence_for_numbers(self): + self.assertContainsExactSubsequence([1, 2, 3], [1]) + self.assertContainsExactSubsequence([1, 2, 3], [1, 2]) + self.assertContainsExactSubsequence([1, 2, 3], [2, 3]) + + with self.assertRaises(AssertionError): + self.assertContainsExactSubsequence([1, 2, 3], [4]) + msg = 'This is a useful message' + whole_msg = ('[1, 2, 4] not an exact subsequence of [1, 2, 3, 4]. ' + 'Longest matching prefix: [1, 2] : This is a useful message') + self.assertRaisesWithLiteralMatch(AssertionError, whole_msg, + self.assertContainsExactSubsequence, + [1, 2, 3, 4], [1, 2, 4], msg=msg) + + def test_assert_contains_exact_subsequence_for_strings(self): + self.assertContainsExactSubsequence( + ['foo', 'bar', 'blorp'], ['foo', 'bar']) + with self.assertRaises(AssertionError): + self.assertContainsExactSubsequence( + ['foo', 'bar', 'blorp'], ['blorp', 'foo']) + + def test_assert_contains_exact_subsequence_with_empty_subsequence(self): + self.assertContainsExactSubsequence([1, 2, 3], []) + self.assertContainsExactSubsequence(['foo', 'bar', 'blorp'], []) + self.assertContainsExactSubsequence([], []) + + def test_assert_contains_exact_subsequence_with_empty_container(self): + with self.assertRaises(AssertionError): + self.assertContainsExactSubsequence([], [3]) + with self.assertRaises(AssertionError): + self.assertContainsExactSubsequence([], ['foo', 'bar']) + self.assertContainsExactSubsequence([], []) + + def test_assert_totally_ordered(self): + # Valid. + self.assertTotallyOrdered() + self.assertTotallyOrdered([1]) + self.assertTotallyOrdered([1], [2]) + self.assertTotallyOrdered([1, 1, 1]) + self.assertTotallyOrdered([(1, 1)], [(1, 2)], [(2, 1)]) + + # From the docstring. + class A(object): + + def __init__(self, x, y): + self.x = x + self.y = y + + def __hash__(self): + return hash(self.x) + + def __repr__(self): + return 'A(%r, %r)' % (self.x, self.y) + + def __eq__(self, other): + try: + return self.x == other.x + except AttributeError: + return NotImplemented + + def __ne__(self, other): + try: + return self.x != other.x + except AttributeError: + return NotImplemented + + def __lt__(self, other): + try: + return self.x < other.x + except AttributeError: + return NotImplemented + + def __le__(self, other): + try: + return self.x <= other.x + except AttributeError: + return NotImplemented + + def __gt__(self, other): + try: + return self.x > other.x + except AttributeError: + return NotImplemented + + def __ge__(self, other): + try: + return self.x >= other.x + except AttributeError: + return NotImplemented + + class B(A): + """Like A, but not hashable.""" + __hash__ = None + + self.assertTotallyOrdered( + [A(1, 'a')], + [A(2, 'b')], # 2 is after 1. + [ + A(3, 'c'), + B(3, 'd'), + B(3, 'e') # The second argument is irrelevant. + ], + [A(4, 'z')]) + + # Invalid. + msg = 'This is a useful message' + whole_msg = '2 not less than 1 : This is a useful message' + self.assertRaisesWithLiteralMatch(AssertionError, whole_msg, + self.assertTotallyOrdered, [2], [1], + msg=msg) + self.assertRaises(AssertionError, self.assertTotallyOrdered, [2], [1]) + self.assertRaises(AssertionError, self.assertTotallyOrdered, [2], [1], [3]) + self.assertRaises(AssertionError, self.assertTotallyOrdered, [1, 2]) + + def test_short_description_without_docstring(self): + self.assertEquals( + self.shortDescription(), + 'TestCaseTest.test_short_description_without_docstring') + + def test_short_description_with_one_line_docstring(self): + """Tests shortDescription() for a method with a docstring.""" + self.assertEquals( + self.shortDescription(), + 'TestCaseTest.test_short_description_with_one_line_docstring\n' + 'Tests shortDescription() for a method with a docstring.') + + def test_short_description_with_multi_line_docstring(self): + """Tests shortDescription() for a method with a longer docstring. + + This method ensures that only the first line of a docstring is + returned used in the short description, no matter how long the + whole thing is. + """ + self.assertEquals( + self.shortDescription(), + 'TestCaseTest.test_short_description_with_multi_line_docstring\n' + 'Tests shortDescription() for a method with a longer docstring.') + + def test_assert_url_equal_same(self): + self.assertUrlEqual('http://a', 'http://a') + self.assertUrlEqual('http://a/path/test', 'http://a/path/test') + self.assertUrlEqual('#fragment', '#fragment') + self.assertUrlEqual('http://a/?q=1', 'http://a/?q=1') + self.assertUrlEqual('http://a/?q=1&v=5', 'http://a/?v=5&q=1') + self.assertUrlEqual('/logs?v=1&a=2&t=labels&f=path%3A%22foo%22', + '/logs?a=2&f=path%3A%22foo%22&v=1&t=labels') + self.assertUrlEqual('http://a/path;p1', 'http://a/path;p1') + self.assertUrlEqual('http://a/path;p2;p3;p1', 'http://a/path;p1;p2;p3') + self.assertUrlEqual('sip:alice@atlanta.com;maddr=239.255.255.1;ttl=15', + 'sip:alice@atlanta.com;ttl=15;maddr=239.255.255.1') + self.assertUrlEqual('http://nyan/cat?p=1&b=', 'http://nyan/cat?b=&p=1') + + def test_assert_url_equal_different(self): + msg = 'This is a useful message' + whole_msg = 'This is a useful message:\n- a\n+ b\n' + self.assertRaisesWithLiteralMatch(AssertionError, whole_msg, + self.assertUrlEqual, + 'http://a', 'http://b', msg=msg) + self.assertRaises(AssertionError, self.assertUrlEqual, + 'http://a/x', 'http://a:8080/x') + self.assertRaises(AssertionError, self.assertUrlEqual, + 'http://a/x', 'http://a/y') + self.assertRaises(AssertionError, self.assertUrlEqual, + 'http://a/?q=2', 'http://a/?q=1') + self.assertRaises(AssertionError, self.assertUrlEqual, + 'http://a/?q=1&v=5', 'http://a/?v=2&q=1') + self.assertRaises(AssertionError, self.assertUrlEqual, + 'http://a', 'sip://b') + self.assertRaises(AssertionError, self.assertUrlEqual, + 'http://a#g', 'sip://a#f') + self.assertRaises(AssertionError, self.assertUrlEqual, + 'http://a/path;p1;p3;p1', 'http://a/path;p1;p2;p3') + self.assertRaises(AssertionError, self.assertUrlEqual, + 'http://nyan/cat?p=1&b=', 'http://nyan/cat?p=1') + + def test_same_structure_same(self): + self.assertSameStructure(0, 0) + self.assertSameStructure(1, 1) + self.assertSameStructure('', '') + self.assertSameStructure('hello', 'hello', msg='This Should not fail') + self.assertSameStructure(set(), set()) + self.assertSameStructure(set([1, 2]), set([1, 2])) + self.assertSameStructure(set(), frozenset()) + self.assertSameStructure(set([1, 2]), frozenset([1, 2])) + self.assertSameStructure([], []) + self.assertSameStructure(['a'], ['a']) + self.assertSameStructure([], ()) + self.assertSameStructure(['a'], ('a',)) + self.assertSameStructure({}, {}) + self.assertSameStructure({'one': 1}, {'one': 1}) + self.assertSameStructure(collections.defaultdict(None, {'one': 1}), + {'one': 1}) + self.assertSameStructure(collections.OrderedDict({'one': 1}), + collections.defaultdict(None, {'one': 1})) + + def test_same_structure_different(self): + # Different type + with self.assertRaisesRegex( + AssertionError, + r"a is a <(type|class) 'int'> but b is a <(type|class) 'str'>"): + self.assertSameStructure(0, 'hello') + with self.assertRaisesRegex( + AssertionError, + r"a is a <(type|class) 'int'> but b is a <(type|class) 'list'>"): + self.assertSameStructure(0, []) + with self.assertRaisesRegex( + AssertionError, + r"a is a <(type|class) 'int'> but b is a <(type|class) 'float'>"): + self.assertSameStructure(2, 2.0) + + with self.assertRaisesRegex( + AssertionError, + r"a is a <(type|class) 'list'> but b is a <(type|class) 'dict'>"): + self.assertSameStructure([], {}) + + with self.assertRaisesRegex( + AssertionError, + r"a is a <(type|class) 'list'> but b is a <(type|class) 'set'>"): + self.assertSameStructure([], set()) + + with self.assertRaisesRegex( + AssertionError, + r"a is a <(type|class) 'dict'> but b is a <(type|class) 'set'>"): + self.assertSameStructure({}, set()) + + # Different scalar values + self.assertRaisesWithLiteralMatch( + AssertionError, 'a is 0 but b is 1', + self.assertSameStructure, 0, 1) + self.assertRaisesWithLiteralMatch( + AssertionError, "a is 'hello' but b is 'goodbye' : This was expected", + self.assertSameStructure, 'hello', 'goodbye', msg='This was expected') + + # Different sets + self.assertRaisesWithLiteralMatch( + AssertionError, + r'AA has 2 but BB does not', + self.assertSameStructure, + set([1, 2]), + set([1]), + aname='AA', + bname='BB') + self.assertRaisesWithLiteralMatch( + AssertionError, + r'AA lacks 2 but BB has it', + self.assertSameStructure, + set([1]), + set([1, 2]), + aname='AA', + bname='BB') + + # Different lists + self.assertRaisesWithLiteralMatch( + AssertionError, "a has [2] with value 'z' but b does not", + self.assertSameStructure, ['x', 'y', 'z'], ['x', 'y']) + self.assertRaisesWithLiteralMatch( + AssertionError, "a lacks [2] but b has it with value 'z'", + self.assertSameStructure, ['x', 'y'], ['x', 'y', 'z']) + self.assertRaisesWithLiteralMatch( + AssertionError, "a[2] is 'z' but b[2] is 'Z'", + self.assertSameStructure, ['x', 'y', 'z'], ['x', 'y', 'Z']) + + # Different dicts + self.assertRaisesWithLiteralMatch( + AssertionError, "a has ['two'] with value 2 but it's missing in b", + self.assertSameStructure, {'one': 1, 'two': 2}, {'one': 1}) + self.assertRaisesWithLiteralMatch( + AssertionError, "a lacks ['two'] but b has it with value 2", + self.assertSameStructure, {'one': 1}, {'one': 1, 'two': 2}) + self.assertRaisesWithLiteralMatch( + AssertionError, "a['two'] is 2 but b['two'] is 3", + self.assertSameStructure, {'one': 1, 'two': 2}, {'one': 1, 'two': 3}) + + # String and byte types should not be considered equivalent to other + # sequences + self.assertRaisesRegex( + AssertionError, + r"a is a <(type|class) 'list'> but b is a <(type|class) 'str'>", + self.assertSameStructure, [], '') + self.assertRaisesRegex( + AssertionError, + r"a is a <(type|class) 'str'> but b is a <(type|class) 'tuple'>", + self.assertSameStructure, '', ()) + self.assertRaisesRegex( + AssertionError, + r"a is a <(type|class) 'list'> but b is a <(type|class) 'str'>", + self.assertSameStructure, ['a', 'b', 'c'], 'abc') + self.assertRaisesRegex( + AssertionError, + r"a is a <(type|class) 'str'> but b is a <(type|class) 'tuple'>", + self.assertSameStructure, 'abc', ('a', 'b', 'c')) + + # Deep key generation + self.assertRaisesWithLiteralMatch( + AssertionError, + "a[0][0]['x']['y']['z'][0] is 1 but b[0][0]['x']['y']['z'][0] is 2", + self.assertSameStructure, + [[{'x': {'y': {'z': [1]}}}]], [[{'x': {'y': {'z': [2]}}}]]) + + # Multiple problems + self.assertRaisesWithLiteralMatch( + AssertionError, + 'a[0] is 1 but b[0] is 3; a[1] is 2 but b[1] is 4', + self.assertSameStructure, [1, 2], [3, 4]) + with self.assertRaisesRegex( + AssertionError, + re.compile(r"^a\[0] is 'a' but b\[0] is 'A'; .*" + r"a\[18] is 's' but b\[18] is 'S'; \.\.\.$")): + self.assertSameStructure( + list(string.ascii_lowercase), list(string.ascii_uppercase)) + + # Verify same behavior with self.maxDiff = None + self.maxDiff = None + self.assertRaisesWithLiteralMatch( + AssertionError, + 'a[0] is 1 but b[0] is 3; a[1] is 2 but b[1] is 4', + self.assertSameStructure, [1, 2], [3, 4]) + + def test_same_structure_mapping_unchanged(self): + default_a = collections.defaultdict(lambda: 'BAD MODIFICATION', {}) + dict_b = {'one': 'z'} + self.assertRaisesWithLiteralMatch( + AssertionError, + r"a lacks ['one'] but b has it with value 'z'", + self.assertSameStructure, default_a, dict_b) + self.assertEmpty(default_a) + + dict_a = {'one': 'z'} + default_b = collections.defaultdict(lambda: 'BAD MODIFICATION', {}) + self.assertRaisesWithLiteralMatch( + AssertionError, + r"a has ['one'] with value 'z' but it's missing in b", + self.assertSameStructure, dict_a, default_b) + self.assertEmpty(default_b) + + def test_assert_json_equal_same(self): + self.assertJsonEqual('{"success": true}', '{"success": true}') + self.assertJsonEqual('{"success": true}', '{"success":true}') + self.assertJsonEqual('true', 'true') + self.assertJsonEqual('null', 'null') + self.assertJsonEqual('false', 'false') + self.assertJsonEqual('34', '34') + self.assertJsonEqual('[1, 2, 3]', '[1,2,3]', msg='please PASS') + self.assertJsonEqual('{"sequence": [1, 2, 3], "float": 23.42}', + '{"float": 23.42, "sequence": [1,2,3]}') + self.assertJsonEqual('{"nest": {"spam": "eggs"}, "float": 23.42}', + '{"float": 23.42, "nest": {"spam":"eggs"}}') + + def test_assert_json_equal_different(self): + with self.assertRaises(AssertionError): + self.assertJsonEqual('{"success": true}', '{"success": false}') + with self.assertRaises(AssertionError): + self.assertJsonEqual('{"success": false}', '{"Success": false}') + with self.assertRaises(AssertionError): + self.assertJsonEqual('false', 'true') + with self.assertRaises(AssertionError) as error_context: + self.assertJsonEqual('null', '0', msg='I demand FAILURE') + self.assertIn('I demand FAILURE', error_context.exception.args[0]) + self.assertIn('None', error_context.exception.args[0]) + with self.assertRaises(AssertionError): + self.assertJsonEqual('[1, 0, 3]', '[1,2,3]') + with self.assertRaises(AssertionError): + self.assertJsonEqual('{"sequence": [1, 2, 3], "float": 23.42}', + '{"float": 23.42, "sequence": [1,0,3]}') + with self.assertRaises(AssertionError): + self.assertJsonEqual('{"nest": {"spam": "eggs"}, "float": 23.42}', + '{"float": 23.42, "nest": {"Spam":"beans"}}') + + def test_assert_json_equal_bad_json(self): + with self.assertRaises(ValueError) as error_context: + self.assertJsonEqual("alhg'2;#", '{"a": true}') + self.assertIn('first', error_context.exception.args[0]) + self.assertIn('alhg', error_context.exception.args[0]) + + with self.assertRaises(ValueError) as error_context: + self.assertJsonEqual('{"a": true}', "alhg'2;#") + self.assertIn('second', error_context.exception.args[0]) + self.assertIn('alhg', error_context.exception.args[0]) + + with self.assertRaises(ValueError) as error_context: + self.assertJsonEqual('', '') + + +class GetCommandStderrTestCase(absltest.TestCase): + + def setUp(self): + super(GetCommandStderrTestCase, self).setUp() + self.original_environ = os.environ.copy() + + def tearDown(self): + super(GetCommandStderrTestCase, self).tearDown() + os.environ = self.original_environ + + def test_return_status(self): + tmpdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value) + returncode = ( + absltest.get_command_stderr( + ['cat', os.path.join(tmpdir, 'file.txt')], + env=_env_for_command_tests())[0]) + self.assertEqual(1, returncode) + + def test_stderr(self): + tmpdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value) + stderr = ( + absltest.get_command_stderr( + ['cat', os.path.join(tmpdir, 'file.txt')], + env=_env_for_command_tests())[1]) + stderr = stderr.decode('utf-8') + self.assertRegex(stderr, 'No such file or directory') + + +@contextlib.contextmanager +def cm_for_test(obj): + try: + obj.cm_state = 'yielded' + yield 'value' + finally: + obj.cm_state = 'exited' + + +class EnterContextTest(absltest.TestCase): + + def setUp(self): + self.cm_state = 'unset' + self.cm_value = 'unset' + + def assert_cm_exited(): + self.assertEqual(self.cm_state, 'exited') + + # Because cleanup functions are run in reverse order, we have to add + # our assert-cleanup before the exit stack registers its own cleanup. + # This ensures we see state after the stack cleanup runs. + self.addCleanup(assert_cm_exited) + + super(EnterContextTest, self).setUp() + self.cm_value = self.enter_context(cm_for_test(self)) + + def test_enter_context(self): + self.assertEqual(self.cm_value, 'value') + self.assertEqual(self.cm_state, 'yielded') + + +@absltest.skipIf(not hasattr(absltest.TestCase, 'addClassCleanup'), + 'Python 3.8 required for class-level enter_context') +class EnterContextClassmethodTest(absltest.TestCase): + + cm_state = 'unset' + cm_value = 'unset' + + @classmethod + def setUpClass(cls): + + def assert_cm_exited(): + assert cls.cm_state == 'exited' + + # Because cleanup functions are run in reverse order, we have to add + # our assert-cleanup before the exit stack registers its own cleanup. + # This ensures we see state after the stack cleanup runs. + cls.addClassCleanup(assert_cm_exited) + + super(EnterContextClassmethodTest, cls).setUpClass() + cls.cm_value = cls.enter_context(cm_for_test(cls)) + + def test_enter_context(self): + self.assertEqual(self.cm_value, 'value') + self.assertEqual(self.cm_state, 'yielded') + + +class EqualityAssertionTest(absltest.TestCase): + """This test verifies that absltest.failIfEqual actually tests __ne__. + + If a user class implements __eq__, unittest.failUnlessEqual will call it + via first == second. However, failIfEqual also calls + first == second. This means that while the caller may believe + their __ne__ method is being tested, it is not. + """ + + class NeverEqual(object): + """Objects of this class behave like NaNs.""" + + def __eq__(self, unused_other): + return False + + def __ne__(self, unused_other): + return False + + class AllSame(object): + """All objects of this class compare as equal.""" + + def __eq__(self, unused_other): + return True + + def __ne__(self, unused_other): + return False + + class EqualityTestsWithEq(object): + """Performs all equality and inequality tests with __eq__.""" + + def __init__(self, value): + self._value = value + + def __eq__(self, other): + return self._value == other._value + + def __ne__(self, other): + return not self.__eq__(other) + + class EqualityTestsWithNe(object): + """Performs all equality and inequality tests with __ne__.""" + + def __init__(self, value): + self._value = value + + def __eq__(self, other): + return not self.__ne__(other) + + def __ne__(self, other): + return self._value != other._value + + class EqualityTestsWithCmp(object): + + def __init__(self, value): + self._value = value + + def __cmp__(self, other): + return cmp(self._value, other._value) + + class EqualityTestsWithLtEq(object): + + def __init__(self, value): + self._value = value + + def __eq__(self, other): + return self._value == other._value + + def __lt__(self, other): + return self._value < other._value + + def test_all_comparisons_fail(self): + i1 = self.NeverEqual() + i2 = self.NeverEqual() + self.assertFalse(i1 == i2) + self.assertFalse(i1 != i2) + + # Compare two distinct objects + self.assertFalse(i1 is i2) + self.assertRaises(AssertionError, self.assertEqual, i1, i2) + self.assertRaises(AssertionError, self.assertEquals, i1, i2) + self.assertRaises(AssertionError, self.failUnlessEqual, i1, i2) + self.assertRaises(AssertionError, self.assertNotEqual, i1, i2) + self.assertRaises(AssertionError, self.assertNotEquals, i1, i2) + self.assertRaises(AssertionError, self.failIfEqual, i1, i2) + # A NeverEqual object should not compare equal to itself either. + i2 = i1 + self.assertTrue(i1 is i2) + self.assertFalse(i1 == i2) + self.assertFalse(i1 != i2) + self.assertRaises(AssertionError, self.assertEqual, i1, i2) + self.assertRaises(AssertionError, self.assertEquals, i1, i2) + self.assertRaises(AssertionError, self.failUnlessEqual, i1, i2) + self.assertRaises(AssertionError, self.assertNotEqual, i1, i2) + self.assertRaises(AssertionError, self.assertNotEquals, i1, i2) + self.assertRaises(AssertionError, self.failIfEqual, i1, i2) + + def test_all_comparisons_succeed(self): + a = self.AllSame() + b = self.AllSame() + self.assertFalse(a is b) + self.assertTrue(a == b) + self.assertFalse(a != b) + self.assertEqual(a, b) + self.assertEquals(a, b) + self.failUnlessEqual(a, b) + self.assertRaises(AssertionError, self.assertNotEqual, a, b) + self.assertRaises(AssertionError, self.assertNotEquals, a, b) + self.assertRaises(AssertionError, self.failIfEqual, a, b) + + def _perform_apple_apple_orange_checks(self, same_a, same_b, different): + """Perform consistency checks with two apples and an orange. + + The two apples should always compare as being the same (and inequality + checks should fail). The orange should always compare as being different + to each of the apples. + + Args: + same_a: the first apple + same_b: the second apple + different: the orange + """ + self.assertTrue(same_a == same_b) + self.assertFalse(same_a != same_b) + self.assertEqual(same_a, same_b) + self.assertEquals(same_a, same_b) + self.failUnlessEqual(same_a, same_b) + + self.assertFalse(same_a == different) + self.assertTrue(same_a != different) + self.assertNotEqual(same_a, different) + self.assertNotEquals(same_a, different) + self.failIfEqual(same_a, different) + + self.assertFalse(same_b == different) + self.assertTrue(same_b != different) + self.assertNotEqual(same_b, different) + self.assertNotEquals(same_b, different) + self.failIfEqual(same_b, different) + + def test_comparison_with_eq(self): + same_a = self.EqualityTestsWithEq(42) + same_b = self.EqualityTestsWithEq(42) + different = self.EqualityTestsWithEq(1769) + self._perform_apple_apple_orange_checks(same_a, same_b, different) + + def test_comparison_with_ne(self): + same_a = self.EqualityTestsWithNe(42) + same_b = self.EqualityTestsWithNe(42) + different = self.EqualityTestsWithNe(1769) + self._perform_apple_apple_orange_checks(same_a, same_b, different) + + def test_comparison_with_cmp_or_lt_eq(self): + same_a = self.EqualityTestsWithLtEq(42) + same_b = self.EqualityTestsWithLtEq(42) + different = self.EqualityTestsWithLtEq(1769) + self._perform_apple_apple_orange_checks(same_a, same_b, different) + + +class AssertSequenceStartsWithTest(absltest.TestCase): + + def setUp(self): + self.a = [5, 'foo', {'c': 'd'}, None] + + def test_empty_sequence_starts_with_empty_prefix(self): + self.assertSequenceStartsWith([], ()) + + def test_sequence_prefix_is_an_empty_list(self): + self.assertSequenceStartsWith([[]], ([], 'foo')) + + def test_raise_if_empty_prefix_with_non_empty_whole(self): + with self.assertRaisesRegex( + AssertionError, 'Prefix length is 0 but whole length is %d: %s' % (len( + self.a), r"\[5, 'foo', \{'c': 'd'\}, None\]")): + self.assertSequenceStartsWith([], self.a) + + def test_single_element_prefix(self): + self.assertSequenceStartsWith([5], self.a) + + def test_two_element_prefix(self): + self.assertSequenceStartsWith((5, 'foo'), self.a) + + def test_prefix_is_full_sequence(self): + self.assertSequenceStartsWith([5, 'foo', {'c': 'd'}, None], self.a) + + def test_string_prefix(self): + self.assertSequenceStartsWith('abc', 'abc123') + + def test_convert_non_sequence_prefix_to_sequence_and_try_again(self): + self.assertSequenceStartsWith(5, self.a) + + def test_whole_not_asequence(self): + msg = (r'For whole: len\(5\) is not supported, it appears to be type: ' + '<(type|class) \'int\'>') + with self.assertRaisesRegex(AssertionError, msg): + self.assertSequenceStartsWith(self.a, 5) + + def test_raise_if_sequence_does_not_start_with_prefix(self): + msg = (r"prefix: \['foo', \{'c': 'd'\}\] not found at start of whole: " + r"\[5, 'foo', \{'c': 'd'\}, None\].") + with self.assertRaisesRegex(AssertionError, msg): + self.assertSequenceStartsWith(['foo', {'c': 'd'}], self.a) + + def test_raise_if_types_ar_not_supported(self): + with self.assertRaisesRegex(TypeError, 'unhashable type'): + self.assertSequenceStartsWith({'a': 1, 2: 'b'}, + {'a': 1, 2: 'b', 'c': '3'}) + + +class TestAssertEmpty(absltest.TestCase): + longMessage = True + + def test_raises_if_not_asized_object(self): + msg = "Expected a Sized object, got: 'int'" + with self.assertRaisesRegex(AssertionError, msg): + self.assertEmpty(1) + + def test_calls_len_not_bool(self): + + class BadList(list): + + def __bool__(self): + return False + + __nonzero__ = __bool__ + + bad_list = BadList() + self.assertEmpty(bad_list) + self.assertFalse(bad_list) + + def test_passes_when_empty(self): + empty_containers = [ + list(), + tuple(), + dict(), + set(), + frozenset(), + b'', + u'', + bytearray(), + ] + for container in empty_containers: + self.assertEmpty(container) + + def test_raises_with_not_empty_containers(self): + not_empty_containers = [ + [1], + (1,), + {'foo': 'bar'}, + {1}, + frozenset([1]), + b'a', + u'a', + bytearray(b'a'), + ] + regexp = r'.* has length of 1\.$' + for container in not_empty_containers: + with self.assertRaisesRegex(AssertionError, regexp): + self.assertEmpty(container) + + def test_user_message_added_to_default(self): + msg = 'This is a useful message' + whole_msg = re.escape('[1] has length of 1. : This is a useful message') + with self.assertRaisesRegex(AssertionError, whole_msg): + self.assertEmpty([1], msg=msg) + + +class TestAssertNotEmpty(absltest.TestCase): + longMessage = True + + def test_raises_if_not_asized_object(self): + msg = "Expected a Sized object, got: 'int'" + with self.assertRaisesRegex(AssertionError, msg): + self.assertNotEmpty(1) + + def test_calls_len_not_bool(self): + + class BadList(list): + + def __bool__(self): + return False + + __nonzero__ = __bool__ + + bad_list = BadList([1]) + self.assertNotEmpty(bad_list) + self.assertFalse(bad_list) + + def test_passes_when_not_empty(self): + not_empty_containers = [ + [1], + (1,), + {'foo': 'bar'}, + {1}, + frozenset([1]), + b'a', + u'a', + bytearray(b'a'), + ] + for container in not_empty_containers: + self.assertNotEmpty(container) + + def test_raises_with_empty_containers(self): + empty_containers = [ + list(), + tuple(), + dict(), + set(), + frozenset(), + b'', + u'', + bytearray(), + ] + regexp = r'.* has length of 0\.$' + for container in empty_containers: + with self.assertRaisesRegex(AssertionError, regexp): + self.assertNotEmpty(container) + + def test_user_message_added_to_default(self): + msg = 'This is a useful message' + whole_msg = re.escape('[] has length of 0. : This is a useful message') + with self.assertRaisesRegex(AssertionError, whole_msg): + self.assertNotEmpty([], msg=msg) + + +class TestAssertLen(absltest.TestCase): + longMessage = True + + def test_raises_if_not_asized_object(self): + msg = "Expected a Sized object, got: 'int'" + with self.assertRaisesRegex(AssertionError, msg): + self.assertLen(1, 1) + + def test_passes_when_expected_len(self): + containers = [ + [[1], 1], + [(1, 2), 2], + [{'a': 1, 'b': 2, 'c': 3}, 3], + [{1, 2, 3, 4}, 4], + [frozenset([1]), 1], + [b'abc', 3], + [u'def', 3], + [bytearray(b'ghij'), 4], + ] + for container, expected_len in containers: + self.assertLen(container, expected_len) + + def test_raises_when_unexpected_len(self): + containers = [ + [1], + (1, 2), + {'a': 1, 'b': 2, 'c': 3}, + {1, 2, 3, 4}, + frozenset([1]), + b'abc', + u'def', + bytearray(b'ghij'), + ] + for container in containers: + regexp = r'.* has length of %d, expected 100\.$' % len(container) + with self.assertRaisesRegex(AssertionError, regexp): + self.assertLen(container, 100) + + def test_user_message_added_to_default(self): + msg = 'This is a useful message' + whole_msg = ( + r'\[1\] has length of 1, expected 100. : This is a useful message') + with self.assertRaisesRegex(AssertionError, whole_msg): + self.assertLen([1], 100, msg) + + +class TestLoaderTest(absltest.TestCase): + """Tests that the TestLoader bans methods named TestFoo.""" + + # pylint: disable=invalid-name + class Valid(absltest.TestCase): + """Test case containing a variety of valid names.""" + + test_property = 1 + TestProperty = 2 + + @staticmethod + def TestStaticMethod(): + pass + + @staticmethod + def TestStaticMethodWithArg(foo): + pass + + @classmethod + def TestClassMethod(cls): + pass + + def Test(self): + pass + + def TestingHelper(self): + pass + + def testMethod(self): + pass + + def TestHelperWithParams(self, a, b): + pass + + def TestHelperWithVarargs(self, *args, **kwargs): + pass + + def TestHelperWithDefaults(self, a=5): + pass + + class Invalid(absltest.TestCase): + """Test case containing a suspicious method.""" + + def testMethod(self): + pass + + def TestSuspiciousMethod(self): + pass + # pylint: enable=invalid-name + + def setUp(self): + self.loader = absltest.TestLoader() + + def test_valid(self): + suite = self.loader.loadTestsFromTestCase(TestLoaderTest.Valid) + self.assertEquals(1, suite.countTestCases()) + + def testInvalid(self): + with self.assertRaisesRegex(TypeError, 'TestSuspiciousMethod'): + self.loader.loadTestsFromTestCase(TestLoaderTest.Invalid) + + +class InitNotNecessaryForAssertsTest(absltest.TestCase): + """TestCase assertions should work even if __init__ wasn't correctly called. + + This is a workaround, see comment in + absltest.TestCase._getAssertEqualityFunc. We know that not calling + __init__ of a superclass is a bad thing, but people keep doing them, + and this (even if a little bit dirty) saves them from shooting + themselves in the foot. + """ + + def test_subclass(self): + + class Subclass(absltest.TestCase): + + def __init__(self): # pylint: disable=super-init-not-called + pass + + Subclass().assertEquals({}, {}) + + def test_multiple_inheritance(self): + + class Foo(object): + + def __init__(self, *args, **kwargs): + pass + + class Subclass(Foo, absltest.TestCase): + pass + + Subclass().assertEquals({}, {}) + + +class GetCommandStringTest(parameterized.TestCase): + + @parameterized.parameters( + ([], '', ''), + ([''], "''", ''), + (['command', 'arg-0'], "'command' 'arg-0'", 'command arg-0'), + ([u'command', u'arg-0'], "'command' 'arg-0'", u'command arg-0'), + (["foo'bar"], "'foo'\"'\"'bar'", "foo'bar"), + (['foo"bar'], "'foo\"bar'", 'foo"bar'), + ('command arg-0', 'command arg-0', 'command arg-0'), + (u'command arg-0', 'command arg-0', 'command arg-0')) + def test_get_command_string( + self, command, expected_non_windows, expected_windows): + expected = expected_windows if os.name == 'nt' else expected_non_windows + self.assertEqual(expected, absltest.get_command_string(command)) + + +class TempFileTest(absltest.TestCase, HelperMixin): + + def assert_dir_exists(self, temp_dir): + path = temp_dir.full_path + self.assertTrue(os.path.exists(path), 'Dir {} does not exist'.format(path)) + self.assertTrue(os.path.isdir(path), + 'Path {} exists, but is not a directory'.format(path)) + + def assert_file_exists(self, temp_file, expected_content=b''): + path = temp_file.full_path + self.assertTrue(os.path.exists(path), 'File {} does not exist'.format(path)) + self.assertTrue(os.path.isfile(path), + 'Path {} exists, but is not a file'.format(path)) + + mode = 'rb' if isinstance(expected_content, bytes) else 'rt' + with io.open(path, mode) as fp: + actual = fp.read() + self.assertEqual(expected_content, actual) + + def run_tempfile_helper(self, cleanup, expected_paths): + tmpdir = self.create_tempdir('helper-test-temp-dir') + env = { + 'ABSLTEST_TEST_HELPER_TEMPFILE_CLEANUP': cleanup, + 'TEST_TMPDIR': tmpdir.full_path, + } + stdout, stderr = self.run_helper(0, ['TempFileHelperTest'], env, + expect_success=False) + output = ('\n=== Helper output ===\n' + '----- stdout -----\n{}\n' + '----- end stdout -----\n' + '----- stderr -----\n{}\n' + '----- end stderr -----\n' + '===== end helper output =====').format(stdout, stderr) + self.assertIn('test_failure', stderr, output) + + # Adjust paths to match on Windows + expected_paths = {path.replace('/', os.sep) for path in expected_paths} + + actual = { + os.path.relpath(f, tmpdir.full_path) + for f in _listdir_recursive(tmpdir.full_path) + if f != tmpdir.full_path + } + self.assertEqual(expected_paths, actual, output) + + def test_create_file_pre_existing_readonly(self): + first = self.create_tempfile('foo', content='first') + os.chmod(first.full_path, 0o444) + second = self.create_tempfile('foo', content='second') + self.assertEqual('second', first.read_text()) + self.assertEqual('second', second.read_text()) + + def test_create_file_fails_cleanup(self): + path = self.create_tempfile().full_path + # Removing the write bit from the file makes it undeletable on Windows. + os.chmod(path, 0) + # Removing the write bit from the whole directory makes all contained files + # undeletable on unix. We also need it to be exec so that os.path.isfile + # returns true, and we reach the buggy branch. + os.chmod(os.path.dirname(path), stat.S_IEXEC) + # The test should pass, even though that file cannot be deleted in teardown. + + def test_temp_file_path_like(self): + tempdir = self.create_tempdir('foo') + self.assertIsInstance(tempdir, os.PathLike) + + tempfile_ = tempdir.create_file('bar') + self.assertIsInstance(tempfile_, os.PathLike) + + self.assertEqual(tempfile_.read_text(), pathlib.Path(tempfile_).read_text()) + + def test_unnamed(self): + td = self.create_tempdir() + self.assert_dir_exists(td) + + tdf = td.create_file() + self.assert_file_exists(tdf) + + tdd = td.mkdir() + self.assert_dir_exists(tdd) + + tf = self.create_tempfile() + self.assert_file_exists(tf) + + def test_named(self): + td = self.create_tempdir('d') + self.assert_dir_exists(td) + + tdf = td.create_file('df') + self.assert_file_exists(tdf) + + tdd = td.mkdir('dd') + self.assert_dir_exists(tdd) + + tf = self.create_tempfile('f') + self.assert_file_exists(tf) + + def test_nested_paths(self): + td = self.create_tempdir('d1/d2') + self.assert_dir_exists(td) + + tdf = td.create_file('df1/df2') + self.assert_file_exists(tdf) + + tdd = td.mkdir('dd1/dd2') + self.assert_dir_exists(tdd) + + tf = self.create_tempfile('f1/f2') + self.assert_file_exists(tf) + + def test_tempdir_create_file(self): + td = self.create_tempdir() + td.create_file(content='text') + + def test_tempfile_text(self): + tf = self.create_tempfile(content='text') + self.assert_file_exists(tf, 'text') + self.assertEqual('text', tf.read_text()) + + with tf.open_text() as fp: + self.assertEqual('text', fp.read()) + + with tf.open_text('w') as fp: + fp.write(u'text-from-open-write') + self.assertEqual('text-from-open-write', tf.read_text()) + + tf.write_text('text-from-write-text') + self.assertEqual('text-from-write-text', tf.read_text()) + + def test_tempfile_bytes(self): + tf = self.create_tempfile(content=b'\x00\x01\x02') + self.assert_file_exists(tf, b'\x00\x01\x02') + self.assertEqual(b'\x00\x01\x02', tf.read_bytes()) + + with tf.open_bytes() as fp: + self.assertEqual(b'\x00\x01\x02', fp.read()) + + with tf.open_bytes('wb') as fp: + fp.write(b'\x03') + self.assertEqual(b'\x03', tf.read_bytes()) + + tf.write_bytes(b'\x04') + self.assertEqual(b'\x04', tf.read_bytes()) + + def test_tempdir_same_name(self): + """Make sure the same directory name can be used.""" + td1 = self.create_tempdir('foo') + td2 = self.create_tempdir('foo') + self.assert_dir_exists(td1) + self.assert_dir_exists(td2) + + def test_tempfile_cleanup_success(self): + expected = { + 'TempFileHelperTest', + 'TempFileHelperTest/test_failure', + 'TempFileHelperTest/test_failure/failure', + 'TempFileHelperTest/test_success', + } + self.run_tempfile_helper('SUCCESS', expected) + + def test_tempfile_cleanup_always(self): + expected = { + 'TempFileHelperTest', + 'TempFileHelperTest/test_failure', + 'TempFileHelperTest/test_success', + } + self.run_tempfile_helper('ALWAYS', expected) + + def test_tempfile_cleanup_off(self): + expected = { + 'TempFileHelperTest', + 'TempFileHelperTest/test_failure', + 'TempFileHelperTest/test_failure/failure', + 'TempFileHelperTest/test_success', + 'TempFileHelperTest/test_success/success', + } + self.run_tempfile_helper('OFF', expected) + + +class SkipClassTest(absltest.TestCase): + + def test_incorrect_decorator_call(self): + with self.assertRaises(TypeError): + + @absltest.skipThisClass # pylint: disable=unused-variable + class Test(absltest.TestCase): + pass + + def test_incorrect_decorator_subclass(self): + with self.assertRaises(TypeError): + + @absltest.skipThisClass('reason') + def test_method(): # pylint: disable=unused-variable + pass + + def test_correct_decorator_class(self): + + @absltest.skipThisClass('reason') + class Test(absltest.TestCase): + pass + + with self.assertRaises(absltest.SkipTest): + Test.setUpClass() + + def test_correct_decorator_subclass(self): + + @absltest.skipThisClass('reason') + class Test(absltest.TestCase): + pass + + class Subclass(Test): + pass + + with self.subTest('Base class should be skipped'): + with self.assertRaises(absltest.SkipTest): + Test.setUpClass() + + with self.subTest('Subclass should not be skipped'): + Subclass.setUpClass() # should not raise. + + def test_setup(self): + + @absltest.skipThisClass('reason') + class Test(absltest.TestCase): + + @classmethod + def setUpClass(cls): + super(Test, cls).setUpClass() + cls.foo = 1 + + class Subclass(Test): + pass + + Subclass.setUpClass() + self.assertEqual(Subclass.foo, 1) + + def test_setup_chain(self): + + @absltest.skipThisClass('reason') + class BaseTest(absltest.TestCase): + + @classmethod + def setUpClass(cls): + super(BaseTest, cls).setUpClass() + cls.foo = 1 + + @absltest.skipThisClass('reason') + class SecondBaseTest(BaseTest): + + @classmethod + def setUpClass(cls): + super(SecondBaseTest, cls).setUpClass() + cls.bar = 2 + + class Subclass(SecondBaseTest): + pass + + Subclass.setUpClass() + self.assertEqual(Subclass.foo, 1) + self.assertEqual(Subclass.bar, 2) + + def test_setup_args(self): + + @absltest.skipThisClass('reason') + class Test(absltest.TestCase): + + @classmethod + def setUpClass(cls, foo, bar=None): + super(Test, cls).setUpClass() + cls.foo = foo + cls.bar = bar + + class Subclass(Test): + + @classmethod + def setUpClass(cls): + super(Subclass, cls).setUpClass('foo', bar='baz') + + Subclass.setUpClass() + self.assertEqual(Subclass.foo, 'foo') + self.assertEqual(Subclass.bar, 'baz') + + def test_setup_multiple_inheritance(self): + + # Test that skipping this class doesn't break the MRO chain and stop + # RequiredBase.setUpClass from running. + @absltest.skipThisClass('reason') + class Left(absltest.TestCase): + pass + + class RequiredBase(absltest.TestCase): + + @classmethod + def setUpClass(cls): + super(RequiredBase, cls).setUpClass() + cls.foo = 'foo' + + class Right(RequiredBase): + + @classmethod + def setUpClass(cls): + super(Right, cls).setUpClass() + + # Test will fail unless Left.setUpClass() follows mro properly + # Right.setUpClass() + class Subclass(Left, Right): + + @classmethod + def setUpClass(cls): + super(Subclass, cls).setUpClass() + + class Test(Subclass): + pass + + Test.setUpClass() + self.assertEqual(Test.foo, 'foo') + + def test_skip_class(self): + + @absltest.skipThisClass('reason') + class BaseTest(absltest.TestCase): + + def test_foo(self): + _ = 1 / 0 + + class Test(BaseTest): + + def test_foo(self): + self.assertEqual(1, 1) + + with self.subTest('base class'): + ts = unittest.makeSuite(BaseTest) + self.assertEqual(1, ts.countTestCases()) + + res = unittest.TestResult() + ts.run(res) + self.assertTrue(res.wasSuccessful()) + self.assertLen(res.skipped, 1) + self.assertEqual(0, res.testsRun) + self.assertEmpty(res.failures) + self.assertEmpty(res.errors) + + with self.subTest('real test'): + ts = unittest.makeSuite(Test) + self.assertEqual(1, ts.countTestCases()) + + res = unittest.TestResult() + ts.run(res) + self.assertTrue(res.wasSuccessful()) + self.assertEqual(1, res.testsRun) + self.assertEmpty(res.skipped) + self.assertEmpty(res.failures) + self.assertEmpty(res.errors) + + def test_skip_class_unittest(self): + + @absltest.skipThisClass('reason') + class Test(unittest.TestCase): # note: unittest not absltest + + def test_foo(self): + _ = 1 / 0 + + ts = unittest.makeSuite(Test) + self.assertEqual(1, ts.countTestCases()) + + res = unittest.TestResult() + ts.run(res) + self.assertTrue(res.wasSuccessful()) + self.assertLen(res.skipped, 1) + self.assertEqual(0, res.testsRun) + self.assertEmpty(res.failures) + self.assertEmpty(res.errors) + + +def _listdir_recursive(path): + for dirname, _, filenames in os.walk(path): + yield dirname + for filename in filenames: + yield os.path.join(dirname, filename) + + +def _env_for_command_tests(): + if os.name == 'nt' and 'PATH' in os.environ: + # get_command_stderr and assertCommandXXX don't inherit environment + # variables by default. This makes sure msys commands can be found on + # Windows. + return {'PATH': os.environ['PATH']} + else: + return None + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/testing/tests/absltest_test_helper.py b/absl/testing/tests/absltest_test_helper.py new file mode 100644 index 0000000..c6b2465 --- /dev/null +++ b/absl/testing/tests/absltest_test_helper.py @@ -0,0 +1,106 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper binary for absltest_test.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile +import unittest + +from absl import flags +from absl.testing import absltest + +FLAGS = flags.FLAGS + +flags.DEFINE_integer('test_id', 0, 'Which test to run.') + + +class HelperTest(absltest.TestCase): + + def test_flags(self): + if FLAGS.test_id == 1: + self.assertEqual(FLAGS.test_random_seed, 301) + if os.name == 'nt': + # On Windows, it's always in the temp dir, which doesn't start with '/'. + expected_prefix = tempfile.gettempdir() + else: + expected_prefix = '/' + self.assertTrue( + absltest.TEST_TMPDIR.value.startswith(expected_prefix), + '--test_tmpdir={} does not start with {}'.format( + absltest.TEST_TMPDIR.value, expected_prefix)) + self.assertTrue(os.access(absltest.TEST_TMPDIR.value, os.W_OK)) + elif FLAGS.test_id == 2: + self.assertEqual(FLAGS.test_random_seed, 321) + self.assertEqual( + absltest.TEST_SRCDIR.value, + os.environ['ABSLTEST_TEST_HELPER_EXPECTED_TEST_SRCDIR']) + self.assertEqual( + absltest.TEST_TMPDIR.value, + os.environ['ABSLTEST_TEST_HELPER_EXPECTED_TEST_TMPDIR']) + elif FLAGS.test_id == 3: + self.assertEqual(FLAGS.test_random_seed, 123) + self.assertEqual( + absltest.TEST_SRCDIR.value, + os.environ['ABSLTEST_TEST_HELPER_EXPECTED_TEST_SRCDIR']) + self.assertEqual( + absltest.TEST_TMPDIR.value, + os.environ['ABSLTEST_TEST_HELPER_EXPECTED_TEST_TMPDIR']) + elif FLAGS.test_id == 4: + self.assertEqual(FLAGS.test_random_seed, 221) + self.assertEqual( + absltest.TEST_SRCDIR.value, + os.environ['ABSLTEST_TEST_HELPER_EXPECTED_TEST_SRCDIR']) + self.assertEqual( + absltest.TEST_TMPDIR.value, + os.environ['ABSLTEST_TEST_HELPER_EXPECTED_TEST_TMPDIR']) + else: + raise unittest.SkipTest( + 'Not asked to run: --test_id={}'.format(FLAGS.test_id)) + + @unittest.expectedFailure + def test_expected_failure(self): + if FLAGS.test_id == 5: + self.assertEqual(1, 1) # Expected failure, got success. + else: + self.assertEqual(1, 2) # The expected failure. + + def test_xml_env_vars(self): + if FLAGS.test_id == 6: + self.assertEqual( + FLAGS.xml_output_file, + os.environ['ABSLTEST_TEST_HELPER_EXPECTED_XML_OUTPUT_FILE']) + else: + raise unittest.SkipTest( + 'Not asked to run: --test_id={}'.format(FLAGS.test_id)) + + +class TempFileHelperTest(absltest.TestCase): + tempfile_cleanup = absltest.TempFileCleanup[os.environ.get( + 'ABSLTEST_TEST_HELPER_TEMPFILE_CLEANUP', 'SUCCESS')] + + def test_failure(self): + self.create_tempfile('failure') + self.fail('expected failure') + + def test_success(self): + self.create_tempfile('success') + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/testing/tests/flagsaver_test.py b/absl/testing/tests/flagsaver_test.py new file mode 100644 index 0000000..3439a32 --- /dev/null +++ b/absl/testing/tests/flagsaver_test.py @@ -0,0 +1,467 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for flagsaver.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import flags +from absl.testing import absltest +from absl.testing import flagsaver + +flags.DEFINE_string('flagsaver_test_flag0', 'unchanged0', 'flag to test with') +flags.DEFINE_string('flagsaver_test_flag1', 'unchanged1', 'flag to test with') + +flags.DEFINE_string('flagsaver_test_validated_flag', None, 'flag to test with') +flags.register_validator('flagsaver_test_validated_flag', lambda x: not x) + +flags.DEFINE_string('flagsaver_test_validated_flag1', None, 'flag to test with') +flags.DEFINE_string('flagsaver_test_validated_flag2', None, 'flag to test with') + +INT_FLAG = flags.DEFINE_integer( + 'flagsaver_test_int_flag', default=1, help='help') +STR_FLAG = flags.DEFINE_string( + 'flagsaver_test_str_flag', default='str default', help='help') + + +@flags.multi_flags_validator( + ('flagsaver_test_validated_flag1', 'flagsaver_test_validated_flag2')) +def validate_test_flags(flag_dict): + return (flag_dict['flagsaver_test_validated_flag1'] == + flag_dict['flagsaver_test_validated_flag2']) + + +FLAGS = flags.FLAGS + + +@flags.validator('flagsaver_test_flag0') +def check_no_upper_case(value): + return value == value.lower() + + +class _TestError(Exception): + """Exception class for use in these tests.""" + + +class FlagSaverTest(absltest.TestCase): + + def test_context_manager_without_parameters(self): + with flagsaver.flagsaver(): + FLAGS.flagsaver_test_flag0 = 'new value' + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + + def test_context_manager_with_overrides(self): + with flagsaver.flagsaver(flagsaver_test_flag0='new value'): + self.assertEqual('new value', FLAGS.flagsaver_test_flag0) + FLAGS.flagsaver_test_flag1 = 'another value' + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1) + + def test_context_manager_with_flagholders(self): + with flagsaver.flagsaver((INT_FLAG, 3), (STR_FLAG, 'new value')): + self.assertEqual('new value', STR_FLAG.value) + self.assertEqual(3, INT_FLAG.value) + FLAGS.flagsaver_test_flag1 = 'another value' + self.assertEqual(INT_FLAG.value, INT_FLAG.default) + self.assertEqual(STR_FLAG.value, STR_FLAG.default) + self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1) + + def test_context_manager_with_overrides_and_flagholders(self): + with flagsaver.flagsaver((INT_FLAG, 3), flagsaver_test_flag0='new value'): + self.assertEqual(STR_FLAG.default, STR_FLAG.value) + self.assertEqual(3, INT_FLAG.value) + FLAGS.flagsaver_test_flag0 = 'new value' + self.assertEqual(INT_FLAG.value, INT_FLAG.default) + self.assertEqual(STR_FLAG.value, STR_FLAG.default) + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + + def test_context_manager_with_cross_validated_overrides_set_together(self): + # When the flags are set in the same flagsaver call their validators will + # be triggered only once the setting is done. + with flagsaver.flagsaver( + flagsaver_test_validated_flag1='new_value', + flagsaver_test_validated_flag2='new_value'): + self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag1) + self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag2) + + self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) + + def test_context_manager_with_cross_validated_overrides_set_badly(self): + + # Different values should violate the validator. + with self.assertRaisesRegex(flags.IllegalFlagValueError, + 'Flag validation failed'): + with flagsaver.flagsaver( + flagsaver_test_validated_flag1='new_value', + flagsaver_test_validated_flag2='other_value'): + pass + + self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) + + def test_context_manager_with_cross_validated_overrides_set_separately(self): + + # Setting just one flag will trip the validator as well. + with self.assertRaisesRegex(flags.IllegalFlagValueError, + 'Flag validation failed'): + with flagsaver.flagsaver(flagsaver_test_validated_flag1='new_value'): + pass + + self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) + + def test_context_manager_with_exception(self): + with self.assertRaises(_TestError): + with flagsaver.flagsaver(flagsaver_test_flag0='new value'): + self.assertEqual('new value', FLAGS.flagsaver_test_flag0) + FLAGS.flagsaver_test_flag1 = 'another value' + raise _TestError('oops') + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1) + + def test_context_manager_with_validation_exception(self): + with self.assertRaises(flags.IllegalFlagValueError): + with flagsaver.flagsaver( + flagsaver_test_flag0='new value', + flagsaver_test_validated_flag='new value'): + pass + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag) + + def test_decorator_without_call(self): + + @flagsaver.flagsaver + def mutate_flags(value): + """Test function that mutates a flag.""" + # The undecorated method mutates --flagsaver_test_flag0 to the given value + # and then returns the value of that flag. If the @flagsaver.flagsaver + # decorator works as designed, then this mutation will be reverted after + # this method returns. + FLAGS.flagsaver_test_flag0 = value + return FLAGS.flagsaver_test_flag0 + + # mutate_flags returns the flag value before it gets restored by + # the flagsaver decorator. So we check that flag value was + # actually changed in the method's scope. + self.assertEqual('new value', mutate_flags('new value')) + # But... notice that the flag is now unchanged0. + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + + def test_decorator_without_parameters(self): + + @flagsaver.flagsaver() + def mutate_flags(value): + FLAGS.flagsaver_test_flag0 = value + return FLAGS.flagsaver_test_flag0 + + self.assertEqual('new value', mutate_flags('new value')) + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + + def test_decorator_with_overrides(self): + + @flagsaver.flagsaver(flagsaver_test_flag0='new value') + def mutate_flags(): + """Test function expecting new value.""" + # If the @flagsaver.decorator decorator works as designed, + # then the value of the flag should be changed in the scope of + # the method but the change will be reverted after this method + # returns. + return FLAGS.flagsaver_test_flag0 + + # mutate_flags returns the flag value before it gets restored by + # the flagsaver decorator. So we check that flag value was + # actually changed in the method's scope. + self.assertEqual('new value', mutate_flags()) + # But... notice that the flag is now unchanged0. + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + + def test_decorator_with_cross_validated_overrides_set_together(self): + + # When the flags are set in the same flagsaver call their validators will + # be triggered only once the setting is done. + @flagsaver.flagsaver( + flagsaver_test_validated_flag1='new_value', + flagsaver_test_validated_flag2='new_value') + def mutate_flags_together(): + return (FLAGS.flagsaver_test_validated_flag1, + FLAGS.flagsaver_test_validated_flag2) + + self.assertEqual(('new_value', 'new_value'), mutate_flags_together()) + + # The flags have not changed outside the context of the function. + self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) + + def test_decorator_with_cross_validated_overrides_set_badly(self): + + # Different values should violate the validator. + @flagsaver.flagsaver( + flagsaver_test_validated_flag1='new_value', + flagsaver_test_validated_flag2='other_value') + def mutate_flags_together_badly(): + return (FLAGS.flagsaver_test_validated_flag1, + FLAGS.flagsaver_test_validated_flag2) + + with self.assertRaisesRegex(flags.IllegalFlagValueError, + 'Flag validation failed'): + mutate_flags_together_badly() + + # The flags have not changed outside the context of the exception. + self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) + + def test_decorator_with_cross_validated_overrides_set_separately(self): + + # Setting the flags sequentially and not together will trip the validator, + # because it will be called at the end of each flagsaver call. + @flagsaver.flagsaver(flagsaver_test_validated_flag1='new_value') + @flagsaver.flagsaver(flagsaver_test_validated_flag2='new_value') + def mutate_flags_separately(): + return (FLAGS.flagsaver_test_validated_flag1, + FLAGS.flagsaver_test_validated_flag2) + + with self.assertRaisesRegex(flags.IllegalFlagValueError, + 'Flag validation failed'): + mutate_flags_separately() + + # The flags have not changed outside the context of the exception. + self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) + self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) + + def test_save_flag_value(self): + # First save the flag values. + saved_flag_values = flagsaver.save_flag_values() + + # Now mutate the flag's value field and check that it changed. + FLAGS.flagsaver_test_flag0 = 'new value' + self.assertEqual('new value', FLAGS.flagsaver_test_flag0) + + # Now restore the flag to its original value. + flagsaver.restore_flag_values(saved_flag_values) + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + + def test_save_flag_default(self): + # First save the flag. + saved_flag_values = flagsaver.save_flag_values() + + # Now mutate the flag's default field and check that it changed. + FLAGS.set_default('flagsaver_test_flag0', 'new_default') + self.assertEqual('new_default', FLAGS['flagsaver_test_flag0'].default) + + # Now restore the flag's default field. + flagsaver.restore_flag_values(saved_flag_values) + self.assertEqual('unchanged0', FLAGS['flagsaver_test_flag0'].default) + + def test_restore_after_parse(self): + # First save the flag. + saved_flag_values = flagsaver.save_flag_values() + + # Sanity check (would fail if called with --flagsaver_test_flag0). + self.assertEqual(0, FLAGS['flagsaver_test_flag0'].present) + # Now populate the flag and check that it changed. + FLAGS['flagsaver_test_flag0'].parse('new value') + self.assertEqual('new value', FLAGS['flagsaver_test_flag0'].value) + self.assertEqual(1, FLAGS['flagsaver_test_flag0'].present) + + # Now restore the flag to its original value. + flagsaver.restore_flag_values(saved_flag_values) + self.assertEqual('unchanged0', FLAGS['flagsaver_test_flag0'].value) + self.assertEqual(0, FLAGS['flagsaver_test_flag0'].present) + + def test_decorator_with_exception(self): + + @flagsaver.flagsaver + def raise_exception(): + FLAGS.flagsaver_test_flag0 = 'new value' + # Simulate a failed test. + raise _TestError('something happened') + + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + self.assertRaises(_TestError, raise_exception) + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + + def test_validator_list_is_restored(self): + + self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 1) + original_validators = list(FLAGS['flagsaver_test_flag0'].validators) + + @flagsaver.flagsaver + def modify_validators(): + + def no_space(value): + return ' ' not in value + + flags.register_validator('flagsaver_test_flag0', no_space) + self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 2) + + modify_validators() + self.assertEqual(original_validators, + FLAGS['flagsaver_test_flag0'].validators) + + +class FlagSaverDecoratorUsageTest(absltest.TestCase): + + @flagsaver.flagsaver + def test_mutate1(self): + # Even though other test cases change the flag, it should be + # restored to 'unchanged0' if the flagsaver is working. + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + FLAGS.flagsaver_test_flag0 = 'changed0' + + @flagsaver.flagsaver + def test_mutate2(self): + # Even though other test cases change the flag, it should be + # restored to 'unchanged0' if the flagsaver is working. + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + FLAGS.flagsaver_test_flag0 = 'changed0' + + @flagsaver.flagsaver + def test_mutate3(self): + # Even though other test cases change the flag, it should be + # restored to 'unchanged0' if the flagsaver is working. + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + FLAGS.flagsaver_test_flag0 = 'changed0' + + @flagsaver.flagsaver + def test_mutate4(self): + # Even though other test cases change the flag, it should be + # restored to 'unchanged0' if the flagsaver is working. + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + FLAGS.flagsaver_test_flag0 = 'changed0' + + +class FlagSaverSetUpTearDownUsageTest(absltest.TestCase): + + def setUp(self): + self.saved_flag_values = flagsaver.save_flag_values() + + def tearDown(self): + flagsaver.restore_flag_values(self.saved_flag_values) + + def test_mutate1(self): + # Even though other test cases change the flag, it should be + # restored to 'unchanged0' if the flagsaver is working. + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + FLAGS.flagsaver_test_flag0 = 'changed0' + + def test_mutate2(self): + # Even though other test cases change the flag, it should be + # restored to 'unchanged0' if the flagsaver is working. + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + FLAGS.flagsaver_test_flag0 = 'changed0' + + def test_mutate3(self): + # Even though other test cases change the flag, it should be + # restored to 'unchanged0' if the flagsaver is working. + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + FLAGS.flagsaver_test_flag0 = 'changed0' + + def test_mutate4(self): + # Even though other test cases change the flag, it should be + # restored to 'unchanged0' if the flagsaver is working. + self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) + FLAGS.flagsaver_test_flag0 = 'changed0' + + +class FlagSaverBadUsageTest(absltest.TestCase): + """Tests that certain kinds of improper usages raise errors.""" + + def test_flag_saver_on_class(self): + with self.assertRaises(TypeError): + + # WRONG. Don't do this. + # Consider the correct usage example in FlagSaverSetUpTearDownUsageTest. + @flagsaver.flagsaver + class FooTest(absltest.TestCase): + + def test_tautology(self): + pass + + del FooTest + + def test_flag_saver_call_on_class(self): + with self.assertRaises(TypeError): + + # WRONG. Don't do this. + # Consider the correct usage example in FlagSaverSetUpTearDownUsageTest. + @flagsaver.flagsaver() + class FooTest(absltest.TestCase): + + def test_tautology(self): + pass + + del FooTest + + def test_flag_saver_with_overrides_on_class(self): + with self.assertRaises(TypeError): + + # WRONG. Don't do this. + # Consider the correct usage example in FlagSaverSetUpTearDownUsageTest. + @flagsaver.flagsaver(foo='bar') + class FooTest(absltest.TestCase): + + def test_tautology(self): + pass + + del FooTest + + def test_multiple_positional_parameters(self): + with self.assertRaises(ValueError): + func_a = lambda: None + func_b = lambda: None + flagsaver.flagsaver(func_a, func_b) + + def test_both_positional_and_keyword_parameters(self): + with self.assertRaises(ValueError): + func_a = lambda: None + flagsaver.flagsaver(func_a, flagsaver_test_flag0='new value') + + def test_duplicate_holder_parameters(self): + with self.assertRaises(ValueError): + flagsaver.flagsaver((INT_FLAG, 45), (INT_FLAG, 45)) + + def test_duplicate_holder_and_kw_parameter(self): + with self.assertRaises(ValueError): + flagsaver.flagsaver((INT_FLAG, 45), **{INT_FLAG.name: 45}) + + def test_both_positional_and_holder_parameters(self): + with self.assertRaises(ValueError): + func_a = lambda: None + flagsaver.flagsaver(func_a, (INT_FLAG, 45)) + + def test_holder_parameters_wrong_shape(self): + with self.assertRaises(ValueError): + flagsaver.flagsaver(INT_FLAG) + + def test_holder_parameters_tuple_too_long(self): + with self.assertRaises(ValueError): + # Even if it is a bool flag, it should be a tuple + flagsaver.flagsaver((INT_FLAG, 4, 5)) + + def test_holder_parameters_tuple_wrong_type(self): + with self.assertRaises(ValueError): + # Even if it is a bool flag, it should be a tuple + flagsaver.flagsaver((4, INT_FLAG)) + + def test_both_wrong_positional_parameters(self): + with self.assertRaises(ValueError): + func_a = lambda: None + flagsaver.flagsaver(func_a, STR_FLAG, '45') + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/testing/tests/parameterized_test.py b/absl/testing/tests/parameterized_test.py new file mode 100644 index 0000000..8acbd93 --- /dev/null +++ b/absl/testing/tests/parameterized_test.py @@ -0,0 +1,1077 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for absl.testing.parameterized.""" + +from collections import abc +import sys +import unittest + +from absl.testing import absltest +from absl.testing import parameterized + + +class MyOwnClass(object): + pass + + +def dummy_decorator(method): + + def decorated(*args, **kwargs): + return method(*args, **kwargs) + + return decorated + + +def dict_decorator(key, value): + """Sample implementation of a chained decorator. + + Sets a single field in a dict on a test with a dict parameter. + Uses the exposed '_ParameterizedTestIter.testcases' field to + modify arguments from previous decorators to allow decorator chains. + + Args: + key: key to map to + value: value to set + + Returns: + The test decorator + """ + def decorator(test_method): + # If decorating result of another dict_decorator + if isinstance(test_method, abc.Iterable): + actual_tests = [] + for old_test in test_method.testcases: + # each test is a ('test_suffix', dict) tuple + new_dict = old_test[1].copy() + new_dict[key] = value + test_suffix = '%s_%s_%s' % (old_test[0], key, value) + actual_tests.append((test_suffix, new_dict)) + + test_method.testcases = actual_tests + return test_method + else: + test_suffix = ('_%s_%s') % (key, value) + tests_to_make = ((test_suffix, {key: value}),) + # 'test_method' here is the original test method + return parameterized.named_parameters(*tests_to_make)(test_method) + return decorator + + +class ParameterizedTestsTest(absltest.TestCase): + # The test testcases are nested so they're not + # picked up by the normal test case loader code. + + class GoodAdditionParams(parameterized.TestCase): + + @parameterized.parameters( + (1, 2, 3), + (4, 5, 9)) + def test_addition(self, op1, op2, result): + self.arguments = (op1, op2, result) + self.assertEqual(result, op1 + op2) + + # This class does not inherit from TestCase. + class BadAdditionParams(absltest.TestCase): + + @parameterized.parameters( + (1, 2, 3), + (4, 5, 9)) + def test_addition(self, op1, op2, result): + pass # Always passes, but not called w/out TestCase. + + class MixedAdditionParams(parameterized.TestCase): + + @parameterized.parameters( + (1, 2, 1), + (4, 5, 9)) + def test_addition(self, op1, op2, result): + self.arguments = (op1, op2, result) + self.assertEqual(result, op1 + op2) + + class DictionaryArguments(parameterized.TestCase): + + @parameterized.parameters( + {'op1': 1, 'op2': 2, 'result': 3}, + {'op1': 4, 'op2': 5, 'result': 9}) + def test_addition(self, op1, op2, result): + self.assertEqual(result, op1 + op2) + + class NoParameterizedTests(parameterized.TestCase): + # iterable member with non-matching name + a = 'BCD' + # member with matching name, but not a generator + testInstanceMember = None # pylint: disable=invalid-name + test_instance_member = None + + # member with a matching name and iterator, but not a generator + testString = 'foo' # pylint: disable=invalid-name + test_string = 'foo' + + # generator, but no matching name + def someGenerator(self): # pylint: disable=invalid-name + yield + yield + yield + + def some_generator(self): + yield + yield + yield + + # Generator function, but not a generator instance. + def testGenerator(self): + yield + yield + yield + + def test_generator(self): + yield + yield + yield + + def testNormal(self): + self.assertEqual(3, 1 + 2) + + def test_normal(self): + self.assertEqual(3, 1 + 2) + + class ArgumentsWithAddresses(parameterized.TestCase): + + @parameterized.parameters( + (object(),), + (MyOwnClass(),), + ) + def test_something(self, case): + pass + + class CamelCaseNamedTests(parameterized.TestCase): + + @parameterized.named_parameters( + ('Interesting', 0), + ) + def testSingle(self, case): + pass + + @parameterized.named_parameters( + {'testcase_name': 'Interesting', 'case': 0}, + ) + def testDictSingle(self, case): + pass + + @parameterized.named_parameters( + ('Interesting', 0), + ('Boring', 1), + ) + def testSomething(self, case): + pass + + @parameterized.named_parameters( + {'testcase_name': 'Interesting', 'case': 0}, + {'testcase_name': 'Boring', 'case': 1}, + ) + def testDictSomething(self, case): + pass + + @parameterized.named_parameters( + {'testcase_name': 'Interesting', 'case': 0}, + ('Boring', 1), + ) + def testMixedSomething(self, case): + pass + + def testWithoutParameters(self): + pass + + class NamedTests(parameterized.TestCase): + """Example tests using PEP-8 style names instead of camel-case.""" + + @parameterized.named_parameters( + ('interesting', 0), + ) + def test_single(self, case): + pass + + @parameterized.named_parameters( + {'testcase_name': 'interesting', 'case': 0}, + ) + def test_dict_single(self, case): + pass + + @parameterized.named_parameters( + ('interesting', 0), + ('boring', 1), + ) + def test_something(self, case): + pass + + @parameterized.named_parameters( + {'testcase_name': 'interesting', 'case': 0}, + {'testcase_name': 'boring', 'case': 1}, + ) + def test_dict_something(self, case): + pass + + @parameterized.named_parameters( + {'testcase_name': 'interesting', 'case': 0}, + ('boring', 1), + ) + def test_mixed_something(self, case): + pass + + def test_without_parameters(self): + pass + + class ChainedTests(parameterized.TestCase): + + @dict_decorator('cone', 'waffle') + @dict_decorator('flavor', 'strawberry') + def test_chained(self, dictionary): + self.assertDictEqual(dictionary, {'cone': 'waffle', + 'flavor': 'strawberry'}) + + class SingletonListExtraction(parameterized.TestCase): + + @parameterized.parameters( + (i, i * 2) for i in range(10)) + def test_something(self, unused_1, unused_2): + pass + + class SingletonArgumentExtraction(parameterized.TestCase): + + @parameterized.parameters(1, 2, 3, 4, 5, 6) + def test_numbers(self, unused_1): + pass + + @parameterized.parameters('foo', 'bar', 'baz') + def test_strings(self, unused_1): + pass + + class SingletonDictArgument(parameterized.TestCase): + + @parameterized.parameters({'op1': 1, 'op2': 2}) + def test_something(self, op1, op2): + del op1, op2 + + @parameterized.parameters( + (1, 2, 3), + (4, 5, 9)) + class DecoratedClass(parameterized.TestCase): + + def test_add(self, arg1, arg2, arg3): + self.assertEqual(arg1 + arg2, arg3) + + def test_subtract_fail(self, arg1, arg2, arg3): + self.assertEqual(arg3 + arg2, arg1) + + @parameterized.parameters( + (a, b, a+b) for a in range(1, 5) for b in range(1, 5)) + class GeneratorDecoratedClass(parameterized.TestCase): + + def test_add(self, arg1, arg2, arg3): + self.assertEqual(arg1 + arg2, arg3) + + def test_subtract_fail(self, arg1, arg2, arg3): + self.assertEqual(arg3 + arg2, arg1) + + @parameterized.parameters( + (1, 2, 3), + (4, 5, 9), + ) + class DecoratedBareClass(absltest.TestCase): + + def test_add(self, arg1, arg2, arg3): + self.assertEqual(arg1 + arg2, arg3) + + class OtherDecoratorUnnamed(parameterized.TestCase): + + @dummy_decorator + @parameterized.parameters((1), (2)) + def test_other_then_parameterized(self, arg1): + pass + + @parameterized.parameters((1), (2)) + @dummy_decorator + def test_parameterized_then_other(self, arg1): + pass + + class OtherDecoratorNamed(parameterized.TestCase): + + @dummy_decorator + @parameterized.named_parameters(('a', 1), ('b', 2)) + def test_other_then_parameterized(self, arg1): + pass + + @parameterized.named_parameters(('a', 1), ('b', 2)) + @dummy_decorator + def test_parameterized_then_other(self, arg1): + pass + + class OtherDecoratorNamedWithDict(parameterized.TestCase): + + @dummy_decorator + @parameterized.named_parameters( + {'testcase_name': 'a', 'arg1': 1}, + {'testcase_name': 'b', 'arg1': 2}) + def test_other_then_parameterized(self, arg1): + pass + + @parameterized.named_parameters( + {'testcase_name': 'a', 'arg1': 1}, + {'testcase_name': 'b', 'arg1': 2}) + @dummy_decorator + def test_parameterized_then_other(self, arg1): + pass + + class UniqueDescriptiveNamesTest(parameterized.TestCase): + + @parameterized.parameters(13, 13) + def test_normal(self, number): + del number + + class MultiGeneratorsTestCase(parameterized.TestCase): + + @parameterized.parameters((i for i in (1, 2, 3)), (i for i in (3, 2, 1))) + def test_sum(self, a, b, c): + self.assertEqual(6, sum([a, b, c])) + + class NamedParametersReusableTestCase(parameterized.TestCase): + named_params_a = ( + {'testcase_name': 'dict_a', 'unused_obj': 0}, + ('list_a', 1), + ) + named_params_b = ( + {'testcase_name': 'dict_b', 'unused_obj': 2}, + ('list_b', 3), + ) + named_params_c = ( + {'testcase_name': 'dict_c', 'unused_obj': 4}, + ('list_b', 5), + ) + + @parameterized.named_parameters(*(named_params_a + named_params_b)) + def testSomething(self, unused_obj): + pass + + @parameterized.named_parameters(*(named_params_a + named_params_c)) + def testSomethingElse(self, unused_obj): + pass + + class SuperclassTestCase(parameterized.TestCase): + + @parameterized.parameters('foo', 'bar') + def test_name(self, name): + del name + + class SubclassTestCase(SuperclassTestCase): + pass + + @unittest.skipIf( + (sys.version_info[:2] == (3, 7) and sys.version_info[2] in {0, 1, 2}), + 'Python 3.7.0 to 3.7.2 have a bug that breaks this test, see ' + 'https://bugs.python.org/issue35767') + def test_missing_inheritance(self): + ts = unittest.makeSuite(self.BadAdditionParams) + self.assertEqual(1, ts.countTestCases()) + + res = unittest.TestResult() + ts.run(res) + self.assertEqual(1, res.testsRun) + self.assertFalse(res.wasSuccessful()) + self.assertIn('without having inherited', str(res.errors[0])) + + def test_correct_extraction_numbers(self): + ts = unittest.makeSuite(self.GoodAdditionParams) + self.assertEqual(2, ts.countTestCases()) + + def test_successful_execution(self): + ts = unittest.makeSuite(self.GoodAdditionParams) + + res = unittest.TestResult() + ts.run(res) + self.assertEqual(2, res.testsRun) + self.assertTrue(res.wasSuccessful()) + + def test_correct_arguments(self): + ts = unittest.makeSuite(self.GoodAdditionParams) + res = unittest.TestResult() + + params = set([ + (1, 2, 3), + (4, 5, 9)]) + for test in ts: + test(res) + self.assertIn(test.arguments, params) + params.remove(test.arguments) + self.assertEmpty(params) + + def test_recorded_failures(self): + ts = unittest.makeSuite(self.MixedAdditionParams) + self.assertEqual(2, ts.countTestCases()) + + res = unittest.TestResult() + ts.run(res) + self.assertEqual(2, res.testsRun) + self.assertFalse(res.wasSuccessful()) + self.assertLen(res.failures, 1) + self.assertEmpty(res.errors) + + def test_short_description(self): + ts = unittest.makeSuite(self.GoodAdditionParams) + short_desc = list(ts)[0].shortDescription() + + location = unittest.util.strclass(self.GoodAdditionParams).replace( + '__main__.', '') + expected = ('{}.test_addition0 (1, 2, 3)\n'.format(location) + + 'test_addition(1, 2, 3)') + self.assertEqual(expected, short_desc) + + def test_short_description_addresses_removed(self): + ts = unittest.makeSuite(self.ArgumentsWithAddresses) + short_desc = list(ts)[0].shortDescription().split('\n') + self.assertEqual( + 'test_something(<object>)', short_desc[1]) + short_desc = list(ts)[1].shortDescription().split('\n') + self.assertEqual( + 'test_something(<__main__.MyOwnClass>)', short_desc[1]) + + def test_id(self): + ts = unittest.makeSuite(self.ArgumentsWithAddresses) + self.assertEqual( + (unittest.util.strclass(self.ArgumentsWithAddresses) + + '.test_something0 (<object>)'), + list(ts)[0].id()) + ts = unittest.makeSuite(self.GoodAdditionParams) + self.assertEqual( + (unittest.util.strclass(self.GoodAdditionParams) + + '.test_addition0 (1, 2, 3)'), + list(ts)[0].id()) + + def test_str(self): + ts = unittest.makeSuite(self.GoodAdditionParams) + test = list(ts)[0] + + expected = 'test_addition0 (1, 2, 3) ({})'.format( + unittest.util.strclass(self.GoodAdditionParams)) + self.assertEqual(expected, str(test)) + + def test_dict_parameters(self): + ts = unittest.makeSuite(self.DictionaryArguments) + res = unittest.TestResult() + ts.run(res) + self.assertEqual(2, res.testsRun) + self.assertTrue(res.wasSuccessful()) + + def test_no_parameterized_tests(self): + ts = unittest.makeSuite(self.NoParameterizedTests) + self.assertEqual(4, ts.countTestCases()) + short_descs = [x.shortDescription() for x in list(ts)] + full_class_name = unittest.util.strclass(self.NoParameterizedTests) + full_class_name = full_class_name.replace('__main__.', '') + self.assertSameElements( + [ + '{}.testGenerator'.format(full_class_name), + '{}.test_generator'.format(full_class_name), + '{}.testNormal'.format(full_class_name), + '{}.test_normal'.format(full_class_name), + ], + short_descs) + + def test_successful_product_test_testgrid(self): + + class GoodProductTestCase(parameterized.TestCase): + + @parameterized.product( + num=(0, 20, 80), + modulo=(2, 4), + expected=(0,) + ) + def testModuloResult(self, num, modulo, expected): + self.assertEqual(expected, num % modulo) + + ts = unittest.makeSuite(GoodProductTestCase) + res = unittest.TestResult() + ts.run(res) + self.assertEqual(ts.countTestCases(), 6) + self.assertEqual(res.testsRun, 6) + self.assertTrue(res.wasSuccessful()) + + def test_successful_product_test_kwarg_seqs(self): + + class GoodProductTestCase(parameterized.TestCase): + + @parameterized.product((dict(num=0), dict(num=20), dict(num=0)), + (dict(modulo=2), dict(modulo=4)), + (dict(expected=0),)) + def testModuloResult(self, num, modulo, expected): + self.assertEqual(expected, num % modulo) + + ts = unittest.makeSuite(GoodProductTestCase) + res = unittest.TestResult() + ts.run(res) + self.assertEqual(ts.countTestCases(), 6) + self.assertEqual(res.testsRun, 6) + self.assertTrue(res.wasSuccessful()) + + def test_successful_product_test_kwarg_seq_and_testgrid(self): + + class GoodProductTestCase(parameterized.TestCase): + + @parameterized.product((dict( + num=5, modulo=3, expected=2), dict(num=7, modulo=4, expected=3)), + dtype=(int, float)) + def testModuloResult(self, num, dtype, modulo, expected): + self.assertEqual(expected, dtype(num) % modulo) + + ts = unittest.makeSuite(GoodProductTestCase) + res = unittest.TestResult() + ts.run(res) + self.assertEqual(ts.countTestCases(), 4) + self.assertEqual(res.testsRun, 4) + self.assertTrue(res.wasSuccessful()) + + def test_inconsistent_arg_names_in_kwargs_seq(self): + with self.assertRaisesRegex(AssertionError, 'must all have the same keys'): + + class BadProductParams(parameterized.TestCase): # pylint: disable=unused-variable + + @parameterized.product((dict(num=5, modulo=3), dict(num=7, modula=2)), + dtype=(int, float)) + def test_something(self): + pass # not called because argnames are not the same + + def test_duplicate_arg_names_in_kwargs_seqs(self): + with self.assertRaisesRegex(AssertionError, 'must all have distinct'): + + class BadProductParams(parameterized.TestCase): # pylint: disable=unused-variable + + @parameterized.product((dict(num=5, modulo=3), dict(num=7, modulo=4)), + (dict(foo='bar', num=5), dict(foo='baz', num=7)), + dtype=(int, float)) + def test_something(self): + pass # not called because `num` is specified twice + + def test_duplicate_arg_names_in_kwargs_seq_and_testgrid(self): + with self.assertRaisesRegex(AssertionError, 'duplicate argument'): + + class BadProductParams(parameterized.TestCase): # pylint: disable=unused-variable + + @parameterized.product( + (dict(num=5, modulo=3), dict(num=7, modulo=4)), + (dict(foo='bar'), dict(foo='baz')), + dtype=(int, float), + foo=('a', 'b'), + ) + def test_something(self): + pass # not called because `foo` is specified twice + + def test_product_recorded_failures(self): + + class MixedProductTestCase(parameterized.TestCase): + + @parameterized.product( + num=(0, 10, 20), + modulo=(2, 4), + expected=(0,) + ) + def testModuloResult(self, num, modulo, expected): + self.assertEqual(expected, num % modulo) + + ts = unittest.makeSuite(MixedProductTestCase) + self.assertEqual(6, ts.countTestCases()) + + res = unittest.TestResult() + ts.run(res) + self.assertEqual(res.testsRun, 6) + self.assertFalse(res.wasSuccessful()) + self.assertLen(res.failures, 1) + self.assertEmpty(res.errors) + + def test_mismatched_product_parameter(self): + + class MismatchedProductParam(parameterized.TestCase): + + @parameterized.product( + a=(1, 2), + mismatch=(1, 2) + ) + # will fail because of mismatch in parameter names. + def test_something(self, a, b): + pass + + ts = unittest.makeSuite(MismatchedProductParam) + res = unittest.TestResult() + ts.run(res) + self.assertEqual(res.testsRun, 4) + self.assertFalse(res.wasSuccessful()) + self.assertLen(res.errors, 4) + + def test_no_test_error_empty_product_parameter(self): + with self.assertRaises(parameterized.NoTestsError): + + class EmptyProductParam(parameterized.TestCase): # pylint: disable=unused-variable + + @parameterized.product(arg1=[1, 2], arg2=[]) + def test_something(self, arg1, arg2): + pass # not called because arg2 has empty list of values. + + def test_bad_product_parameters(self): + with self.assertRaisesRegex(AssertionError, 'must be given as list or'): + + class BadProductParams(parameterized.TestCase): # pylint: disable=unused-variable + + @parameterized.product(arg1=[1, 2], arg2='abcd') + def test_something(self, arg1, arg2): + pass # not called because arg2 is not list or tuple. + + def test_generator_tests_disallowed(self): + with self.assertRaisesRegex(RuntimeError, 'generated.*without'): + class GeneratorTests(parameterized.TestCase): # pylint: disable=unused-variable + test_generator_method = (lambda self: None for _ in range(10)) + + def test_named_parameters_run(self): + ts = unittest.makeSuite(self.NamedTests) + self.assertEqual(9, ts.countTestCases()) + res = unittest.TestResult() + ts.run(res) + self.assertEqual(9, res.testsRun) + self.assertTrue(res.wasSuccessful()) + + def test_named_parameters_id(self): + ts = sorted(unittest.makeSuite(self.CamelCaseNamedTests), + key=lambda t: t.id()) + self.assertLen(ts, 9) + full_class_name = unittest.util.strclass(self.CamelCaseNamedTests) + self.assertEqual( + full_class_name + '.testDictSingleInteresting', + ts[0].id()) + self.assertEqual( + full_class_name + '.testDictSomethingBoring', + ts[1].id()) + self.assertEqual( + full_class_name + '.testDictSomethingInteresting', + ts[2].id()) + self.assertEqual( + full_class_name + '.testMixedSomethingBoring', + ts[3].id()) + self.assertEqual( + full_class_name + '.testMixedSomethingInteresting', + ts[4].id()) + self.assertEqual( + full_class_name + '.testSingleInteresting', + ts[5].id()) + self.assertEqual( + full_class_name + '.testSomethingBoring', + ts[6].id()) + self.assertEqual( + full_class_name + '.testSomethingInteresting', + ts[7].id()) + self.assertEqual( + full_class_name + '.testWithoutParameters', + ts[8].id()) + + def test_named_parameters_id_with_underscore_case(self): + ts = sorted(unittest.makeSuite(self.NamedTests), + key=lambda t: t.id()) + self.assertLen(ts, 9) + full_class_name = unittest.util.strclass(self.NamedTests) + self.assertEqual( + full_class_name + '.test_dict_single_interesting', + ts[0].id()) + self.assertEqual( + full_class_name + '.test_dict_something_boring', + ts[1].id()) + self.assertEqual( + full_class_name + '.test_dict_something_interesting', + ts[2].id()) + self.assertEqual( + full_class_name + '.test_mixed_something_boring', + ts[3].id()) + self.assertEqual( + full_class_name + '.test_mixed_something_interesting', + ts[4].id()) + self.assertEqual( + full_class_name + '.test_single_interesting', + ts[5].id()) + self.assertEqual( + full_class_name + '.test_something_boring', + ts[6].id()) + self.assertEqual( + full_class_name + '.test_something_interesting', + ts[7].id()) + self.assertEqual( + full_class_name + '.test_without_parameters', + ts[8].id()) + + def test_named_parameters_short_description(self): + ts = sorted(unittest.makeSuite(self.NamedTests), + key=lambda t: t.id()) + actual = {t._testMethodName: t.shortDescription() for t in ts} + expected = { + 'test_dict_single_interesting': 'case=0', + 'test_dict_something_boring': 'case=1', + 'test_dict_something_interesting': 'case=0', + 'test_mixed_something_boring': '1', + 'test_mixed_something_interesting': 'case=0', + 'test_something_boring': '1', + 'test_something_interesting': '0', + } + for test_name, param_repr in expected.items(): + short_desc = actual[test_name].split('\n') + self.assertIn(test_name, short_desc[0]) + self.assertEqual('{}({})'.format(test_name, param_repr), short_desc[1]) + + def test_load_tuple_named_test(self): + loader = unittest.TestLoader() + ts = list(loader.loadTestsFromName('NamedTests.test_something_interesting', + module=self)) + self.assertLen(ts, 1) + self.assertEndsWith(ts[0].id(), '.test_something_interesting') + + def test_load_dict_named_test(self): + loader = unittest.TestLoader() + ts = list( + loader.loadTestsFromName( + 'NamedTests.test_dict_something_interesting', module=self)) + self.assertLen(ts, 1) + self.assertEndsWith(ts[0].id(), '.test_dict_something_interesting') + + def test_load_mixed_named_test(self): + loader = unittest.TestLoader() + ts = list( + loader.loadTestsFromName( + 'NamedTests.test_mixed_something_interesting', module=self)) + self.assertLen(ts, 1) + self.assertEndsWith(ts[0].id(), '.test_mixed_something_interesting') + + def test_duplicate_named_test_fails(self): + with self.assertRaises(parameterized.DuplicateTestNameError): + + class _(parameterized.TestCase): + + @parameterized.named_parameters( + ('Interesting', 0), + ('Interesting', 1), + ) + def test_something(self, unused_obj): + pass + + def test_duplicate_dict_named_test_fails(self): + with self.assertRaises(parameterized.DuplicateTestNameError): + + class _(parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': 'Interesting', 'unused_obj': 0}, + {'testcase_name': 'Interesting', 'unused_obj': 1}, + ) + def test_dict_something(self, unused_obj): + pass + + def test_duplicate_mixed_named_test_fails(self): + with self.assertRaises(parameterized.DuplicateTestNameError): + + class _(parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': 'Interesting', 'unused_obj': 0}, + ('Interesting', 1), + ) + def test_mixed_something(self, unused_obj): + pass + + def test_named_test_with_no_name_fails(self): + with self.assertRaises(RuntimeError): + + class _(parameterized.TestCase): + + @parameterized.named_parameters( + (0,), + ) + def test_something(self, unused_obj): + pass + + def test_named_test_dict_with_no_name_fails(self): + with self.assertRaises(RuntimeError): + + class _(parameterized.TestCase): + + @parameterized.named_parameters( + {'unused_obj': 0}, + ) + def test_something(self, unused_obj): + pass + + def test_parameterized_test_iter_has_testcases_property(self): + @parameterized.parameters(1, 2, 3, 4, 5, 6) + def test_something(unused_self, unused_obj): # pylint: disable=invalid-name + pass + + expected_testcases = [1, 2, 3, 4, 5, 6] + self.assertTrue(hasattr(test_something, 'testcases')) + self.assertCountEqual(expected_testcases, test_something.testcases) + + def test_chained_decorator(self): + ts = unittest.makeSuite(self.ChainedTests) + self.assertEqual(1, ts.countTestCases()) + test = next(t for t in ts) + self.assertTrue(hasattr(test, 'test_chained_flavor_strawberry_cone_waffle')) + res = unittest.TestResult() + + ts.run(res) + self.assertEqual(1, res.testsRun) + self.assertTrue(res.wasSuccessful()) + + def test_singleton_list_extraction(self): + ts = unittest.makeSuite(self.SingletonListExtraction) + res = unittest.TestResult() + ts.run(res) + self.assertEqual(10, res.testsRun) + self.assertTrue(res.wasSuccessful()) + + def test_singleton_argument_extraction(self): + ts = unittest.makeSuite(self.SingletonArgumentExtraction) + res = unittest.TestResult() + ts.run(res) + self.assertEqual(9, res.testsRun) + self.assertTrue(res.wasSuccessful()) + + def test_singleton_dict_argument(self): + ts = unittest.makeSuite(self.SingletonDictArgument) + res = unittest.TestResult() + ts.run(res) + self.assertEqual(1, res.testsRun) + self.assertTrue(res.wasSuccessful()) + + def test_decorated_bare_class(self): + ts = unittest.makeSuite(self.DecoratedBareClass) + res = unittest.TestResult() + ts.run(res) + self.assertEqual(2, res.testsRun) + self.assertTrue(res.wasSuccessful(), msg=str(res.failures)) + + def test_decorated_class(self): + ts = unittest.makeSuite(self.DecoratedClass) + res = unittest.TestResult() + ts.run(res) + self.assertEqual(4, res.testsRun) + self.assertLen(res.failures, 2) + + def test_generator_decorated_class(self): + ts = unittest.makeSuite(self.GeneratorDecoratedClass) + res = unittest.TestResult() + ts.run(res) + self.assertEqual(32, res.testsRun) + self.assertLen(res.failures, 16) + + def test_no_duplicate_decorations(self): + with self.assertRaises(AssertionError): + + @parameterized.parameters(1, 2, 3, 4) + class _(parameterized.TestCase): + + @parameterized.parameters(5, 6, 7, 8) + def test_something(self, unused_obj): + pass + + def test_double_class_decorations_not_supported(self): + + @parameterized.parameters('foo', 'bar') + class SuperclassWithClassDecorator(parameterized.TestCase): + + def test_name(self, name): + del name + + with self.assertRaises(AssertionError): + + @parameterized.parameters('foo', 'bar') + class SubclassWithClassDecorator(SuperclassWithClassDecorator): + pass + + del SubclassWithClassDecorator + + def test_other_decorator_ordering_unnamed(self): + ts = unittest.makeSuite(self.OtherDecoratorUnnamed) + res = unittest.TestResult() + ts.run(res) + # Two for when the parameterized tests call the skip wrapper. + # One for when the skip wrapper is called first and doesn't iterate. + self.assertEqual(3, res.testsRun) + self.assertFalse(res.wasSuccessful()) + self.assertEmpty(res.failures) + # One error from test_other_then_parameterized. + self.assertLen(res.errors, 1) + + def test_other_decorator_ordering_named(self): + ts = unittest.makeSuite(self.OtherDecoratorNamed) + # Verify it generates the test method names from the original test method. + for test in ts: # There is only one test. + ts_attributes = dir(test) + self.assertIn('test_parameterized_then_other_a', ts_attributes) + self.assertIn('test_parameterized_then_other_b', ts_attributes) + + res = unittest.TestResult() + ts.run(res) + # Two for when the parameterized tests call the skip wrapper. + # One for when the skip wrapper is called first and doesn't iterate. + self.assertEqual(3, res.testsRun) + self.assertFalse(res.wasSuccessful()) + self.assertEmpty(res.failures) + # One error from test_other_then_parameterized. + self.assertLen(res.errors, 1) + + def test_other_decorator_ordering_named_with_dict(self): + ts = unittest.makeSuite(self.OtherDecoratorNamedWithDict) + # Verify it generates the test method names from the original test method. + for test in ts: # There is only one test. + ts_attributes = dir(test) + self.assertIn('test_parameterized_then_other_a', ts_attributes) + self.assertIn('test_parameterized_then_other_b', ts_attributes) + + res = unittest.TestResult() + ts.run(res) + # Two for when the parameterized tests call the skip wrapper. + # One for when the skip wrapper is called first and doesn't iterate. + self.assertEqual(3, res.testsRun) + self.assertFalse(res.wasSuccessful()) + self.assertEmpty(res.failures) + # One error from test_other_then_parameterized. + self.assertLen(res.errors, 1) + + def test_no_test_error_empty_parameters(self): + with self.assertRaises(parameterized.NoTestsError): + + @parameterized.parameters() + def test_something(): + pass + + del test_something + + def test_no_test_error_empty_generator(self): + with self.assertRaises(parameterized.NoTestsError): + + @parameterized.parameters((i for i in [])) + def test_something(): + pass + + del test_something + + def test_unique_descriptive_names(self): + + class RecordSuccessTestsResult(unittest.TestResult): + + def __init__(self, *args, **kwargs): + super(RecordSuccessTestsResult, self).__init__(*args, **kwargs) + self.successful_tests = [] + + def addSuccess(self, test): + self.successful_tests.append(test) + + ts = unittest.makeSuite(self.UniqueDescriptiveNamesTest) + res = RecordSuccessTestsResult() + ts.run(res) + self.assertTrue(res.wasSuccessful()) + self.assertEqual(2, res.testsRun) + test_ids = [test.id() for test in res.successful_tests] + full_class_name = unittest.util.strclass(self.UniqueDescriptiveNamesTest) + expected_test_ids = [ + full_class_name + '.test_normal0 (13)', + full_class_name + '.test_normal1 (13)', + ] + self.assertTrue(test_ids) + self.assertCountEqual(expected_test_ids, test_ids) + + def test_multi_generators(self): + ts = unittest.makeSuite(self.MultiGeneratorsTestCase) + res = unittest.TestResult() + ts.run(res) + self.assertEqual(2, res.testsRun) + self.assertTrue(res.wasSuccessful(), msg=str(res.failures)) + + def test_named_parameters_reusable(self): + ts = unittest.makeSuite(self.NamedParametersReusableTestCase) + res = unittest.TestResult() + ts.run(res) + self.assertEqual(8, res.testsRun) + self.assertTrue(res.wasSuccessful(), msg=str(res.failures)) + + def test_subclass_inherits_superclass_test_params_reprs(self): + self.assertEqual( + {'test_name0': "('foo')", 'test_name1': "('bar')"}, + self.SuperclassTestCase._test_params_reprs) + self.assertEqual( + {'test_name0': "('foo')", 'test_name1': "('bar')"}, + self.SubclassTestCase._test_params_reprs) + + +def _decorate_with_side_effects(func, self): + self.sideeffect = True + func(self) + + +class CoopMetaclassCreationTest(absltest.TestCase): + + class TestBase(absltest.TestCase): + + # This test simulates a metaclass that sets some attribute ('sideeffect') + # on each member of the class that starts with 'test'. The test code then + # checks that this attribute exists when the custom metaclass and + # TestGeneratorMetaclass are combined with cooperative inheritance. + + # The attribute has to be set in the __init__ method of the metaclass, + # since the TestGeneratorMetaclass already overrides __new__. Only one + # base metaclass can override __new__, but all can provide custom __init__ + # methods. + + class __metaclass__(type): # pylint: disable=g-bad-name + + def __init__(cls, name, bases, dct): + type.__init__(cls, name, bases, dct) + for member_name, obj in dct.items(): + if member_name.startswith('test'): + setattr(cls, member_name, + lambda self, f=obj: _decorate_with_side_effects(f, self)) + + class MyParams(parameterized.CoopTestCase(TestBase)): + + @parameterized.parameters( + (1, 2, 3), + (4, 5, 9)) + def test_addition(self, op1, op2, result): + self.assertEqual(result, op1 + op2) + + class MySuite(unittest.TestSuite): + # Under Python 3.4 the TestCases in the suite's list of tests to run are + # destroyed and replaced with None after successful execution by default. + # This disables that behavior. + _cleanup = False + + def test_successful_execution(self): + ts = unittest.makeSuite(self.MyParams) + + res = unittest.TestResult() + ts.run(res) + self.assertEqual(2, res.testsRun) + self.assertTrue(res.wasSuccessful()) + + def test_metaclass_side_effects(self): + ts = unittest.makeSuite(self.MyParams, suiteClass=self.MySuite) + + res = unittest.TestResult() + ts.run(res) + self.assertTrue(list(ts)[0].sideeffect) + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/testing/tests/xml_reporter_helper_test.py b/absl/testing/tests/xml_reporter_helper_test.py new file mode 100644 index 0000000..661bbdc --- /dev/null +++ b/absl/testing/tests/xml_reporter_helper_test.py @@ -0,0 +1,97 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random + +from absl import flags +from absl.testing import absltest + + +FLAGS = flags.FLAGS +flags.DEFINE_boolean('set_up_module_error', False, + 'Cause setupModule to error.') +flags.DEFINE_boolean('tear_down_module_error', False, + 'Cause tearDownModule to error.') + +flags.DEFINE_boolean('set_up_class_error', False, 'Cause setUpClass to error.') +flags.DEFINE_boolean('tear_down_class_error', False, + 'Cause tearDownClass to error.') + +flags.DEFINE_boolean('set_up_error', False, 'Cause setUp to error.') +flags.DEFINE_boolean('tear_down_error', False, 'Cause tearDown to error.') +flags.DEFINE_boolean('test_error', False, 'Cause the test to error.') + +flags.DEFINE_boolean('set_up_fail', False, 'Cause setUp to fail.') +flags.DEFINE_boolean('tear_down_fail', False, 'Cause tearDown to fail.') +flags.DEFINE_boolean('test_fail', False, 'Cause the test to fail.') + +flags.DEFINE_float('random_error', 0.0, + '0 - 1.0: fraction of a random failure at any step', + lower_bound=0.0, upper_bound=1.0) + + +def _random_error(): + return random.random() < FLAGS.random_error + + +def setUpModule(): + if FLAGS.set_up_module_error or _random_error(): + raise Exception('setUpModule Errored!') + + +def tearDownModule(): + if FLAGS.tear_down_module_error or _random_error(): + raise Exception('tearDownModule Errored!') + + +class FailableTest(absltest.TestCase): + + @classmethod + def setUpClass(cls): + if FLAGS.set_up_class_error or _random_error(): + raise Exception('setUpClass Errored!') + + @classmethod + def tearDownClass(cls): + if FLAGS.tear_down_class_error or _random_error(): + raise Exception('tearDownClass Errored!') + + def setUp(self): + if FLAGS.set_up_error or _random_error(): + raise Exception('setUp Errored!') + + if FLAGS.set_up_fail: + self.fail('setUp Failed!') + + def tearDown(self): + if FLAGS.tear_down_error or _random_error(): + raise Exception('tearDown Errored!') + + if FLAGS.tear_down_fail: + self.fail('tearDown Failed!') + + def test(self): + if FLAGS.test_error or _random_error(): + raise Exception('test Errored!') + + if FLAGS.test_fail: + self.fail('test Failed!') + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/testing/tests/xml_reporter_test.py b/absl/testing/tests/xml_reporter_test.py new file mode 100644 index 0000000..0261f64 --- /dev/null +++ b/absl/testing/tests/xml_reporter_test.py @@ -0,0 +1,1108 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import io +import os +import re +import subprocess +import sys +import tempfile +import threading +import time +import unittest +from unittest import mock +from xml.etree import ElementTree +from xml.parsers import expat + +from absl import logging +from absl.testing import _bazelize_command +from absl.testing import absltest +from absl.testing import parameterized +from absl.testing import xml_reporter + + +class StringIOWriteLn(io.StringIO): + + def writeln(self, line): + self.write(line + '\n') + + +class MockTest(absltest.TestCase): + failureException = AssertionError + + def __init__(self, name): + super(MockTest, self).__init__() + self.name = name + + def id(self): + return self.name + + def runTest(self): + return + + def shortDescription(self): + return "This is this test's description." + + +# str(exception_type) is different between Python 2 and 3. +def xml_escaped_exception_type(exception_type): + return xml_reporter._escape_xml_attr(str(exception_type)) + + +OUTPUT_STRING = '\n'.join([ + r'<\?xml version="1.0"\?>', + ('<testsuites name="" tests="%(tests)d" failures="%(failures)d"' + ' errors="%(errors)d" time="%(run_time).1f" timestamp="%(start_time)s">'), + ('<testsuite name="%(suite_name)s" tests="%(tests)d"' + ' failures="%(failures)d" errors="%(errors)d" time="%(run_time).1f"' + ' timestamp="%(start_time)s">'), + (' <testcase name="%(test_name)s" status="%(status)s" result="%(result)s"' + ' time="%(run_time).1f" classname="%(classname)s"' + ' timestamp="%(start_time)s">%(message)s'), + ' </testcase>', '</testsuite>', + '</testsuites>', +]) + +FAILURE_MESSAGE = r""" + <failure message="e" type="{}"><!\[CDATA\[Traceback \(most recent call last\): + File ".*xml_reporter_test\.py", line \d+, in get_sample_failure + self.fail\(\'e\'\) +AssertionError: e +\]\]></failure>""".format(xml_escaped_exception_type(AssertionError)) + +ERROR_MESSAGE = r""" + <error message="invalid literal for int\(\) with base 10: (')?a(')?" type="{}"><!\[CDATA\[Traceback \(most recent call last\): + File ".*xml_reporter_test\.py", line \d+, in get_sample_error + int\('a'\) +ValueError: invalid literal for int\(\) with base 10: '?a'? +\]\]></error>""".format(xml_escaped_exception_type(ValueError)) + +UNICODE_MESSAGE = r""" + <%s message="{0}" type="{1}"><!\[CDATA\[Traceback \(most recent call last\): + File ".*xml_reporter_test\.py", line \d+, in get_unicode_sample_failure + raise AssertionError\(u'\\xe9'\) +AssertionError: {0} +\]\]></%s>""".format( + r'\xe9', + xml_escaped_exception_type(AssertionError)) + +NEWLINE_MESSAGE = r""" + <%s message="{0}" type="{1}"><!\[CDATA\[Traceback \(most recent call last\): + File ".*xml_reporter_test\.py", line \d+, in get_newline_message_sample_failure + raise AssertionError\(\'{2}'\) +AssertionError: {3} +\]\]></%s>""".format( + 'new
line', + xml_escaped_exception_type(AssertionError), + r'new\\nline', + 'new\nline') + +UNEXPECTED_SUCCESS_MESSAGE = '\n'.join([ + '', + (r' <error message="" type=""><!\[CDATA\[Test case ' + r'__main__.MockTest.unexpectedly_passing_test should have failed, ' + r'but passed.\]\]></error>'), +]) + +UNICODE_ERROR_MESSAGE = UNICODE_MESSAGE % ('error', 'error') +NEWLINE_ERROR_MESSAGE = NEWLINE_MESSAGE % ('error', 'error') + + +class TextAndXMLTestResultTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.stream = StringIOWriteLn() + self.xml_stream = io.StringIO() + + def _make_result(self, times): + timer = mock.Mock() + timer.side_effect = times + return xml_reporter._TextAndXMLTestResult(self.xml_stream, self.stream, + 'foo', 0, timer) + + def _assert_match(self, regex, output): + fail_msg = 'Expected regex:\n{}\nTo match:\n{}'.format(regex, output) + self.assertRegex(output, regex, fail_msg) + + def _assert_valid_xml(self, xml_output): + try: + expat.ParserCreate().Parse(xml_output) + except expat.ExpatError as e: + raise AssertionError('Bad XML output: {}\n{}'.format(e, xml_output)) + + def _simulate_error_test(self, test, result): + result.startTest(test) + result.addError(test, self.get_sample_error()) + result.stopTest(test) + + def _simulate_failing_test(self, test, result): + result.startTest(test) + result.addFailure(test, self.get_sample_failure()) + result.stopTest(test) + + def _simulate_passing_test(self, test, result): + result.startTest(test) + result.addSuccess(test) + result.stopTest(test) + + def _iso_timestamp(self, timestamp): + return datetime.datetime.utcfromtimestamp(timestamp).isoformat() + '+00:00' + + def test_with_passing_test(self): + start_time = 0 + end_time = 2 + result = self._make_result((start_time, start_time, end_time, end_time)) + + test = MockTest('__main__.MockTest.passing_test') + result.startTestRun() + result.startTest(test) + result.addSuccess(test) + result.stopTest(test) + result.stopTestRun() + result.printErrors() + + run_time = end_time - start_time + expected_re = OUTPUT_STRING % { + 'suite_name': 'MockTest', + 'tests': 1, + 'failures': 0, + 'errors': 0, + 'run_time': run_time, + 'start_time': re.escape(self._iso_timestamp(start_time),), + 'test_name': 'passing_test', + 'classname': '__main__.MockTest', + 'status': 'run', + 'result': 'completed', + 'attributes': '', + 'message': '' + } + self._assert_match(expected_re, self.xml_stream.getvalue()) + + def test_with_passing_subtest(self): + start_time = 0 + end_time = 2 + result = self._make_result((start_time, start_time, end_time, end_time)) + + test = MockTest('__main__.MockTest.passing_test') + subtest = unittest.case._SubTest(test, 'msg', None) + result.startTestRun() + result.startTest(test) + result.addSubTest(test, subtest, None) + result.stopTestRun() + result.printErrors() + + run_time = end_time - start_time + expected_re = OUTPUT_STRING % { + 'suite_name': 'MockTest', + 'tests': 1, + 'failures': 0, + 'errors': 0, + 'run_time': run_time, + 'start_time': re.escape(self._iso_timestamp(start_time),), + 'test_name': r'passing_test \[msg\]', + 'classname': '__main__.MockTest', + 'status': 'run', + 'result': 'completed', + 'attributes': '', + 'message': '' + } + self._assert_match(expected_re, self.xml_stream.getvalue()) + + def test_with_passing_subtest_with_dots_in_parameter_name(self): + start_time = 0 + end_time = 2 + result = self._make_result((start_time, start_time, end_time, end_time)) + + test = MockTest('__main__.MockTest.passing_test') + subtest = unittest.case._SubTest(test, 'msg', {'case': 'a.b.c'}) + result.startTestRun() + result.startTest(test) + result.addSubTest(test, subtest, None) + result.stopTestRun() + result.printErrors() + + run_time = end_time - start_time + expected_re = OUTPUT_STRING % { + 'suite_name': + 'MockTest', + 'tests': + 1, + 'failures': + 0, + 'errors': + 0, + 'run_time': + run_time, + 'start_time': + re.escape(self._iso_timestamp(start_time),), + 'test_name': + r'passing_test \[msg\] \(case='a.b.c'\)', + 'classname': + '__main__.MockTest', + 'status': + 'run', + 'result': + 'completed', + 'attributes': + '', + 'message': + '' + } + self._assert_match(expected_re, self.xml_stream.getvalue()) + + def get_sample_error(self): + try: + int('a') + except ValueError: + error_values = sys.exc_info() + return error_values + + def get_sample_failure(self): + try: + self.fail('e') + except AssertionError: + error_values = sys.exc_info() + return error_values + + def get_newline_message_sample_failure(self): + try: + raise AssertionError('new\nline') + except AssertionError: + error_values = sys.exc_info() + return error_values + + def get_unicode_sample_failure(self): + try: + raise AssertionError(u'\xe9') + except AssertionError: + error_values = sys.exc_info() + return error_values + + def get_terminal_escape_sample_failure(self): + try: + raise AssertionError('\x1b') + except AssertionError: + error_values = sys.exc_info() + return error_values + + def test_with_failing_test(self): + start_time = 10 + end_time = 20 + result = self._make_result((start_time, start_time, end_time, end_time)) + + test = MockTest('__main__.MockTest.failing_test') + result.startTestRun() + result.startTest(test) + result.addFailure(test, self.get_sample_failure()) + result.stopTest(test) + result.stopTestRun() + result.printErrors() + + run_time = end_time - start_time + expected_re = OUTPUT_STRING % { + 'suite_name': 'MockTest', + 'tests': 1, + 'failures': 1, + 'errors': 0, + 'run_time': run_time, + 'start_time': re.escape(self._iso_timestamp(start_time),), + 'test_name': 'failing_test', + 'classname': '__main__.MockTest', + 'status': 'run', + 'result': 'completed', + 'attributes': '', + 'message': FAILURE_MESSAGE + } + self._assert_match(expected_re, self.xml_stream.getvalue()) + + def test_with_failing_subtest(self): + start_time = 10 + end_time = 20 + result = self._make_result((start_time, start_time, end_time, end_time)) + + test = MockTest('__main__.MockTest.failing_test') + subtest = unittest.case._SubTest(test, 'msg', None) + result.startTestRun() + result.startTest(test) + result.addSubTest(test, subtest, self.get_sample_failure()) + result.stopTestRun() + result.printErrors() + + run_time = end_time - start_time + expected_re = OUTPUT_STRING % { + 'suite_name': 'MockTest', + 'tests': 1, + 'failures': 1, + 'errors': 0, + 'run_time': run_time, + 'start_time': re.escape(self._iso_timestamp(start_time),), + 'test_name': r'failing_test \[msg\]', + 'classname': '__main__.MockTest', + 'status': 'run', + 'result': 'completed', + 'attributes': '', + 'message': FAILURE_MESSAGE + } + self._assert_match(expected_re, self.xml_stream.getvalue()) + + def test_with_error_test(self): + start_time = 100 + end_time = 200 + result = self._make_result((start_time, start_time, end_time, end_time)) + + test = MockTest('__main__.MockTest.failing_test') + result.startTestRun() + result.startTest(test) + result.addError(test, self.get_sample_error()) + result.stopTest(test) + result.stopTestRun() + result.printErrors() + xml = self.xml_stream.getvalue() + + self._assert_valid_xml(xml) + + run_time = end_time - start_time + expected_re = OUTPUT_STRING % { + 'suite_name': 'MockTest', + 'tests': 1, + 'failures': 0, + 'errors': 1, + 'run_time': run_time, + 'start_time': re.escape(self._iso_timestamp(start_time),), + 'test_name': 'failing_test', + 'classname': '__main__.MockTest', + 'status': 'run', + 'result': 'completed', + 'attributes': '', + 'message': ERROR_MESSAGE + } + self._assert_match(expected_re, xml) + + def test_with_error_subtest(self): + start_time = 10 + end_time = 20 + result = self._make_result((start_time, start_time, end_time, end_time)) + + test = MockTest('__main__.MockTest.error_test') + subtest = unittest.case._SubTest(test, 'msg', None) + result.startTestRun() + result.startTest(test) + result.addSubTest(test, subtest, self.get_sample_error()) + result.stopTestRun() + result.printErrors() + + run_time = end_time - start_time + expected_re = OUTPUT_STRING % { + 'suite_name': 'MockTest', + 'tests': 1, + 'failures': 0, + 'errors': 1, + 'run_time': run_time, + 'start_time': re.escape(self._iso_timestamp(start_time),), + 'test_name': r'error_test \[msg\]', + 'classname': '__main__.MockTest', + 'status': 'run', + 'result': 'completed', + 'attributes': '', + 'message': ERROR_MESSAGE + } + self._assert_match(expected_re, self.xml_stream.getvalue()) + + def test_with_fail_and_error_test(self): + """Tests a failure and subsequent error within a single result.""" + start_time = 123 + end_time = 456 + result = self._make_result((start_time, start_time, end_time, end_time)) + + test = MockTest('__main__.MockTest.failing_test') + result.startTestRun() + result.startTest(test) + result.addFailure(test, self.get_sample_failure()) + # This could happen in tearDown + result.addError(test, self.get_sample_error()) + result.stopTest(test) + result.stopTestRun() + result.printErrors() + xml = self.xml_stream.getvalue() + + self._assert_valid_xml(xml) + + run_time = end_time - start_time + expected_re = OUTPUT_STRING % { + 'suite_name': 'MockTest', + 'tests': 1, + 'failures': 1, # Only the failure is tallied (because it was first). + 'errors': 0, + 'run_time': run_time, + 'start_time': re.escape(self._iso_timestamp(start_time),), + 'test_name': 'failing_test', + 'classname': '__main__.MockTest', + 'status': 'run', + 'result': 'completed', + 'attributes': '', + # Messages from failure and error should be concatenated in order. + 'message': FAILURE_MESSAGE + ERROR_MESSAGE + } + self._assert_match(expected_re, xml) + + def test_with_error_and_fail_test(self): + """Tests an error and subsequent failure within a single result.""" + start_time = 123 + end_time = 456 + result = self._make_result((start_time, start_time, end_time, end_time)) + + test = MockTest('__main__.MockTest.failing_test') + result.startTestRun() + result.startTest(test) + result.addError(test, self.get_sample_error()) + result.addFailure(test, self.get_sample_failure()) + result.stopTest(test) + result.stopTestRun() + result.printErrors() + xml = self.xml_stream.getvalue() + + self._assert_valid_xml(xml) + + run_time = end_time - start_time + expected_re = OUTPUT_STRING % { + 'suite_name': 'MockTest', + 'tests': 1, + 'failures': 0, + 'errors': 1, # Only the error is tallied (because it was first). + 'run_time': run_time, + 'start_time': re.escape(self._iso_timestamp(start_time),), + 'test_name': 'failing_test', + 'classname': '__main__.MockTest', + 'status': 'run', + 'result': 'completed', + 'attributes': '', + # Messages from error and failure should be concatenated in order. + 'message': ERROR_MESSAGE + FAILURE_MESSAGE + } + self._assert_match(expected_re, xml) + + def test_with_newline_error_test(self): + start_time = 100 + end_time = 200 + result = self._make_result((start_time, start_time, end_time, end_time)) + + test = MockTest('__main__.MockTest.failing_test') + result.startTestRun() + result.startTest(test) + result.addError(test, self.get_newline_message_sample_failure()) + result.stopTest(test) + result.stopTestRun() + result.printErrors() + xml = self.xml_stream.getvalue() + + self._assert_valid_xml(xml) + + run_time = end_time - start_time + expected_re = OUTPUT_STRING % { + 'suite_name': 'MockTest', + 'tests': 1, + 'failures': 0, + 'errors': 1, + 'run_time': run_time, + 'start_time': re.escape(self._iso_timestamp(start_time),), + 'test_name': 'failing_test', + 'classname': '__main__.MockTest', + 'status': 'run', + 'result': 'completed', + 'attributes': '', + 'message': NEWLINE_ERROR_MESSAGE + } + '\n' + self._assert_match(expected_re, xml) + + def test_with_unicode_error_test(self): + start_time = 100 + end_time = 200 + result = self._make_result((start_time, start_time, end_time, end_time)) + + test = MockTest('__main__.MockTest.failing_test') + result.startTestRun() + result.startTest(test) + result.addError(test, self.get_unicode_sample_failure()) + result.stopTest(test) + result.stopTestRun() + result.printErrors() + xml = self.xml_stream.getvalue() + + self._assert_valid_xml(xml) + + run_time = end_time - start_time + expected_re = OUTPUT_STRING % { + 'suite_name': 'MockTest', + 'tests': 1, + 'failures': 0, + 'errors': 1, + 'run_time': run_time, + 'start_time': re.escape(self._iso_timestamp(start_time),), + 'test_name': 'failing_test', + 'classname': '__main__.MockTest', + 'status': 'run', + 'result': 'completed', + 'attributes': '', + 'message': UNICODE_ERROR_MESSAGE + } + self._assert_match(expected_re, xml) + + def test_with_terminal_escape_error(self): + start_time = 100 + end_time = 200 + result = self._make_result((start_time, start_time, end_time, end_time)) + + test = MockTest('__main__.MockTest.failing_test') + result.startTestRun() + result.startTest(test) + result.addError(test, self.get_terminal_escape_sample_failure()) + result.stopTest(test) + result.stopTestRun() + result.printErrors() + + self._assert_valid_xml(self.xml_stream.getvalue()) + + def test_with_expected_failure_test(self): + start_time = 100 + end_time = 200 + result = self._make_result((start_time, start_time, end_time, end_time)) + error_values = '' + + try: + raise RuntimeError('Test expectedFailure') + except RuntimeError: + error_values = sys.exc_info() + + test = MockTest('__main__.MockTest.expected_failing_test') + result.startTestRun() + result.startTest(test) + result.addExpectedFailure(test, error_values) + result.stopTest(test) + result.stopTestRun() + result.printErrors() + + run_time = end_time - start_time + expected_re = OUTPUT_STRING % { + 'suite_name': 'MockTest', + 'tests': 1, + 'failures': 0, + 'errors': 0, + 'run_time': run_time, + 'start_time': re.escape(self._iso_timestamp(start_time),), + 'test_name': 'expected_failing_test', + 'classname': '__main__.MockTest', + 'status': 'run', + 'result': 'completed', + 'attributes': '', + 'message': '' + } + self._assert_match(re.compile(expected_re, re.DOTALL), + self.xml_stream.getvalue()) + + def test_with_unexpected_success_error_test(self): + start_time = 100 + end_time = 200 + result = self._make_result((start_time, start_time, end_time, end_time)) + + test = MockTest('__main__.MockTest.unexpectedly_passing_test') + result.startTestRun() + result.startTest(test) + result.addUnexpectedSuccess(test) + result.stopTest(test) + result.stopTestRun() + result.printErrors() + + run_time = end_time - start_time + expected_re = OUTPUT_STRING % { + 'suite_name': 'MockTest', + 'tests': 1, + 'failures': 0, + 'errors': 1, + 'run_time': run_time, + 'start_time': re.escape(self._iso_timestamp(start_time),), + 'test_name': 'unexpectedly_passing_test', + 'classname': '__main__.MockTest', + 'status': 'run', + 'result': 'completed', + 'attributes': '', + 'message': UNEXPECTED_SUCCESS_MESSAGE + } + self._assert_match(expected_re, self.xml_stream.getvalue()) + + def test_with_skipped_test(self): + start_time = 100 + end_time = 100 + result = self._make_result((start_time, start_time, end_time, end_time)) + + test = MockTest('__main__.MockTest.skipped_test_with_reason') + result.startTestRun() + result.startTest(test) + result.addSkip(test, 'b"r') + result.stopTest(test) + result.stopTestRun() + result.printErrors() + + run_time = end_time - start_time + expected_re = OUTPUT_STRING % { + 'suite_name': 'MockTest', + 'tests': 1, + 'failures': 0, + 'errors': 0, + 'run_time': run_time, + 'start_time': re.escape(self._iso_timestamp(start_time),), + 'test_name': 'skipped_test_with_reason', + 'classname': '__main__.MockTest', + 'status': 'notrun', + 'result': 'suppressed', + 'message': '' + } + self._assert_match(expected_re, self.xml_stream.getvalue()) + + def test_suite_time(self): + start_time1 = 100 + end_time1 = 200 + start_time2 = 400 + end_time2 = 700 + name = '__main__.MockTest.failing_test' + result = self._make_result((start_time1, start_time1, end_time1, + start_time2, end_time2, end_time2)) + + test = MockTest('%s1' % name) + result.startTestRun() + result.startTest(test) + result.addSuccess(test) + result.stopTest(test) + + test = MockTest('%s2' % name) + result.startTest(test) + result.addSuccess(test) + result.stopTest(test) + result.stopTestRun() + result.printErrors() + + run_time = max(end_time1, end_time2) - min(start_time1, start_time2) + timestamp = self._iso_timestamp(start_time1) + expected_prefix = """<?xml version="1.0"?> +<testsuites name="" tests="2" failures="0" errors="0" time="%.1f" timestamp="%s"> +<testsuite name="MockTest" tests="2" failures="0" errors="0" time="%.1f" timestamp="%s"> +""" % (run_time, timestamp, run_time, timestamp) + xml_output = self.xml_stream.getvalue() + self.assertTrue( + xml_output.startswith(expected_prefix), + '%s not found in %s' % (expected_prefix, xml_output)) + + def test_with_no_suite_name(self): + start_time = 1000 + end_time = 1200 + result = self._make_result((start_time, start_time, end_time, end_time)) + + test = MockTest('__main__.MockTest.bad_name') + result.startTestRun() + result.startTest(test) + result.addSuccess(test) + result.stopTest(test) + result.stopTestRun() + result.printErrors() + + run_time = end_time - start_time + expected_re = OUTPUT_STRING % { + 'suite_name': 'MockTest', + 'tests': 1, + 'failures': 0, + 'errors': 0, + 'run_time': run_time, + 'start_time': re.escape(self._iso_timestamp(start_time),), + 'test_name': 'bad_name', + 'classname': '__main__.MockTest', + 'status': 'run', + 'result': 'completed', + 'attributes': '', + 'message': '' + } + self._assert_match(expected_re, self.xml_stream.getvalue()) + + def test_unnamed_parameterized_testcase(self): + """Test unnamed parameterized test cases. + + Unnamed parameterized test cases might have non-alphanumeric characters in + their test method names. This test ensures xml_reporter handles them + correctly. + """ + + class ParameterizedTest(parameterized.TestCase): + + @parameterized.parameters(('a (b.c)',)) + def test_prefix(self, case): + self.assertTrue(case.startswith('a')) + + start_time = 1000 + end_time = 1200 + result = self._make_result((start_time, start_time, end_time, end_time)) + test = ParameterizedTest(methodName='test_prefix0') + result.startTestRun() + result.startTest(test) + result.addSuccess(test) + result.stopTest(test) + result.stopTestRun() + result.printErrors() + + run_time = end_time - start_time + classname = xml_reporter._escape_xml_attr( + unittest.util.strclass(test.__class__)) + expected_re = OUTPUT_STRING % { + 'suite_name': 'ParameterizedTest', + 'tests': 1, + 'failures': 0, + 'errors': 0, + 'run_time': run_time, + 'start_time': re.escape(self._iso_timestamp(start_time),), + 'test_name': re.escape('test_prefix0 ('a (b.c)')'), + 'classname': classname, + 'status': 'run', + 'result': 'completed', + 'attributes': '', + 'message': '' + } + self._assert_match(expected_re, self.xml_stream.getvalue()) + + def teststop_test_without_pending_test(self): + end_time = 1200 + result = self._make_result((end_time,)) + + test = MockTest('__main__.MockTest.bad_name') + result.stopTest(test) + result.stopTestRun() + # Just verify that this doesn't crash + + def test_text_and_xmltest_runner(self): + runner = xml_reporter.TextAndXMLTestRunner(self.xml_stream, self.stream, + 'foo', 1) + result1 = runner._makeResult() + result2 = xml_reporter._TextAndXMLTestResult(None, None, None, 0, None) + self.failUnless(type(result1) is type(result2)) + + def test_timing_with_time_stub(self): + """Make sure that timing is correct even if time.time is stubbed out.""" + try: + saved_time = time.time + time.time = lambda: -1 + reporter = xml_reporter._TextAndXMLTestResult(self.xml_stream, + self.stream, + 'foo', 0) + test = MockTest('bar') + reporter.startTest(test) + self.failIf(reporter.start_time == -1) + finally: + time.time = saved_time + + def test_concurrent_add_and_delete_pending_test_case_result(self): + """Make sure adding/deleting pending test case results are thread safe.""" + result = xml_reporter._TextAndXMLTestResult(None, self.stream, None, 0, + None) + def add_and_delete_pending_test_case_result(test_name): + test = MockTest(test_name) + result.addSuccess(test) + result.delete_pending_test_case_result(test) + + for i in range(50): + add_and_delete_pending_test_case_result('add_and_delete_test%s' % i) + self.assertEqual(result.pending_test_case_results, {}) + + def test_concurrent_test_runs(self): + """Make sure concurrent test runs do not race each other.""" + num_passing_tests = 20 + num_failing_tests = 20 + num_error_tests = 20 + total_num_tests = num_passing_tests + num_failing_tests + num_error_tests + + times = [0] + [i for i in range(2 * total_num_tests) + ] + [2 * total_num_tests - 1] + result = self._make_result(times) + threads = [] + names = [] + result.startTestRun() + for i in range(num_passing_tests): + name = 'passing_concurrent_test_%s' % i + names.append(name) + test_name = '__main__.MockTest.%s' % name + # xml_reporter uses id(test) as the test identifier. + # In a real testing scenario, all the test instances are created before + # running them. So all ids will be unique. + # We must do the same here: create test instance beforehand. + test = MockTest(test_name) + threads.append(threading.Thread( + target=self._simulate_passing_test, args=(test, result))) + for i in range(num_failing_tests): + name = 'failing_concurrent_test_%s' % i + names.append(name) + test_name = '__main__.MockTest.%s' % name + test = MockTest(test_name) + threads.append(threading.Thread( + target=self._simulate_failing_test, args=(test, result))) + for i in range(num_error_tests): + name = 'error_concurrent_test_%s' % i + names.append(name) + test_name = '__main__.MockTest.%s' % name + test = MockTest(test_name) + threads.append(threading.Thread( + target=self._simulate_error_test, args=(test, result))) + for t in threads: + t.start() + for t in threads: + t.join() + + result.stopTestRun() + result.printErrors() + tests_not_in_xml = [] + for tn in names: + if tn not in self.xml_stream.getvalue(): + tests_not_in_xml.append(tn) + msg = ('Expected xml_stream to contain all test %s results, but %s tests ' + 'are missing. List of missing tests: %s' % ( + total_num_tests, len(tests_not_in_xml), tests_not_in_xml)) + self.assertEqual([], tests_not_in_xml, msg) + + def test_add_failure_during_stop_test(self): + """Tests an addFailure() call from within a stopTest() call stack.""" + result = self._make_result((0, 2)) + test = MockTest('__main__.MockTest.failing_test') + result.startTestRun() + result.startTest(test) + + # Replace parent stopTest method from unittest.TextTestResult with + # a version that calls self.addFailure(). + with mock.patch.object( + unittest.TextTestResult, + 'stopTest', + side_effect=lambda t: result.addFailure(t, self.get_sample_failure())): + # Run stopTest in a separate thread since we are looking to verify that + # it does not deadlock, and would otherwise prevent the test from + # completing. + stop_test_thread = threading.Thread(target=result.stopTest, args=(test,)) + stop_test_thread.daemon = True + stop_test_thread.start() + + stop_test_thread.join(10.0) + self.assertFalse(stop_test_thread.is_alive(), + 'result.stopTest(test) call failed to complete') + + +class XMLTest(absltest.TestCase): + + def test_escape_xml(self): + self.assertEqual(xml_reporter._escape_xml_attr('"Hi" <\'>\t\r\n'), + '"Hi" <'>	
') + + +class XmlReporterFixtureTest(absltest.TestCase): + + def _get_helper(self): + binary_name = 'absl/testing/tests/xml_reporter_helper_test' + return _bazelize_command.get_executable_path(binary_name) + + def _run_test_and_get_xml(self, flag): + """Runs xml_reporter_helper_test and returns an Element instance. + + Runs xml_reporter_helper_test in a new process so that it can + exercise the entire test infrastructure, and easily test issues in + the test fixture. + + Args: + flag: flag to pass to xml_reporter_helper_test + + Returns: + The Element instance of the XML output. + """ + + xml_fhandle, xml_fname = tempfile.mkstemp() + os.close(xml_fhandle) + + try: + binary = self._get_helper() + args = [binary, flag, '--xml_output_file=%s' % xml_fname] + ret = subprocess.call(args) + self.assertEqual(ret, 0) + + xml = ElementTree.parse(xml_fname).getroot() + finally: + os.remove(xml_fname) + + return xml + + def _run_test(self, flag, num_errors, num_failures, suites): + xml_fhandle, xml_fname = tempfile.mkstemp() + os.close(xml_fhandle) + + try: + binary = self._get_helper() + args = [binary, flag, '--xml_output_file=%s' % xml_fname] + ret = subprocess.call(args) + self.assertNotEqual(ret, 0) + + xml = ElementTree.parse(xml_fname).getroot() + logging.info('xml output is:\n%s', ElementTree.tostring(xml)) + finally: + os.remove(xml_fname) + + self.assertEqual(int(xml.attrib['errors']), num_errors) + self.assertEqual(int(xml.attrib['failures']), num_failures) + self.assertLen(xml, len(suites)) + actual_suites = sorted( + xml.findall('testsuite'), key=lambda x: x.attrib['name']) + suites = sorted(suites, key=lambda x: x['name']) + for actual_suite, expected_suite in zip(actual_suites, suites): + self.assertEqual(actual_suite.attrib['name'], expected_suite['name']) + self.assertLen(actual_suite, len(expected_suite['cases'])) + actual_cases = sorted(actual_suite.findall('testcase'), + key=lambda x: x.attrib['name']) + expected_cases = sorted(expected_suite['cases'], key=lambda x: x['name']) + for actual_case, expected_case in zip(actual_cases, expected_cases): + self.assertEqual(actual_case.attrib['name'], expected_case['name']) + self.assertEqual(actual_case.attrib['classname'], + expected_case['classname']) + if 'error' in expected_case: + actual_error = actual_case.find('error') + self.assertEqual(actual_error.attrib['message'], + expected_case['error']) + if 'failure' in expected_case: + actual_failure = actual_case.find('failure') + self.assertEqual(actual_failure.attrib['message'], + expected_case['failure']) + + return xml + + def test_set_up_module_error(self): + self._run_test( + flag='--set_up_module_error', + num_errors=1, + num_failures=0, + suites=[{'name': '__main__', + 'cases': [{'name': 'setUpModule', + 'classname': '__main__', + 'error': 'setUpModule Errored!'}]}]) + + def test_tear_down_module_error(self): + self._run_test( + flag='--tear_down_module_error', + num_errors=1, + num_failures=0, + suites=[{'name': 'FailableTest', + 'cases': [{'name': 'test', + 'classname': '__main__.FailableTest'}]}, + {'name': '__main__', + 'cases': [{'name': 'tearDownModule', + 'classname': '__main__', + 'error': 'tearDownModule Errored!'}]}]) + + def test_set_up_class_error(self): + self._run_test( + flag='--set_up_class_error', + num_errors=1, + num_failures=0, + suites=[{'name': 'FailableTest', + 'cases': [{'name': 'setUpClass', + 'classname': '__main__.FailableTest', + 'error': 'setUpClass Errored!'}]}]) + + def test_tear_down_class_error(self): + self._run_test( + flag='--tear_down_class_error', + num_errors=1, + num_failures=0, + suites=[{'name': 'FailableTest', + 'cases': [{'name': 'test', + 'classname': '__main__.FailableTest'}, + {'name': 'tearDownClass', + 'classname': '__main__.FailableTest', + 'error': 'tearDownClass Errored!'}]}]) + + def test_set_up_error(self): + self._run_test( + flag='--set_up_error', + num_errors=1, + num_failures=0, + suites=[{'name': 'FailableTest', + 'cases': [{'name': 'test', + 'classname': '__main__.FailableTest', + 'error': 'setUp Errored!'}]}]) + + def test_tear_down_error(self): + self._run_test( + flag='--tear_down_error', + num_errors=1, + num_failures=0, + suites=[{'name': 'FailableTest', + 'cases': [{'name': 'test', + 'classname': '__main__.FailableTest', + 'error': 'tearDown Errored!'}]}]) + + def test_test_error(self): + self._run_test( + flag='--test_error', + num_errors=1, + num_failures=0, + suites=[{'name': 'FailableTest', + 'cases': [{'name': 'test', + 'classname': '__main__.FailableTest', + 'error': 'test Errored!'}]}]) + + def test_set_up_failure(self): + self._run_test( + flag='--set_up_fail', + num_errors=0, + num_failures=1, + suites=[{'name': 'FailableTest', + 'cases': [{'name': 'test', + 'classname': '__main__.FailableTest', + 'failure': 'setUp Failed!'}]}]) + + def test_tear_down_failure(self): + self._run_test( + flag='--tear_down_fail', + num_errors=0, + num_failures=1, + suites=[{'name': 'FailableTest', + 'cases': [{'name': 'test', + 'classname': '__main__.FailableTest', + 'failure': 'tearDown Failed!'}]}]) + + def test_test_fail(self): + self._run_test( + flag='--test_fail', + num_errors=0, + num_failures=1, + suites=[{'name': 'FailableTest', + 'cases': [{'name': 'test', + 'classname': '__main__.FailableTest', + 'failure': 'test Failed!'}]}]) + + def test_test_randomization_seed_logging(self): + # We expect the resulting XML to start as follows: + # <testsuites ...> + # <properties> + # <property name="test_randomize_ordering_seed" value="17" /> + # ... + # + # which we validate here. + out = self._run_test_and_get_xml('--test_randomize_ordering_seed=17') + expected_attrib = {'name': 'test_randomize_ordering_seed', 'value': '17'} + property_attributes = [ + prop.attrib for prop in out.findall('./properties/property')] + self.assertIn(expected_attrib, property_attributes) + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/testing/xml_reporter.py b/absl/testing/xml_reporter.py new file mode 100644 index 0000000..da56e39 --- /dev/null +++ b/absl/testing/xml_reporter.py @@ -0,0 +1,562 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Python test reporter that generates test reports in JUnit XML format.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import datetime +import re +import sys +import threading +import time +import traceback +import unittest +from xml.sax import saxutils +from absl.testing import _pretty_print_reporter + + +# See http://www.w3.org/TR/REC-xml/#NT-Char +_bad_control_character_codes = set(range(0, 0x20)) - {0x9, 0xA, 0xD} + + +_control_character_conversions = { + chr(i): '\\x{:02x}'.format(i) for i in _bad_control_character_codes} + + +_escape_xml_attr_conversions = { + '"': '"', + "'": ''', + '\n': '
', + '\t': '	', + '\r': '
', + ' ': ' '} +_escape_xml_attr_conversions.update(_control_character_conversions) + + +# When class or module level function fails, unittest/suite.py adds a +# _ErrorHolder instance instead of a real TestCase, and it has a description +# like "setUpClass (__main__.MyTestCase)". +_CLASS_OR_MODULE_LEVEL_TEST_DESC_REGEX = re.compile(r'^(\w+) \((\S+)\)$') + + +# NOTE: while saxutils.quoteattr() theoretically does the same thing; it +# seems to often end up being too smart for it's own good not escaping properly. +# This function is much more reliable. +def _escape_xml_attr(content): + """Escapes xml attributes.""" + # Note: saxutils doesn't escape the quotes. + return saxutils.escape(content, _escape_xml_attr_conversions) + + +def _escape_cdata(s): + """Escapes a string to be used as XML CDATA. + + CDATA characters are treated strictly as character data, not as XML markup, + but there are still certain restrictions on them. + + Args: + s: the string to be escaped. + Returns: + An escaped version of the input string. + """ + for char, escaped in _control_character_conversions.items(): + s = s.replace(char, escaped) + return s.replace(']]>', ']] >') + + +def _iso8601_timestamp(timestamp): + """Produces an ISO8601 datetime. + + Args: + timestamp: an Epoch based timestamp in seconds. + + Returns: + A iso8601 format timestamp if the input is a valid timestamp, None otherwise + """ + if timestamp is None or timestamp < 0: + return None + return datetime.datetime.fromtimestamp( + timestamp, tz=datetime.timezone.utc).isoformat() + + +def _print_xml_element_header(element, attributes, stream, indentation=''): + """Prints an XML header of an arbitrary element. + + Args: + element: element name (testsuites, testsuite, testcase) + attributes: 2-tuple list with (attributes, values) already escaped + stream: output stream to write test report XML to + indentation: indentation added to the element header + """ + stream.write('%s<%s' % (indentation, element)) + for attribute in attributes: + if (len(attribute) == 2 and attribute[0] is not None and + attribute[1] is not None): + stream.write(' %s="%s"' % (attribute[0], attribute[1])) + stream.write('>\n') + +# Copy time.time which ensures the real time is used internally. +# This prevents bad interactions with tests that stub out time. +_time_copy = time.time + +if hasattr(traceback, '_some_str'): + # Use the traceback module str function to format safely. + _safe_str = traceback._some_str +else: + _safe_str = str # pylint: disable=invalid-name + + +class _TestCaseResult(object): + """Private helper for _TextAndXMLTestResult that represents a test result. + + Attributes: + test: A TestCase instance of an individual test method. + name: The name of the individual test method. + full_class_name: The full name of the test class. + run_time: The duration (in seconds) it took to run the test. + start_time: Epoch relative timestamp of when test started (in seconds) + errors: A list of error 4-tuples. Error tuple entries are + 1) a string identifier of either "failure" or "error" + 2) an exception_type + 3) an exception_message + 4) a string version of a sys.exc_info()-style tuple of values + ('error', err[0], err[1], self._exc_info_to_string(err)) + If the length of errors is 0, then the test is either passed or + skipped. + skip_reason: A string explaining why the test was skipped. + """ + + def __init__(self, test): + self.run_time = -1 + self.start_time = -1 + self.skip_reason = None + self.errors = [] + self.test = test + + # Parse the test id to get its test name and full class path. + # Unfortunately there is no better way of knowning the test and class. + # Worse, unittest uses _ErrorHandler instances to represent class / module + # level failures. + test_desc = test.id() or str(test) + # Check if it's something like "setUpClass (__main__.TestCase)". + match = _CLASS_OR_MODULE_LEVEL_TEST_DESC_REGEX.match(test_desc) + if match: + name = match.group(1) + full_class_name = match.group(2) + else: + class_name = unittest.util.strclass(test.__class__) + if isinstance(test, unittest.case._SubTest): + # If the test case is a _SubTest, the real TestCase instance is + # available as _SubTest.test_case. + class_name = unittest.util.strclass(test.test_case.__class__) + if test_desc.startswith(class_name + '.'): + # In a typical unittest.TestCase scenario, test.id() returns with + # a class name formatted using unittest.util.strclass. + name = test_desc[len(class_name)+1:] + full_class_name = class_name + else: + # Otherwise make a best effort to guess the test name and full class + # path. + parts = test_desc.rsplit('.', 1) + name = parts[-1] + full_class_name = parts[0] if len(parts) == 2 else '' + self.name = _escape_xml_attr(name) + self.full_class_name = _escape_xml_attr(full_class_name) + + def set_run_time(self, time_in_secs): + self.run_time = time_in_secs + + def set_start_time(self, time_in_secs): + self.start_time = time_in_secs + + def print_xml_summary(self, stream): + """Prints an XML Summary of a TestCase. + + Status and result are populated as per JUnit XML test result reporter. + A test that has been skipped will always have a skip reason, + as every skip method in Python's unittest requires the reason arg to be + passed. + + Args: + stream: output stream to write test report XML to + """ + + if self.skip_reason is None: + status = 'run' + result = 'completed' + else: + status = 'notrun' + result = 'suppressed' + + test_case_attributes = [ + ('name', '%s' % self.name), + ('status', '%s' % status), + ('result', '%s' % result), + ('time', '%.1f' % self.run_time), + ('classname', self.full_class_name), + ('timestamp', _iso8601_timestamp(self.start_time)), + ] + _print_xml_element_header('testcase', test_case_attributes, stream, ' ') + self._print_testcase_details(stream) + stream.write(' </testcase>\n') + + def _print_testcase_details(self, stream): + for error in self.errors: + outcome, exception_type, message, error_msg = error # pylint: disable=unpacking-non-sequence + message = _escape_xml_attr(_safe_str(message)) + exception_type = _escape_xml_attr(str(exception_type)) + error_msg = _escape_cdata(error_msg) + stream.write(' <%s message="%s" type="%s"><![CDATA[%s]]></%s>\n' + % (outcome, message, exception_type, error_msg, outcome)) + + +class _TestSuiteResult(object): + """Private helper for _TextAndXMLTestResult.""" + + def __init__(self): + self.suites = {} + self.failure_counts = {} + self.error_counts = {} + self.overall_start_time = -1 + self.overall_end_time = -1 + self._testsuites_properties = {} + + def add_test_case_result(self, test_case_result): + suite_name = type(test_case_result.test).__name__ + if suite_name == '_ErrorHolder': + # _ErrorHolder is a special case created by unittest for class / module + # level functions. + suite_name = test_case_result.full_class_name.rsplit('.')[-1] + if isinstance(test_case_result.test, unittest.case._SubTest): + # If the test case is a _SubTest, the real TestCase instance is + # available as _SubTest.test_case. + suite_name = type(test_case_result.test.test_case).__name__ + + self._setup_test_suite(suite_name) + self.suites[suite_name].append(test_case_result) + for error in test_case_result.errors: + # Only count the first failure or error so that the sum is equal to the + # total number of *testcases* that have failures or errors. + if error[0] == 'failure': + self.failure_counts[suite_name] += 1 + break + elif error[0] == 'error': + self.error_counts[suite_name] += 1 + break + + def print_xml_summary(self, stream): + overall_test_count = sum(len(x) for x in self.suites.values()) + overall_failures = sum(self.failure_counts.values()) + overall_errors = sum(self.error_counts.values()) + overall_attributes = [ + ('name', ''), + ('tests', '%d' % overall_test_count), + ('failures', '%d' % overall_failures), + ('errors', '%d' % overall_errors), + ('time', '%.1f' % (self.overall_end_time - self.overall_start_time)), + ('timestamp', _iso8601_timestamp(self.overall_start_time)), + ] + _print_xml_element_header('testsuites', overall_attributes, stream) + if self._testsuites_properties: + stream.write(' <properties>\n') + for name, value in sorted(self._testsuites_properties.items()): + stream.write(' <property name="%s" value="%s"></property>\n' % + (_escape_xml_attr(name), _escape_xml_attr(str(value)))) + stream.write(' </properties>\n') + + for suite_name in self.suites: + suite = self.suites[suite_name] + suite_end_time = max(x.start_time + x.run_time for x in suite) + suite_start_time = min(x.start_time for x in suite) + failures = self.failure_counts[suite_name] + errors = self.error_counts[suite_name] + suite_attributes = [ + ('name', '%s' % suite_name), + ('tests', '%d' % len(suite)), + ('failures', '%d' % failures), + ('errors', '%d' % errors), + ('time', '%.1f' % (suite_end_time - suite_start_time)), + ('timestamp', _iso8601_timestamp(suite_start_time)), + ] + _print_xml_element_header('testsuite', suite_attributes, stream) + + for test_case_result in suite: + test_case_result.print_xml_summary(stream) + stream.write('</testsuite>\n') + stream.write('</testsuites>\n') + + def _setup_test_suite(self, suite_name): + """Adds a test suite to the set of suites tracked by this test run. + + Args: + suite_name: string, The name of the test suite being initialized. + """ + if suite_name in self.suites: + return + self.suites[suite_name] = [] + self.failure_counts[suite_name] = 0 + self.error_counts[suite_name] = 0 + + def set_end_time(self, timestamp_in_secs): + """Sets the start timestamp of this test suite. + + Args: + timestamp_in_secs: timestamp in seconds since epoch + """ + self.overall_end_time = timestamp_in_secs + + def set_start_time(self, timestamp_in_secs): + """Sets the end timestamp of this test suite. + + Args: + timestamp_in_secs: timestamp in seconds since epoch + """ + self.overall_start_time = timestamp_in_secs + + +class _TextAndXMLTestResult(_pretty_print_reporter.TextTestResult): + """Private TestResult class that produces both formatted text results and XML. + + Used by TextAndXMLTestRunner. + """ + + _TEST_SUITE_RESULT_CLASS = _TestSuiteResult + _TEST_CASE_RESULT_CLASS = _TestCaseResult + + def __init__(self, xml_stream, stream, descriptions, verbosity, + time_getter=_time_copy, testsuites_properties=None): + super(_TextAndXMLTestResult, self).__init__(stream, descriptions, verbosity) + self.xml_stream = xml_stream + self.pending_test_case_results = {} + self.suite = self._TEST_SUITE_RESULT_CLASS() + if testsuites_properties: + self.suite._testsuites_properties = testsuites_properties + self.time_getter = time_getter + + # This lock guards any mutations on pending_test_case_results. + self._pending_test_case_results_lock = threading.RLock() + + def startTest(self, test): + self.start_time = self.time_getter() + super(_TextAndXMLTestResult, self).startTest(test) + + def stopTest(self, test): + # Grabbing the write lock to avoid conflicting with stopTestRun. + with self._pending_test_case_results_lock: + super(_TextAndXMLTestResult, self).stopTest(test) + result = self.get_pending_test_case_result(test) + if not result: + test_name = test.id() or str(test) + sys.stderr.write('No pending test case: %s\n' % test_name) + return + test_id = id(test) + run_time = self.time_getter() - self.start_time + result.set_run_time(run_time) + result.set_start_time(self.start_time) + self.suite.add_test_case_result(result) + del self.pending_test_case_results[test_id] + + def startTestRun(self): + self.suite.set_start_time(self.time_getter()) + super(_TextAndXMLTestResult, self).startTestRun() + + def stopTestRun(self): + self.suite.set_end_time(self.time_getter()) + # All pending_test_case_results will be added to the suite and removed from + # the pending_test_case_results dictionary. Grabbing the write lock to avoid + # results from being added during this process to avoid duplicating adds or + # accidentally erasing newly appended pending results. + with self._pending_test_case_results_lock: + # Errors in the test fixture (setUpModule, tearDownModule, + # setUpClass, tearDownClass) can leave a pending result which + # never gets added to the suite. The runner calls stopTestRun + # which gives us an opportunity to add these errors for + # reporting here. + for test_id in self.pending_test_case_results: + result = self.pending_test_case_results[test_id] + if hasattr(self, 'start_time'): + run_time = self.suite.overall_end_time - self.start_time + result.set_run_time(run_time) + result.set_start_time(self.start_time) + self.suite.add_test_case_result(result) + self.pending_test_case_results.clear() + + def _exc_info_to_string(self, err, test=None): + """Converts a sys.exc_info()-style tuple of values into a string. + + This method must be overridden because the method signature in + unittest.TestResult changed between Python 2.2 and 2.4. + + Args: + err: A sys.exc_info() tuple of values for an error. + test: The test method. + + Returns: + A formatted exception string. + """ + if test: + return super(_TextAndXMLTestResult, self)._exc_info_to_string(err, test) + return ''.join(traceback.format_exception(*err)) + + def add_pending_test_case_result(self, test, error_summary=None, + skip_reason=None): + """Adds result information to a test case result which may still be running. + + If a result entry for the test already exists, add_pending_test_case_result + will add error summary tuples and/or overwrite skip_reason for the result. + If it does not yet exist, a result entry will be created. + Note that a test result is considered to have been run and passed + only if there are no errors or skip_reason. + + Args: + test: A test method as defined by unittest + error_summary: A 4-tuple with the following entries: + 1) a string identifier of either "failure" or "error" + 2) an exception_type + 3) an exception_message + 4) a string version of a sys.exc_info()-style tuple of values + ('error', err[0], err[1], self._exc_info_to_string(err)) + If the length of errors is 0, then the test is either passed or + skipped. + skip_reason: a string explaining why the test was skipped + """ + with self._pending_test_case_results_lock: + test_id = id(test) + if test_id not in self.pending_test_case_results: + self.pending_test_case_results[test_id] = self._TEST_CASE_RESULT_CLASS( + test) + if error_summary: + self.pending_test_case_results[test_id].errors.append(error_summary) + if skip_reason: + self.pending_test_case_results[test_id].skip_reason = skip_reason + + def delete_pending_test_case_result(self, test): + with self._pending_test_case_results_lock: + test_id = id(test) + del self.pending_test_case_results[test_id] + + def get_pending_test_case_result(self, test): + test_id = id(test) + return self.pending_test_case_results.get(test_id, None) + + def addSuccess(self, test): + super(_TextAndXMLTestResult, self).addSuccess(test) + self.add_pending_test_case_result(test) + + def addError(self, test, err): + super(_TextAndXMLTestResult, self).addError(test, err) + error_summary = ('error', err[0], err[1], + self._exc_info_to_string(err, test=test)) + self.add_pending_test_case_result(test, error_summary=error_summary) + + def addFailure(self, test, err): + super(_TextAndXMLTestResult, self).addFailure(test, err) + error_summary = ('failure', err[0], err[1], + self._exc_info_to_string(err, test=test)) + self.add_pending_test_case_result(test, error_summary=error_summary) + + def addSkip(self, test, reason): + super(_TextAndXMLTestResult, self).addSkip(test, reason) + self.add_pending_test_case_result(test, skip_reason=reason) + + def addExpectedFailure(self, test, err): + super(_TextAndXMLTestResult, self).addExpectedFailure(test, err) + if callable(getattr(test, 'recordProperty', None)): + test.recordProperty('EXPECTED_FAILURE', + self._exc_info_to_string(err, test=test)) + self.add_pending_test_case_result(test) + + def addUnexpectedSuccess(self, test): + super(_TextAndXMLTestResult, self).addUnexpectedSuccess(test) + test_name = test.id() or str(test) + error_summary = ('error', '', '', + 'Test case %s should have failed, but passed.' + % (test_name)) + self.add_pending_test_case_result(test, error_summary=error_summary) + + def addSubTest(self, test, subtest, err): # pylint: disable=invalid-name + super(_TextAndXMLTestResult, self).addSubTest(test, subtest, err) + if err is not None: + if issubclass(err[0], test.failureException): + error_summary = ('failure', err[0], err[1], + self._exc_info_to_string(err, test=test)) + else: + error_summary = ('error', err[0], err[1], + self._exc_info_to_string(err, test=test)) + else: + error_summary = None + self.add_pending_test_case_result(subtest, error_summary=error_summary) + + def printErrors(self): + super(_TextAndXMLTestResult, self).printErrors() + self.xml_stream.write('<?xml version="1.0"?>\n') + self.suite.print_xml_summary(self.xml_stream) + + +class TextAndXMLTestRunner(unittest.TextTestRunner): + """A test runner that produces both formatted text results and XML. + + It prints out the names of tests as they are run, errors as they + occur, and a summary of the results at the end of the test run. + """ + + _TEST_RESULT_CLASS = _TextAndXMLTestResult + + _xml_stream = None + _testsuites_properties = {} + + def __init__(self, xml_stream=None, *args, **kwargs): + """Initialize a TextAndXMLTestRunner. + + Args: + xml_stream: file-like or None; XML-formatted test results are output + via this object's write() method. If None (the default), the + new instance behaves as described in the set_default_xml_stream method + documentation below. + *args: passed unmodified to unittest.TextTestRunner.__init__. + **kwargs: passed unmodified to unittest.TextTestRunner.__init__. + """ + super(TextAndXMLTestRunner, self).__init__(*args, **kwargs) + if xml_stream is not None: + self._xml_stream = xml_stream + # else, do not set self._xml_stream to None -- this allows implicit fallback + # to the class attribute's value. + + @classmethod + def set_default_xml_stream(cls, xml_stream): + """Sets the default XML stream for the class. + + Args: + xml_stream: file-like or None; used for instances when xml_stream is None + or not passed to their constructors. If None is passed, instances + created with xml_stream=None will act as ordinary TextTestRunner + instances; this is the default state before any calls to this method + have been made. + """ + cls._xml_stream = xml_stream + + def _makeResult(self): + if self._xml_stream is None: + return super(TextAndXMLTestRunner, self)._makeResult() + else: + return self._TEST_RESULT_CLASS( + self._xml_stream, self.stream, self.descriptions, self.verbosity, + testsuites_properties=self._testsuites_properties) + + @classmethod + def set_testsuites_property(cls, key, value): + cls._testsuites_properties[key] = value diff --git a/absl/tests/__init__.py b/absl/tests/__init__.py new file mode 100644 index 0000000..a3bd1cd --- /dev/null +++ b/absl/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/absl/tests/app_test.py b/absl/tests/app_test.py new file mode 100644 index 0000000..1d8b764 --- /dev/null +++ b/absl/tests/app_test.py @@ -0,0 +1,359 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for app.py.""" + +import contextlib +import copy +import enum +import io +import os +import re +import subprocess +import sys +import tempfile +from unittest import mock + +from absl import app +from absl import flags +from absl.testing import _bazelize_command +from absl.testing import absltest +from absl.testing import flagsaver +from absl.tests import app_test_helper + + +FLAGS = flags.FLAGS + + +_newline_regex = re.compile('(\r\n)|\r') + + +@contextlib.contextmanager +def patch_main_module_docstring(docstring): + old_doc = sys.modules['__main__'].__doc__ + sys.modules['__main__'].__doc__ = docstring + yield + sys.modules['__main__'].__doc__ = old_doc + + +def _normalize_newlines(s): + return re.sub('(\r\n)|\r', '\n', s) + + +class UnitTests(absltest.TestCase): + + def test_install_exception_handler(self): + with self.assertRaises(TypeError): + app.install_exception_handler(1) + + def test_usage(self): + with mock.patch.object( + sys, 'stderr', new=io.StringIO()) as mock_stderr: + app.usage() + self.assertIn(__doc__, mock_stderr.getvalue()) + # Assert that flags are written to stderr. + self.assertIn('\n --[no]helpfull:', mock_stderr.getvalue()) + + def test_usage_shorthelp(self): + with mock.patch.object( + sys, 'stderr', new=io.StringIO()) as mock_stderr: + app.usage(shorthelp=True) + # Assert that flags are NOT written to stderr. + self.assertNotIn(' --', mock_stderr.getvalue()) + + def test_usage_writeto_stderr(self): + with mock.patch.object( + sys, 'stdout', new=io.StringIO()) as mock_stdout: + app.usage(writeto_stdout=True) + self.assertIn(__doc__, mock_stdout.getvalue()) + + def test_usage_detailed_error(self): + with mock.patch.object( + sys, 'stderr', new=io.StringIO()) as mock_stderr: + app.usage(detailed_error='BAZBAZ') + self.assertIn('BAZBAZ', mock_stderr.getvalue()) + + def test_usage_exitcode(self): + with mock.patch.object(sys, 'stderr', new=sys.stderr): + try: + app.usage(exitcode=2) + self.fail('app.usage(exitcode=1) should raise SystemExit') + except SystemExit as e: + self.assertEqual(2, e.code) + + def test_usage_expands_docstring(self): + with patch_main_module_docstring('Name: %s, %%s'): + with mock.patch.object( + sys, 'stderr', new=io.StringIO()) as mock_stderr: + app.usage() + self.assertIn('Name: {}, %s'.format(sys.argv[0]), + mock_stderr.getvalue()) + + def test_usage_does_not_expand_bad_docstring(self): + with patch_main_module_docstring('Name: %s, %%s, %@'): + with mock.patch.object( + sys, 'stderr', new=io.StringIO()) as mock_stderr: + app.usage() + self.assertIn('Name: %s, %%s, %@', mock_stderr.getvalue()) + + @flagsaver.flagsaver + def test_register_and_parse_flags_with_usage_exits_on_only_check_args(self): + done = app._register_and_parse_flags_with_usage.done + try: + app._register_and_parse_flags_with_usage.done = False + with self.assertRaises(SystemExit): + app._register_and_parse_flags_with_usage( + argv=['./program', '--only_check_args']) + finally: + app._register_and_parse_flags_with_usage.done = done + + def test_register_and_parse_flags_with_usage_exits_on_second_run(self): + with self.assertRaises(SystemError): + app._register_and_parse_flags_with_usage() + + +class FunctionalTests(absltest.TestCase): + """Functional tests that use runs app_test_helper.""" + + helper_type = 'pure_python' + + def run_helper(self, expect_success, + expected_stdout_substring=None, expected_stderr_substring=None, + arguments=(), + env_overrides=None): + env = os.environ.copy() + env['APP_TEST_HELPER_TYPE'] = self.helper_type + env['PYTHONIOENCODING'] = 'utf8' + if env_overrides: + env.update(env_overrides) + + helper = 'absl/tests/app_test_helper_{}'.format(self.helper_type) + process = subprocess.Popen( + [_bazelize_command.get_executable_path(helper)] + list(arguments), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, env=env, universal_newlines=False) + stdout, stderr = process.communicate() + # In Python 2, we can't control the encoding used by universal_newline + # mode, which can cause UnicodeDecodeErrors when subprocess tries to + # convert the bytes to unicode, so we have to decode it manually. + stdout = _normalize_newlines(stdout.decode('utf8')) + stderr = _normalize_newlines(stderr.decode('utf8')) + + message = (u'Command: {command}\n' + 'Exit Code: {exitcode}\n' + '===== stdout =====\n{stdout}' + '===== stderr =====\n{stderr}' + '=================='.format( + command=' '.join([helper] + list(arguments)), + exitcode=process.returncode, + stdout=stdout or '<no output>\n', + stderr=stderr or '<no output>\n')) + if expect_success: + self.assertEqual(0, process.returncode, msg=message) + else: + self.assertNotEqual(0, process.returncode, msg=message) + + if expected_stdout_substring: + self.assertIn(expected_stdout_substring, stdout, message) + if expected_stderr_substring: + self.assertIn(expected_stderr_substring, stderr, message) + + return process.returncode, stdout, stderr + + def test_help(self): + _, _, stderr = self.run_helper( + False, + arguments=['--help'], + expected_stdout_substring=app_test_helper.__doc__) + self.assertNotIn('--', stderr) + + def test_helpfull_basic(self): + self.run_helper( + False, + arguments=['--helpfull'], + # --logtostderr is from absl.logging module. + expected_stdout_substring='--[no]logtostderr') + + def test_helpfull_unicode_flag_help(self): + _, stdout, _ = self.run_helper( + False, + arguments=['--helpfull'], + expected_stdout_substring='str_flag_with_unicode_args') + + self.assertIn(u'smile:\U0001F604', stdout) + + self.assertIn(u'thumb:\U0001F44D', stdout) + + def test_helpshort(self): + _, _, stderr = self.run_helper( + False, + arguments=['--helpshort'], + expected_stdout_substring=app_test_helper.__doc__) + self.assertNotIn('--', stderr) + + def test_custom_main(self): + self.run_helper( + True, + env_overrides={'APP_TEST_CUSTOM_MAIN_FUNC': 'custom_main'}, + expected_stdout_substring='Function called: custom_main.') + + def test_custom_argv(self): + self.run_helper( + True, + expected_stdout_substring='argv: ./program pos_arg1', + env_overrides={ + 'APP_TEST_CUSTOM_ARGV': './program --noraise_exception pos_arg1', + 'APP_TEST_PRINT_ARGV': '1', + }) + + def test_gwq_status_file_on_exception(self): + if self.helper_type == 'pure_python': + # Pure python binary does not write to GWQ Status. + return + + tmpdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value) + self.run_helper( + False, + arguments=['--raise_exception'], + env_overrides={'GOOGLE_STATUS_DIR': tmpdir}) + with open(os.path.join(tmpdir, 'STATUS')) as status_file: + self.assertIn('MyException:', status_file.read()) + + def test_faulthandler_dumps_stack_on_sigsegv(self): + return_code, _, _ = self.run_helper( + False, + expected_stderr_substring='app_test_helper.py", line', + arguments=['--faulthandler_sigsegv']) + # sigsegv returns 3 on Windows, and -11 on LINUX/macOS. + expected_return_code = 3 if os.name == 'nt' else -11 + self.assertEqual(expected_return_code, return_code) + + def test_top_level_exception(self): + self.run_helper( + False, + arguments=['--raise_exception'], + expected_stderr_substring='MyException') + + def test_only_check_args(self): + self.run_helper( + True, + arguments=['--only_check_args', '--raise_exception']) + + def test_only_check_args_failure(self): + self.run_helper( + False, + arguments=['--only_check_args', '--banana'], + expected_stderr_substring='FATAL Flags parsing error') + + def test_usage_error(self): + exitcode, _, _ = self.run_helper( + False, + arguments=['--raise_usage_error'], + expected_stderr_substring=app_test_helper.__doc__) + self.assertEqual(1, exitcode) + + def test_usage_error_exitcode(self): + exitcode, _, _ = self.run_helper( + False, + arguments=['--raise_usage_error', '--usage_error_exitcode=88'], + expected_stderr_substring=app_test_helper.__doc__) + self.assertEqual(88, exitcode) + + def test_exception_handler(self): + exception_handler_messages = ( + 'MyExceptionHandler: first\nMyExceptionHandler: second\n') + self.run_helper( + False, + arguments=['--raise_exception'], + expected_stdout_substring=exception_handler_messages) + + def test_exception_handler_not_called(self): + _, _, stdout = self.run_helper(True) + self.assertNotIn('MyExceptionHandler', stdout) + + def test_print_init_callbacks(self): + _, stdout, _ = self.run_helper( + expect_success=True, arguments=['--print_init_callbacks']) + self.assertIn('before app.run', stdout) + self.assertIn('during real_main', stdout) + + +class FlagDeepCopyTest(absltest.TestCase): + """Make sure absl flags are copy.deepcopy() compatible.""" + + def test_deepcopyable(self): + copy.deepcopy(FLAGS) + # Nothing to assert + + +class FlagValuesExternalizationTest(absltest.TestCase): + """Test to make sure FLAGS can be serialized out and parsed back in.""" + + @flagsaver.flagsaver + def test_nohelp_doesnt_show_help(self): + with self.assertRaisesWithPredicateMatch(SystemExit, + lambda e: e.code == 1): + app.run( + len, + argv=[ + './program', '--nohelp', '--helpshort=false', '--helpfull=0', + '--helpxml=f' + ]) + + @flagsaver.flagsaver + def test_serialize_roundtrip(self): + # Use the global 'FLAGS' as the source, to ensure all the framework defined + # flags will go through the round trip process. + flags.DEFINE_string('testflag', 'testval', 'help', flag_values=FLAGS) + + flags.DEFINE_multi_enum('test_multi_enum_flag', + ['x', 'y'], ['x', 'y', 'z'], + 'Multi enum help.', + flag_values=FLAGS) + + class Fruit(enum.Enum): + APPLE = 1 + ORANGE = 2 + TOMATO = 3 + flags.DEFINE_multi_enum_class('test_multi_enum_class_flag', + ['APPLE', 'TOMATO'], Fruit, + 'Fruit help.', + flag_values=FLAGS) + + new_flag_values = flags.FlagValues() + new_flag_values.append_flag_values(FLAGS) + + FLAGS.testflag = 'roundtrip_me' + FLAGS.test_multi_enum_flag = ['y', 'z'] + FLAGS.test_multi_enum_class_flag = [Fruit.ORANGE, Fruit.APPLE] + argv = ['binary_name'] + FLAGS.flags_into_string().splitlines() + + self.assertNotEqual(new_flag_values['testflag'], FLAGS.testflag) + self.assertNotEqual(new_flag_values['test_multi_enum_flag'], + FLAGS.test_multi_enum_flag) + self.assertNotEqual(new_flag_values['test_multi_enum_class_flag'], + FLAGS.test_multi_enum_class_flag) + new_flag_values(argv) + self.assertEqual(new_flag_values.testflag, FLAGS.testflag) + self.assertEqual(new_flag_values.test_multi_enum_flag, + FLAGS.test_multi_enum_flag) + self.assertEqual(new_flag_values.test_multi_enum_class_flag, + FLAGS.test_multi_enum_class_flag) + del FLAGS.testflag + del FLAGS.test_multi_enum_flag + del FLAGS.test_multi_enum_class_flag + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/tests/app_test_helper.py b/absl/tests/app_test_helper.py new file mode 100644 index 0000000..6bd1a89 --- /dev/null +++ b/absl/tests/app_test_helper.py @@ -0,0 +1,151 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper script used by app_test.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +try: + import faulthandler +except ImportError: + faulthandler = None + +from absl import app +from absl import flags + +FLAGS = flags.FLAGS +flags.DEFINE_boolean('faulthandler_sigsegv', False, 'raise SIGSEGV') +flags.DEFINE_boolean('raise_exception', False, 'Raise MyException from main.') +flags.DEFINE_boolean( + 'raise_usage_error', False, 'Raise app.UsageError from main.') +flags.DEFINE_integer( + 'usage_error_exitcode', None, 'The exitcode if app.UsageError if raised.') +flags.DEFINE_string( + 'str_flag_with_unicode_args', u'thumb:\U0001F44D', u'smile:\U0001F604') +flags.DEFINE_boolean('print_init_callbacks', False, + 'print init callbacks and exit') + + +class MyException(Exception): + pass + + +class MyExceptionHandler(app.ExceptionHandler): + + def __init__(self, message): + self.message = message + + def handle(self, exc): + sys.stdout.write('MyExceptionHandler: {}\n'.format(self.message)) + + +def real_main(argv): + """The main function.""" + if os.environ.get('APP_TEST_PRINT_ARGV', False): + sys.stdout.write('argv: {}\n'.format(' '.join(argv))) + + if FLAGS.raise_exception: + raise MyException + + if FLAGS.raise_usage_error: + if FLAGS.usage_error_exitcode is not None: + raise app.UsageError('Error!', FLAGS.usage_error_exitcode) + else: + raise app.UsageError('Error!') + + if FLAGS.faulthandler_sigsegv: + faulthandler._sigsegv() # pylint: disable=protected-access + sys.exit(1) # Should not reach here. + + if FLAGS.print_init_callbacks: + app.call_after_init(lambda: _callback_results.append('during real_main')) + for value in _callback_results: + print('callback: {}'.format(value)) + sys.exit(0) + + # Ensure that we have a random C++ flag in flags.FLAGS; this shows + # us that app.run() did the right thing in conjunction with C++ flags. + helper_type = os.environ['APP_TEST_HELPER_TYPE'] + if helper_type == 'clif': + if 'heap_check_before_constructors' in flags.FLAGS: + print('PASS: C++ flag present and helper_type is {}'.format(helper_type)) + sys.exit(0) + else: + print('FAILED: C++ flag absent but helper_type is {}'.format(helper_type)) + sys.exit(1) + elif helper_type == 'pure_python': + if 'heap_check_before_constructors' in flags.FLAGS: + print('FAILED: C++ flag present but helper_type is pure_python') + sys.exit(1) + else: + print('PASS: C++ flag absent and helper_type is pure_python') + sys.exit(0) + else: + print('Unexpected helper_type "{}"'.format(helper_type)) + sys.exit(1) + + +def custom_main(argv): + print('Function called: custom_main.') + real_main(argv) + + +def main(argv): + print('Function called: main.') + real_main(argv) + + +flags_parser_argv_sentinel = object() + + +def flags_parser_main(argv): + print('Function called: main_with_flags_parser.') + if argv is not flags_parser_argv_sentinel: + sys.exit( + 'FAILED: main function should be called with the return value of ' + 'flags_parser, but found: {}'.format(argv)) + + +def flags_parser(argv): + print('Function called: flags_parser.') + if os.environ.get('APP_TEST_FLAGS_PARSER_PARSE_FLAGS', None): + FLAGS(argv) + return flags_parser_argv_sentinel + + +# Holds results from callbacks triggered by `app.run_after_init`. +_callback_results = [] + +if __name__ == '__main__': + kwargs = {'main': main} + main_function_name = os.environ.get('APP_TEST_CUSTOM_MAIN_FUNC', None) + if main_function_name: + kwargs['main'] = globals()[main_function_name] + custom_argv = os.environ.get('APP_TEST_CUSTOM_ARGV', None) + if custom_argv: + kwargs['argv'] = custom_argv.split(' ') + if os.environ.get('APP_TEST_USE_CUSTOM_PARSER', None): + kwargs['flags_parser'] = flags_parser + + app.call_after_init(lambda: _callback_results.append('before app.run')) + app.install_exception_handler(MyExceptionHandler('first')) + app.install_exception_handler(MyExceptionHandler('second')) + app.run(**kwargs) + + sys.exit('This is not reachable.') diff --git a/absl/tests/command_name_test.py b/absl/tests/command_name_test.py new file mode 100644 index 0000000..2679521 --- /dev/null +++ b/absl/tests/command_name_test.py @@ -0,0 +1,108 @@ +# -*- coding=utf-8 -*- +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for absl.command_name.""" + +import ctypes +import errno +import os +import unittest +from unittest import mock + +from absl import command_name +from absl.testing import absltest + + +def _get_kernel_process_name(): + """Returns the Kernel's name for our process or an empty string.""" + try: + with open('/proc/self/status', 'rt') as status_file: + for line in status_file: + if line.startswith('Name:'): + return line.split(':', 2)[1].strip().encode('ascii', 'replace') + return b'' + except IOError: + return b'' + + +def _is_prctl_syscall_available(): + try: + libc = ctypes.CDLL('libc.so.6', use_errno=True) + except OSError: + return False + zero = ctypes.c_ulong(0) + try: + status = libc.prctl(zero, zero, zero, zero, zero) + except AttributeError: + return False + if status < 0 and errno.ENOSYS == ctypes.get_errno(): + return False + return True + + +@unittest.skipIf(not _get_kernel_process_name(), + '_get_kernel_process_name() fails.') +class CommandNameTest(absltest.TestCase): + + def assertProcessNameSimilarTo(self, new_name): + if not isinstance(new_name, bytes): + new_name = new_name.encode('ascii', 'replace') + actual_name = _get_kernel_process_name() + self.assertTrue(actual_name) + self.assertTrue(new_name.startswith(actual_name), + msg='set {!r} vs found {!r}'.format(new_name, actual_name)) + + @unittest.skipIf(not os.access('/proc/self/comm', os.W_OK), + '/proc/self/comm is not writeable.') + def test_set_kernel_process_name(self): + new_name = u'ProcessNam0123456789abcdefghijklmnöp' + command_name.set_kernel_process_name(new_name) + self.assertProcessNameSimilarTo(new_name) + + @unittest.skipIf(not _is_prctl_syscall_available(), + 'prctl() system call missing from libc.so.6.') + def test_set_kernel_process_name_no_proc_file(self): + new_name = b'NoProcFile0123456789abcdefghijklmnop' + mock_open = mock.mock_open() + with mock.patch.object(command_name, 'open', mock_open, create=True): + mock_open.side_effect = IOError('mock open that raises.') + command_name.set_kernel_process_name(new_name) + mock_open.assert_called_with('/proc/self/comm', mock.ANY) + self.assertProcessNameSimilarTo(new_name) + + def test_set_kernel_process_name_failure(self): + starting_name = _get_kernel_process_name() + new_name = b'NameTest' + mock_open = mock.mock_open() + mock_ctypes_cdll = mock.patch('ctypes.CDLL') + with mock.patch.object(command_name, 'open', mock_open, create=True): + with mock.patch('ctypes.CDLL') as mock_ctypes_cdll: + mock_open.side_effect = IOError('mock open that raises.') + mock_libc = mock.Mock(['prctl']) + mock_ctypes_cdll.return_value = mock_libc + command_name.set_kernel_process_name(new_name) + mock_open.assert_called_with('/proc/self/comm', mock.ANY) + self.assertEqual(1, mock_libc.prctl.call_count) + self.assertEqual(starting_name, _get_kernel_process_name()) # No change. + + def test_make_process_name_useful(self): + test_name = 'hello.from.test' + with mock.patch('sys.argv', [test_name]): + command_name.make_process_name_useful() + self.assertProcessNameSimilarTo(test_name) + + +if __name__ == '__main__': + absltest.main() diff --git a/absl/tests/python_version_test.py b/absl/tests/python_version_test.py new file mode 100644 index 0000000..eebfff2 --- /dev/null +++ b/absl/tests/python_version_test.py @@ -0,0 +1,40 @@ +# Copyright 2021 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test that verifies the Python version used in bazel is expected.""" + +import sys +from absl import flags +from absl.testing import absltest + +_EXPECTED_VERSION = flags.DEFINE_string( + 'expected_version', + None, + 'The expected Python SemVer version, ' + 'can be major.minor or major.minor.patch.', +) + + +class PythonVersionTest(absltest.TestCase): + + def test_version(self): + version = _EXPECTED_VERSION.value + if not version: + self.skipTest( + 'Skipping version test since --expected_version is not specified') + num_parts = len(version.split('.')) + self.assertEqual('.'.join(map(str, sys.version_info[:num_parts])), version) + + +if __name__ == '__main__': + absltest.main() diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..13b2353 --- /dev/null +++ b/setup.py @@ -0,0 +1,77 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Abseil setup configuration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +try: + import setuptools +except ImportError: + from ez_setup import use_setuptools + use_setuptools() + import setuptools + +if sys.version_info < (3, 6): + raise RuntimeError('Python version 3.6+ is required.') + +setuptools_version = tuple( + int(x) for x in setuptools.__version__.split('.')[:2]) + +additional_kwargs = {} +if setuptools_version >= (24, 2): + # `python_requires` was added in 24.2, see + # https://packaging.python.org/guides/distributing-packages-using-setuptools/#python-requires + additional_kwargs['python_requires'] = '>=3.6' + +_README_PATH = os.path.join( + os.path.dirname(os.path.realpath(__file__)), 'README.md') +with open(_README_PATH, 'rb') as fp: + LONG_DESCRIPTION = fp.read().decode('utf-8') + +setuptools.setup( + name='absl-py', + version='1.1.0', + description=( + 'Abseil Python Common Libraries, ' + 'see https://github.com/abseil/abseil-py.'), + long_description=LONG_DESCRIPTION, + long_description_content_type='text/markdown', + author='The Abseil Authors', + url='https://github.com/abseil/abseil-py', + packages=setuptools.find_packages(exclude=[ + '*.tests', '*.tests.*', 'tests.*', 'tests', + ]), + include_package_data=True, + license='Apache 2.0', + classifiers=[ + 'Programming Language :: Python', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Intended Audience :: Developers', + 'Topic :: Software Development :: Libraries :: Python Modules', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + ], + **additional_kwargs, +) diff --git a/smoke_tests/sample_app.py b/smoke_tests/sample_app.py new file mode 100644 index 0000000..532a11e --- /dev/null +++ b/smoke_tests/sample_app.py @@ -0,0 +1,41 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test helper for smoke_test.sh.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from absl import app +from absl import flags +from absl import logging + +FLAGS = flags.FLAGS + +flags.DEFINE_string('echo', None, 'Text to echo.') + + +def main(argv): + del argv # Unused. + + print('Running under Python {0[0]}.{0[1]}.{0[2]}'.format(sys.version_info), + file=sys.stderr) + logging.info('echo is %s.', FLAGS.echo) + + +if __name__ == '__main__': + app.run(main) diff --git a/smoke_tests/sample_test.py b/smoke_tests/sample_test.py new file mode 100644 index 0000000..713677a --- /dev/null +++ b/smoke_tests/sample_test.py @@ -0,0 +1,33 @@ +# Copyright 2018 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test helper for smoke_test.sh.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest + + +class SampleTest(absltest.TestCase): + + def test_subtest(self): + for i in (1, 2): + with self.subTest(i=i): + self.assertEqual(i, i) + print('msg_for_test') + +if __name__ == '__main__': + absltest.main() diff --git a/smoke_tests/smoke_test.sh b/smoke_tests/smoke_test.sh new file mode 100755 index 0000000..99307a4 --- /dev/null +++ b/smoke_tests/smoke_test.sh @@ -0,0 +1,70 @@ +#!/bin/bash +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Smoke test to verify setup.py works as expected. +# Note on Windows, this must run via msys. + +# Fail on any error. Treat unset variables an error. Print commands as executed. +set -eux + +if [[ "$#" -ne "2" ]]; then + echo 'Must specify the Python interpreter and virtualenv path.' + echo 'Usage:' + echo ' smoke_tests/smoke_test.sh [Python interpreter path] [virtualenv Path]' + exit 1 +fi + +ABSL_PYTHON=$1 +ABSL_VIRTUALENV=$2 +TMP_DIR=$(mktemp -d) +trap "{ rm -rf ${TMP_DIR}; }" EXIT +# Do not bootstrap pip/setuptools, they are manually installed with get-pip.py +# inside the virtualenv. +if ${ABSL_VIRTUALENV} --help | grep '\--no-site-packages'; then + no_site_packages_flag="--no-site-packages" +else + # --no-site-packages becomes the default in version 20 and is no longer a + # flag. + no_site_packages_flag="" +fi +${ABSL_VIRTUALENV} ${no_site_packages_flag} --no-pip --no-setuptools --no-wheel \ + -p ${ABSL_PYTHON} ${TMP_DIR} + +# Temporarily disable unbound variable errors to activate virtualenv. +set +u +if [[ $(uname -s) == MSYS* ]]; then + source ${TMP_DIR}/Scripts/activate +else + source ${TMP_DIR}/bin/activate +fi +set -u + +trap 'deactivate' EXIT + +# When running macOS <= 10.12, pip 9.0.3 is required to connect to PyPI. +# So we need to manually use the latest pip to install absl-py. See: +# https://mail.python.org/pipermail/distutils-sig/2018-April/032114.html +if [[ "$(python -c "import sys; print(sys.version_info.major, sys.version_info.minor)")" == "3 6" ]]; then + # Latest get-pip.py no longer supports Python 3.6. + curl https://bootstrap.pypa.io/pip/3.6/get-pip.py | python +else + curl https://bootstrap.pypa.io/get-pip.py | python +fi +pip --version + +python --version +python setup.py install +python smoke_tests/sample_app.py --echo smoke 2>&1 |grep 'echo is smoke.' +python smoke_tests/sample_test.py 2>&1 | grep 'msg_for_test' |