summaryrefslogtreecommitdiff
path: root/mock/mock.py
diff options
context:
space:
mode:
Diffstat (limited to 'mock/mock.py')
-rw-r--r--mock/mock.py1030
1 files changed, 676 insertions, 354 deletions
diff --git a/mock/mock.py b/mock/mock.py
index 2d39253..4766672 100644
--- a/mock/mock.py
+++ b/mock/mock.py
@@ -1,41 +1,10 @@
# mock.py
# Test tools for mocking and patching.
-# E-mail: fuzzyman AT voidspace DOT org DOT uk
-#
-# http://www.voidspace.org.uk/python/mock/
-#
-# Copyright (c) 2007-2013, Michael Foord & the mock team
-# All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-#
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following
-# disclaimer in the documentation and/or other materials provided
-# with the distribution.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-from __future__ import absolute_import
+# Maintained by Michael Foord
+# Backport for other versions of Python available from
+# https://pypi.org/project/mock
__all__ = (
- '__version__',
- 'version_info',
'Mock',
'MagicMock',
'patch',
@@ -44,8 +13,8 @@ __all__ = (
'ANY',
'call',
'create_autospec',
+ 'AsyncMock',
'FILTER_DIR',
- 'CallableMixin',
'NonCallableMock',
'NonCallableMagicMock',
'mock_open',
@@ -54,79 +23,44 @@ __all__ = (
)
-from functools import partial
+import asyncio
+import contextlib
import io
import inspect
import pprint
import sys
-try:
- import builtins
-except ImportError:
- import __builtin__ as builtins
-from types import ModuleType, MethodType
+import builtins
+from asyncio import iscoroutinefunction
+from types import CodeType, ModuleType, MethodType
from unittest.util import safe_repr
+from functools import wraps, partial
-import six
-from six import wraps
-
-__version__ = '3.0.5'
-version_info = tuple(int(p) for p in __version__.split('.'))
-
-import mock
-
-try:
- inspectsignature = inspect.signature
-except AttributeError:
- import funcsigs
- inspectsignature = funcsigs.signature
-
-
-# TODO: use six.
-try:
- unicode
-except NameError:
- # Python 3
- basestring = unicode = str
-
-try:
- long
-except NameError:
- # Python 3
- long = int
-
-if six.PY2:
- # Python 2's next() can't handle a non-iterator with a __next__ method.
- _next = next
- def next(obj, _next=_next):
- if getattr(obj, '__next__', None):
- return obj.__next__()
- return _next(obj)
-
- del _next
-
+from mock import IS_PYPY
+from .backports import iscoroutinefunction
_builtins = {name for name in dir(builtins) if not name.startswith('_')}
-try:
- _isidentifier = str.isidentifier
-except AttributeError:
- # Python 2.X
- import keyword
- import re
- regex = re.compile(r'^[a-z_][a-z0-9_]*$', re.I)
- def _isidentifier(string):
- if string in keyword.kwlist:
- return False
- return regex.match(string)
-
-
-# NOTE: This FILTER_DIR is not used. The binding in mock.FILTER_DIR is.
FILTER_DIR = True
-# Workaround for Python issue #12370
+# Workaround for issue #12370
# Without this, the __class__ properties wouldn't be set correctly
_safe_super = super
+def _is_async_obj(obj):
+ if _is_instance_mock(obj) and not isinstance(obj, AsyncMock):
+ return False
+ if hasattr(obj, '__func__'):
+ obj = getattr(obj, '__func__')
+ return iscoroutinefunction(obj) or inspect.isawaitable(obj)
+
+
+def _is_async_func(func):
+ if getattr(func, '__code__', None):
+ return iscoroutinefunction(func)
+ else:
+ return False
+
+
def _is_instance_mock(obj):
# can't use isinstance on Mock objects because they override __class__
# The base class for all mocks is NonCallableMock
@@ -136,20 +70,17 @@ def _is_instance_mock(obj):
def _is_exception(obj):
return (
isinstance(obj, BaseException) or
- isinstance(obj, ClassTypes) and issubclass(obj, BaseException)
+ isinstance(obj, type) and issubclass(obj, BaseException)
)
-class _slotted(object):
- __slots__ = ['a']
-
-
-# Do not use this tuple. It was never documented as a public API.
-# It will be removed. It has no obvious signs of users on github.
-DescriptorTypes = (
- type(_slotted.a),
- property,
-)
+def _extract_mock(obj):
+ # Autospecced functions will return a FunctionType with "mock" attribute
+ # which is the actual mock object that needs to be used.
+ if isinstance(obj, FunctionTypes) and hasattr(obj, 'mock'):
+ return obj.mock
+ else:
+ return obj
def _get_signature_object(func, as_instance, eat_self):
@@ -158,12 +89,9 @@ def _get_signature_object(func, as_instance, eat_self):
signature object.
Return a (reduced func, signature) tuple, or None.
"""
- if isinstance(func, ClassTypes) and not as_instance:
+ if isinstance(func, type) and not as_instance:
# If it's a type and should be modelled as a type, use __init__.
- try:
- func = func.__init__
- except AttributeError:
- return None
+ func = func.__init__
# Skip the `self` argument in __init__
eat_self = True
elif not isinstance(func, FunctionTypes):
@@ -177,9 +105,8 @@ def _get_signature_object(func, as_instance, eat_self):
sig_func = partial(func, None)
else:
sig_func = func
-
try:
- return func, inspectsignature(sig_func)
+ return func, inspect.signature(sig_func)
except ValueError:
# Certain callable types are not supported by inspect.signature()
return None
@@ -208,15 +135,10 @@ def _copy_func_details(func, funcopy):
setattr(funcopy, attribute, getattr(func, attribute))
except AttributeError:
pass
- if six.PY2:
- try:
- funcopy.func_defaults = func.func_defaults
- except AttributeError:
- pass
def _callable(obj):
- if isinstance(obj, ClassTypes):
+ if isinstance(obj, type):
return True
if isinstance(obj, (staticmethod, classmethod, MethodType)):
return _callable(obj.__func__)
@@ -234,25 +156,15 @@ def _is_list(obj):
def _instance_callable(obj):
"""Given an object, return True if the object is callable.
For classes, return True if instances would be callable."""
- if not isinstance(obj, ClassTypes):
+ if not isinstance(obj, type):
# already an instance
return getattr(obj, '__call__', None) is not None
- if six.PY3:
- # *could* be broken by a class overriding __mro__ or __dict__ via
- # a metaclass
- for base in (obj,) + obj.__mro__:
- if base.__dict__.get('__call__') is not None:
- return True
- else:
- klass = obj
- # uses __bases__ instead of __mro__ so that we work with old style classes
- if klass.__dict__.get('__call__') is not None:
+ # *could* be broken by a class overriding __mro__ or __dict__ via
+ # a metaclass
+ for base in (obj,) + obj.__mro__:
+ if base.__dict__.get('__call__') is not None:
return True
-
- for base in klass.__bases__:
- if _instance_callable(base):
- return True
return False
@@ -261,7 +173,7 @@ def _set_signature(mock, original, instance=False):
# mock. It still does signature checking by calling a lambda with the same
# signature as the original.
- skipfirst = isinstance(original, ClassTypes)
+ skipfirst = isinstance(original, type)
result = _get_signature_object(original, instance, skipfirst)
if result is None:
return mock
@@ -271,13 +183,13 @@ def _set_signature(mock, original, instance=False):
_copy_func_details(func, checksig)
name = original.__name__
- if not _isidentifier(name):
+ if not name.isidentifier():
name = 'funcopy'
context = {'_checksig_': checksig, 'mock': mock}
src = """def %s(*args, **kwargs):
_checksig_(*args, **kwargs)
return mock(*args, **kwargs)""" % name
- six.exec_(src, context)
+ exec (src, context)
funcopy = context[name]
_setup_func(funcopy, mock, sig)
return funcopy
@@ -286,14 +198,14 @@ def _set_signature(mock, original, instance=False):
def _setup_func(funcopy, mock, sig):
funcopy.mock = mock
+ def assert_called_with(*args, **kwargs):
+ return mock.assert_called_with(*args, **kwargs)
def assert_called(*args, **kwargs):
return mock.assert_called(*args, **kwargs)
def assert_not_called(*args, **kwargs):
return mock.assert_not_called(*args, **kwargs)
def assert_called_once(*args, **kwargs):
return mock.assert_called_once(*args, **kwargs)
- def assert_called_with(*args, **kwargs):
- return mock.assert_called_with(*args, **kwargs)
def assert_called_once_with(*args, **kwargs):
return mock.assert_called_once_with(*args, **kwargs)
def assert_has_calls(*args, **kwargs):
@@ -332,6 +244,33 @@ def _setup_func(funcopy, mock, sig):
mock._mock_delegate = funcopy
+def _setup_async_mock(mock):
+ mock._is_coroutine = asyncio.coroutines._is_coroutine
+ mock.await_count = 0
+ mock.await_args = None
+ mock.await_args_list = _CallList()
+
+ # Mock is not configured yet so the attributes are set
+ # to a function and then the corresponding mock helper function
+ # is called when the helper is accessed similar to _setup_func.
+ def wrapper(attr, *args, **kwargs):
+ return getattr(mock.mock, attr)(*args, **kwargs)
+
+ for attribute in ('assert_awaited',
+ 'assert_awaited_once',
+ 'assert_awaited_with',
+ 'assert_awaited_once_with',
+ 'assert_any_await',
+ 'assert_has_awaits',
+ 'assert_not_awaited'):
+
+ # setattr(mock, attribute, wrapper) causes late binding
+ # hence attribute will always be the last value in the loop
+ # Use partial(wrapper, attribute) to ensure the attribute is bound
+ # correctly.
+ setattr(mock, attribute, partial(wrapper, attribute))
+
+
def _is_magic(name):
return '__%s__' % name[2:-2] == name
@@ -345,11 +284,7 @@ class _SentinelObject(object):
return 'sentinel.%s' % self.name
def __reduce__(self):
- return _unpickle_sentinel, (self.name, )
-
-
-def _unpickle_sentinel(name):
- return getattr(sentinel, name)
+ return 'sentinel.%s' % self.name
class _Sentinel(object):
@@ -363,6 +298,9 @@ class _Sentinel(object):
raise AttributeError
return self._sentinels.setdefault(name, _SentinelObject(name))
+ def __reduce__(self):
+ return 'sentinel'
+
sentinel = _Sentinel()
@@ -371,15 +309,6 @@ _missing = sentinel.MISSING
_deleted = sentinel.DELETED
-class OldStyleClass:
- pass
-ClassType = type(OldStyleClass)
-
-
-ClassTypes = (type,)
-if six.PY2:
- ClassTypes = (type, ClassType)
-
_allowed_names = {
'return_value', '_mock_return_value', 'side_effect',
'_mock_side_effect', '_mock_parent', '_mock_new_parent',
@@ -427,13 +356,7 @@ class _CallList(list):
def _check_and_set_parent(parent, value, name, new_name):
- # function passed to create_autospec will have mock
- # attribute attached to which parent must be set
- if isinstance(value, FunctionTypes):
- try:
- value = value.mock
- except AttributeError:
- pass
+ value = _extract_mock(value)
if not _is_instance_mock(value):
return False
@@ -480,8 +403,15 @@ class NonCallableMock(Base):
# every instance has its own class
# so we can create magic methods on the
# class without stomping on other mocks
- new = type(cls.__name__, (cls,), {'__doc__': cls.__doc__})
- instance = object.__new__(new)
+ bases = (cls,)
+ if not issubclass(cls, AsyncMockMixin):
+ # Check if spec is an async object or function
+ bound_args = _MOCK_SIG.bind_partial(cls, *args, **kw).arguments
+ spec_arg = bound_args.get('spec_set', bound_args.get('spec'))
+ if spec_arg and _is_async_obj(spec_arg):
+ bases = (AsyncMockMixin, cls)
+ new = type(cls.__name__, bases, {'__doc__': cls.__doc__})
+ instance = _safe_super(NonCallableMock, cls).__new__(new)
return instance
@@ -535,10 +465,12 @@ class NonCallableMock(Base):
Attach a mock as an attribute of this one, replacing its name and
parent. Calls to the attached mock will be recorded in the
`method_calls` and `mock_calls` attributes of this one."""
- mock._mock_parent = None
- mock._mock_new_parent = None
- mock._mock_name = ''
- mock._mock_new_name = None
+ inner_mock = _extract_mock(mock)
+
+ inner_mock._mock_parent = None
+ inner_mock._mock_new_parent = None
+ inner_mock._mock_name = ''
+ inner_mock._mock_new_name = None
setattr(self, attribute, mock)
@@ -556,9 +488,14 @@ class NonCallableMock(Base):
_eat_self=False):
_spec_class = None
_spec_signature = None
+ _spec_asyncs = []
+
+ for attr in dir(spec):
+ if iscoroutinefunction(getattr(spec, attr, None)):
+ _spec_asyncs.append(attr)
if spec is not None and not _is_list(spec):
- if isinstance(spec, ClassTypes):
+ if isinstance(spec, type):
_spec_class = spec
else:
_spec_class = type(spec)
@@ -573,7 +510,7 @@ class NonCallableMock(Base):
__dict__['_spec_set'] = spec_set
__dict__['_spec_signature'] = _spec_signature
__dict__['_mock_methods'] = spec
-
+ __dict__['_spec_asyncs'] = _spec_asyncs
def __get_return_value(self):
ret = self._mock_return_value
@@ -635,7 +572,7 @@ class NonCallableMock(Base):
side_effect = property(__get_side_effect, __set_side_effect)
- def reset_mock(self, visited=None, return_value=False, side_effect=False):
+ def reset_mock(self, visited=None,*, return_value=False, side_effect=False):
"Restore the mock object to its initial state."
if visited is None:
visited = []
@@ -658,7 +595,7 @@ class NonCallableMock(Base):
for child in self._mock_children.values():
if isinstance(child, _SpecState) or child is _deleted:
continue
- child.reset_mock(visited)
+ child.reset_mock(visited, return_value=return_value, side_effect=side_effect)
ret = self._mock_return_value
if _is_instance_mock(ret) and ret is not self:
@@ -688,7 +625,7 @@ class NonCallableMock(Base):
def __getattr__(self, name):
- if name in ('_mock_methods', '_mock_unsafe'):
+ if name in {'_mock_methods', '_mock_unsafe'}:
raise AttributeError(name)
elif self._mock_methods is not None:
if name not in self._mock_methods or name in _all_magics:
@@ -697,7 +634,8 @@ class NonCallableMock(Base):
raise AttributeError(name)
if not self._mock_unsafe:
if name.startswith(('assert', 'assret')):
- raise AttributeError(name)
+ raise AttributeError("Attributes cannot start with 'assert' "
+ "or 'assret'")
result = self._mock_children.get(name)
if result is _deleted:
@@ -765,7 +703,7 @@ class NonCallableMock(Base):
if self._spec_set:
spec_string = ' spec_set=%r'
spec_string = spec_string % self._spec_class.__name__
- return "<{}{}{} id='{}'>".format(
+ return "<%s%s%s id='%s'>" % (
type(self).__name__,
name_string,
spec_string,
@@ -775,8 +713,7 @@ class NonCallableMock(Base):
def __dir__(self):
"""Filter the output of `dir(mock)` to only useful members."""
- if not mock.FILTER_DIR and getattr(object, '__dir__', None):
- # object.__dir__ is not in 2.7
+ if not FILTER_DIR:
return object.__dir__(self)
extras = self._mock_methods or []
@@ -786,12 +723,9 @@ class NonCallableMock(Base):
m_name for m_name, m_value in self._mock_children.items()
if m_value is not _deleted]
- if mock.FILTER_DIR:
- # object.__dir__ is not in 2.7
- from_type = [e for e in from_type if not e.startswith('_')]
- from_dict = [e for e in from_dict if not e.startswith('_') or
- _is_magic(e)]
-
+ from_type = [e for e in from_type if not e.startswith('_')]
+ from_dict = [e for e in from_dict if not e.startswith('_') or
+ _is_magic(e)]
return sorted(set(extras + from_type + from_dict + from_child_mocks))
@@ -828,8 +762,8 @@ class NonCallableMock(Base):
self._mock_children[name] = value
if self._mock_sealed and not hasattr(self, name):
- mock_name = self._extract_mock_name()+'.'+name
- raise AttributeError('Cannot set '+mock_name)
+ mock_name = f'{self._extract_mock_name()}.{name}'
+ raise AttributeError(f'Cannot set {mock_name}')
return object.__setattr__(self, name, value)
@@ -857,12 +791,45 @@ class NonCallableMock(Base):
return _format_call_signature(name, args, kwargs)
- def _format_mock_failure_message(self, args, kwargs):
- message = 'expected call not found.\nExpected: %s\nActual: %s'
+ def _format_mock_failure_message(self, args, kwargs, action='call'):
+ message = 'expected %s not found.\nExpected: %s\nActual: %s'
expected_string = self._format_mock_call_signature(args, kwargs)
call_args = self.call_args
actual_string = self._format_mock_call_signature(*call_args)
- return message % (expected_string, actual_string)
+ return message % (action, expected_string, actual_string)
+
+
+ def _get_call_signature_from_name(self, name):
+ """
+ * If call objects are asserted against a method/function like obj.meth1
+ then there could be no name for the call object to lookup. Hence just
+ return the spec_signature of the method/function being asserted against.
+ * If the name is not empty then remove () and split by '.' to get
+ list of names to iterate through the children until a potential
+ match is found. A child mock is created only during attribute access
+ so if we get a _SpecState then no attributes of the spec were accessed
+ and can be safely exited.
+ """
+ if not name:
+ return self._spec_signature
+
+ sig = None
+ names = name.replace('()', '').split('.')
+ children = self._mock_children
+
+ for name in names:
+ child = children.get(name)
+ if child is None or isinstance(child, _SpecState):
+ break
+ else:
+ # If an autospecced object is attached using attach_mock the
+ # child would be a function with mock object as attribute from
+ # which signature has to be derived.
+ child = _extract_mock(child)
+ children = child._mock_children
+ sig = child._spec_signature
+
+ return sig
def _call_matcher(self, _call):
@@ -872,7 +839,12 @@ class NonCallableMock(Base):
This is a best effort method which relies on the spec's signature,
if available, or falls back on the arguments themselves.
"""
- sig = self._spec_signature
+
+ if isinstance(_call, tuple) and len(_call) > 2:
+ sig = self._get_call_signature_from_name(_call[0])
+ else:
+ sig = self._spec_signature
+
if sig is not None:
if len(_call) == 2:
name = ''
@@ -880,10 +852,10 @@ class NonCallableMock(Base):
else:
name, args, kwargs = _call
try:
- return name, sig.bind(*args, **kwargs)
+ bound_call = sig.bind(*args, **kwargs)
+ return call(name, bound_call.args, bound_call.kwargs)
except TypeError as e:
- e.__traceback__ = None
- return e
+ return e.with_traceback(None)
else:
return _call
@@ -904,7 +876,7 @@ class NonCallableMock(Base):
self = _mock_self
if self.call_count == 0:
msg = ("Expected '%s' to have been called." %
- self._mock_name or 'mock')
+ (self._mock_name or 'mock'))
raise AssertionError(msg)
def assert_called_once(_mock_self):
@@ -919,7 +891,7 @@ class NonCallableMock(Base):
raise AssertionError(msg)
def assert_called_with(_mock_self, *args, **kwargs):
- """assert that the mock was called with the specified arguments.
+ """assert that the last call was made with the specified arguments.
Raises an AssertionError if the args and keyword args passed in are
different to the last call to the mock."""
@@ -928,20 +900,17 @@ class NonCallableMock(Base):
expected = self._format_mock_call_signature(args, kwargs)
actual = 'not called.'
error_message = ('expected call not found.\nExpected: %s\nActual: %s'
- % (expected, actual))
+ % (expected, actual))
raise AssertionError(error_message)
- def _error_message(cause):
+ def _error_message():
msg = self._format_mock_failure_message(args, kwargs)
- if six.PY2 and cause is not None:
- # Tack on some diagnostics for Python without __cause__
- msg = '{}\n{}'.format(msg, str(cause))
return msg
- expected = self._call_matcher((args, kwargs))
+ expected = self._call_matcher(_Call((args, kwargs), two=True))
actual = self._call_matcher(self.call_args)
- if expected != actual:
+ if actual != expected:
cause = expected if isinstance(expected, Exception) else None
- six.raise_from(AssertionError(_error_message(cause)), cause)
+ raise AssertionError(_error_message()) from cause
def assert_called_once_with(_mock_self, *args, **kwargs):
@@ -968,14 +937,22 @@ class NonCallableMock(Base):
If `any_order` is True then the calls can be in any order, but
they must all appear in `mock_calls`."""
expected = [self._call_matcher(c) for c in calls]
- cause = expected if isinstance(expected, Exception) else None
+ cause = next((e for e in expected if isinstance(e, Exception)), None)
all_calls = _CallList(self._call_matcher(c) for c in self.mock_calls)
if not any_order:
if expected not in all_calls:
- six.raise_from(AssertionError(
- 'Calls not found.\nExpected: %r%s'
- % (_CallList(calls), self._calls_repr(prefix="Actual"))
- ), cause)
+ if cause is None:
+ problem = 'Calls not found.'
+ else:
+ problem = ('Error processing expected calls.\n'
+ 'Errors: {}').format(
+ [e if isinstance(e, Exception) else None
+ for e in expected])
+ raise AssertionError(
+ f'{problem}\n'
+ f'Expected: {_CallList(calls)}'
+ f'{self._calls_repr(prefix="Actual").rstrip(".")}'
+ ) from cause
return
all_calls = list(all_calls)
@@ -987,11 +964,11 @@ class NonCallableMock(Base):
except ValueError:
not_found.append(kall)
if not_found:
- six.raise_from(AssertionError(
+ raise AssertionError(
'%r does not contain all of %r in its call list, '
'found %r instead' % (self._mock_name or 'mock',
tuple(not_found), all_calls)
- ), cause)
+ ) from cause
def assert_any_call(self, *args, **kwargs):
@@ -1000,14 +977,14 @@ class NonCallableMock(Base):
The assert passes if the mock has *ever* been called, unlike
`assert_called_with` and `assert_called_once_with` that only pass if
the call is the most recent one."""
- expected = self._call_matcher((args, kwargs))
+ expected = self._call_matcher(_Call((args, kwargs), two=True))
+ cause = expected if isinstance(expected, Exception) else None
actual = [self._call_matcher(c) for c in self.call_args_list]
- if expected not in actual:
- cause = expected if isinstance(expected, Exception) else None
+ if cause or expected not in _AnyComparer(actual):
expected_string = self._format_mock_call_signature(args, kwargs)
- six.raise_from(AssertionError(
+ raise AssertionError(
'%s call not found' % expected_string
- ), cause)
+ ) from cause
def _get_child_mock(self, **kw):
@@ -1018,11 +995,25 @@ class NonCallableMock(Base):
For non-callable mocks the callable variant will be used (rather than
any custom subclass)."""
+ _new_name = kw.get("_new_name")
+ if _new_name in self.__dict__['_spec_asyncs']:
+ return AsyncMock(**kw)
+
_type = type(self)
- if not issubclass(_type, CallableMixin):
+ if issubclass(_type, MagicMock) and _new_name in _async_method_magics:
+ # Any asynchronous magic becomes an AsyncMock
+ klass = AsyncMock
+ elif issubclass(_type, AsyncMockMixin):
+ if (_new_name in _all_sync_magics or
+ self._mock_methods and _new_name in self._mock_methods):
+ # Any synchronous method on AsyncMock becomes a MagicMock
+ klass = MagicMock
+ else:
+ klass = AsyncMock
+ elif not issubclass(_type, CallableMixin):
if issubclass(_type, NonCallableMagicMock):
klass = MagicMock
- elif issubclass(_type, NonCallableMock) :
+ elif issubclass(_type, NonCallableMock):
klass = Mock
else:
klass = _type.__mro__[1]
@@ -1045,9 +1036,27 @@ class NonCallableMock(Base):
"""
if not self.mock_calls:
return ""
- return "\n"+prefix+": "+safe_repr(self.mock_calls)+"."
+ return f"\n{prefix}: {safe_repr(self.mock_calls)}."
+_MOCK_SIG = inspect.signature(NonCallableMock.__init__)
+
+
+class _AnyComparer(list):
+ """A list which checks if it contains a call which may have an
+ argument of ANY, flipping the components of item and self from
+ their traditional locations so that ANY is guaranteed to be on
+ the left."""
+ def __contains__(self, item):
+ for _call in self:
+ assert len(item) == len(_call)
+ if all([
+ expected == actual
+ for expected, actual in zip(item, _call)
+ ]):
+ return True
+ return False
+
def _try_iter(obj):
if obj is None:
@@ -1064,14 +1073,12 @@ def _try_iter(obj):
return obj
-
class CallableMixin(Base):
def __init__(self, spec=None, side_effect=None, return_value=DEFAULT,
wraps=None, name=None, spec_set=None, parent=None,
_spec_state=None, _new_name='', _new_parent=None, **kwargs):
self.__dict__['_mock_return_value'] = return_value
-
_safe_super(CallableMixin, self).__init__(
spec, wraps, name, spec_set, parent,
_spec_state, _new_name, _new_parent, **kwargs
@@ -1089,15 +1096,21 @@ class CallableMixin(Base):
# can't use self in-case a function / method we are mocking uses self
# in the signature
_mock_self._mock_check_sig(*args, **kwargs)
+ _mock_self._increment_mock_call(*args, **kwargs)
return _mock_self._mock_call(*args, **kwargs)
def _mock_call(_mock_self, *args, **kwargs):
+ return _mock_self._execute_mock_call(*args, **kwargs)
+
+ def _increment_mock_call(_mock_self, *args, **kwargs):
self = _mock_self
self.called = True
self.call_count += 1
# handle call_args
+ # needs to be set here so assertions on call arguments pass before
+ # execution in the case of awaited calls
_call = _Call((args, kwargs), two=True)
self.call_args = _call
self.call_args_list.append(_call)
@@ -1137,6 +1150,11 @@ class CallableMixin(Base):
# follow the parental chain:
_new_parent = _new_parent._mock_new_parent
+ def _execute_mock_call(_mock_self, *args, **kwargs):
+ self = _mock_self
+ # separate from _increment_mock_call so that awaited functions are
+ # executed separately from their call, also AsyncMock overrides this method
+
effect = self.side_effect
if effect is not None:
if _is_exception(effect):
@@ -1186,9 +1204,6 @@ class Mock(CallableMixin, NonCallableMock):
arguments as the mock, and unless it returns `DEFAULT`, the return
value of this function is used as the return value.
- Alternatively `side_effect` can be an exception class or instance. In
- this case the exception will be raised when the mock is called.
-
If `side_effect` is an iterable then each call to the mock will return
the next value from the iterable. If any of the members of the iterable
are exceptions they will be raised instead of returned.
@@ -1216,7 +1231,6 @@ class Mock(CallableMixin, NonCallableMock):
"""
-
def _dot_lookup(thing, comp, import_path):
try:
return getattr(thing, comp)
@@ -1287,8 +1301,10 @@ class _patch(object):
def __call__(self, func):
- if isinstance(func, ClassTypes):
+ if isinstance(func, type):
return self.decorate_class(func)
+ if inspect.iscoroutinefunction(func):
+ return self.decorate_async_callable(func)
return self.decorate_callable(func)
@@ -1306,41 +1322,68 @@ class _patch(object):
return klass
+ @contextlib.contextmanager
+ def decoration_helper(self, patched, args, keywargs):
+ extra_args = []
+ entered_patchers = []
+ patching = None
+
+ exc_info = tuple()
+ try:
+ for patching in patched.patchings:
+ arg = patching.__enter__()
+ entered_patchers.append(patching)
+ if patching.attribute_name is not None:
+ keywargs.update(arg)
+ elif patching.new is DEFAULT:
+ extra_args.append(arg)
+
+ args += tuple(extra_args)
+ yield (args, keywargs)
+ except:
+ if (patching not in entered_patchers and
+ _is_started(patching)):
+ # the patcher may have been started, but an exception
+ # raised whilst entering one of its additional_patchers
+ entered_patchers.append(patching)
+ # Pass the exception to __exit__
+ exc_info = sys.exc_info()
+ # re-raise the exception
+ raise
+ finally:
+ for patching in reversed(entered_patchers):
+ patching.__exit__(*exc_info)
+
+
def decorate_callable(self, func):
+ # NB. Keep the method in sync with decorate_async_callable()
if hasattr(func, 'patchings'):
func.patchings.append(self)
return func
@wraps(func)
def patched(*args, **keywargs):
- extra_args = []
- entered_patchers = []
+ with self.decoration_helper(patched,
+ args,
+ keywargs) as (newargs, newkeywargs):
+ return func(*newargs, **newkeywargs)
- exc_info = tuple()
- try:
- for patching in patched.patchings:
- arg = patching.__enter__()
- entered_patchers.append(patching)
- if patching.attribute_name is not None:
- keywargs.update(arg)
- elif patching.new is DEFAULT:
- extra_args.append(arg)
-
- args += tuple(extra_args)
- return func(*args, **keywargs)
- except:
- if (patching not in entered_patchers and
- _is_started(patching)):
- # the patcher may have been started, but an exception
- # raised whilst entering one of its additional_patchers
- entered_patchers.append(patching)
- # Pass the exception to __exit__
- exc_info = sys.exc_info()
- # re-raise the exception
- raise
- finally:
- for patching in reversed(entered_patchers):
- patching.__exit__(*exc_info)
+ patched.patchings = [self]
+ return patched
+
+
+ def decorate_async_callable(self, func):
+ # NB. Keep the method in sync with decorate_callable()
+ if hasattr(func, 'patchings'):
+ func.patchings.append(self)
+ return func
+
+ @wraps(func)
+ async def patched(*args, **keywargs):
+ with self.decoration_helper(patched,
+ args,
+ keywargs) as (newargs, newkeywargs):
+ return await func(*newargs, **newkeywargs)
patched.patchings = [self]
return patched
@@ -1365,7 +1408,7 @@ class _patch(object):
if not self.create and original is DEFAULT:
raise AttributeError(
- "{} does not have the attribute {!r}".format(target, name)
+ "%s does not have the attribute %r" % (target, name)
)
return original, local
@@ -1411,11 +1454,13 @@ class _patch(object):
if spec is not None or spec_set is not None:
if original is DEFAULT:
raise TypeError("Can't use 'spec' with create=True")
- if isinstance(original, ClassTypes):
+ if isinstance(original, type):
# If we're patching out a class and there is a spec
inherit = True
-
- Klass = MagicMock
+ if spec is None and _is_async_obj(original):
+ Klass = AsyncMock
+ else:
+ Klass = MagicMock
_kwargs = {}
if new_callable is not None:
Klass = new_callable
@@ -1426,8 +1471,10 @@ class _patch(object):
if _is_list(this_spec):
not_callable = '__call__' not in this_spec
else:
- not_callable = not _callable(this_spec)
- if not_callable:
+ not_callable = not callable(this_spec)
+ if _is_async_obj(this_spec):
+ Klass = AsyncMock
+ elif not_callable:
Klass = NonCallableMagicMock
if spec is not None:
@@ -1567,6 +1614,10 @@ def _patch_object(
When used as a class decorator `patch.object` honours `patch.TEST_PREFIX`
for choosing which methods to wrap.
"""
+ if type(target) is str:
+ raise TypeError(
+ f"{target!r} must be the actual object to be patched, not a str"
+ )
getter = lambda: target
return _patch(
getter, attribute, new, spec, create,
@@ -1596,7 +1647,7 @@ def _patch_multiple(target, spec=None, create=False, spec_set=None,
When used as a class decorator `patch.multiple` honours `patch.TEST_PREFIX`
for choosing which methods to wrap.
"""
- if type(target) in (unicode, str):
+ if type(target) is str:
getter = lambda: _importer(target)
else:
getter = lambda: target
@@ -1633,8 +1684,9 @@ def patch(
is patched with a `new` object. When the function/with statement exits
the patch is undone.
- If `new` is omitted, then the target is replaced with a
- `MagicMock`. If `patch` is used as a decorator and `new` is
+ If `new` is omitted, then the target is replaced with an
+ `AsyncMock if the patched object is an async function or a
+ `MagicMock` otherwise. If `patch` is used as a decorator and `new` is
omitted, the created mock is passed in as an extra argument to the
decorated function. If `patch` is used as a context manager the created
mock is returned by the context manager.
@@ -1652,8 +1704,8 @@ def patch(
patch to pass in the object being mocked as the spec/spec_set object.
`new_callable` allows you to specify a different class, or callable object,
- that will be called to create the `new` object. By default `MagicMock` is
- used.
+ that will be called to create the `new` object. By default `AsyncMock` is
+ used for async functions and `MagicMock` for the rest.
A more powerful form of `spec` is `autospec`. If you set `autospec=True`
then the mock will be created with a spec from the object being replaced.
@@ -1687,7 +1739,8 @@ def patch(
"as"; very useful if `patch` is creating a mock object for you.
`patch` takes arbitrary keyword arguments. These will be passed to
- the `Mock` (or `new_callable`) on construction.
+ `AsyncMock` if the patched object is asynchronous, to `MagicMock`
+ otherwise or to `new_callable` if specified.
`patch.dict(...)`, `patch.multiple(...)` and `patch.object(...)` are
available for alternate use-cases.
@@ -1738,7 +1791,7 @@ class _patch_dict(object):
def __call__(self, f):
- if isinstance(f, ClassTypes):
+ if isinstance(f, type):
return self.decorate_class(f)
@wraps(f)
def _inner(*args, **kw):
@@ -1765,11 +1818,12 @@ class _patch_dict(object):
def __enter__(self):
"""Patch the dict."""
self._patch_dict()
+ return self.in_dict
def _patch_dict(self):
values = self.values
- if isinstance(self.in_dict, basestring):
+ if isinstance(self.in_dict, str):
self.in_dict = _importer(self.in_dict)
in_dict = self.in_dict
clear = self.clear
@@ -1810,11 +1864,27 @@ class _patch_dict(object):
def __exit__(self, *args):
"""Unpatch the dict."""
- self._unpatch_dict()
+ if self._original is not None:
+ self._unpatch_dict()
return False
- start = __enter__
- stop = __exit__
+
+ def start(self):
+ """Activate a patch, returning any created mock."""
+ result = self.__enter__()
+ _patch._active_patches.append(self)
+ return result
+
+
+ def stop(self):
+ """Stop an active patch."""
+ try:
+ _patch._active_patches.remove(self)
+ except ValueError:
+ # If the patch hasn't been started this will fail
+ pass
+
+ return self.__exit__()
def _clear_dict(in_dict):
@@ -1849,29 +1919,26 @@ magic_methods = (
"divmod rdivmod neg pos abs invert "
"complex int float index "
"round trunc floor ceil "
+ "bool next "
+ "fspath "
+ "aiter "
)
+if IS_PYPY:
+ # PyPy has no __sizeof__: http://doc.pypy.org/en/latest/cpython_differences.html
+ magic_methods = magic_methods.replace('sizeof ', '')
+
numerics = (
- "add sub mul matmul div floordiv mod lshift rshift and xor or pow"
+ "add sub mul matmul div floordiv mod lshift rshift and xor or pow truediv"
)
-if six.PY3:
- numerics += ' truediv'
inplace = ' '.join('i%s' % n for n in numerics.split())
right = ' '.join('r%s' % n for n in numerics.split())
-extra = ''
-if six.PY3:
- extra = 'bool next '
- if sys.version_info >= (3, 6):
- extra += 'fspath '
-else:
- extra = 'unicode long nonzero oct hex truediv rtruediv '
# not including __prepare__, __instancecheck__, __subclasscheck__
# (as they are metaclass methods)
# __del__ is not supported at all as it causes problems if it exists
_non_defaults = {
- '__cmp__', '__getslice__', '__setslice__', '__coerce__', # <3.x
'__get__', '__set__', '__delete__', '__reversed__', '__missing__',
'__reduce__', '__reduce_ex__', '__getinitargs__', '__getnewargs__',
'__getstate__', '__setstate__', '__getformat__', '__setformat__',
@@ -1890,10 +1957,17 @@ def _get_method(name, func):
_magics = {
'__%s__' % method for method in
- ' '.join([magic_methods, numerics, inplace, right, extra]).split()
+ ' '.join([magic_methods, numerics, inplace, right]).split()
}
-_all_magics = _magics | _non_defaults
+# Magic methods used for async `with` statements
+_async_method_magics = {"__aenter__", "__aexit__", "__anext__"}
+# Magic methods that are only used with async calls but are synchronous functions themselves
+_sync_async_magics = {"__aiter__"}
+_async_magics = _async_method_magics | _sync_async_magics
+
+_all_sync_magics = _magics | _non_defaults
+_all_magics = _all_sync_magics | _async_magics
_unsupported_magics = {
'__getattr__', '__setattr__',
@@ -1906,8 +1980,7 @@ _calculate_return_value = {
'__hash__': lambda self: object.__hash__(self),
'__str__': lambda self: object.__str__(self),
'__sizeof__': lambda self: object.__sizeof__(self),
- '__unicode__': lambda self: unicode(object.__str__(self)),
- '__fspath__': lambda self: type(self).__name__+'/'+self._extract_mock_name()+'/'+str(id(self)),
+ '__fspath__': lambda self: f"{type(self).__name__}/{self._extract_mock_name()}/{id(self)}",
}
_return_values = {
@@ -1922,11 +1995,8 @@ _return_values = {
'__complex__': 1j,
'__float__': 1.0,
'__bool__': True,
- '__nonzero__': True,
- '__oct__': '1',
- '__hex__': '0x1',
- '__long__': long(1),
'__index__': 1,
+ '__aexit__': False,
}
@@ -1959,29 +2029,38 @@ def _get_iter(self):
return iter(ret_val)
return __iter__
+def _get_async_iter(self):
+ def __aiter__():
+ ret_val = self.__aiter__._mock_return_value
+ if ret_val is DEFAULT:
+ return _AsyncIterator(iter([]))
+ return _AsyncIterator(iter(ret_val))
+ return __aiter__
+
_side_effect_methods = {
'__eq__': _get_eq,
'__ne__': _get_ne,
'__iter__': _get_iter,
+ '__aiter__': _get_async_iter
}
def _set_return_value(mock, method, name):
+ # If _mock_wraps is present then attach it so that wrapped object
+ # is used for return value is used when called.
+ if mock._mock_wraps is not None:
+ method._mock_wraps = getattr(mock._mock_wraps, name)
+ return
+
fixed = _return_values.get(name, DEFAULT)
if fixed is not DEFAULT:
method.return_value = fixed
return
- return_calulator = _calculate_return_value.get(name)
- if return_calulator is not None:
- try:
- return_value = return_calulator(mock)
- except AttributeError:
- # XXXX why do we return AttributeError here?
- # set it as a side_effect instead?
- # Answer: it makes magic mocks work on pypy?!
- return_value = AttributeError(name)
+ return_calculator = _calculate_return_value.get(name)
+ if return_calculator is not None:
+ return_value = return_calculator(mock)
method.return_value = return_value
return
@@ -1991,7 +2070,7 @@ def _set_return_value(mock, method, name):
-class MagicMixin(object):
+class MagicMixin(Base):
def __init__(self, *args, **kw):
self._mock_set_magics() # make magic work for kwargs in init
_safe_super(MagicMixin, self).__init__(*args, **kw)
@@ -1999,13 +2078,14 @@ class MagicMixin(object):
def _mock_set_magics(self):
- these_magics = _magics
+ orig_magics = _magics | _async_method_magics
+ these_magics = orig_magics
if getattr(self, "_mock_methods", None) is not None:
- these_magics = _magics.intersection(self._mock_methods)
+ these_magics = orig_magics.intersection(self._mock_methods)
remove_magics = set()
- remove_magics = _magics - these_magics
+ remove_magics = orig_magics - these_magics
for entry in remove_magics:
if entry in type(self).__dict__:
@@ -2033,6 +2113,12 @@ class NonCallableMagicMock(MagicMixin, NonCallableMock):
self._mock_set_magics()
+class AsyncMagicMixin(MagicMixin):
+ def __init__(self, *args, **kw):
+ self._mock_set_magics() # make magic work for kwargs in init
+ _safe_super(AsyncMagicMixin, self).__init__(*args, **kw)
+ self._mock_set_magics() # fix magic broken by upper level init
+
class MagicMock(MagicMixin, Mock):
"""
@@ -2056,7 +2142,7 @@ class MagicMock(MagicMixin, Mock):
-class MagicProxy(object):
+class MagicProxy(Base):
def __init__(self, name, parent):
self.name = name
self.parent = parent
@@ -2074,6 +2160,239 @@ class MagicProxy(object):
return self.create_mock()
+class AsyncMockMixin(Base):
+ await_count = _delegating_property('await_count')
+ await_args = _delegating_property('await_args')
+ await_args_list = _delegating_property('await_args_list')
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # iscoroutinefunction() checks _is_coroutine property to say if an
+ # object is a coroutine. Without this check it looks to see if it is a
+ # function/method, which in this case it is not (since it is an
+ # AsyncMock).
+ # It is set through __dict__ because when spec_set is True, this
+ # attribute is likely undefined.
+ self.__dict__['_is_coroutine'] = asyncio.coroutines._is_coroutine
+ self.__dict__['_mock_await_count'] = 0
+ self.__dict__['_mock_await_args'] = None
+ self.__dict__['_mock_await_args_list'] = _CallList()
+ code_mock = NonCallableMock(spec_set=CodeType)
+ code_mock.co_flags = inspect.CO_COROUTINE
+ self.__dict__['__code__'] = code_mock
+
+ async def _execute_mock_call(_mock_self, *args, **kwargs):
+ self = _mock_self
+ # This is nearly just like super(), except for special handling
+ # of coroutines
+
+ _call = _Call((args, kwargs), two=True)
+ self.await_count += 1
+ self.await_args = _call
+ self.await_args_list.append(_call)
+
+ effect = self.side_effect
+ if effect is not None:
+ if _is_exception(effect):
+ raise effect
+ elif not _callable(effect):
+ try:
+ result = next(effect)
+ except StopIteration:
+ # It is impossible to propogate a StopIteration
+ # through coroutines because of PEP 479
+ raise StopAsyncIteration
+ if _is_exception(result):
+ raise result
+ elif iscoroutinefunction(effect):
+ result = await effect(*args, **kwargs)
+ else:
+ result = effect(*args, **kwargs)
+
+ if result is not DEFAULT:
+ return result
+
+ if self._mock_return_value is not DEFAULT:
+ return self.return_value
+
+ if self._mock_wraps is not None:
+ if iscoroutinefunction(self._mock_wraps):
+ return await self._mock_wraps(*args, **kwargs)
+ return self._mock_wraps(*args, **kwargs)
+
+ return self.return_value
+
+ def assert_awaited(_mock_self):
+ """
+ Assert that the mock was awaited at least once.
+ """
+ self = _mock_self
+ if self.await_count == 0:
+ msg = f"Expected {self._mock_name or 'mock'} to have been awaited."
+ raise AssertionError(msg)
+
+ def assert_awaited_once(_mock_self):
+ """
+ Assert that the mock was awaited exactly once.
+ """
+ self = _mock_self
+ if not self.await_count == 1:
+ msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once."
+ f" Awaited {self.await_count} times.")
+ raise AssertionError(msg)
+
+ def assert_awaited_with(_mock_self, *args, **kwargs):
+ """
+ Assert that the last await was with the specified arguments.
+ """
+ self = _mock_self
+ if self.await_args is None:
+ expected = self._format_mock_call_signature(args, kwargs)
+ raise AssertionError(f'Expected await: {expected}\nNot awaited')
+
+ def _error_message():
+ msg = self._format_mock_failure_message(args, kwargs, action='await')
+ return msg
+
+ expected = self._call_matcher(_Call((args, kwargs), two=True))
+ actual = self._call_matcher(self.await_args)
+ if actual != expected:
+ cause = expected if isinstance(expected, Exception) else None
+ raise AssertionError(_error_message()) from cause
+
+ def assert_awaited_once_with(_mock_self, *args, **kwargs):
+ """
+ Assert that the mock was awaited exactly once and with the specified
+ arguments.
+ """
+ self = _mock_self
+ if not self.await_count == 1:
+ msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once."
+ f" Awaited {self.await_count} times.")
+ raise AssertionError(msg)
+ return self.assert_awaited_with(*args, **kwargs)
+
+ def assert_any_await(_mock_self, *args, **kwargs):
+ """
+ Assert the mock has ever been awaited with the specified arguments.
+ """
+ self = _mock_self
+ expected = self._call_matcher(_Call((args, kwargs), two=True))
+ cause = expected if isinstance(expected, Exception) else None
+ actual = [self._call_matcher(c) for c in self.await_args_list]
+ if cause or expected not in _AnyComparer(actual):
+ expected_string = self._format_mock_call_signature(args, kwargs)
+ raise AssertionError(
+ '%s await not found' % expected_string
+ ) from cause
+
+ def assert_has_awaits(_mock_self, calls, any_order=False):
+ """
+ Assert the mock has been awaited with the specified calls.
+ The :attr:`await_args_list` list is checked for the awaits.
+
+ If `any_order` is False (the default) then the awaits must be
+ sequential. There can be extra calls before or after the
+ specified awaits.
+
+ If `any_order` is True then the awaits can be in any order, but
+ they must all appear in :attr:`await_args_list`.
+ """
+ self = _mock_self
+ expected = [self._call_matcher(c) for c in calls]
+ cause = cause = next((e for e in expected if isinstance(e, Exception)), None)
+ all_awaits = _CallList(self._call_matcher(c) for c in self.await_args_list)
+ if not any_order:
+ if expected not in all_awaits:
+ if cause is None:
+ problem = 'Awaits not found.'
+ else:
+ problem = ('Error processing expected awaits.\n'
+ 'Errors: {}').format(
+ [e if isinstance(e, Exception) else None
+ for e in expected])
+ raise AssertionError(
+ f'{problem}\n'
+ f'Expected: {_CallList(calls)}\n'
+ f'Actual: {self.await_args_list}'
+ ) from cause
+ return
+
+ all_awaits = list(all_awaits)
+
+ not_found = []
+ for kall in expected:
+ try:
+ all_awaits.remove(kall)
+ except ValueError:
+ not_found.append(kall)
+ if not_found:
+ raise AssertionError(
+ '%r not all found in await list' % (tuple(not_found),)
+ ) from cause
+
+ def assert_not_awaited(_mock_self):
+ """
+ Assert that the mock was never awaited.
+ """
+ self = _mock_self
+ if self.await_count != 0:
+ msg = (f"Expected {self._mock_name or 'mock'} to not have been awaited."
+ f" Awaited {self.await_count} times.")
+ raise AssertionError(msg)
+
+ def reset_mock(self, *args, **kwargs):
+ """
+ See :func:`.Mock.reset_mock()`
+ """
+ super().reset_mock(*args, **kwargs)
+ self.await_count = 0
+ self.await_args = None
+ self.await_args_list = _CallList()
+
+
+class AsyncMock(AsyncMockMixin, AsyncMagicMixin, Mock):
+ """
+ Enhance :class:`Mock` with features allowing to mock
+ an async function.
+
+ The :class:`AsyncMock` object will behave so the object is
+ recognized as an async function, and the result of a call is an awaitable:
+
+ >>> mock = AsyncMock()
+ >>> iscoroutinefunction(mock)
+ True
+ >>> inspect.isawaitable(mock())
+ True
+
+
+ The result of ``mock()`` is an async function which will have the outcome
+ of ``side_effect`` or ``return_value``:
+
+ - if ``side_effect`` is a function, the async function will return the
+ result of that function,
+ - if ``side_effect`` is an exception, the async function will raise the
+ exception,
+ - if ``side_effect`` is an iterable, the async function will return the
+ next value of the iterable, however, if the sequence of result is
+ exhausted, ``StopIteration`` is raised immediately,
+ - if ``side_effect`` is not defined, the async function will return the
+ value defined by ``return_value``, hence, by default, the async function
+ returns a new :class:`AsyncMock` object.
+
+ If the outcome of ``side_effect`` or ``return_value`` is an async function,
+ the mock async function obtained when the mock object is called will be this
+ async function itself (and not an async function returning an async
+ function).
+
+ The test author can also specify a wrapped object with ``wraps``. In this
+ case, the :class:`Mock` object behavior is the same as with an
+ :class:`.Mock` object: the wrapped object may have methods
+ defined as async function functions.
+
+ Based on Martin Richard's asynctest project.
+ """
+
class _ANY(object):
"A helper object that compares equal to everything."
@@ -2087,8 +2406,6 @@ class _ANY(object):
def __repr__(self):
return '<ANY>'
- __hash__ = None
-
ANY = _ANY()
@@ -2097,15 +2414,8 @@ def _format_call_signature(name, args, kwargs):
message = '%s(%%s)' % name
formatted_args = ''
args_string = ', '.join([repr(arg) for arg in args])
-
- def encode_item(item):
- if six.PY2 and isinstance(item, unicode):
- return item.encode("utf-8")
- else:
- return item
-
kwargs_string = ', '.join([
- '{}={!r}'.format(encode_item(key), value) for key, value in sorted(kwargs.items())
+ '%s=%r' % (key, value) for key, value in kwargs.items()
])
if args_string:
formatted_args = args_string
@@ -2146,7 +2456,7 @@ class _Call(tuple):
name, args, kwargs = value
elif _len == 2:
first, second = value
- if isinstance(first, basestring):
+ if isinstance(first, str):
name = first
if isinstance(second, tuple):
args = second
@@ -2156,7 +2466,7 @@ class _Call(tuple):
args, kwargs = first, second
elif _len == 1:
value, = value
- if isinstance(value, basestring):
+ if isinstance(value, str):
name = value
elif isinstance(value, tuple):
args = value
@@ -2177,12 +2487,10 @@ class _Call(tuple):
def __eq__(self, other):
- if other is ANY:
- return True
try:
len_other = len(other)
except TypeError:
- return False
+ return NotImplemented
self_name = ''
if len(self) == 2:
@@ -2204,7 +2512,7 @@ class _Call(tuple):
if isinstance(value, tuple):
other_args = value
other_kwargs = {}
- elif isinstance(value, basestring):
+ elif isinstance(value, str):
other_name = value
other_args, other_kwargs = (), {}
else:
@@ -2213,7 +2521,7 @@ class _Call(tuple):
elif len_other == 2:
# could be (name, args) or (name, kwargs) or (args, kwargs)
first, second = other
- if isinstance(first, basestring):
+ if isinstance(first, str):
other_name = first
if isinstance(second, tuple):
other_args, other_kwargs = second, {}
@@ -2231,10 +2539,8 @@ class _Call(tuple):
return (other_args, other_kwargs) == (self_args, self_kwargs)
- def __ne__(self, other):
- return not self.__eq__(other)
+ __ne__ = object.__ne__
- __hash__ = None
def __call__(self, *args, **kwargs):
if self._mock_name is None:
@@ -2247,15 +2553,15 @@ class _Call(tuple):
def __getattr__(self, attr):
if self._mock_name is None:
return _Call(name=attr, from_kall=False)
- name = '{}.{}'.format(self._mock_name, attr)
+ name = '%s.%s' % (self._mock_name, attr)
return _Call(name=name, parent=self, from_kall=False)
- def count(self, *args, **kwargs):
- return self.__getattr__('count')(*args, **kwargs)
+ def __getattribute__(self, attr):
+ if attr in tuple.__dict__:
+ raise AttributeError
+ return tuple.__getattribute__(self, attr)
- def index(self, *args, **kwargs):
- return self.__getattr__('index')(*args, **kwargs)
def _get_call_arguments(self):
if len(self) == 2:
@@ -2310,7 +2616,6 @@ class _Call(tuple):
call = _Call(from_kall=False)
-
def create_autospec(spec, spec_set=False, instance=False, _parent=None,
_name=None, **kwargs):
"""Create a mock object using another object as a spec. Attributes on the
@@ -2335,8 +2640,8 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
# interpreted as a list of strings
spec = type(spec)
- is_type = isinstance(spec, ClassTypes)
-
+ is_type = isinstance(spec, type)
+ is_async_func = _is_async_func(spec)
_kwargs = {'spec': spec}
if spec_set:
_kwargs = {'spec_set': spec}
@@ -2353,6 +2658,11 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
# descriptors don't have a spec
# because we don't know what type they return
_kwargs = {}
+ elif is_async_func:
+ if instance:
+ raise RuntimeError("Instance can not be True when create_autospec "
+ "is mocking an async function")
+ Klass = AsyncMock
elif not _callable(spec):
Klass = NonCallableMagicMock
elif is_type and instance and not _instance_callable(spec):
@@ -2372,6 +2682,8 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
# should only happen at the top level because we don't
# recurse for functions
mock = _set_signature(mock, spec)
+ if is_async_func:
+ _setup_async_mock(mock)
else:
_check_signature(spec, mock, is_type, instance)
@@ -2383,12 +2695,6 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
_name='()', _parent=mock)
for entry in dir(spec):
-
- # This are __ and so treated as magic on Py3, on Py2 we need to
- # explicitly ignore them:
- if six.PY2 and (entry.startswith('im_') or entry.startswith('func_')):
- continue
-
if _is_magic(entry):
# MagicMock already does the useful magic methods for us
continue
@@ -2421,9 +2727,13 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
skipfirst = _must_skip(spec, entry, is_type)
kwargs['_eat_self'] = skipfirst
- new = MagicMock(parent=parent, name=entry, _new_name=entry,
- _new_parent=parent,
- **kwargs)
+ if iscoroutinefunction(original):
+ child_klass = AsyncMock
+ else:
+ child_klass = MagicMock
+ new = child_klass(parent=parent, name=entry, _new_name=entry,
+ _new_parent=parent,
+ **kwargs)
mock._mock_children[entry] = new
_check_signature(original, new, skipfirst=skipfirst)
@@ -2442,14 +2752,11 @@ def _must_skip(spec, entry, is_type):
Return whether we should skip the first argument on spec's `entry`
attribute.
"""
- if not isinstance(spec, ClassTypes):
+ if not isinstance(spec, type):
if entry in getattr(spec, '__dict__', {}):
# instance attribute - shouldn't skip
return False
spec = spec.__class__
- if not hasattr(spec, '__mro__'):
- # old style class: can't have descriptors anyway
- return is_type
for klass in spec.__mro__:
result = klass.__dict__.get(entry, DEFAULT)
@@ -2457,7 +2764,7 @@ def _must_skip(spec, entry, is_type):
continue
if isinstance(result, (staticmethod, classmethod)):
return False
- elif isinstance(getattr(result, '__get__', None), MethodWrapperTypes):
+ elif isinstance(result, FunctionTypes):
# Normal method => skip if looked up on type
# (if looked up on instance, self is already skipped)
return is_type
@@ -2487,10 +2794,6 @@ FunctionTypes = (
type(ANY.__eq__),
)
-MethodWrapperTypes = (
- type(ANY.__eq__.__get__),
-)
-
file_spec = None
@@ -2527,9 +2830,8 @@ def mock_open(mock=None, read_data=''):
return handle.read.return_value
return _state[0].read(*args, **kwargs)
- def _readline_side_effect(*args, **kwargs):
- for item in _iter_side_effect():
- yield item
+ def _readline_side_effect(*args, **kwargs):
+ yield from _iter_side_effect()
while True:
yield _state[0].readline(*args, **kwargs)
@@ -2540,14 +2842,15 @@ def mock_open(mock=None, read_data=''):
for line in _state[0]:
yield line
+ def _next_side_effect():
+ if handle.readline.return_value is not None:
+ return handle.readline.return_value
+ return next(_state[0])
+
global file_spec
if file_spec is None:
- # set on first use
- if six.PY3:
- import _io
- file_spec = list(set(dir(_io.TextIOWrapper)).union(set(dir(_io.BytesIO))))
- else:
- file_spec = file
+ import _io
+ file_spec = list(set(dir(_io.TextIOWrapper)).union(set(dir(_io.BytesIO))))
if mock is None:
mock = MagicMock(name='open', spec=open)
@@ -2565,6 +2868,7 @@ def mock_open(mock=None, read_data=''):
handle.readline.side_effect = _state[1]
handle.readlines.side_effect = _readlines_side_effect
handle.__iter__.side_effect = _iter_side_effect
+ handle.__next__.side_effect = _next_side_effect
def reset_data(*args, **kwargs):
_state[0] = _to_stream(read_data)
@@ -2591,7 +2895,7 @@ class PropertyMock(Mock):
def _get_child_mock(self, **kwargs):
return MagicMock(**kwargs)
- def __get__(self, obj, obj_type):
+ def __get__(self, obj, obj_type=None):
return self()
def __set__(self, obj, val):
self(val)
@@ -2617,3 +2921,21 @@ def seal(mock):
continue
if m._mock_new_parent is mock:
seal(m)
+
+
+class _AsyncIterator:
+ """
+ Wraps an iterator in an asynchronous iterator.
+ """
+ def __init__(self, iterator):
+ self.iterator = iterator
+ code_mock = NonCallableMock(spec_set=CodeType)
+ code_mock.co_flags = inspect.CO_ITERABLE_COROUTINE
+ self.__dict__['__code__'] = code_mock
+
+ async def __anext__(self):
+ try:
+ return next(self.iterator)
+ except StopIteration:
+ pass
+ raise StopAsyncIteration