# Copyright 2017 The Abseil Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Decorator and context manager for saving and restoring flag values. There are many ways to save and restore. Always use the most convenient method for a given use case. Here are examples of each method. They all call ``do_stuff()`` while ``FLAGS.someflag`` is temporarily set to ``'foo'``:: from absl.testing import flagsaver # Use a decorator which can optionally override flags via arguments. @flagsaver.flagsaver(someflag='foo') def some_func(): do_stuff() # Use a decorator which can optionally override flags with flagholders. @flagsaver.flagsaver((module.FOO_FLAG, 'foo'), (other_mod.BAR_FLAG, 23)) def some_func(): do_stuff() # Use a decorator which does not override flags itself. @flagsaver.flagsaver def some_func(): FLAGS.someflag = 'foo' do_stuff() # Use a context manager which can optionally override flags via arguments. with flagsaver.flagsaver(someflag='foo'): do_stuff() # Save and restore the flag values yourself. saved_flag_values = flagsaver.save_flag_values() try: FLAGS.someflag = 'foo' do_stuff() finally: flagsaver.restore_flag_values(saved_flag_values) We save and restore a shallow copy of each Flag object's ``__dict__`` attribute. This preserves all attributes of the flag, such as whether or not it was overridden from its default value. WARNING: Currently a flag that is saved and then deleted cannot be restored. An exception will be raised. However if you *add* a flag after saving flag values, and then restore flag values, the added flag will be deleted with no errors. """ import functools import inspect from absl import flags FLAGS = flags.FLAGS def flagsaver(*args, **kwargs): """The main flagsaver interface. See module doc for usage.""" if not args: return _FlagOverrider(**kwargs) # args can be [func] if used as `@flagsaver` instead of `@flagsaver(...)` if len(args) == 1 and callable(args[0]): if kwargs: raise ValueError( "It's invalid to specify both positional and keyword parameters.") func = args[0] if inspect.isclass(func): raise TypeError('@flagsaver.flagsaver cannot be applied to a class.') return _wrap(func, {}) # args can be a list of (FlagHolder, value) pairs. # In which case they augment any specified kwargs. for arg in args: if not isinstance(arg, tuple) or len(arg) != 2: raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,)) holder, value = arg if not isinstance(holder, flags.FlagHolder): raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,)) if holder.name in kwargs: raise ValueError('Cannot set --%s multiple times' % holder.name) kwargs[holder.name] = value return _FlagOverrider(**kwargs) def save_flag_values(flag_values=FLAGS): """Returns copy of flag values as a dict. Args: flag_values: FlagValues, the FlagValues instance with which the flag will be saved. This should almost never need to be overridden. Returns: Dictionary mapping keys to values. Keys are flag names, values are corresponding ``__dict__`` members. E.g. ``{'key': value_dict, ...}``. """ return {name: _copy_flag_dict(flag_values[name]) for name in flag_values} def restore_flag_values(saved_flag_values, flag_values=FLAGS): """Restores flag values based on the dictionary of flag values. Args: saved_flag_values: {'flag_name': value_dict, ...} flag_values: FlagValues, the FlagValues instance from which the flag will be restored. This should almost never need to be overridden. """ new_flag_names = list(flag_values) for name in new_flag_names: saved = saved_flag_values.get(name) if saved is None: # If __dict__ was not saved delete "new" flag. delattr(flag_values, name) else: if flag_values[name].value != saved['_value']: flag_values[name].value = saved['_value'] # Ensure C++ value is set. flag_values[name].__dict__ = saved def _wrap(func, overrides): """Creates a wrapper function that saves/restores flag values. Args: func: function object - This will be called between saving flags and restoring flags. overrides: {str: object} - Flag names mapped to their values. These flags will be set after saving the original flag state. Returns: return value from func() """ @functools.wraps(func) def _flagsaver_wrapper(*args, **kwargs): """Wrapper function that saves and restores flags.""" with _FlagOverrider(**overrides): return func(*args, **kwargs) return _flagsaver_wrapper class _FlagOverrider(object): """Overrides flags for the duration of the decorated function call. It also restores all original values of flags after decorated method completes. """ def __init__(self, **overrides): self._overrides = overrides self._saved_flag_values = None def __call__(self, func): if inspect.isclass(func): raise TypeError('flagsaver cannot be applied to a class.') return _wrap(func, self._overrides) def __enter__(self): self._saved_flag_values = save_flag_values(FLAGS) try: FLAGS._set_attributes(**self._overrides) except: # It may fail because of flag validators. restore_flag_values(self._saved_flag_values, FLAGS) raise def __exit__(self, exc_type, exc_value, traceback): restore_flag_values(self._saved_flag_values, FLAGS) def _copy_flag_dict(flag): """Returns a copy of the flag object's ``__dict__``. It's mostly a shallow copy of the ``__dict__``, except it also does a shallow copy of the validator list. Args: flag: flags.Flag, the flag to copy. Returns: A copy of the flag object's ``__dict__``. """ copy = flag.__dict__.copy() copy['_value'] = flag.value # Ensure correct restore for C++ flags. copy['validators'] = list(flag.validators) return copy