aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--google/api_core/grpc_helpers.py18
-rw-r--r--tests/unit/test_grpc_helpers.py41
2 files changed, 54 insertions, 5 deletions
diff --git a/google/api_core/grpc_helpers.py b/google/api_core/grpc_helpers.py
index 4d63beb..c47b09f 100644
--- a/google/api_core/grpc_helpers.py
+++ b/google/api_core/grpc_helpers.py
@@ -65,6 +65,19 @@ class _StreamingResponseIterator(grpc.Call):
def __init__(self, wrapped):
self._wrapped = wrapped
+ # This iterator is used in a retry context, and returned outside after init.
+ # gRPC will not throw an exception until the stream is consumed, so we need
+ # to retrieve the first result, in order to fail, in order to trigger a retry.
+ try:
+ self._stored_first_result = six.next(self._wrapped)
+ except TypeError:
+ # It is possible the wrapped method isn't an iterable (a grpc.Call
+ # for instance). If this happens don't store the first result.
+ pass
+ except StopIteration:
+ # ignore stop iteration at this time. This should be handled outside of retry.
+ pass
+
def __iter__(self):
"""This iterator is also an iterable that returns itself."""
return self
@@ -76,8 +89,13 @@ class _StreamingResponseIterator(grpc.Call):
protobuf.Message: A single response from the stream.
"""
try:
+ if hasattr(self, "_stored_first_result"):
+ result = self._stored_first_result
+ del self._stored_first_result
+ return result
return six.next(self._wrapped)
except grpc.RpcError as exc:
+ # If the stream has already returned data, we cannot recover here.
six.raise_from(exceptions.from_grpc_error(exc), exc)
# Alias needed for Python 2/3 support.
diff --git a/tests/unit/test_grpc_helpers.py b/tests/unit/test_grpc_helpers.py
index c37c3ee..1fec64f 100644
--- a/tests/unit/test_grpc_helpers.py
+++ b/tests/unit/test_grpc_helpers.py
@@ -129,24 +129,55 @@ def test_wrap_stream_errors_invocation():
assert exc_info.value.response == grpc_error
+def test_wrap_stream_empty_iterator():
+ expected_responses = []
+ callable_ = mock.Mock(spec=["__call__"], return_value=iter(expected_responses))
+
+ wrapped_callable = grpc_helpers._wrap_stream_errors(callable_)
+
+ got_iterator = wrapped_callable()
+
+ responses = list(got_iterator)
+
+ callable_.assert_called_once_with()
+ assert responses == expected_responses
+
+
class RpcResponseIteratorImpl(object):
- def __init__(self, exception):
- self._exception = exception
+ def __init__(self, iterable):
+ self._iterable = iter(iterable)
def next(self):
- raise self._exception
+ next_item = next(self._iterable)
+ if isinstance(next_item, RpcErrorImpl):
+ raise next_item
+ return next_item
__next__ = next
-def test_wrap_stream_errors_iterator():
+def test_wrap_stream_errors_iterator_initialization():
grpc_error = RpcErrorImpl(grpc.StatusCode.UNAVAILABLE)
- response_iter = RpcResponseIteratorImpl(grpc_error)
+ response_iter = RpcResponseIteratorImpl([grpc_error])
callable_ = mock.Mock(spec=["__call__"], return_value=response_iter)
wrapped_callable = grpc_helpers._wrap_stream_errors(callable_)
+ with pytest.raises(exceptions.ServiceUnavailable) as exc_info:
+ wrapped_callable(1, 2, three="four")
+
+ callable_.assert_called_once_with(1, 2, three="four")
+ assert exc_info.value.response == grpc_error
+
+
+def test_wrap_stream_errors_during_iteration():
+ grpc_error = RpcErrorImpl(grpc.StatusCode.UNAVAILABLE)
+ response_iter = RpcResponseIteratorImpl([1, grpc_error])
+ callable_ = mock.Mock(spec=["__call__"], return_value=response_iter)
+
+ wrapped_callable = grpc_helpers._wrap_stream_errors(callable_)
got_iterator = wrapped_callable(1, 2, three="four")
+ next(got_iterator)
with pytest.raises(exceptions.ServiceUnavailable) as exc_info:
next(got_iterator)