aboutsummaryrefslogtreecommitdiff
path: root/catapult/devil/devil/utils/mock_calls.py
blob: 2b359385f689be8cb42635bcfde1120545629eca (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# Copyright 2014 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""
A test facility to assert call sequences while mocking their behavior.
"""

import unittest

from devil import devil_env

with devil_env.SysPath(devil_env.PYMOCK_PATH):
  import mock  # pylint: disable=import-error


class TestCase(unittest.TestCase):
  """Adds assertCalls to TestCase objects."""

  class _AssertCalls(object):
    def __init__(self, test_case, expected_calls, watched):
      def call_action(pair):
        if isinstance(pair, type(mock.call)):
          return (pair, None)
        else:
          return pair

      def do_check(call):
        def side_effect(*args, **kwargs):
          received_call = call(*args, **kwargs)
          self._test_case.assertTrue(
              self._expected_calls,
              msg=('Unexpected call: %s' % str(received_call)))
          expected_call, action = self._expected_calls.pop(0)
          self._test_case.assertTrue(
              received_call == expected_call,
              msg=('Expected call mismatch:\n'
                   '  expected: %s\n'
                   '  received: %s\n' % (str(expected_call),
                                         str(received_call))))
          if callable(action):
            return action(*args, **kwargs)
          else:
            return action

        return side_effect

      self._test_case = test_case
      self._expected_calls = [call_action(pair) for pair in expected_calls]
      watched = watched.copy()  # do not pollute the caller's dict
      watched.update(
          (call.parent.name, call.parent) for call, _ in self._expected_calls)
      self._patched = [
          test_case.patch_call(call, side_effect=do_check(call))
          for call in watched.itervalues()
      ]

    def __enter__(self):
      for patch in self._patched:
        patch.__enter__()
      return self

    def __exit__(self, exc_type, exc_val, exc_tb):
      for patch in self._patched:
        patch.__exit__(exc_type, exc_val, exc_tb)
      if exc_type is None:
        missing = ''.join(
            '  expected: %s\n' % str(call) for call, _ in self._expected_calls)
        self._test_case.assertFalse(
            missing, msg='Expected calls not found:\n' + missing)

  def __init__(self, *args, **kwargs):
    super(TestCase, self).__init__(*args, **kwargs)
    self.call = mock.call.self
    self._watched = {}

  def call_target(self, call):
    """Resolve a self.call instance to the target it represents.

    Args:
      call: a self.call instance, e.g. self.call.adb.Shell

    Returns:
      The target object represented by the call, e.g. self.adb.Shell

    Raises:
      ValueError if the path of the call does not start with "self", i.e. the
          target of the call is external to the self object.
      AttributeError if the path of the call does not specify a valid
          chain of attributes (without any calls) starting from "self".
    """
    path = call.name.split('.')
    if path.pop(0) != 'self':
      raise ValueError("Target %r outside of 'self' object" % call.name)
    target = self
    for attr in path:
      target = getattr(target, attr)
    return target

  def patch_call(self, call, **kwargs):
    """Patch the target of a mock.call instance.

    Args:
      call: a mock.call instance identifying a target to patch
      Extra keyword arguments are processed by mock.patch

    Returns:
      A context manager to mock/unmock the target of the call
    """
    if call.name.startswith('self.'):
      target = self.call_target(call.parent)
      _, attribute = call.name.rsplit('.', 1)
      if (hasattr(type(target), attribute)
          and isinstance(getattr(type(target), attribute), property)):
        return mock.patch.object(
            type(target), attribute, new_callable=mock.PropertyMock, **kwargs)
      else:
        return mock.patch.object(target, attribute, **kwargs)
    else:
      return mock.patch(call.name, **kwargs)

  def watchCalls(self, calls):
    """Add calls to the set of watched calls.

    Args:
      calls: a sequence of mock.call instances identifying targets to watch
    """
    self._watched.update((call.name, call) for call in calls)

  def watchMethodCalls(self, call, ignore=None):
    """Watch all public methods of the target identified by a self.call.

    Args:
      call: a self.call instance indetifying an object
      ignore: a list of public methods to ignore when watching for calls
    """
    target = self.call_target(call)
    if ignore is None:
      ignore = []
    self.watchCalls(
        getattr(call, method) for method in dir(target.__class__)
        if not method.startswith('_') and not method in ignore)

  def clearWatched(self):
    """Clear the set of watched calls."""
    self._watched = {}

  def assertCalls(self, *calls):
    """A context manager to assert that a sequence of calls is made.

    During the assertion, a number of functions and methods will be "watched",
    and any calls made to them is expected to appear---in the exact same order,
    and with the exact same arguments---as specified by the argument |calls|.

    By default, the targets of all expected calls are watched. Further targets
    to watch may be added using watchCalls and watchMethodCalls.

    Optionaly, each call may be accompanied by an action. If the action is a
    (non-callable) value, this value will be used as the return value given to
    the caller when the matching call is found. Alternatively, if the action is
    a callable, the action will be then called with the same arguments as the
    intercepted call, so that it can provide a return value or perform other
    side effects. If the action is missing, a return value of None is assumed.

    Note that mock.Mock objects are often convenient to use as a callable
    action, e.g. to raise exceptions or return other objects which are
    themselves callable.

    Args:
      calls: each argument is either a pair (expected_call, action) or just an
          expected_call, where expected_call is a mock.call instance.

    Raises:
      AssertionError if the watched targets do not receive the exact sequence
          of calls specified. Missing calls, extra calls, and calls with
          mismatching arguments, all cause the assertion to fail.
    """
    return self._AssertCalls(self, calls, self._watched)

  def assertCall(self, call, action=None):
    return self.assertCalls((call, action))