aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPeter Lamut <plamut@users.noreply.github.com>2021-06-08 10:30:29 +0200
committerGitHub <noreply@github.com>2021-06-08 10:30:29 +0200
commit3487d68bdab6f20e2ab931c8283f63c94862cf31 (patch)
treedeec7de500fc831efabcbe25e2371f4326ce116b
parent641fbbf95c4cf72e48e2a58d563e41b2a1787bbf (diff)
downloadpython-api-core-3487d68bdab6f20e2ab931c8283f63c94862cf31.tar.gz
feat: add iterator capability to paged iterators (#200)
* feat: add iterator capability to *Iterator classes The *Iterator classes are only _iterables_, and this commit also makes them _iterators_, i.e. calling next(iterator) on them now works. * Make AsyncIterator an actual async iterator
-rw-r--r--google/api_core/page_iterator.py10
-rw-r--r--google/api_core/page_iterator_async.py7
-rw-r--r--tests/asyncio/test_page_iterator_async.py27
-rw-r--r--tests/unit/test_page_iterator.py20
4 files changed, 64 insertions, 0 deletions
diff --git a/google/api_core/page_iterator.py b/google/api_core/page_iterator.py
index fff3b55..49879bc 100644
--- a/google/api_core/page_iterator.py
+++ b/google/api_core/page_iterator.py
@@ -170,6 +170,8 @@ class Iterator(object):
max_results=None,
):
self._started = False
+ self.__active_iterator = None
+
self.client = client
"""Optional[Any]: The client that created this iterator."""
self.item_to_value = item_to_value
@@ -228,6 +230,14 @@ class Iterator(object):
self._started = True
return self._items_iter()
+ def __next__(self):
+ if self.__active_iterator is None:
+ self.__active_iterator = iter(self)
+ return next(self.__active_iterator)
+
+ # Preserve Python 2 compatibility.
+ next = __next__
+
def _page_iter(self, increment):
"""Generator of pages of API responses.
diff --git a/google/api_core/page_iterator_async.py b/google/api_core/page_iterator_async.py
index a0aa41a..c072575 100644
--- a/google/api_core/page_iterator_async.py
+++ b/google/api_core/page_iterator_async.py
@@ -101,6 +101,8 @@ class AsyncIterator(abc.ABC):
max_results=None,
):
self._started = False
+ self.__active_aiterator = None
+
self.client = client
"""Optional[Any]: The client that created this iterator."""
self.item_to_value = item_to_value
@@ -159,6 +161,11 @@ class AsyncIterator(abc.ABC):
self._started = True
return self._items_aiter()
+ async def __anext__(self):
+ if self.__active_aiterator is None:
+ self.__active_aiterator = self.__aiter__()
+ return await self.__active_aiterator.__anext__()
+
async def _page_aiter(self, increment):
"""Generator of pages of API responses.
diff --git a/tests/asyncio/test_page_iterator_async.py b/tests/asyncio/test_page_iterator_async.py
index 42fac2a..4abacc6 100644
--- a/tests/asyncio/test_page_iterator_async.py
+++ b/tests/asyncio/test_page_iterator_async.py
@@ -47,6 +47,33 @@ class TestAsyncIterator:
assert iterator.next_page_token == token
assert iterator.num_results == 0
+ @pytest.mark.asyncio
+ async def test_anext(self):
+ parent = mock.sentinel.parent
+ page_1 = page_iterator_async.Page(
+ parent, ("item 1.1", "item 1.2"), page_iterator_async._item_to_value_identity
+ )
+ page_2 = page_iterator_async.Page(
+ parent, ("item 2.1",), page_iterator_async._item_to_value_identity
+ )
+
+ async_iterator = PageAsyncIteratorImpl(None, None)
+ async_iterator._next_page = mock.AsyncMock(side_effect=[page_1, page_2, None])
+
+ # Consume items and check the state of the async_iterator.
+ assert async_iterator.num_results == 0
+ assert await async_iterator.__anext__() == "item 1.1"
+ assert async_iterator.num_results == 1
+
+ assert await async_iterator.__anext__() == "item 1.2"
+ assert async_iterator.num_results == 2
+
+ assert await async_iterator.__anext__() == "item 2.1"
+ assert async_iterator.num_results == 3
+
+ with pytest.raises(StopAsyncIteration):
+ await async_iterator.__anext__()
+
def test_pages_property_starts(self):
iterator = PageAsyncIteratorImpl(None, None)
diff --git a/tests/unit/test_page_iterator.py b/tests/unit/test_page_iterator.py
index 8359537..97b0657 100644
--- a/tests/unit/test_page_iterator.py
+++ b/tests/unit/test_page_iterator.py
@@ -109,6 +109,26 @@ class TestIterator(object):
assert iterator.next_page_token == token
assert iterator.num_results == 0
+ def test_next(self):
+ iterator = PageIteratorImpl(None, None)
+ page_1 = page_iterator.Page(
+ iterator, ("item 1.1", "item 1.2"), page_iterator._item_to_value_identity
+ )
+ page_2 = page_iterator.Page(
+ iterator, ("item 2.1",), page_iterator._item_to_value_identity
+ )
+ iterator._next_page = mock.Mock(side_effect=[page_1, page_2, None])
+
+ result = next(iterator)
+ assert result == "item 1.1"
+ result = next(iterator)
+ assert result == "item 1.2"
+ result = next(iterator)
+ assert result == "item 2.1"
+
+ with pytest.raises(StopIteration):
+ next(iterator)
+
def test_pages_property_starts(self):
iterator = PageIteratorImpl(None, None)