aboutsummaryrefslogtreecommitdiff
path: root/absl/testing/flagsaver.py
diff options
context:
space:
mode:
authorYilei Yang <yileiyang@google.com>2017-09-19 14:25:01 -0700
committerYilei Yang <yileiyang@google.com>2017-09-19 14:25:01 -0700
commit1c6972eef4d4dc774fa50f29248f776fa6628e9e (patch)
tree49f251967b459d7ec3c72042d132b77adb7d1b1a /absl/testing/flagsaver.py
downloadabsl-py-1c6972eef4d4dc774fa50f29248f776fa6628e9e.tar.gz
Initial commit: Abseil Python Common Libraries.
Diffstat (limited to 'absl/testing/flagsaver.py')
-rwxr-xr-xabsl/testing/flagsaver.py183
1 files changed, 183 insertions, 0 deletions
diff --git a/absl/testing/flagsaver.py b/absl/testing/flagsaver.py
new file mode 100755
index 0000000..a95b742
--- /dev/null
+++ b/absl/testing/flagsaver.py
@@ -0,0 +1,183 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Decorator and context manager for saving and restoring flag values.
+
+There are many ways to save and restore. Always use the most convenient method
+for a given use case.
+
+Here are examples of each method. They all call do_stuff() while FLAGS.someflag
+is temporarily set to 'foo'.
+
+ # Use a decorator which can optionally override flags via arguments.
+ @flagsaver.flagsaver(someflag='foo')
+ def some_func():
+ do_stuff()
+
+ # Use a decorator which does not override flags itself.
+ @flagsaver.flagsaver
+ def some_func():
+ FLAGS.someflag = 'foo'
+ do_stuff()
+
+ # Use a context manager which can optionally override flags via arguments.
+ with flagsaver.flagsaver(someflag='foo'):
+ do_stuff()
+
+ # Save and restore the flag values yourself.
+ saved_flag_values = flagsaver.save_flag_values()
+ try:
+ FLAGS.someflag = 'foo'
+ do_stuff()
+ finally:
+ flagsaver.restore_flag_values(saved_flag_values)
+
+We save and restore a shallow copy of each Flag object's __dict__ attribute.
+This preserves all attributes of the flag, such as whether or not it was
+overridden from its default value.
+
+WARNING: Currently a flag that is saved and then deleted cannot be restored. An
+exception will be raised. However if you *add* a flag after saving flag values,
+and then restore flag values, the added flag will be deleted with no errors.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import inspect
+
+from absl import flags
+import six
+
+FLAGS = flags.FLAGS
+
+
+def flagsaver(*args, **kwargs):
+ """The main flagsaver interface. See module doc for usage."""
+ if not args:
+ return _FlagOverrider(**kwargs)
+ elif len(args) == 1:
+ if kwargs:
+ raise ValueError(
+ "It's invalid to specify both positional and keyword parameters.")
+ func = args[0]
+ if inspect.isclass(func):
+ raise TypeError('@flagsaver.flagsaver cannot be applied to a class.')
+ return _wrap(func, {})
+ else:
+ raise ValueError(
+ "It's invalid to specify more than one positional parameters.")
+
+
+def save_flag_values(flag_values=FLAGS):
+ """Returns copy of flag values as a dict.
+
+ Args:
+ flag_values: FlagValues, the FlagValues instance with which the flag will
+ be saved. This should almost never need to be overridden.
+ Returns:
+ Dictionary mapping keys to values. Keys are flag names, values are
+ corresponding __dict__ members. E.g. {'key': value_dict, ...}.
+ """
+ return {name: _copy_flag_dict(flag_values[name]) for name in flag_values}
+
+
+def restore_flag_values(saved_flag_values, flag_values=FLAGS):
+ """Restores flag values based on the dictionary of flag values.
+
+ Args:
+ saved_flag_values: {'flag_name': value_dict, ...}
+ flag_values: FlagValues, the FlagValues instance from which the flag will
+ be restored. This should almost never need to be overridden.
+ """
+ new_flag_names = list(flag_values)
+ for name in new_flag_names:
+ saved = saved_flag_values.get(name)
+ if saved is None:
+ # If __dict__ was not saved delete "new" flag.
+ delattr(flag_values, name)
+ else:
+ if flag_values[name].value != saved['_value']:
+ flag_values[name].value = saved['_value'] # Ensure C++ value is set.
+ flag_values[name].__dict__ = saved
+
+
+def _wrap(func, overrides):
+ """Creates a wrapper function that saves/restores flag values.
+
+ Args:
+ func: function object - This will be called between saving flags and
+ restoring flags.
+ overrides: {str: object} - Flag names mapped to their values. These flags
+ will be set after saving the original flag state.
+
+ Returns:
+ return value from func()
+ """
+ @functools.wraps(func)
+ def _flagsaver_wrapper(*args, **kwargs):
+ """Wrapper function that saves and restores flags."""
+ with _FlagOverrider(**overrides):
+ return func(*args, **kwargs)
+ return _flagsaver_wrapper
+
+
+class _FlagOverrider(object):
+ """Overrides flags for the duration of the decorated function call.
+
+ It also restores all original values of flags after decorated method
+ completes.
+ """
+
+ def __init__(self, **overrides):
+ self._overrides = overrides
+ self._saved_flag_values = None
+
+ def __call__(self, func):
+ if inspect.isclass(func):
+ raise TypeError('flagsaver cannot be applied to a class.')
+ return _wrap(func, self._overrides)
+
+ def __enter__(self):
+ self._saved_flag_values = save_flag_values(FLAGS)
+ try:
+ for name, value in six.iteritems(self._overrides):
+ setattr(FLAGS, name, value)
+ except:
+ # It may fail because of flag validators.
+ restore_flag_values(self._saved_flag_values, FLAGS)
+ raise
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ restore_flag_values(self._saved_flag_values, FLAGS)
+
+
+def _copy_flag_dict(flag):
+ """Returns a copy of the flag object's __dict__.
+
+ It's mostly a shallow copy of the __dict__, except it also does a shallow
+ copy of the validator list.
+
+ Args:
+ flag: flags.Flag, the flag to copy.
+
+ Returns:
+ A copy of the flag object's __dict__.
+ """
+ copy = flag.__dict__.copy()
+ copy['_value'] = flag.value # Ensure correct restore for C++ flags.
+ copy['validators'] = list(flag.validators)
+ return copy