aboutsummaryrefslogtreecommitdiff
path: root/absl/testing
diff options
context:
space:
mode:
authorAbseil Team <absl-team@google.com>2020-12-14 13:02:02 -0800
committerCopybara-Service <copybara-worker@google.com>2020-12-14 13:02:30 -0800
commit94670de3dcf271f95d000cd8cdad754014a4cb5d (patch)
tree05d8ab4ffa1a70fdd88ce16a60f0ffb15fb0a5f4 /absl/testing
parentd61b0b6bda1902f645e5bbbc3f138c142767befa (diff)
downloadabsl-py-94670de3dcf271f95d000cd8cdad754014a4cb5d.tar.gz
Support using flagholder in flagsaver.
`(HOLDER, value)` pairs can now be specified in positional arguments. It is equivalent to specifying `**{HOLDER.name: value}` We can mix and match holder and non-holder overrides. So usages like `flagsaver((HOLDER1, value1), (HOLDER2, value2), flag_name=value)` are legal as well. PiperOrigin-RevId: 347450631 Change-Id: I45bdf7bd56ad1d65d62ff34536ef09e47fce7ae8
Diffstat (limited to 'absl/testing')
-rw-r--r--absl/testing/flagsaver.py23
-rw-r--r--absl/testing/tests/flagsaver_test.py63
2 files changed, 77 insertions, 9 deletions
diff --git a/absl/testing/flagsaver.py b/absl/testing/flagsaver.py
index c33d56a..7fe95fe 100644
--- a/absl/testing/flagsaver.py
+++ b/absl/testing/flagsaver.py
@@ -27,6 +27,11 @@ is temporarily set to 'foo'.
def some_func():
do_stuff()
+ # Use a decorator which can optionally override flags with flagholders.
+ @flagsaver.flagsaver((module.FOO_FLAG, 'foo'), (other_mod.BAR_FLAG, 23))
+ def some_func():
+ do_stuff()
+
# Use a decorator which does not override flags itself.
@flagsaver.flagsaver
def some_func():
@@ -70,7 +75,8 @@ def flagsaver(*args, **kwargs):
"""The main flagsaver interface. See module doc for usage."""
if not args:
return _FlagOverrider(**kwargs)
- elif len(args) == 1:
+ # args can be [func] if used as `@flagsaver` instead of `@flagsaver(...)`
+ if len(args) == 1 and callable(args[0]):
if kwargs:
raise ValueError(
"It's invalid to specify both positional and keyword parameters.")
@@ -78,9 +84,18 @@ def flagsaver(*args, **kwargs):
if inspect.isclass(func):
raise TypeError('@flagsaver.flagsaver cannot be applied to a class.')
return _wrap(func, {})
- else:
- raise ValueError(
- "It's invalid to specify more than one positional parameters.")
+ # args can be a list of (FlagHolder, value) pairs.
+ # In which case they augment any specified kwargs.
+ for arg in args:
+ if not isinstance(arg, tuple) or len(arg) != 2:
+ raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,))
+ holder, value = arg
+ if not isinstance(holder, flags.FlagHolder):
+ raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,))
+ if holder.name in kwargs:
+ raise ValueError('Cannot set --%s multiple times' % holder.name)
+ kwargs[holder.name] = value
+ return _FlagOverrider(**kwargs)
def save_flag_values(flag_values=FLAGS):
diff --git a/absl/testing/tests/flagsaver_test.py b/absl/testing/tests/flagsaver_test.py
index ed428df..3439a32 100644
--- a/absl/testing/tests/flagsaver_test.py
+++ b/absl/testing/tests/flagsaver_test.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""Tests for flagsaver."""
from __future__ import absolute_import
@@ -31,6 +30,11 @@ flags.register_validator('flagsaver_test_validated_flag', lambda x: not x)
flags.DEFINE_string('flagsaver_test_validated_flag1', None, 'flag to test with')
flags.DEFINE_string('flagsaver_test_validated_flag2', None, 'flag to test with')
+INT_FLAG = flags.DEFINE_integer(
+ 'flagsaver_test_int_flag', default=1, help='help')
+STR_FLAG = flags.DEFINE_string(
+ 'flagsaver_test_str_flag', default='str default', help='help')
+
@flags.multi_flags_validator(
('flagsaver_test_validated_flag1', 'flagsaver_test_validated_flag2'))
@@ -65,6 +69,24 @@ class FlagSaverTest(absltest.TestCase):
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
+ def test_context_manager_with_flagholders(self):
+ with flagsaver.flagsaver((INT_FLAG, 3), (STR_FLAG, 'new value')):
+ self.assertEqual('new value', STR_FLAG.value)
+ self.assertEqual(3, INT_FLAG.value)
+ FLAGS.flagsaver_test_flag1 = 'another value'
+ self.assertEqual(INT_FLAG.value, INT_FLAG.default)
+ self.assertEqual(STR_FLAG.value, STR_FLAG.default)
+ self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
+
+ def test_context_manager_with_overrides_and_flagholders(self):
+ with flagsaver.flagsaver((INT_FLAG, 3), flagsaver_test_flag0='new value'):
+ self.assertEqual(STR_FLAG.default, STR_FLAG.value)
+ self.assertEqual(3, INT_FLAG.value)
+ FLAGS.flagsaver_test_flag0 = 'new value'
+ self.assertEqual(INT_FLAG.value, INT_FLAG.default)
+ self.assertEqual(STR_FLAG.value, STR_FLAG.default)
+ self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
+
def test_context_manager_with_cross_validated_overrides_set_together(self):
# When the flags are set in the same flagsaver call their validators will
# be triggered only once the setting is done.
@@ -135,8 +157,7 @@ class FlagSaverTest(absltest.TestCase):
# mutate_flags returns the flag value before it gets restored by
# the flagsaver decorator. So we check that flag value was
# actually changed in the method's scope.
- self.assertEqual('new value',
- mutate_flags('new value'))
+ self.assertEqual('new value', mutate_flags('new value'))
# But... notice that the flag is now unchanged0.
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
@@ -288,8 +309,8 @@ class FlagSaverTest(absltest.TestCase):
self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 2)
modify_validators()
- self.assertEqual(
- original_validators, FLAGS['flagsaver_test_flag0'].validators)
+ self.assertEqual(original_validators,
+ FLAGS['flagsaver_test_flag0'].validators)
class FlagSaverDecoratorUsageTest(absltest.TestCase):
@@ -409,6 +430,38 @@ class FlagSaverBadUsageTest(absltest.TestCase):
func_a = lambda: None
flagsaver.flagsaver(func_a, flagsaver_test_flag0='new value')
+ def test_duplicate_holder_parameters(self):
+ with self.assertRaises(ValueError):
+ flagsaver.flagsaver((INT_FLAG, 45), (INT_FLAG, 45))
+
+ def test_duplicate_holder_and_kw_parameter(self):
+ with self.assertRaises(ValueError):
+ flagsaver.flagsaver((INT_FLAG, 45), **{INT_FLAG.name: 45})
+
+ def test_both_positional_and_holder_parameters(self):
+ with self.assertRaises(ValueError):
+ func_a = lambda: None
+ flagsaver.flagsaver(func_a, (INT_FLAG, 45))
+
+ def test_holder_parameters_wrong_shape(self):
+ with self.assertRaises(ValueError):
+ flagsaver.flagsaver(INT_FLAG)
+
+ def test_holder_parameters_tuple_too_long(self):
+ with self.assertRaises(ValueError):
+ # Even if it is a bool flag, it should be a tuple
+ flagsaver.flagsaver((INT_FLAG, 4, 5))
+
+ def test_holder_parameters_tuple_wrong_type(self):
+ with self.assertRaises(ValueError):
+ # Even if it is a bool flag, it should be a tuple
+ flagsaver.flagsaver((4, INT_FLAG))
+
+ def test_both_wrong_positional_parameters(self):
+ with self.assertRaises(ValueError):
+ func_a = lambda: None
+ flagsaver.flagsaver(func_a, STR_FLAG, '45')
+
if __name__ == '__main__':
absltest.main()