aboutsummaryrefslogtreecommitdiff
path: root/catapult/telemetry/telemetry/testing/simple_mock.py
blob: dbd02b680f47aeb746cff99d0b1476d0eaf61880 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Copyright 2012 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""A very very simple mock object harness."""

DONT_CARE = ''

class MockFunctionCall(object):
  def __init__(self, name):
    self.name = name
    self.args = tuple()
    self.return_value = None
    self.when_called_handlers = []

  def WithArgs(self, *args):
    self.args = args
    return self

  def WillReturn(self, value):
    self.return_value = value
    return self

  def WhenCalled(self, handler):
    self.when_called_handlers.append(handler)

  def VerifyEquals(self, got):
    if self.name != got.name:
      raise Exception('Self %s, got %s' % (repr(self), repr(got)))
    if len(self.args) != len(got.args):
      raise Exception('Self %s, got %s' % (repr(self), repr(got)))
    for i in range(len(self.args)):
      self_a = self.args[i]
      got_a = got.args[i]
      if self_a == DONT_CARE:
        continue
      if self_a != got_a:
        raise Exception('Self %s, got %s' % (repr(self), repr(got)))

  def __repr__(self):
    def arg_to_text(a):
      if a == DONT_CARE:
        return '_'
      return repr(a)
    args_text = ', '.join([arg_to_text(a) for a in self.args])
    if self.return_value in (None, DONT_CARE):
      return '%s(%s)' % (self.name, args_text)
    return '%s(%s)->%s' % (self.name, args_text, repr(self.return_value))

class MockTrace(object):
  def __init__(self):
    self.expected_calls = []
    self.next_call_index = 0

class MockObject(object):
  def __init__(self, parent_mock=None):
    if parent_mock:
      self._trace = parent_mock._trace # pylint: disable=protected-access
    else:
      self._trace = MockTrace()

  def __setattr__(self, name, value):
    if (not hasattr(self, '_trace') or
        hasattr(value, 'is_hook')):
      object.__setattr__(self, name, value)
      return
    assert isinstance(value, MockObject)
    object.__setattr__(self, name, value)

  def SetAttribute(self, name, value):
    setattr(self, name, value)

  def ExpectCall(self, func_name, *args):
    assert self._trace.next_call_index == 0
    if not hasattr(self, func_name):
      self._install_hook(func_name)

    call = MockFunctionCall(func_name)
    self._trace.expected_calls.append(call)
    call.WithArgs(*args)
    return call

  def _install_hook(self, func_name):
    def handler(*args, **_):
      got_call = MockFunctionCall(
        func_name).WithArgs(*args).WillReturn(DONT_CARE)
      if self._trace.next_call_index >= len(self._trace.expected_calls):
        raise Exception(
          'Call to %s was not expected, at end of programmed trace.' %
          repr(got_call))
      expected_call = self._trace.expected_calls[
        self._trace.next_call_index]
      expected_call.VerifyEquals(got_call)
      self._trace.next_call_index += 1
      for h in expected_call.when_called_handlers:
        h(*args)
      return expected_call.return_value
    handler.is_hook = True
    setattr(self, func_name, handler)