aboutsummaryrefslogtreecommitdiff
path: root/catapult/telemetry/telemetry/testing/simple_mock.py
diff options
context:
space:
mode:
Diffstat (limited to 'catapult/telemetry/telemetry/testing/simple_mock.py')
-rw-r--r--catapult/telemetry/telemetry/testing/simple_mock.py98
1 files changed, 98 insertions, 0 deletions
diff --git a/catapult/telemetry/telemetry/testing/simple_mock.py b/catapult/telemetry/telemetry/testing/simple_mock.py
new file mode 100644
index 00000000..dbd02b68
--- /dev/null
+++ b/catapult/telemetry/telemetry/testing/simple_mock.py
@@ -0,0 +1,98 @@
+# Copyright 2012 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+"""A very very simple mock object harness."""
+
+DONT_CARE = ''
+
+class MockFunctionCall(object):
+ def __init__(self, name):
+ self.name = name
+ self.args = tuple()
+ self.return_value = None
+ self.when_called_handlers = []
+
+ def WithArgs(self, *args):
+ self.args = args
+ return self
+
+ def WillReturn(self, value):
+ self.return_value = value
+ return self
+
+ def WhenCalled(self, handler):
+ self.when_called_handlers.append(handler)
+
+ def VerifyEquals(self, got):
+ if self.name != got.name:
+ raise Exception('Self %s, got %s' % (repr(self), repr(got)))
+ if len(self.args) != len(got.args):
+ raise Exception('Self %s, got %s' % (repr(self), repr(got)))
+ for i in range(len(self.args)):
+ self_a = self.args[i]
+ got_a = got.args[i]
+ if self_a == DONT_CARE:
+ continue
+ if self_a != got_a:
+ raise Exception('Self %s, got %s' % (repr(self), repr(got)))
+
+ def __repr__(self):
+ def arg_to_text(a):
+ if a == DONT_CARE:
+ return '_'
+ return repr(a)
+ args_text = ', '.join([arg_to_text(a) for a in self.args])
+ if self.return_value in (None, DONT_CARE):
+ return '%s(%s)' % (self.name, args_text)
+ return '%s(%s)->%s' % (self.name, args_text, repr(self.return_value))
+
+class MockTrace(object):
+ def __init__(self):
+ self.expected_calls = []
+ self.next_call_index = 0
+
+class MockObject(object):
+ def __init__(self, parent_mock=None):
+ if parent_mock:
+ self._trace = parent_mock._trace # pylint: disable=protected-access
+ else:
+ self._trace = MockTrace()
+
+ def __setattr__(self, name, value):
+ if (not hasattr(self, '_trace') or
+ hasattr(value, 'is_hook')):
+ object.__setattr__(self, name, value)
+ return
+ assert isinstance(value, MockObject)
+ object.__setattr__(self, name, value)
+
+ def SetAttribute(self, name, value):
+ setattr(self, name, value)
+
+ def ExpectCall(self, func_name, *args):
+ assert self._trace.next_call_index == 0
+ if not hasattr(self, func_name):
+ self._install_hook(func_name)
+
+ call = MockFunctionCall(func_name)
+ self._trace.expected_calls.append(call)
+ call.WithArgs(*args)
+ return call
+
+ def _install_hook(self, func_name):
+ def handler(*args, **_):
+ got_call = MockFunctionCall(
+ func_name).WithArgs(*args).WillReturn(DONT_CARE)
+ if self._trace.next_call_index >= len(self._trace.expected_calls):
+ raise Exception(
+ 'Call to %s was not expected, at end of programmed trace.' %
+ repr(got_call))
+ expected_call = self._trace.expected_calls[
+ self._trace.next_call_index]
+ expected_call.VerifyEquals(got_call)
+ self._trace.next_call_index += 1
+ for h in expected_call.when_called_handlers:
+ h(*args)
+ return expected_call.return_value
+ handler.is_hook = True
+ setattr(self, func_name, handler)