diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/unit/test_grpc_helpers.py | 41 |
1 files changed, 36 insertions, 5 deletions
diff --git a/tests/unit/test_grpc_helpers.py b/tests/unit/test_grpc_helpers.py index c37c3ee..1fec64f 100644 --- a/tests/unit/test_grpc_helpers.py +++ b/tests/unit/test_grpc_helpers.py @@ -129,24 +129,55 @@ def test_wrap_stream_errors_invocation(): assert exc_info.value.response == grpc_error +def test_wrap_stream_empty_iterator(): + expected_responses = [] + callable_ = mock.Mock(spec=["__call__"], return_value=iter(expected_responses)) + + wrapped_callable = grpc_helpers._wrap_stream_errors(callable_) + + got_iterator = wrapped_callable() + + responses = list(got_iterator) + + callable_.assert_called_once_with() + assert responses == expected_responses + + class RpcResponseIteratorImpl(object): - def __init__(self, exception): - self._exception = exception + def __init__(self, iterable): + self._iterable = iter(iterable) def next(self): - raise self._exception + next_item = next(self._iterable) + if isinstance(next_item, RpcErrorImpl): + raise next_item + return next_item __next__ = next -def test_wrap_stream_errors_iterator(): +def test_wrap_stream_errors_iterator_initialization(): grpc_error = RpcErrorImpl(grpc.StatusCode.UNAVAILABLE) - response_iter = RpcResponseIteratorImpl(grpc_error) + response_iter = RpcResponseIteratorImpl([grpc_error]) callable_ = mock.Mock(spec=["__call__"], return_value=response_iter) wrapped_callable = grpc_helpers._wrap_stream_errors(callable_) + with pytest.raises(exceptions.ServiceUnavailable) as exc_info: + wrapped_callable(1, 2, three="four") + + callable_.assert_called_once_with(1, 2, three="four") + assert exc_info.value.response == grpc_error + + +def test_wrap_stream_errors_during_iteration(): + grpc_error = RpcErrorImpl(grpc.StatusCode.UNAVAILABLE) + response_iter = RpcResponseIteratorImpl([1, grpc_error]) + callable_ = mock.Mock(spec=["__call__"], return_value=response_iter) + + wrapped_callable = grpc_helpers._wrap_stream_errors(callable_) got_iterator = wrapped_callable(1, 2, three="four") + next(got_iterator) with pytest.raises(exceptions.ServiceUnavailable) as exc_info: next(got_iterator) |