diff options
author | Yilei Yang <yileiyang@google.com> | 2017-09-19 14:25:01 -0700 |
---|---|---|
committer | Yilei Yang <yileiyang@google.com> | 2017-09-19 14:25:01 -0700 |
commit | 1c6972eef4d4dc774fa50f29248f776fa6628e9e (patch) | |
tree | 49f251967b459d7ec3c72042d132b77adb7d1b1a /absl/testing/flagsaver.py | |
download | absl-py-1c6972eef4d4dc774fa50f29248f776fa6628e9e.tar.gz |
Initial commit: Abseil Python Common Libraries.
Diffstat (limited to 'absl/testing/flagsaver.py')
-rwxr-xr-x | absl/testing/flagsaver.py | 183 |
1 files changed, 183 insertions, 0 deletions
diff --git a/absl/testing/flagsaver.py b/absl/testing/flagsaver.py new file mode 100755 index 0000000..a95b742 --- /dev/null +++ b/absl/testing/flagsaver.py @@ -0,0 +1,183 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Decorator and context manager for saving and restoring flag values. + +There are many ways to save and restore. Always use the most convenient method +for a given use case. + +Here are examples of each method. They all call do_stuff() while FLAGS.someflag +is temporarily set to 'foo'. + + # Use a decorator which can optionally override flags via arguments. + @flagsaver.flagsaver(someflag='foo') + def some_func(): + do_stuff() + + # Use a decorator which does not override flags itself. + @flagsaver.flagsaver + def some_func(): + FLAGS.someflag = 'foo' + do_stuff() + + # Use a context manager which can optionally override flags via arguments. + with flagsaver.flagsaver(someflag='foo'): + do_stuff() + + # Save and restore the flag values yourself. + saved_flag_values = flagsaver.save_flag_values() + try: + FLAGS.someflag = 'foo' + do_stuff() + finally: + flagsaver.restore_flag_values(saved_flag_values) + +We save and restore a shallow copy of each Flag object's __dict__ attribute. +This preserves all attributes of the flag, such as whether or not it was +overridden from its default value. + +WARNING: Currently a flag that is saved and then deleted cannot be restored. An +exception will be raised. However if you *add* a flag after saving flag values, +and then restore flag values, the added flag will be deleted with no errors. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import inspect + +from absl import flags +import six + +FLAGS = flags.FLAGS + + +def flagsaver(*args, **kwargs): + """The main flagsaver interface. See module doc for usage.""" + if not args: + return _FlagOverrider(**kwargs) + elif len(args) == 1: + if kwargs: + raise ValueError( + "It's invalid to specify both positional and keyword parameters.") + func = args[0] + if inspect.isclass(func): + raise TypeError('@flagsaver.flagsaver cannot be applied to a class.') + return _wrap(func, {}) + else: + raise ValueError( + "It's invalid to specify more than one positional parameters.") + + +def save_flag_values(flag_values=FLAGS): + """Returns copy of flag values as a dict. + + Args: + flag_values: FlagValues, the FlagValues instance with which the flag will + be saved. This should almost never need to be overridden. + Returns: + Dictionary mapping keys to values. Keys are flag names, values are + corresponding __dict__ members. E.g. {'key': value_dict, ...}. + """ + return {name: _copy_flag_dict(flag_values[name]) for name in flag_values} + + +def restore_flag_values(saved_flag_values, flag_values=FLAGS): + """Restores flag values based on the dictionary of flag values. + + Args: + saved_flag_values: {'flag_name': value_dict, ...} + flag_values: FlagValues, the FlagValues instance from which the flag will + be restored. This should almost never need to be overridden. + """ + new_flag_names = list(flag_values) + for name in new_flag_names: + saved = saved_flag_values.get(name) + if saved is None: + # If __dict__ was not saved delete "new" flag. + delattr(flag_values, name) + else: + if flag_values[name].value != saved['_value']: + flag_values[name].value = saved['_value'] # Ensure C++ value is set. + flag_values[name].__dict__ = saved + + +def _wrap(func, overrides): + """Creates a wrapper function that saves/restores flag values. + + Args: + func: function object - This will be called between saving flags and + restoring flags. + overrides: {str: object} - Flag names mapped to their values. These flags + will be set after saving the original flag state. + + Returns: + return value from func() + """ + @functools.wraps(func) + def _flagsaver_wrapper(*args, **kwargs): + """Wrapper function that saves and restores flags.""" + with _FlagOverrider(**overrides): + return func(*args, **kwargs) + return _flagsaver_wrapper + + +class _FlagOverrider(object): + """Overrides flags for the duration of the decorated function call. + + It also restores all original values of flags after decorated method + completes. + """ + + def __init__(self, **overrides): + self._overrides = overrides + self._saved_flag_values = None + + def __call__(self, func): + if inspect.isclass(func): + raise TypeError('flagsaver cannot be applied to a class.') + return _wrap(func, self._overrides) + + def __enter__(self): + self._saved_flag_values = save_flag_values(FLAGS) + try: + for name, value in six.iteritems(self._overrides): + setattr(FLAGS, name, value) + except: + # It may fail because of flag validators. + restore_flag_values(self._saved_flag_values, FLAGS) + raise + + def __exit__(self, exc_type, exc_value, traceback): + restore_flag_values(self._saved_flag_values, FLAGS) + + +def _copy_flag_dict(flag): + """Returns a copy of the flag object's __dict__. + + It's mostly a shallow copy of the __dict__, except it also does a shallow + copy of the validator list. + + Args: + flag: flags.Flag, the flag to copy. + + Returns: + A copy of the flag object's __dict__. + """ + copy = flag.__dict__.copy() + copy['_value'] = flag.value # Ensure correct restore for C++ flags. + copy['validators'] = list(flag.validators) + return copy |