diff options
author | Jason Fried <fried@fb.com> | 2019-11-20 16:27:51 -0800 |
---|---|---|
committer | Chris Withers <chris@withers.org> | 2020-01-29 19:12:17 +0000 |
commit | 30cc1881168041d1ddede3ecd15b69407a564e7f (patch) | |
tree | 55da56958743a00ac23f829d6421c667dd766519 | |
parent | 0fda551a97c2a99f6dd3e426b4ceedcde4d38d1b (diff) | |
download | mock-30cc1881168041d1ddede3ecd15b69407a564e7f.tar.gz |
bpo-38857: AsyncMock fix for awaitable values and StopIteration fix [3.8] (GH-17269)
Backports: 046442d02bcc6e848e71e93e47f6cde9e279e993
Signed-off-by: Chris Withers <chris@withers.org>
-rw-r--r-- | NEWS.d/2019-11-19-16-28-25.bpo-38857.YPUkU9.rst | 4 | ||||
-rw-r--r-- | NEWS.d/2019-11-19-16-30-46.bpo-38859.AZUzL8.rst | 3 | ||||
-rw-r--r-- | mock/mock.py | 62 | ||||
-rw-r--r-- | mock/tests/testasync.py | 75 |
4 files changed, 102 insertions, 42 deletions
diff --git a/NEWS.d/2019-11-19-16-28-25.bpo-38857.YPUkU9.rst b/NEWS.d/2019-11-19-16-28-25.bpo-38857.YPUkU9.rst new file mode 100644 index 0000000..f28df28 --- /dev/null +++ b/NEWS.d/2019-11-19-16-28-25.bpo-38857.YPUkU9.rst @@ -0,0 +1,4 @@ +AsyncMock fix for return values that are awaitable types. This also covers +side_effect iterable values that happend to be awaitable, and wraps +callables that return an awaitable type. Before these awaitables were being +awaited instead of being returned as is. diff --git a/NEWS.d/2019-11-19-16-30-46.bpo-38859.AZUzL8.rst b/NEWS.d/2019-11-19-16-30-46.bpo-38859.AZUzL8.rst new file mode 100644 index 0000000..c059539 --- /dev/null +++ b/NEWS.d/2019-11-19-16-30-46.bpo-38859.AZUzL8.rst @@ -0,0 +1,3 @@ +AsyncMock now returns StopAsyncIteration on the exaustion of a side_effects +iterable. Since PEP-479 its Impossible to raise a StopIteration exception +from a coroutine. diff --git a/mock/mock.py b/mock/mock.py index fb96c83..807d7b7 100644 --- a/mock/mock.py +++ b/mock/mock.py @@ -1148,8 +1148,8 @@ class CallableMixin(Base): def _execute_mock_call(_mock_self, *args, **kwargs): self = _mock_self - # seperate from _increment_mock_call so that awaited functions are - # executed seperately from their call + # separate from _increment_mock_call so that awaited functions are + # executed separately from their call, also AsyncMock overrides this method effect = self.side_effect if effect is not None: @@ -2150,30 +2150,46 @@ class AsyncMockMixin(Base): code_mock.co_flags = inspect.CO_COROUTINE self.__dict__['__code__'] = code_mock - async def _mock_call(_mock_self, *args, **kwargs): + async def _execute_mock_call(_mock_self, *args, **kwargs): self = _mock_self - try: - result = super()._mock_call(*args, **kwargs) - except (BaseException, StopIteration) as e: - side_effect = self.side_effect - if side_effect is not None and not callable(side_effect): - raise - return await _raise(e) + # This is nearly just like super(), except for sepcial handling + # of coroutines _call = self.call_args + self.await_count += 1 + self.await_args = _call + self.await_args_list.append(_call) - async def proxy(): - try: - if inspect.isawaitable(result): - return await result - else: - return result - finally: - self.await_count += 1 - self.await_args = _call - self.await_args_list.append(_call) + effect = self.side_effect + if effect is not None: + if _is_exception(effect): + raise effect + elif not _callable(effect): + try: + result = next(effect) + except StopIteration: + # It is impossible to propogate a StopIteration + # through coroutines because of PEP 479 + raise StopAsyncIteration + if _is_exception(result): + raise result + elif asyncio.iscoroutinefunction(effect): + result = await effect(*args, **kwargs) + else: + result = effect(*args, **kwargs) - return await proxy() + if result is not DEFAULT: + return result + + if self._mock_return_value is not DEFAULT: + return self.return_value + + if self._mock_wraps is not None: + if asyncio.iscoroutinefunction(self._mock_wraps): + return await self._mock_wraps(*args, **kwargs) + return self._mock_wraps(*args, **kwargs) + + return self.return_value def assert_awaited(_mock_self): """ @@ -2880,10 +2896,6 @@ def seal(mock): seal(m) -async def _raise(exception): - raise exception - - class _AsyncIterator: """ Wraps an iterator in an asynchronous iterator. diff --git a/mock/tests/testasync.py b/mock/tests/testasync.py index 2eda1e5..e839fca 100644 --- a/mock/tests/testasync.py +++ b/mock/tests/testasync.py @@ -8,7 +8,6 @@ from mock import (ANY, call, AsyncMock, patch, MagicMock, Mock, create_autospec, sentinel) from mock.mock import _CallList - try: from asyncio import run except ImportError: @@ -372,42 +371,84 @@ class AsyncSpecSetTest(unittest.TestCase): self.assertIsInstance(cm, MagicMock) -class AsyncArguments(unittest.TestCase): - def test_add_return_value(self): +class AsyncArguments(unittest.IsolatedAsyncioTestCase): + async def test_add_return_value(self): async def addition(self, var): return var + 1 mock = AsyncMock(addition, return_value=10) - output = run(mock(5)) + output = await mock(5) self.assertEqual(output, 10) - def test_add_side_effect_exception(self): + async def test_add_side_effect_exception(self): async def addition(var): return var + 1 mock = AsyncMock(addition, side_effect=Exception('err')) with self.assertRaises(Exception): - run(mock(5)) + await mock(5) - def test_add_side_effect_function(self): + async def test_add_side_effect_function(self): async def addition(var): return var + 1 mock = AsyncMock(side_effect=addition) - result = run(mock(5)) + result = await mock(5) self.assertEqual(result, 6) - def test_add_side_effect_iterable(self): + async def test_add_side_effect_iterable(self): vals = [1, 2, 3] mock = AsyncMock(side_effect=vals) for item in vals: - self.assertEqual(item, run(mock())) - - with self.assertRaises(RuntimeError) as e: - run(mock()) - self.assertEqual( - e.exception, - RuntimeError('coroutine raised StopIteration') - ) + self.assertEqual(item, await mock()) + + with self.assertRaises(StopAsyncIteration) as e: + await mock() + + async def test_return_value_AsyncMock(self): + value = AsyncMock(return_value=10) + mock = AsyncMock(return_value=value) + result = await mock() + self.assertIs(result, value) + + async def test_return_value_awaitable(self): + fut = asyncio.Future() + fut.set_result(None) + mock = AsyncMock(return_value=fut) + result = await mock() + self.assertIsInstance(result, asyncio.Future) + + async def test_side_effect_awaitable_values(self): + fut = asyncio.Future() + fut.set_result(None) + + mock = AsyncMock(side_effect=[fut]) + result = await mock() + self.assertIsInstance(result, asyncio.Future) + + with self.assertRaises(StopAsyncIteration): + await mock() + + async def test_side_effect_is_AsyncMock(self): + effect = AsyncMock(return_value=10) + mock = AsyncMock(side_effect=effect) + + result = await mock() + self.assertEqual(result, 10) + + async def test_wraps_coroutine(self): + value = asyncio.Future() + + ran = False + async def inner(): + nonlocal ran + ran = True + return value + + mock = AsyncMock(wraps=inner) + result = await mock() + self.assertEqual(result, value) + mock.assert_awaited() + self.assertTrue(ran) class AsyncMagicMethods(unittest.TestCase): def test_async_magic_methods_return_async_mocks(self): |