summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJason Fried <fried@fb.com>2019-11-20 16:27:51 -0800
committerChris Withers <chris@withers.org>2020-01-29 19:12:17 +0000
commit30cc1881168041d1ddede3ecd15b69407a564e7f (patch)
tree55da56958743a00ac23f829d6421c667dd766519
parent0fda551a97c2a99f6dd3e426b4ceedcde4d38d1b (diff)
downloadmock-30cc1881168041d1ddede3ecd15b69407a564e7f.tar.gz
bpo-38857: AsyncMock fix for awaitable values and StopIteration fix [3.8] (GH-17269)
Backports: 046442d02bcc6e848e71e93e47f6cde9e279e993 Signed-off-by: Chris Withers <chris@withers.org>
-rw-r--r--NEWS.d/2019-11-19-16-28-25.bpo-38857.YPUkU9.rst4
-rw-r--r--NEWS.d/2019-11-19-16-30-46.bpo-38859.AZUzL8.rst3
-rw-r--r--mock/mock.py62
-rw-r--r--mock/tests/testasync.py75
4 files changed, 102 insertions, 42 deletions
diff --git a/NEWS.d/2019-11-19-16-28-25.bpo-38857.YPUkU9.rst b/NEWS.d/2019-11-19-16-28-25.bpo-38857.YPUkU9.rst
new file mode 100644
index 0000000..f28df28
--- /dev/null
+++ b/NEWS.d/2019-11-19-16-28-25.bpo-38857.YPUkU9.rst
@@ -0,0 +1,4 @@
+AsyncMock fix for return values that are awaitable types. This also covers
+side_effect iterable values that happend to be awaitable, and wraps
+callables that return an awaitable type. Before these awaitables were being
+awaited instead of being returned as is.
diff --git a/NEWS.d/2019-11-19-16-30-46.bpo-38859.AZUzL8.rst b/NEWS.d/2019-11-19-16-30-46.bpo-38859.AZUzL8.rst
new file mode 100644
index 0000000..c059539
--- /dev/null
+++ b/NEWS.d/2019-11-19-16-30-46.bpo-38859.AZUzL8.rst
@@ -0,0 +1,3 @@
+AsyncMock now returns StopAsyncIteration on the exaustion of a side_effects
+iterable. Since PEP-479 its Impossible to raise a StopIteration exception
+from a coroutine.
diff --git a/mock/mock.py b/mock/mock.py
index fb96c83..807d7b7 100644
--- a/mock/mock.py
+++ b/mock/mock.py
@@ -1148,8 +1148,8 @@ class CallableMixin(Base):
def _execute_mock_call(_mock_self, *args, **kwargs):
self = _mock_self
- # seperate from _increment_mock_call so that awaited functions are
- # executed seperately from their call
+ # separate from _increment_mock_call so that awaited functions are
+ # executed separately from their call, also AsyncMock overrides this method
effect = self.side_effect
if effect is not None:
@@ -2150,30 +2150,46 @@ class AsyncMockMixin(Base):
code_mock.co_flags = inspect.CO_COROUTINE
self.__dict__['__code__'] = code_mock
- async def _mock_call(_mock_self, *args, **kwargs):
+ async def _execute_mock_call(_mock_self, *args, **kwargs):
self = _mock_self
- try:
- result = super()._mock_call(*args, **kwargs)
- except (BaseException, StopIteration) as e:
- side_effect = self.side_effect
- if side_effect is not None and not callable(side_effect):
- raise
- return await _raise(e)
+ # This is nearly just like super(), except for sepcial handling
+ # of coroutines
_call = self.call_args
+ self.await_count += 1
+ self.await_args = _call
+ self.await_args_list.append(_call)
- async def proxy():
- try:
- if inspect.isawaitable(result):
- return await result
- else:
- return result
- finally:
- self.await_count += 1
- self.await_args = _call
- self.await_args_list.append(_call)
+ effect = self.side_effect
+ if effect is not None:
+ if _is_exception(effect):
+ raise effect
+ elif not _callable(effect):
+ try:
+ result = next(effect)
+ except StopIteration:
+ # It is impossible to propogate a StopIteration
+ # through coroutines because of PEP 479
+ raise StopAsyncIteration
+ if _is_exception(result):
+ raise result
+ elif asyncio.iscoroutinefunction(effect):
+ result = await effect(*args, **kwargs)
+ else:
+ result = effect(*args, **kwargs)
- return await proxy()
+ if result is not DEFAULT:
+ return result
+
+ if self._mock_return_value is not DEFAULT:
+ return self.return_value
+
+ if self._mock_wraps is not None:
+ if asyncio.iscoroutinefunction(self._mock_wraps):
+ return await self._mock_wraps(*args, **kwargs)
+ return self._mock_wraps(*args, **kwargs)
+
+ return self.return_value
def assert_awaited(_mock_self):
"""
@@ -2880,10 +2896,6 @@ def seal(mock):
seal(m)
-async def _raise(exception):
- raise exception
-
-
class _AsyncIterator:
"""
Wraps an iterator in an asynchronous iterator.
diff --git a/mock/tests/testasync.py b/mock/tests/testasync.py
index 2eda1e5..e839fca 100644
--- a/mock/tests/testasync.py
+++ b/mock/tests/testasync.py
@@ -8,7 +8,6 @@ from mock import (ANY, call, AsyncMock, patch, MagicMock, Mock,
create_autospec, sentinel)
from mock.mock import _CallList
-
try:
from asyncio import run
except ImportError:
@@ -372,42 +371,84 @@ class AsyncSpecSetTest(unittest.TestCase):
self.assertIsInstance(cm, MagicMock)
-class AsyncArguments(unittest.TestCase):
- def test_add_return_value(self):
+class AsyncArguments(unittest.IsolatedAsyncioTestCase):
+ async def test_add_return_value(self):
async def addition(self, var):
return var + 1
mock = AsyncMock(addition, return_value=10)
- output = run(mock(5))
+ output = await mock(5)
self.assertEqual(output, 10)
- def test_add_side_effect_exception(self):
+ async def test_add_side_effect_exception(self):
async def addition(var):
return var + 1
mock = AsyncMock(addition, side_effect=Exception('err'))
with self.assertRaises(Exception):
- run(mock(5))
+ await mock(5)
- def test_add_side_effect_function(self):
+ async def test_add_side_effect_function(self):
async def addition(var):
return var + 1
mock = AsyncMock(side_effect=addition)
- result = run(mock(5))
+ result = await mock(5)
self.assertEqual(result, 6)
- def test_add_side_effect_iterable(self):
+ async def test_add_side_effect_iterable(self):
vals = [1, 2, 3]
mock = AsyncMock(side_effect=vals)
for item in vals:
- self.assertEqual(item, run(mock()))
-
- with self.assertRaises(RuntimeError) as e:
- run(mock())
- self.assertEqual(
- e.exception,
- RuntimeError('coroutine raised StopIteration')
- )
+ self.assertEqual(item, await mock())
+
+ with self.assertRaises(StopAsyncIteration) as e:
+ await mock()
+
+ async def test_return_value_AsyncMock(self):
+ value = AsyncMock(return_value=10)
+ mock = AsyncMock(return_value=value)
+ result = await mock()
+ self.assertIs(result, value)
+
+ async def test_return_value_awaitable(self):
+ fut = asyncio.Future()
+ fut.set_result(None)
+ mock = AsyncMock(return_value=fut)
+ result = await mock()
+ self.assertIsInstance(result, asyncio.Future)
+
+ async def test_side_effect_awaitable_values(self):
+ fut = asyncio.Future()
+ fut.set_result(None)
+
+ mock = AsyncMock(side_effect=[fut])
+ result = await mock()
+ self.assertIsInstance(result, asyncio.Future)
+
+ with self.assertRaises(StopAsyncIteration):
+ await mock()
+
+ async def test_side_effect_is_AsyncMock(self):
+ effect = AsyncMock(return_value=10)
+ mock = AsyncMock(side_effect=effect)
+
+ result = await mock()
+ self.assertEqual(result, 10)
+
+ async def test_wraps_coroutine(self):
+ value = asyncio.Future()
+
+ ran = False
+ async def inner():
+ nonlocal ran
+ ran = True
+ return value
+
+ mock = AsyncMock(wraps=inner)
+ result = await mock()
+ self.assertEqual(result, value)
+ mock.assert_awaited()
+ self.assertTrue(ran)
class AsyncMagicMethods(unittest.TestCase):
def test_async_magic_methods_return_async_mocks(self):