diff options
Diffstat (limited to 'catapult/telemetry/telemetry/testing/simple_mock.py')
-rw-r--r-- | catapult/telemetry/telemetry/testing/simple_mock.py | 98 |
1 files changed, 98 insertions, 0 deletions
diff --git a/catapult/telemetry/telemetry/testing/simple_mock.py b/catapult/telemetry/telemetry/testing/simple_mock.py new file mode 100644 index 00000000..dbd02b68 --- /dev/null +++ b/catapult/telemetry/telemetry/testing/simple_mock.py @@ -0,0 +1,98 @@ +# Copyright 2012 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. +"""A very very simple mock object harness.""" + +DONT_CARE = '' + +class MockFunctionCall(object): + def __init__(self, name): + self.name = name + self.args = tuple() + self.return_value = None + self.when_called_handlers = [] + + def WithArgs(self, *args): + self.args = args + return self + + def WillReturn(self, value): + self.return_value = value + return self + + def WhenCalled(self, handler): + self.when_called_handlers.append(handler) + + def VerifyEquals(self, got): + if self.name != got.name: + raise Exception('Self %s, got %s' % (repr(self), repr(got))) + if len(self.args) != len(got.args): + raise Exception('Self %s, got %s' % (repr(self), repr(got))) + for i in range(len(self.args)): + self_a = self.args[i] + got_a = got.args[i] + if self_a == DONT_CARE: + continue + if self_a != got_a: + raise Exception('Self %s, got %s' % (repr(self), repr(got))) + + def __repr__(self): + def arg_to_text(a): + if a == DONT_CARE: + return '_' + return repr(a) + args_text = ', '.join([arg_to_text(a) for a in self.args]) + if self.return_value in (None, DONT_CARE): + return '%s(%s)' % (self.name, args_text) + return '%s(%s)->%s' % (self.name, args_text, repr(self.return_value)) + +class MockTrace(object): + def __init__(self): + self.expected_calls = [] + self.next_call_index = 0 + +class MockObject(object): + def __init__(self, parent_mock=None): + if parent_mock: + self._trace = parent_mock._trace # pylint: disable=protected-access + else: + self._trace = MockTrace() + + def __setattr__(self, name, value): + if (not hasattr(self, '_trace') or + hasattr(value, 'is_hook')): + object.__setattr__(self, name, value) + return + assert isinstance(value, MockObject) + object.__setattr__(self, name, value) + + def SetAttribute(self, name, value): + setattr(self, name, value) + + def ExpectCall(self, func_name, *args): + assert self._trace.next_call_index == 0 + if not hasattr(self, func_name): + self._install_hook(func_name) + + call = MockFunctionCall(func_name) + self._trace.expected_calls.append(call) + call.WithArgs(*args) + return call + + def _install_hook(self, func_name): + def handler(*args, **_): + got_call = MockFunctionCall( + func_name).WithArgs(*args).WillReturn(DONT_CARE) + if self._trace.next_call_index >= len(self._trace.expected_calls): + raise Exception( + 'Call to %s was not expected, at end of programmed trace.' % + repr(got_call)) + expected_call = self._trace.expected_calls[ + self._trace.next_call_index] + expected_call.VerifyEquals(got_call) + self._trace.next_call_index += 1 + for h in expected_call.when_called_handlers: + h(*args) + return expected_call.return_value + handler.is_hook = True + setattr(self, func_name, handler) |