aboutsummaryrefslogtreecommitdiff
path: root/tests/unit
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unit')
-rw-r--r--tests/unit/__init__.py0
-rw-r--r--tests/unit/future/__init__.py0
-rw-r--r--tests/unit/future/test__helpers.py37
-rw-r--r--tests/unit/future/test_polling.py242
-rw-r--r--tests/unit/gapic/test_client_info.py31
-rw-r--r--tests/unit/gapic/test_config.py94
-rw-r--r--tests/unit/gapic/test_method.py244
-rw-r--r--tests/unit/gapic/test_routing_header.py41
-rw-r--r--tests/unit/operations_v1/__init__.py0
-rw-r--r--tests/unit/operations_v1/test_operations_client.py98
-rw-r--r--tests/unit/operations_v1/test_operations_rest_client.py944
-rw-r--r--tests/unit/test_bidi.py869
-rw-r--r--tests/unit/test_client_info.py98
-rw-r--r--tests/unit/test_client_options.py117
-rw-r--r--tests/unit/test_datetime_helpers.py396
-rw-r--r--tests/unit/test_exceptions.py353
-rw-r--r--tests/unit/test_grpc_helpers.py860
-rw-r--r--tests/unit/test_iam.py382
-rw-r--r--tests/unit/test_operation.py326
-rw-r--r--tests/unit/test_page_iterator.py665
-rw-r--r--tests/unit/test_path_template.py389
-rw-r--r--tests/unit/test_protobuf_helpers.py518
-rw-r--r--tests/unit/test_rest_helpers.py77
-rw-r--r--tests/unit/test_retry.py458
-rw-r--r--tests/unit/test_timeout.py129
25 files changed, 7368 insertions, 0 deletions
diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/unit/__init__.py
diff --git a/tests/unit/future/__init__.py b/tests/unit/future/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/unit/future/__init__.py
diff --git a/tests/unit/future/test__helpers.py b/tests/unit/future/test__helpers.py
new file mode 100644
index 0000000..98afc59
--- /dev/null
+++ b/tests/unit/future/test__helpers.py
@@ -0,0 +1,37 @@
+# Copyright 2017, Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import mock
+
+from google.api_core.future import _helpers
+
+
+@mock.patch("threading.Thread", autospec=True)
+def test_start_deamon_thread(unused_thread):
+ deamon_thread = _helpers.start_daemon_thread(target=mock.sentinel.target)
+ assert deamon_thread.daemon is True
+
+
+def test_safe_invoke_callback():
+ callback = mock.Mock(spec=["__call__"], return_value=42)
+ result = _helpers.safe_invoke_callback(callback, "a", b="c")
+ assert result == 42
+ callback.assert_called_once_with("a", b="c")
+
+
+def test_safe_invoke_callback_exception():
+ callback = mock.Mock(spec=["__call__"], side_effect=ValueError())
+ result = _helpers.safe_invoke_callback(callback, "a", b="c")
+ assert result is None
+ callback.assert_called_once_with("a", b="c")
diff --git a/tests/unit/future/test_polling.py b/tests/unit/future/test_polling.py
new file mode 100644
index 0000000..2381d03
--- /dev/null
+++ b/tests/unit/future/test_polling.py
@@ -0,0 +1,242 @@
+# Copyright 2017, Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import concurrent.futures
+import threading
+import time
+
+import mock
+import pytest
+
+from google.api_core import exceptions, retry
+from google.api_core.future import polling
+
+
+class PollingFutureImpl(polling.PollingFuture):
+ def done(self):
+ return False
+
+ def cancel(self):
+ return True
+
+ def cancelled(self):
+ return False
+
+ def running(self):
+ return True
+
+
+def test_polling_future_constructor():
+ future = PollingFutureImpl()
+ assert not future.done()
+ assert not future.cancelled()
+ assert future.running()
+ assert future.cancel()
+ with mock.patch.object(future, "done", return_value=True):
+ future.result()
+
+
+def test_set_result():
+ future = PollingFutureImpl()
+ callback = mock.Mock()
+
+ future.set_result(1)
+
+ assert future.result() == 1
+ future.add_done_callback(callback)
+ callback.assert_called_once_with(future)
+
+
+def test_set_exception():
+ future = PollingFutureImpl()
+ exception = ValueError("meep")
+
+ future.set_exception(exception)
+
+ assert future.exception() == exception
+ with pytest.raises(ValueError):
+ future.result()
+
+ callback = mock.Mock()
+ future.add_done_callback(callback)
+ callback.assert_called_once_with(future)
+
+
+def test_invoke_callback_exception():
+ future = PollingFutureImplWithPoll()
+ future.set_result(42)
+
+ # This should not raise, despite the callback causing an exception.
+ callback = mock.Mock(side_effect=ValueError)
+ future.add_done_callback(callback)
+ callback.assert_called_once_with(future)
+
+
+class PollingFutureImplWithPoll(PollingFutureImpl):
+ def __init__(self):
+ super(PollingFutureImplWithPoll, self).__init__()
+ self.poll_count = 0
+ self.event = threading.Event()
+
+ def done(self, retry=polling.DEFAULT_RETRY):
+ self.poll_count += 1
+ self.event.wait()
+ self.set_result(42)
+ return True
+
+
+def test_result_with_polling():
+ future = PollingFutureImplWithPoll()
+
+ future.event.set()
+ result = future.result()
+
+ assert result == 42
+ assert future.poll_count == 1
+ # Repeated calls should not cause additional polling
+ assert future.result() == result
+ assert future.poll_count == 1
+
+
+class PollingFutureImplTimeout(PollingFutureImplWithPoll):
+ def done(self, retry=polling.DEFAULT_RETRY):
+ time.sleep(1)
+ return False
+
+
+def test_result_timeout():
+ future = PollingFutureImplTimeout()
+ with pytest.raises(concurrent.futures.TimeoutError):
+ future.result(timeout=1)
+
+
+def test_exception_timeout():
+ future = PollingFutureImplTimeout()
+ with pytest.raises(concurrent.futures.TimeoutError):
+ future.exception(timeout=1)
+
+
+class PollingFutureImplTransient(PollingFutureImplWithPoll):
+ def __init__(self, errors):
+ super(PollingFutureImplTransient, self).__init__()
+ self._errors = errors
+
+ def done(self, retry=polling.DEFAULT_RETRY):
+ if self._errors:
+ error, self._errors = self._errors[0], self._errors[1:]
+ raise error("testing")
+ self.poll_count += 1
+ self.set_result(42)
+ return True
+
+
+def test_result_transient_error():
+ future = PollingFutureImplTransient(
+ (
+ exceptions.TooManyRequests,
+ exceptions.InternalServerError,
+ exceptions.BadGateway,
+ )
+ )
+ result = future.result()
+ assert result == 42
+ assert future.poll_count == 1
+ # Repeated calls should not cause additional polling
+ assert future.result() == result
+ assert future.poll_count == 1
+
+
+def test_callback_background_thread():
+ future = PollingFutureImplWithPoll()
+ callback = mock.Mock()
+
+ future.add_done_callback(callback)
+
+ assert future._polling_thread is not None
+
+ # Give the thread a second to poll
+ time.sleep(1)
+ assert future.poll_count == 1
+
+ future.event.set()
+ future._polling_thread.join()
+
+ callback.assert_called_once_with(future)
+
+
+def test_double_callback_background_thread():
+ future = PollingFutureImplWithPoll()
+ callback = mock.Mock()
+ callback2 = mock.Mock()
+
+ future.add_done_callback(callback)
+ current_thread = future._polling_thread
+ assert current_thread is not None
+
+ # only one polling thread should be created.
+ future.add_done_callback(callback2)
+ assert future._polling_thread is current_thread
+
+ future.event.set()
+ future._polling_thread.join()
+
+ assert future.poll_count == 1
+ callback.assert_called_once_with(future)
+ callback2.assert_called_once_with(future)
+
+
+class PollingFutureImplWithoutRetry(PollingFutureImpl):
+ def done(self):
+ return True
+
+ def result(self):
+ return super(PollingFutureImplWithoutRetry, self).result()
+
+ def _blocking_poll(self, timeout):
+ return super(PollingFutureImplWithoutRetry, self)._blocking_poll(
+ timeout=timeout
+ )
+
+
+class PollingFutureImplWith_done_or_raise(PollingFutureImpl):
+ def done(self):
+ return True
+
+ def _done_or_raise(self):
+ return super(PollingFutureImplWith_done_or_raise, self)._done_or_raise()
+
+
+def test_polling_future_without_retry():
+ custom_retry = retry.Retry(
+ predicate=retry.if_exception_type(exceptions.TooManyRequests)
+ )
+ future = PollingFutureImplWithoutRetry()
+ assert future.done()
+ assert future.running()
+ assert future.result() is None
+
+ with mock.patch.object(future, "done") as done_mock:
+ future._done_or_raise()
+ done_mock.assert_called_once_with()
+
+ with mock.patch.object(future, "done") as done_mock:
+ future._done_or_raise(retry=custom_retry)
+ done_mock.assert_called_once_with(retry=custom_retry)
+
+
+def test_polling_future_with__done_or_raise():
+ future = PollingFutureImplWith_done_or_raise()
+ assert future.done()
+ assert future.running()
+ assert future.result() is None
diff --git a/tests/unit/gapic/test_client_info.py b/tests/unit/gapic/test_client_info.py
new file mode 100644
index 0000000..2ca5c40
--- /dev/null
+++ b/tests/unit/gapic/test_client_info.py
@@ -0,0 +1,31 @@
+# Copyright 2017 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+try:
+ import grpc # noqa: F401
+except ImportError:
+ pytest.skip("No GRPC", allow_module_level=True)
+
+
+from google.api_core.gapic_v1 import client_info
+
+
+def test_to_grpc_metadata():
+ info = client_info.ClientInfo()
+
+ metadata = info.to_grpc_metadata()
+
+ assert metadata == (client_info.METRICS_METADATA_KEY, info.to_user_agent())
diff --git a/tests/unit/gapic/test_config.py b/tests/unit/gapic/test_config.py
new file mode 100644
index 0000000..5e42fde
--- /dev/null
+++ b/tests/unit/gapic/test_config.py
@@ -0,0 +1,94 @@
+# Copyright 2017 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+try:
+ import grpc # noqa: F401
+except ImportError:
+ pytest.skip("No GRPC", allow_module_level=True)
+
+from google.api_core import exceptions
+from google.api_core.gapic_v1 import config
+
+
+INTERFACE_CONFIG = {
+ "retry_codes": {
+ "idempotent": ["DEADLINE_EXCEEDED", "UNAVAILABLE"],
+ "other": ["FAILED_PRECONDITION"],
+ "non_idempotent": [],
+ },
+ "retry_params": {
+ "default": {
+ "initial_retry_delay_millis": 1000,
+ "retry_delay_multiplier": 2.5,
+ "max_retry_delay_millis": 120000,
+ "initial_rpc_timeout_millis": 120000,
+ "rpc_timeout_multiplier": 1.0,
+ "max_rpc_timeout_millis": 120000,
+ "total_timeout_millis": 600000,
+ },
+ "other": {
+ "initial_retry_delay_millis": 1000,
+ "retry_delay_multiplier": 1,
+ "max_retry_delay_millis": 1000,
+ "initial_rpc_timeout_millis": 1000,
+ "rpc_timeout_multiplier": 1,
+ "max_rpc_timeout_millis": 1000,
+ "total_timeout_millis": 1000,
+ },
+ },
+ "methods": {
+ "AnnotateVideo": {
+ "timeout_millis": 60000,
+ "retry_codes_name": "idempotent",
+ "retry_params_name": "default",
+ },
+ "Other": {
+ "timeout_millis": 60000,
+ "retry_codes_name": "other",
+ "retry_params_name": "other",
+ },
+ "Plain": {"timeout_millis": 30000},
+ },
+}
+
+
+def test_create_method_configs():
+ method_configs = config.parse_method_configs(INTERFACE_CONFIG)
+
+ retry, timeout = method_configs["AnnotateVideo"]
+ assert retry._predicate(exceptions.DeadlineExceeded(None))
+ assert retry._predicate(exceptions.ServiceUnavailable(None))
+ assert retry._initial == 1.0
+ assert retry._multiplier == 2.5
+ assert retry._maximum == 120.0
+ assert retry._deadline == 600.0
+ assert timeout._initial == 120.0
+ assert timeout._multiplier == 1.0
+ assert timeout._maximum == 120.0
+
+ retry, timeout = method_configs["Other"]
+ assert retry._predicate(exceptions.FailedPrecondition(None))
+ assert retry._initial == 1.0
+ assert retry._multiplier == 1.0
+ assert retry._maximum == 1.0
+ assert retry._deadline == 1.0
+ assert timeout._initial == 1.0
+ assert timeout._multiplier == 1.0
+ assert timeout._maximum == 1.0
+
+ retry, timeout = method_configs["Plain"]
+ assert retry is None
+ assert timeout._timeout == 30.0
diff --git a/tests/unit/gapic/test_method.py b/tests/unit/gapic/test_method.py
new file mode 100644
index 0000000..9778d23
--- /dev/null
+++ b/tests/unit/gapic/test_method.py
@@ -0,0 +1,244 @@
+# Copyright 2017 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+
+import mock
+import pytest
+
+try:
+ import grpc # noqa: F401
+except ImportError:
+ pytest.skip("No GRPC", allow_module_level=True)
+
+
+from google.api_core import exceptions
+from google.api_core import retry
+from google.api_core import timeout
+import google.api_core.gapic_v1.client_info
+import google.api_core.gapic_v1.method
+import google.api_core.page_iterator
+
+
+def _utcnow_monotonic():
+ curr_value = datetime.datetime.min
+ delta = datetime.timedelta(seconds=0.5)
+ while True:
+ yield curr_value
+ curr_value += delta
+
+
+def test__determine_timeout():
+ # Check _determine_timeout always returns a Timeout object.
+ timeout_type_timeout = timeout.ConstantTimeout(600.0)
+ returned_timeout = google.api_core.gapic_v1.method._determine_timeout(
+ 600.0, 600.0, None
+ )
+ assert isinstance(returned_timeout, timeout.ConstantTimeout)
+ returned_timeout = google.api_core.gapic_v1.method._determine_timeout(
+ 600.0, timeout_type_timeout, None
+ )
+ assert isinstance(returned_timeout, timeout.ConstantTimeout)
+ returned_timeout = google.api_core.gapic_v1.method._determine_timeout(
+ timeout_type_timeout, 600.0, None
+ )
+ assert isinstance(returned_timeout, timeout.ConstantTimeout)
+ returned_timeout = google.api_core.gapic_v1.method._determine_timeout(
+ timeout_type_timeout, timeout_type_timeout, None
+ )
+ assert isinstance(returned_timeout, timeout.ConstantTimeout)
+
+
+def test_wrap_method_basic():
+ method = mock.Mock(spec=["__call__"], return_value=42)
+
+ wrapped_method = google.api_core.gapic_v1.method.wrap_method(method)
+
+ result = wrapped_method(1, 2, meep="moop")
+
+ assert result == 42
+ method.assert_called_once_with(1, 2, meep="moop", metadata=mock.ANY)
+
+ # Check that the default client info was specified in the metadata.
+ metadata = method.call_args[1]["metadata"]
+ assert len(metadata) == 1
+ client_info = google.api_core.gapic_v1.client_info.DEFAULT_CLIENT_INFO
+ user_agent_metadata = client_info.to_grpc_metadata()
+ assert user_agent_metadata in metadata
+
+
+def test_wrap_method_with_no_client_info():
+ method = mock.Mock(spec=["__call__"])
+
+ wrapped_method = google.api_core.gapic_v1.method.wrap_method(
+ method, client_info=None
+ )
+
+ wrapped_method(1, 2, meep="moop")
+
+ method.assert_called_once_with(1, 2, meep="moop")
+
+
+def test_wrap_method_with_custom_client_info():
+ client_info = google.api_core.gapic_v1.client_info.ClientInfo(
+ python_version=1,
+ grpc_version=2,
+ api_core_version=3,
+ gapic_version=4,
+ client_library_version=5,
+ )
+ method = mock.Mock(spec=["__call__"])
+
+ wrapped_method = google.api_core.gapic_v1.method.wrap_method(
+ method, client_info=client_info
+ )
+
+ wrapped_method(1, 2, meep="moop")
+
+ method.assert_called_once_with(1, 2, meep="moop", metadata=mock.ANY)
+
+ # Check that the custom client info was specified in the metadata.
+ metadata = method.call_args[1]["metadata"]
+ assert client_info.to_grpc_metadata() in metadata
+
+
+def test_invoke_wrapped_method_with_metadata():
+ method = mock.Mock(spec=["__call__"])
+
+ wrapped_method = google.api_core.gapic_v1.method.wrap_method(method)
+
+ wrapped_method(mock.sentinel.request, metadata=[("a", "b")])
+
+ method.assert_called_once_with(mock.sentinel.request, metadata=mock.ANY)
+ metadata = method.call_args[1]["metadata"]
+ # Metadata should have two items: the client info metadata and our custom
+ # metadata.
+ assert len(metadata) == 2
+ assert ("a", "b") in metadata
+
+
+def test_invoke_wrapped_method_with_metadata_as_none():
+ method = mock.Mock(spec=["__call__"])
+
+ wrapped_method = google.api_core.gapic_v1.method.wrap_method(method)
+
+ wrapped_method(mock.sentinel.request, metadata=None)
+
+ method.assert_called_once_with(mock.sentinel.request, metadata=mock.ANY)
+ metadata = method.call_args[1]["metadata"]
+ # Metadata should have just one items: the client info metadata.
+ assert len(metadata) == 1
+
+
+@mock.patch("time.sleep")
+def test_wrap_method_with_default_retry_and_timeout(unusued_sleep):
+ method = mock.Mock(
+ spec=["__call__"], side_effect=[exceptions.InternalServerError(None), 42]
+ )
+ default_retry = retry.Retry()
+ default_timeout = timeout.ConstantTimeout(60)
+ wrapped_method = google.api_core.gapic_v1.method.wrap_method(
+ method, default_retry, default_timeout
+ )
+
+ result = wrapped_method()
+
+ assert result == 42
+ assert method.call_count == 2
+ method.assert_called_with(timeout=60, metadata=mock.ANY)
+
+
+@mock.patch("time.sleep")
+def test_wrap_method_with_default_retry_and_timeout_using_sentinel(unusued_sleep):
+ method = mock.Mock(
+ spec=["__call__"], side_effect=[exceptions.InternalServerError(None), 42]
+ )
+ default_retry = retry.Retry()
+ default_timeout = timeout.ConstantTimeout(60)
+ wrapped_method = google.api_core.gapic_v1.method.wrap_method(
+ method, default_retry, default_timeout
+ )
+
+ result = wrapped_method(
+ retry=google.api_core.gapic_v1.method.DEFAULT,
+ timeout=google.api_core.gapic_v1.method.DEFAULT,
+ )
+
+ assert result == 42
+ assert method.call_count == 2
+ method.assert_called_with(timeout=60, metadata=mock.ANY)
+
+
+@mock.patch("time.sleep")
+def test_wrap_method_with_overriding_retry_and_timeout(unusued_sleep):
+ method = mock.Mock(spec=["__call__"], side_effect=[exceptions.NotFound(None), 42])
+ default_retry = retry.Retry()
+ default_timeout = timeout.ConstantTimeout(60)
+ wrapped_method = google.api_core.gapic_v1.method.wrap_method(
+ method, default_retry, default_timeout
+ )
+
+ result = wrapped_method(
+ retry=retry.Retry(retry.if_exception_type(exceptions.NotFound)),
+ timeout=timeout.ConstantTimeout(22),
+ )
+
+ assert result == 42
+ assert method.call_count == 2
+ method.assert_called_with(timeout=22, metadata=mock.ANY)
+
+
+@mock.patch("time.sleep")
+@mock.patch(
+ "google.api_core.datetime_helpers.utcnow",
+ side_effect=_utcnow_monotonic(),
+ autospec=True,
+)
+def test_wrap_method_with_overriding_retry_deadline(utcnow, unused_sleep):
+ method = mock.Mock(
+ spec=["__call__"],
+ side_effect=([exceptions.InternalServerError(None)] * 4) + [42],
+ )
+ default_retry = retry.Retry()
+ default_timeout = timeout.ExponentialTimeout(deadline=60)
+ wrapped_method = google.api_core.gapic_v1.method.wrap_method(
+ method, default_retry, default_timeout
+ )
+
+ # Overriding only the retry's deadline should also override the timeout's
+ # deadline.
+ result = wrapped_method(retry=default_retry.with_deadline(30))
+
+ assert result == 42
+ timeout_args = [call[1]["timeout"] for call in method.call_args_list]
+ assert timeout_args == [5.0, 10.0, 20.0, 26.0, 25.0]
+ assert utcnow.call_count == (
+ 1
+ + 5 # First to set the deadline.
+ + 5 # One for each min(timeout, maximum, (DEADLINE - NOW).seconds)
+ )
+
+
+def test_wrap_method_with_overriding_timeout_as_a_number():
+ method = mock.Mock(spec=["__call__"], return_value=42)
+ default_retry = retry.Retry()
+ default_timeout = timeout.ConstantTimeout(60)
+ wrapped_method = google.api_core.gapic_v1.method.wrap_method(
+ method, default_retry, default_timeout
+ )
+
+ result = wrapped_method(timeout=22)
+
+ assert result == 42
+ method.assert_called_once_with(timeout=22, metadata=mock.ANY)
diff --git a/tests/unit/gapic/test_routing_header.py b/tests/unit/gapic/test_routing_header.py
new file mode 100644
index 0000000..3037867
--- /dev/null
+++ b/tests/unit/gapic/test_routing_header.py
@@ -0,0 +1,41 @@
+# Copyright 2017 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+try:
+ import grpc # noqa: F401
+except ImportError:
+ pytest.skip("No GRPC", allow_module_level=True)
+
+
+from google.api_core.gapic_v1 import routing_header
+
+
+def test_to_routing_header():
+ params = [("name", "meep"), ("book.read", "1")]
+ value = routing_header.to_routing_header(params)
+ assert value == "name=meep&book.read=1"
+
+
+def test_to_routing_header_with_slashes():
+ params = [("name", "me/ep"), ("book.read", "1&2")]
+ value = routing_header.to_routing_header(params)
+ assert value == "name=me/ep&book.read=1%262"
+
+
+def test_to_grpc_metadata():
+ params = [("name", "meep"), ("book.read", "1")]
+ metadata = routing_header.to_grpc_metadata(params)
+ assert metadata == (routing_header.ROUTING_METADATA_KEY, "name=meep&book.read=1")
diff --git a/tests/unit/operations_v1/__init__.py b/tests/unit/operations_v1/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/unit/operations_v1/__init__.py
diff --git a/tests/unit/operations_v1/test_operations_client.py b/tests/unit/operations_v1/test_operations_client.py
new file mode 100644
index 0000000..187f0be
--- /dev/null
+++ b/tests/unit/operations_v1/test_operations_client.py
@@ -0,0 +1,98 @@
+# Copyright 2017 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+try:
+ import grpc # noqa: F401
+except ImportError:
+ pytest.skip("No GRPC", allow_module_level=True)
+
+from google.api_core import grpc_helpers
+from google.api_core import operations_v1
+from google.api_core import page_iterator
+from google.longrunning import operations_pb2
+from google.protobuf import empty_pb2
+
+
+def test_get_operation():
+ channel = grpc_helpers.ChannelStub()
+ client = operations_v1.OperationsClient(channel)
+ channel.GetOperation.response = operations_pb2.Operation(name="meep")
+
+ response = client.get_operation("name", metadata=[("header", "foo")])
+
+ assert ("header", "foo") in channel.GetOperation.calls[0].metadata
+ assert ("x-goog-request-params", "name=name") in channel.GetOperation.calls[
+ 0
+ ].metadata
+ assert len(channel.GetOperation.requests) == 1
+ assert channel.GetOperation.requests[0].name == "name"
+ assert response == channel.GetOperation.response
+
+
+def test_list_operations():
+ channel = grpc_helpers.ChannelStub()
+ client = operations_v1.OperationsClient(channel)
+ operations = [
+ operations_pb2.Operation(name="1"),
+ operations_pb2.Operation(name="2"),
+ ]
+ list_response = operations_pb2.ListOperationsResponse(operations=operations)
+ channel.ListOperations.response = list_response
+
+ response = client.list_operations("name", "filter", metadata=[("header", "foo")])
+
+ assert isinstance(response, page_iterator.Iterator)
+ assert list(response) == operations
+
+ assert ("header", "foo") in channel.ListOperations.calls[0].metadata
+ assert ("x-goog-request-params", "name=name") in channel.ListOperations.calls[
+ 0
+ ].metadata
+ assert len(channel.ListOperations.requests) == 1
+ request = channel.ListOperations.requests[0]
+ assert isinstance(request, operations_pb2.ListOperationsRequest)
+ assert request.name == "name"
+ assert request.filter == "filter"
+
+
+def test_delete_operation():
+ channel = grpc_helpers.ChannelStub()
+ client = operations_v1.OperationsClient(channel)
+ channel.DeleteOperation.response = empty_pb2.Empty()
+
+ client.delete_operation("name", metadata=[("header", "foo")])
+
+ assert ("header", "foo") in channel.DeleteOperation.calls[0].metadata
+ assert ("x-goog-request-params", "name=name") in channel.DeleteOperation.calls[
+ 0
+ ].metadata
+ assert len(channel.DeleteOperation.requests) == 1
+ assert channel.DeleteOperation.requests[0].name == "name"
+
+
+def test_cancel_operation():
+ channel = grpc_helpers.ChannelStub()
+ client = operations_v1.OperationsClient(channel)
+ channel.CancelOperation.response = empty_pb2.Empty()
+
+ client.cancel_operation("name", metadata=[("header", "foo")])
+
+ assert ("header", "foo") in channel.CancelOperation.calls[0].metadata
+ assert ("x-goog-request-params", "name=name") in channel.CancelOperation.calls[
+ 0
+ ].metadata
+ assert len(channel.CancelOperation.requests) == 1
+ assert channel.CancelOperation.requests[0].name == "name"
diff --git a/tests/unit/operations_v1/test_operations_rest_client.py b/tests/unit/operations_v1/test_operations_rest_client.py
new file mode 100644
index 0000000..dddf6b7
--- /dev/null
+++ b/tests/unit/operations_v1/test_operations_rest_client.py
@@ -0,0 +1,944 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+
+import mock
+import pytest
+
+try:
+ import grpc # noqa: F401
+except ImportError:
+ pytest.skip("No GRPC", allow_module_level=True)
+from requests import Response # noqa I201
+from requests.sessions import Session
+
+from google.api_core import client_options
+from google.api_core import exceptions as core_exceptions
+from google.api_core import gapic_v1
+from google.api_core.operations_v1 import AbstractOperationsClient
+from google.api_core.operations_v1 import pagers
+from google.api_core.operations_v1 import transports
+import google.auth
+from google.auth import credentials as ga_credentials
+from google.auth.exceptions import MutualTLSChannelError
+from google.longrunning import operations_pb2
+from google.oauth2 import service_account
+from google.protobuf import json_format # type: ignore
+from google.rpc import status_pb2 # type: ignore
+
+
+HTTP_OPTIONS = {
+ "google.longrunning.Operations.CancelOperation": [
+ {"method": "post", "uri": "/v3/{name=operations/*}:cancel", "body": "*"},
+ ],
+ "google.longrunning.Operations.DeleteOperation": [
+ {"method": "delete", "uri": "/v3/{name=operations/*}"},
+ ],
+ "google.longrunning.Operations.GetOperation": [
+ {"method": "get", "uri": "/v3/{name=operations/*}"},
+ ],
+ "google.longrunning.Operations.ListOperations": [
+ {"method": "get", "uri": "/v3/{name=operations}"},
+ ],
+}
+
+
+def client_cert_source_callback():
+ return b"cert bytes", b"key bytes"
+
+
+def _get_operations_client(http_options=HTTP_OPTIONS):
+ transport = transports.rest.OperationsRestTransport(
+ credentials=ga_credentials.AnonymousCredentials(), http_options=http_options
+ )
+
+ return AbstractOperationsClient(transport=transport)
+
+
+# If default endpoint is localhost, then default mtls endpoint will be the same.
+# This method modifies the default endpoint so the client can produce a different
+# mtls endpoint for endpoint testing purposes.
+def modify_default_endpoint(client):
+ return (
+ "foo.googleapis.com"
+ if ("localhost" in client.DEFAULT_ENDPOINT)
+ else client.DEFAULT_ENDPOINT
+ )
+
+
+def test__get_default_mtls_endpoint():
+ api_endpoint = "example.googleapis.com"
+ api_mtls_endpoint = "example.mtls.googleapis.com"
+ sandbox_endpoint = "example.sandbox.googleapis.com"
+ sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com"
+ non_googleapi = "api.example.com"
+
+ assert AbstractOperationsClient._get_default_mtls_endpoint(None) is None
+ assert (
+ AbstractOperationsClient._get_default_mtls_endpoint(api_endpoint)
+ == api_mtls_endpoint
+ )
+ assert (
+ AbstractOperationsClient._get_default_mtls_endpoint(api_mtls_endpoint)
+ == api_mtls_endpoint
+ )
+ assert (
+ AbstractOperationsClient._get_default_mtls_endpoint(sandbox_endpoint)
+ == sandbox_mtls_endpoint
+ )
+ assert (
+ AbstractOperationsClient._get_default_mtls_endpoint(sandbox_mtls_endpoint)
+ == sandbox_mtls_endpoint
+ )
+ assert (
+ AbstractOperationsClient._get_default_mtls_endpoint(non_googleapi)
+ == non_googleapi
+ )
+
+
+@pytest.mark.parametrize("client_class", [AbstractOperationsClient])
+def test_operations_client_from_service_account_info(client_class):
+ creds = ga_credentials.AnonymousCredentials()
+ with mock.patch.object(
+ service_account.Credentials, "from_service_account_info"
+ ) as factory:
+ factory.return_value = creds
+ info = {"valid": True}
+ client = client_class.from_service_account_info(info)
+ assert client.transport._credentials == creds
+ assert isinstance(client, client_class)
+
+ assert client.transport._host == "longrunning.googleapis.com:443"
+
+
+@pytest.mark.parametrize(
+ "transport_class,transport_name", [(transports.OperationsRestTransport, "rest")]
+)
+def test_operations_client_service_account_always_use_jwt(
+ transport_class, transport_name
+):
+ with mock.patch.object(
+ service_account.Credentials, "with_always_use_jwt_access", create=True
+ ) as use_jwt:
+ creds = service_account.Credentials(None, None, None)
+ transport_class(credentials=creds, always_use_jwt_access=True)
+ use_jwt.assert_called_once_with(True)
+
+ with mock.patch.object(
+ service_account.Credentials, "with_always_use_jwt_access", create=True
+ ) as use_jwt:
+ creds = service_account.Credentials(None, None, None)
+ transport_class(credentials=creds, always_use_jwt_access=False)
+ use_jwt.assert_not_called()
+
+
+@pytest.mark.parametrize("client_class", [AbstractOperationsClient])
+def test_operations_client_from_service_account_file(client_class):
+ creds = ga_credentials.AnonymousCredentials()
+ with mock.patch.object(
+ service_account.Credentials, "from_service_account_file"
+ ) as factory:
+ factory.return_value = creds
+ client = client_class.from_service_account_file("dummy/file/path.json")
+ assert client.transport._credentials == creds
+ assert isinstance(client, client_class)
+
+ client = client_class.from_service_account_json("dummy/file/path.json")
+ assert client.transport._credentials == creds
+ assert isinstance(client, client_class)
+
+ assert client.transport._host == "longrunning.googleapis.com:443"
+
+
+def test_operations_client_get_transport_class():
+ transport = AbstractOperationsClient.get_transport_class()
+ available_transports = [
+ transports.OperationsRestTransport,
+ ]
+ assert transport in available_transports
+
+ transport = AbstractOperationsClient.get_transport_class("rest")
+ assert transport == transports.OperationsRestTransport
+
+
+@pytest.mark.parametrize(
+ "client_class,transport_class,transport_name",
+ [(AbstractOperationsClient, transports.OperationsRestTransport, "rest")],
+)
+@mock.patch.object(
+ AbstractOperationsClient,
+ "DEFAULT_ENDPOINT",
+ modify_default_endpoint(AbstractOperationsClient),
+)
+def test_operations_client_client_options(
+ client_class, transport_class, transport_name
+):
+ # Check that if channel is provided we won't create a new one.
+ with mock.patch.object(AbstractOperationsClient, "get_transport_class") as gtc:
+ transport = transport_class(credentials=ga_credentials.AnonymousCredentials())
+ client = client_class(transport=transport)
+ gtc.assert_not_called()
+
+ # Check that if channel is provided via str we will create a new one.
+ with mock.patch.object(AbstractOperationsClient, "get_transport_class") as gtc:
+ client = client_class(transport=transport_name)
+ gtc.assert_called()
+
+ # Check the case api_endpoint is provided.
+ options = client_options.ClientOptions(api_endpoint="squid.clam.whelk")
+ with mock.patch.object(transport_class, "__init__") as patched:
+ patched.return_value = None
+ client = client_class(client_options=options)
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host="squid.clam.whelk",
+ scopes=None,
+ client_cert_source_for_mtls=None,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ always_use_jwt_access=True,
+ )
+
+ # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
+ # "never".
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}):
+ with mock.patch.object(transport_class, "__init__") as patched:
+ patched.return_value = None
+ client = client_class()
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=client.DEFAULT_ENDPOINT,
+ scopes=None,
+ client_cert_source_for_mtls=None,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ always_use_jwt_access=True,
+ )
+
+ # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
+ # "always".
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}):
+ with mock.patch.object(transport_class, "__init__") as patched:
+ patched.return_value = None
+ client = client_class()
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=client.DEFAULT_MTLS_ENDPOINT,
+ scopes=None,
+ client_cert_source_for_mtls=None,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ always_use_jwt_access=True,
+ )
+
+ # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has
+ # unsupported value.
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}):
+ with pytest.raises(MutualTLSChannelError):
+ client = client_class()
+
+ # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value.
+ with mock.patch.dict(
+ os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}
+ ):
+ with pytest.raises(ValueError):
+ client = client_class()
+
+ # Check the case quota_project_id is provided
+ options = client_options.ClientOptions(quota_project_id="octopus")
+ with mock.patch.object(transport_class, "__init__") as patched:
+ patched.return_value = None
+ client = client_class(client_options=options)
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=client.DEFAULT_ENDPOINT,
+ scopes=None,
+ client_cert_source_for_mtls=None,
+ quota_project_id="octopus",
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ always_use_jwt_access=True,
+ )
+
+
+@pytest.mark.parametrize(
+ "client_class,transport_class,transport_name,use_client_cert_env",
+ [
+ (AbstractOperationsClient, transports.OperationsRestTransport, "rest", "true"),
+ (AbstractOperationsClient, transports.OperationsRestTransport, "rest", "false"),
+ ],
+)
+@mock.patch.object(
+ AbstractOperationsClient,
+ "DEFAULT_ENDPOINT",
+ modify_default_endpoint(AbstractOperationsClient),
+)
+@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"})
+def test_operations_client_mtls_env_auto(
+ client_class, transport_class, transport_name, use_client_cert_env
+):
+ # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default
+ # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists.
+
+ # Check the case client_cert_source is provided. Whether client cert is used depends on
+ # GOOGLE_API_USE_CLIENT_CERTIFICATE value.
+ with mock.patch.dict(
+ os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}
+ ):
+ options = client_options.ClientOptions(
+ client_cert_source=client_cert_source_callback
+ )
+
+ def fake_init(client_cert_source_for_mtls=None, **kwargs):
+ """Invoke client_cert source if provided."""
+
+ if client_cert_source_for_mtls:
+ client_cert_source_for_mtls()
+ return None
+
+ with mock.patch.object(transport_class, "__init__") as patched:
+ patched.side_effect = fake_init
+ client = client_class(client_options=options)
+
+ if use_client_cert_env == "false":
+ expected_client_cert_source = None
+ expected_host = client.DEFAULT_ENDPOINT
+ else:
+ expected_client_cert_source = client_cert_source_callback
+ expected_host = client.DEFAULT_MTLS_ENDPOINT
+
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=expected_host,
+ scopes=None,
+ client_cert_source_for_mtls=expected_client_cert_source,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ always_use_jwt_access=True,
+ )
+
+ # Check the case ADC client cert is provided. Whether client cert is used depends on
+ # GOOGLE_API_USE_CLIENT_CERTIFICATE value.
+ with mock.patch.dict(
+ os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}
+ ):
+ with mock.patch.object(transport_class, "__init__") as patched:
+ with mock.patch(
+ "google.auth.transport.mtls.has_default_client_cert_source",
+ return_value=True,
+ ):
+ with mock.patch(
+ "google.auth.transport.mtls.default_client_cert_source",
+ return_value=client_cert_source_callback,
+ ):
+ if use_client_cert_env == "false":
+ expected_host = client.DEFAULT_ENDPOINT
+ expected_client_cert_source = None
+ else:
+ expected_host = client.DEFAULT_MTLS_ENDPOINT
+ expected_client_cert_source = client_cert_source_callback
+
+ patched.return_value = None
+ client = client_class()
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=expected_host,
+ scopes=None,
+ client_cert_source_for_mtls=expected_client_cert_source,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ always_use_jwt_access=True,
+ )
+
+ # Check the case client_cert_source and ADC client cert are not provided.
+ with mock.patch.dict(
+ os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}
+ ):
+ with mock.patch.object(transport_class, "__init__") as patched:
+ with mock.patch(
+ "google.auth.transport.mtls.has_default_client_cert_source",
+ return_value=False,
+ ):
+ patched.return_value = None
+ client = client_class()
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=client.DEFAULT_ENDPOINT,
+ scopes=None,
+ client_cert_source_for_mtls=None,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ always_use_jwt_access=True,
+ )
+
+
+@pytest.mark.parametrize(
+ "client_class,transport_class,transport_name",
+ [(AbstractOperationsClient, transports.OperationsRestTransport, "rest")],
+)
+def test_operations_client_client_options_scopes(
+ client_class, transport_class, transport_name
+):
+ # Check the case scopes are provided.
+ options = client_options.ClientOptions(scopes=["1", "2"],)
+ with mock.patch.object(transport_class, "__init__") as patched:
+ patched.return_value = None
+ client = client_class(client_options=options)
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=client.DEFAULT_ENDPOINT,
+ scopes=["1", "2"],
+ client_cert_source_for_mtls=None,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ always_use_jwt_access=True,
+ )
+
+
+@pytest.mark.parametrize(
+ "client_class,transport_class,transport_name",
+ [(AbstractOperationsClient, transports.OperationsRestTransport, "rest")],
+)
+def test_operations_client_client_options_credentials_file(
+ client_class, transport_class, transport_name
+):
+ # Check the case credentials file is provided.
+ options = client_options.ClientOptions(credentials_file="credentials.json")
+ with mock.patch.object(transport_class, "__init__") as patched:
+ patched.return_value = None
+ client = client_class(client_options=options)
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file="credentials.json",
+ host=client.DEFAULT_ENDPOINT,
+ scopes=None,
+ client_cert_source_for_mtls=None,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ always_use_jwt_access=True,
+ )
+
+
+def test_list_operations_rest(
+ transport: str = "rest", request_type=operations_pb2.ListOperationsRequest
+):
+ client = _get_operations_client()
+
+ # Mock the http request call within the method and fake a response.
+ with mock.patch.object(Session, "request") as req:
+ # Designate an appropriate value for the returned response.
+ return_value = operations_pb2.ListOperationsResponse(
+ next_page_token="next_page_token_value",
+ )
+
+ # Wrap the value into a proper Response obj
+ response_value = Response()
+ response_value.status_code = 200
+ json_return_value = json_format.MessageToJson(return_value)
+ response_value._content = json_return_value.encode("UTF-8")
+ req.return_value = response_value
+ response = client.list_operations(
+ name="operations", filter_="my_filter", page_size=10, page_token="abc"
+ )
+
+ actual_args = req.call_args
+ assert actual_args.args[0] == "GET"
+ assert (
+ actual_args.args[1]
+ == "https://longrunning.googleapis.com:443/v3/operations"
+ )
+ assert actual_args.kwargs["params"] == [
+ ("filter", "my_filter"),
+ ("pageSize", 10),
+ ("pageToken", "abc"),
+ ]
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, pagers.ListOperationsPager)
+ assert response.next_page_token == "next_page_token_value"
+
+
+def test_list_operations_rest_failure():
+ client = _get_operations_client(http_options=None)
+
+ with mock.patch.object(Session, "request") as req:
+ response_value = Response()
+ response_value.status_code = 400
+ mock_request = mock.MagicMock()
+ mock_request.method = "GET"
+ mock_request.url = "https://longrunning.googleapis.com:443/v1/operations"
+ response_value.request = mock_request
+ req.return_value = response_value
+ with pytest.raises(core_exceptions.GoogleAPIError):
+ client.list_operations(name="operations")
+
+
+def test_list_operations_rest_pager():
+ client = AbstractOperationsClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Mock the http request call within the method and fake a response.
+ with mock.patch.object(Session, "request") as req:
+ # TODO(kbandes): remove this mock unless there's a good reason for it.
+ # with mock.patch.object(path_template, 'transcode') as transcode:
+ # Set the response as a series of pages
+ response = (
+ operations_pb2.ListOperationsResponse(
+ operations=[
+ operations_pb2.Operation(),
+ operations_pb2.Operation(),
+ operations_pb2.Operation(),
+ ],
+ next_page_token="abc",
+ ),
+ operations_pb2.ListOperationsResponse(
+ operations=[], next_page_token="def",
+ ),
+ operations_pb2.ListOperationsResponse(
+ operations=[operations_pb2.Operation()], next_page_token="ghi",
+ ),
+ operations_pb2.ListOperationsResponse(
+ operations=[operations_pb2.Operation(), operations_pb2.Operation()],
+ ),
+ )
+ # Two responses for two calls
+ response = response + response
+
+ # Wrap the values into proper Response objs
+ response = tuple(json_format.MessageToJson(x) for x in response)
+ return_values = tuple(Response() for i in response)
+ for return_val, response_val in zip(return_values, response):
+ return_val._content = response_val.encode("UTF-8")
+ return_val.status_code = 200
+ req.side_effect = return_values
+
+ pager = client.list_operations(name="operations")
+
+ results = list(pager)
+ assert len(results) == 6
+ assert all(isinstance(i, operations_pb2.Operation) for i in results)
+
+ pages = list(client.list_operations(name="operations").pages)
+ for page_, token in zip(pages, ["abc", "def", "ghi", ""]):
+ assert page_.next_page_token == token
+
+
+def test_get_operation_rest(
+ transport: str = "rest", request_type=operations_pb2.GetOperationRequest
+):
+ client = _get_operations_client()
+
+ # Mock the http request call within the method and fake a response.
+ with mock.patch.object(Session, "request") as req:
+ # Designate an appropriate value for the returned response.
+ return_value = operations_pb2.Operation(
+ name="operations/sample1", done=True, error=status_pb2.Status(code=411),
+ )
+
+ # Wrap the value into a proper Response obj
+ response_value = Response()
+ response_value.status_code = 200
+ json_return_value = json_format.MessageToJson(return_value)
+ response_value._content = json_return_value.encode("UTF-8")
+ req.return_value = response_value
+ response = client.get_operation("operations/sample1")
+
+ actual_args = req.call_args
+ assert actual_args.args[0] == "GET"
+ assert (
+ actual_args.args[1]
+ == "https://longrunning.googleapis.com:443/v3/operations/sample1"
+ )
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, operations_pb2.Operation)
+ assert response.name == "operations/sample1"
+ assert response.done is True
+
+
+def test_get_operation_rest_failure():
+ client = _get_operations_client(http_options=None)
+
+ with mock.patch.object(Session, "request") as req:
+ response_value = Response()
+ response_value.status_code = 400
+ mock_request = mock.MagicMock()
+ mock_request.method = "GET"
+ mock_request.url = (
+ "https://longrunning.googleapis.com:443/v1/operations/sample1"
+ )
+ response_value.request = mock_request
+ req.return_value = response_value
+ with pytest.raises(core_exceptions.GoogleAPIError):
+ client.get_operation("operations/sample1")
+
+
+def test_delete_operation_rest(
+ transport: str = "rest", request_type=operations_pb2.DeleteOperationRequest
+):
+ client = _get_operations_client()
+
+ # Mock the http request call within the method and fake a response.
+ with mock.patch.object(Session, "request") as req:
+ # Wrap the value into a proper Response obj
+ response_value = Response()
+ response_value.status_code = 200
+ json_return_value = ""
+ response_value._content = json_return_value.encode("UTF-8")
+ req.return_value = response_value
+ client.delete_operation(name="operations/sample1")
+ assert req.call_count == 1
+ actual_args = req.call_args
+ assert actual_args.args[0] == "DELETE"
+ assert (
+ actual_args.args[1]
+ == "https://longrunning.googleapis.com:443/v3/operations/sample1"
+ )
+
+
+def test_delete_operation_rest_failure():
+ client = _get_operations_client(http_options=None)
+
+ with mock.patch.object(Session, "request") as req:
+ response_value = Response()
+ response_value.status_code = 400
+ mock_request = mock.MagicMock()
+ mock_request.method = "DELETE"
+ mock_request.url = (
+ "https://longrunning.googleapis.com:443/v1/operations/sample1"
+ )
+ response_value.request = mock_request
+ req.return_value = response_value
+ with pytest.raises(core_exceptions.GoogleAPIError):
+ client.delete_operation(name="operations/sample1")
+
+
+def test_cancel_operation_rest(transport: str = "rest"):
+ client = _get_operations_client()
+
+ # Mock the http request call within the method and fake a response.
+ with mock.patch.object(Session, "request") as req:
+ # Wrap the value into a proper Response obj
+ response_value = Response()
+ response_value.status_code = 200
+ json_return_value = ""
+ response_value._content = json_return_value.encode("UTF-8")
+ req.return_value = response_value
+ client.cancel_operation(name="operations/sample1")
+ assert req.call_count == 1
+ actual_args = req.call_args
+ assert actual_args.args[0] == "POST"
+ assert (
+ actual_args.args[1]
+ == "https://longrunning.googleapis.com:443/v3/operations/sample1:cancel"
+ )
+
+
+def test_cancel_operation_rest_failure():
+ client = _get_operations_client(http_options=None)
+
+ with mock.patch.object(Session, "request") as req:
+ response_value = Response()
+ response_value.status_code = 400
+ mock_request = mock.MagicMock()
+ mock_request.method = "POST"
+ mock_request.url = (
+ "https://longrunning.googleapis.com:443/v1/operations/sample1:cancel"
+ )
+ response_value.request = mock_request
+ req.return_value = response_value
+ with pytest.raises(core_exceptions.GoogleAPIError):
+ client.cancel_operation(name="operations/sample1")
+
+
+def test_credentials_transport_error():
+ # It is an error to provide credentials and a transport instance.
+ transport = transports.OperationsRestTransport(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+ with pytest.raises(ValueError):
+ AbstractOperationsClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # It is an error to provide a credentials file and a transport instance.
+ transport = transports.OperationsRestTransport(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+ with pytest.raises(ValueError):
+ AbstractOperationsClient(
+ client_options={"credentials_file": "credentials.json"},
+ transport=transport,
+ )
+
+ # It is an error to provide scopes and a transport instance.
+ transport = transports.OperationsRestTransport(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+ with pytest.raises(ValueError):
+ AbstractOperationsClient(
+ client_options={"scopes": ["1", "2"]}, transport=transport,
+ )
+
+
+def test_transport_instance():
+ # A client may be instantiated with a custom transport instance.
+ transport = transports.OperationsRestTransport(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+ client = AbstractOperationsClient(transport=transport)
+ assert client.transport is transport
+
+
+@pytest.mark.parametrize("transport_class", [transports.OperationsRestTransport])
+def test_transport_adc(transport_class):
+ # Test default credentials are used if not provided.
+ with mock.patch.object(google.auth, "default") as adc:
+ adc.return_value = (ga_credentials.AnonymousCredentials(), None)
+ transport_class()
+ adc.assert_called_once()
+
+
+def test_operations_base_transport_error():
+ # Passing both a credentials object and credentials_file should raise an error
+ with pytest.raises(core_exceptions.DuplicateCredentialArgs):
+ transports.OperationsTransport(
+ credentials=ga_credentials.AnonymousCredentials(),
+ credentials_file="credentials.json",
+ )
+
+
+def test_operations_base_transport():
+ # Instantiate the base transport.
+ with mock.patch(
+ "google.api_core.operations_v1.transports.OperationsTransport.__init__"
+ ) as Transport:
+ Transport.return_value = None
+ transport = transports.OperationsTransport(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Every method on the transport should just blindly
+ # raise NotImplementedError.
+ methods = (
+ "list_operations",
+ "get_operation",
+ "delete_operation",
+ "cancel_operation",
+ )
+ for method in methods:
+ with pytest.raises(NotImplementedError):
+ getattr(transport, method)(request=object())
+
+ with pytest.raises(NotImplementedError):
+ transport.close()
+
+
+def test_operations_base_transport_with_credentials_file():
+ # Instantiate the base transport with a credentials file
+ with mock.patch.object(
+ google.auth, "load_credentials_from_file", autospec=True
+ ) as load_creds, mock.patch(
+ "google.api_core.operations_v1.transports.OperationsTransport._prep_wrapped_messages"
+ ) as Transport:
+ Transport.return_value = None
+ load_creds.return_value = (ga_credentials.AnonymousCredentials(), None)
+ transports.OperationsTransport(
+ credentials_file="credentials.json", quota_project_id="octopus",
+ )
+ load_creds.assert_called_once_with(
+ "credentials.json",
+ scopes=None,
+ default_scopes=(),
+ quota_project_id="octopus",
+ )
+
+
+def test_operations_base_transport_with_adc():
+ # Test the default credentials are used if credentials and credentials_file are None.
+ with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch(
+ "google.api_core.operations_v1.transports.OperationsTransport._prep_wrapped_messages"
+ ) as Transport:
+ Transport.return_value = None
+ adc.return_value = (ga_credentials.AnonymousCredentials(), None)
+ transports.OperationsTransport()
+ adc.assert_called_once()
+
+
+def test_operations_auth_adc():
+ # If no credentials are provided, we should use ADC credentials.
+ with mock.patch.object(google.auth, "default", autospec=True) as adc:
+ adc.return_value = (ga_credentials.AnonymousCredentials(), None)
+ AbstractOperationsClient()
+ adc.assert_called_once_with(
+ scopes=None, default_scopes=(), quota_project_id=None,
+ )
+
+
+def test_operations_http_transport_client_cert_source_for_mtls():
+ cred = ga_credentials.AnonymousCredentials()
+ with mock.patch(
+ "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel"
+ ) as mock_configure_mtls_channel:
+ transports.OperationsRestTransport(
+ credentials=cred, client_cert_source_for_mtls=client_cert_source_callback
+ )
+ mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback)
+
+
+def test_operations_host_no_port():
+ client = AbstractOperationsClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ client_options=client_options.ClientOptions(
+ api_endpoint="longrunning.googleapis.com"
+ ),
+ )
+ assert client.transport._host == "longrunning.googleapis.com:443"
+
+
+def test_operations_host_with_port():
+ client = AbstractOperationsClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ client_options=client_options.ClientOptions(
+ api_endpoint="longrunning.googleapis.com:8000"
+ ),
+ )
+ assert client.transport._host == "longrunning.googleapis.com:8000"
+
+
+def test_common_billing_account_path():
+ billing_account = "squid"
+ expected = "billingAccounts/{billing_account}".format(
+ billing_account=billing_account,
+ )
+ actual = AbstractOperationsClient.common_billing_account_path(billing_account)
+ assert expected == actual
+
+
+def test_parse_common_billing_account_path():
+ expected = {
+ "billing_account": "clam",
+ }
+ path = AbstractOperationsClient.common_billing_account_path(**expected)
+
+ # Check that the path construction is reversible.
+ actual = AbstractOperationsClient.parse_common_billing_account_path(path)
+ assert expected == actual
+
+
+def test_common_folder_path():
+ folder = "whelk"
+ expected = "folders/{folder}".format(folder=folder,)
+ actual = AbstractOperationsClient.common_folder_path(folder)
+ assert expected == actual
+
+
+def test_parse_common_folder_path():
+ expected = {
+ "folder": "octopus",
+ }
+ path = AbstractOperationsClient.common_folder_path(**expected)
+
+ # Check that the path construction is reversible.
+ actual = AbstractOperationsClient.parse_common_folder_path(path)
+ assert expected == actual
+
+
+def test_common_organization_path():
+ organization = "oyster"
+ expected = "organizations/{organization}".format(organization=organization,)
+ actual = AbstractOperationsClient.common_organization_path(organization)
+ assert expected == actual
+
+
+def test_parse_common_organization_path():
+ expected = {
+ "organization": "nudibranch",
+ }
+ path = AbstractOperationsClient.common_organization_path(**expected)
+
+ # Check that the path construction is reversible.
+ actual = AbstractOperationsClient.parse_common_organization_path(path)
+ assert expected == actual
+
+
+def test_common_project_path():
+ project = "cuttlefish"
+ expected = "projects/{project}".format(project=project,)
+ actual = AbstractOperationsClient.common_project_path(project)
+ assert expected == actual
+
+
+def test_parse_common_project_path():
+ expected = {
+ "project": "mussel",
+ }
+ path = AbstractOperationsClient.common_project_path(**expected)
+
+ # Check that the path construction is reversible.
+ actual = AbstractOperationsClient.parse_common_project_path(path)
+ assert expected == actual
+
+
+def test_common_location_path():
+ project = "winkle"
+ location = "nautilus"
+ expected = "projects/{project}/locations/{location}".format(
+ project=project, location=location,
+ )
+ actual = AbstractOperationsClient.common_location_path(project, location)
+ assert expected == actual
+
+
+def test_parse_common_location_path():
+ expected = {
+ "project": "scallop",
+ "location": "abalone",
+ }
+ path = AbstractOperationsClient.common_location_path(**expected)
+
+ # Check that the path construction is reversible.
+ actual = AbstractOperationsClient.parse_common_location_path(path)
+ assert expected == actual
+
+
+def test_client_withDEFAULT_CLIENT_INFO():
+ client_info = gapic_v1.client_info.ClientInfo()
+
+ with mock.patch.object(
+ transports.OperationsTransport, "_prep_wrapped_messages"
+ ) as prep:
+ AbstractOperationsClient(
+ credentials=ga_credentials.AnonymousCredentials(), client_info=client_info,
+ )
+ prep.assert_called_once_with(client_info)
+
+ with mock.patch.object(
+ transports.OperationsTransport, "_prep_wrapped_messages"
+ ) as prep:
+ transport_class = AbstractOperationsClient.get_transport_class()
+ transport_class(
+ credentials=ga_credentials.AnonymousCredentials(), client_info=client_info,
+ )
+ prep.assert_called_once_with(client_info)
diff --git a/tests/unit/test_bidi.py b/tests/unit/test_bidi.py
new file mode 100644
index 0000000..7fb1620
--- /dev/null
+++ b/tests/unit/test_bidi.py
@@ -0,0 +1,869 @@
+# Copyright 2018, Google LLC All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import logging
+import queue
+import threading
+
+import mock
+import pytest
+
+try:
+ import grpc
+except ImportError:
+ pytest.skip("No GRPC", allow_module_level=True)
+
+from google.api_core import bidi
+from google.api_core import exceptions
+
+
+class Test_RequestQueueGenerator(object):
+ def test_bounded_consume(self):
+ call = mock.create_autospec(grpc.Call, instance=True)
+ call.is_active.return_value = True
+
+ def queue_generator(rpc):
+ yield mock.sentinel.A
+ yield queue.Empty()
+ yield mock.sentinel.B
+ rpc.is_active.return_value = False
+ yield mock.sentinel.C
+
+ q = mock.create_autospec(queue.Queue, instance=True)
+ q.get.side_effect = queue_generator(call)
+
+ generator = bidi._RequestQueueGenerator(q)
+ generator.call = call
+
+ items = list(generator)
+
+ assert items == [mock.sentinel.A, mock.sentinel.B]
+
+ def test_yield_initial_and_exit(self):
+ q = mock.create_autospec(queue.Queue, instance=True)
+ q.get.side_effect = queue.Empty()
+ call = mock.create_autospec(grpc.Call, instance=True)
+ call.is_active.return_value = False
+
+ generator = bidi._RequestQueueGenerator(q, initial_request=mock.sentinel.A)
+ generator.call = call
+
+ items = list(generator)
+
+ assert items == [mock.sentinel.A]
+
+ def test_yield_initial_callable_and_exit(self):
+ q = mock.create_autospec(queue.Queue, instance=True)
+ q.get.side_effect = queue.Empty()
+ call = mock.create_autospec(grpc.Call, instance=True)
+ call.is_active.return_value = False
+
+ generator = bidi._RequestQueueGenerator(
+ q, initial_request=lambda: mock.sentinel.A
+ )
+ generator.call = call
+
+ items = list(generator)
+
+ assert items == [mock.sentinel.A]
+
+ def test_exit_when_inactive_with_item(self):
+ q = mock.create_autospec(queue.Queue, instance=True)
+ q.get.side_effect = [mock.sentinel.A, queue.Empty()]
+ call = mock.create_autospec(grpc.Call, instance=True)
+ call.is_active.return_value = False
+
+ generator = bidi._RequestQueueGenerator(q)
+ generator.call = call
+
+ items = list(generator)
+
+ assert items == []
+ # Make sure it put the item back.
+ q.put.assert_called_once_with(mock.sentinel.A)
+
+ def test_exit_when_inactive_empty(self):
+ q = mock.create_autospec(queue.Queue, instance=True)
+ q.get.side_effect = queue.Empty()
+ call = mock.create_autospec(grpc.Call, instance=True)
+ call.is_active.return_value = False
+
+ generator = bidi._RequestQueueGenerator(q)
+ generator.call = call
+
+ items = list(generator)
+
+ assert items == []
+
+ def test_exit_with_stop(self):
+ q = mock.create_autospec(queue.Queue, instance=True)
+ q.get.side_effect = [None, queue.Empty()]
+ call = mock.create_autospec(grpc.Call, instance=True)
+ call.is_active.return_value = True
+
+ generator = bidi._RequestQueueGenerator(q)
+ generator.call = call
+
+ items = list(generator)
+
+ assert items == []
+
+
+class Test_Throttle(object):
+ def test_repr(self):
+ delta = datetime.timedelta(seconds=4.5)
+ instance = bidi._Throttle(access_limit=42, time_window=delta)
+ assert repr(instance) == "_Throttle(access_limit=42, time_window={})".format(
+ repr(delta)
+ )
+
+ def test_raises_error_on_invalid_init_arguments(self):
+ with pytest.raises(ValueError) as exc_info:
+ bidi._Throttle(access_limit=10, time_window=datetime.timedelta(seconds=0.0))
+ assert "time_window" in str(exc_info.value)
+ assert "must be a positive timedelta" in str(exc_info.value)
+
+ with pytest.raises(ValueError) as exc_info:
+ bidi._Throttle(access_limit=0, time_window=datetime.timedelta(seconds=10))
+ assert "access_limit" in str(exc_info.value)
+ assert "must be positive" in str(exc_info.value)
+
+ def test_does_not_delay_entry_attempts_under_threshold(self):
+ throttle = bidi._Throttle(
+ access_limit=3, time_window=datetime.timedelta(seconds=1)
+ )
+ entries = []
+
+ for _ in range(3):
+ with throttle as time_waited:
+ entry_info = {
+ "entered_at": datetime.datetime.now(),
+ "reported_wait": time_waited,
+ }
+ entries.append(entry_info)
+
+ # check the reported wait times ...
+ assert all(entry["reported_wait"] == 0.0 for entry in entries)
+
+ # .. and the actual wait times
+ delta = entries[1]["entered_at"] - entries[0]["entered_at"]
+ assert delta.total_seconds() < 0.1
+ delta = entries[2]["entered_at"] - entries[1]["entered_at"]
+ assert delta.total_seconds() < 0.1
+
+ def test_delays_entry_attempts_above_threshold(self):
+ throttle = bidi._Throttle(
+ access_limit=3, time_window=datetime.timedelta(seconds=1)
+ )
+ entries = []
+
+ for _ in range(6):
+ with throttle as time_waited:
+ entry_info = {
+ "entered_at": datetime.datetime.now(),
+ "reported_wait": time_waited,
+ }
+ entries.append(entry_info)
+
+ # For each group of 4 consecutive entries the time difference between
+ # the first and the last entry must have been greater than time_window,
+ # because a maximum of 3 are allowed in each time_window.
+ for i, entry in enumerate(entries[3:], start=3):
+ first_entry = entries[i - 3]
+ delta = entry["entered_at"] - first_entry["entered_at"]
+ assert delta.total_seconds() > 1.0
+
+ # check the reported wait times
+ # (NOTE: not using assert all(...), b/c the coverage check would complain)
+ for i, entry in enumerate(entries):
+ if i != 3:
+ assert entry["reported_wait"] == 0.0
+
+ # The delayed entry is expected to have been delayed for a significant
+ # chunk of the full second, and the actual and reported delay times
+ # should reflect that.
+ assert entries[3]["reported_wait"] > 0.7
+ delta = entries[3]["entered_at"] - entries[2]["entered_at"]
+ assert delta.total_seconds() > 0.7
+
+
+class _CallAndFuture(grpc.Call, grpc.Future):
+ pass
+
+
+def make_rpc():
+ """Makes a mock RPC used to test Bidi classes."""
+ call = mock.create_autospec(_CallAndFuture, instance=True)
+ rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True)
+
+ def rpc_side_effect(request, metadata=None):
+ call.is_active.return_value = True
+ call.request = request
+ call.metadata = metadata
+ return call
+
+ rpc.side_effect = rpc_side_effect
+
+ def cancel_side_effect():
+ call.is_active.return_value = False
+
+ call.cancel.side_effect = cancel_side_effect
+
+ return rpc, call
+
+
+class ClosedCall(object):
+ def __init__(self, exception):
+ self.exception = exception
+
+ def __next__(self):
+ raise self.exception
+
+ def is_active(self):
+ return False
+
+
+class TestBidiRpc(object):
+ def test_initial_state(self):
+ bidi_rpc = bidi.BidiRpc(None)
+
+ assert bidi_rpc.is_active is False
+
+ def test_done_callbacks(self):
+ bidi_rpc = bidi.BidiRpc(None)
+ callback = mock.Mock(spec=["__call__"])
+
+ bidi_rpc.add_done_callback(callback)
+ bidi_rpc._on_call_done(mock.sentinel.future)
+
+ callback.assert_called_once_with(mock.sentinel.future)
+
+ def test_metadata(self):
+ rpc, call = make_rpc()
+ bidi_rpc = bidi.BidiRpc(rpc, metadata=mock.sentinel.A)
+ assert bidi_rpc._rpc_metadata == mock.sentinel.A
+
+ bidi_rpc.open()
+ assert bidi_rpc.call == call
+ assert bidi_rpc.call.metadata == mock.sentinel.A
+
+ def test_open(self):
+ rpc, call = make_rpc()
+ bidi_rpc = bidi.BidiRpc(rpc)
+
+ bidi_rpc.open()
+
+ assert bidi_rpc.call == call
+ assert bidi_rpc.is_active
+ call.add_done_callback.assert_called_once_with(bidi_rpc._on_call_done)
+
+ def test_open_error_already_open(self):
+ rpc, _ = make_rpc()
+ bidi_rpc = bidi.BidiRpc(rpc)
+
+ bidi_rpc.open()
+
+ with pytest.raises(ValueError):
+ bidi_rpc.open()
+
+ def test_close(self):
+ rpc, call = make_rpc()
+ bidi_rpc = bidi.BidiRpc(rpc)
+ bidi_rpc.open()
+
+ bidi_rpc.close()
+
+ call.cancel.assert_called_once()
+ assert bidi_rpc.call == call
+ assert bidi_rpc.is_active is False
+ # ensure the request queue was signaled to stop.
+ assert bidi_rpc.pending_requests == 1
+ assert bidi_rpc._request_queue.get() is None
+
+ def test_close_no_rpc(self):
+ bidi_rpc = bidi.BidiRpc(None)
+ bidi_rpc.close()
+
+ def test_send(self):
+ rpc, call = make_rpc()
+ bidi_rpc = bidi.BidiRpc(rpc)
+ bidi_rpc.open()
+
+ bidi_rpc.send(mock.sentinel.request)
+
+ assert bidi_rpc.pending_requests == 1
+ assert bidi_rpc._request_queue.get() is mock.sentinel.request
+
+ def test_send_not_open(self):
+ rpc, call = make_rpc()
+ bidi_rpc = bidi.BidiRpc(rpc)
+
+ with pytest.raises(ValueError):
+ bidi_rpc.send(mock.sentinel.request)
+
+ def test_send_dead_rpc(self):
+ error = ValueError()
+ bidi_rpc = bidi.BidiRpc(None)
+ bidi_rpc.call = ClosedCall(error)
+
+ with pytest.raises(ValueError) as exc_info:
+ bidi_rpc.send(mock.sentinel.request)
+
+ assert exc_info.value == error
+
+ def test_recv(self):
+ bidi_rpc = bidi.BidiRpc(None)
+ bidi_rpc.call = iter([mock.sentinel.response])
+
+ response = bidi_rpc.recv()
+
+ assert response == mock.sentinel.response
+
+ def test_recv_not_open(self):
+ rpc, call = make_rpc()
+ bidi_rpc = bidi.BidiRpc(rpc)
+
+ with pytest.raises(ValueError):
+ bidi_rpc.recv()
+
+
+class CallStub(object):
+ def __init__(self, values, active=True):
+ self.values = iter(values)
+ self._is_active = active
+ self.cancelled = False
+
+ def __next__(self):
+ item = next(self.values)
+ if isinstance(item, Exception):
+ self._is_active = False
+ raise item
+ return item
+
+ def is_active(self):
+ return self._is_active
+
+ def add_done_callback(self, callback):
+ pass
+
+ def cancel(self):
+ self.cancelled = True
+
+
+class TestResumableBidiRpc(object):
+ def test_ctor_defaults(self):
+ start_rpc = mock.Mock()
+ should_recover = mock.Mock()
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
+
+ assert bidi_rpc.is_active is False
+ assert bidi_rpc._finalized is False
+ assert bidi_rpc._start_rpc is start_rpc
+ assert bidi_rpc._should_recover is should_recover
+ assert bidi_rpc._should_terminate is bidi._never_terminate
+ assert bidi_rpc._initial_request is None
+ assert bidi_rpc._rpc_metadata is None
+ assert bidi_rpc._reopen_throttle is None
+
+ def test_ctor_explicit(self):
+ start_rpc = mock.Mock()
+ should_recover = mock.Mock()
+ should_terminate = mock.Mock()
+ initial_request = mock.Mock()
+ metadata = {"x-foo": "bar"}
+ bidi_rpc = bidi.ResumableBidiRpc(
+ start_rpc,
+ should_recover,
+ should_terminate=should_terminate,
+ initial_request=initial_request,
+ metadata=metadata,
+ throttle_reopen=True,
+ )
+
+ assert bidi_rpc.is_active is False
+ assert bidi_rpc._finalized is False
+ assert bidi_rpc._should_recover is should_recover
+ assert bidi_rpc._should_terminate is should_terminate
+ assert bidi_rpc._initial_request is initial_request
+ assert bidi_rpc._rpc_metadata == metadata
+ assert isinstance(bidi_rpc._reopen_throttle, bidi._Throttle)
+
+ def test_done_callbacks_terminate(self):
+ cancellation = mock.Mock()
+ start_rpc = mock.Mock()
+ should_recover = mock.Mock(spec=["__call__"], return_value=True)
+ should_terminate = mock.Mock(spec=["__call__"], return_value=True)
+ bidi_rpc = bidi.ResumableBidiRpc(
+ start_rpc, should_recover, should_terminate=should_terminate
+ )
+ callback = mock.Mock(spec=["__call__"])
+
+ bidi_rpc.add_done_callback(callback)
+ bidi_rpc._on_call_done(cancellation)
+
+ should_terminate.assert_called_once_with(cancellation)
+ should_recover.assert_not_called()
+ callback.assert_called_once_with(cancellation)
+ assert not bidi_rpc.is_active
+
+ def test_done_callbacks_recoverable(self):
+ start_rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True)
+ should_recover = mock.Mock(spec=["__call__"], return_value=True)
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
+ callback = mock.Mock(spec=["__call__"])
+
+ bidi_rpc.add_done_callback(callback)
+ bidi_rpc._on_call_done(mock.sentinel.future)
+
+ callback.assert_not_called()
+ start_rpc.assert_called_once()
+ should_recover.assert_called_once_with(mock.sentinel.future)
+ assert bidi_rpc.is_active
+
+ def test_done_callbacks_non_recoverable(self):
+ start_rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True)
+ should_recover = mock.Mock(spec=["__call__"], return_value=False)
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
+ callback = mock.Mock(spec=["__call__"])
+
+ bidi_rpc.add_done_callback(callback)
+ bidi_rpc._on_call_done(mock.sentinel.future)
+
+ callback.assert_called_once_with(mock.sentinel.future)
+ should_recover.assert_called_once_with(mock.sentinel.future)
+ assert not bidi_rpc.is_active
+
+ def test_send_terminate(self):
+ cancellation = ValueError()
+ call_1 = CallStub([cancellation], active=False)
+ call_2 = CallStub([])
+ start_rpc = mock.create_autospec(
+ grpc.StreamStreamMultiCallable, instance=True, side_effect=[call_1, call_2]
+ )
+ should_recover = mock.Mock(spec=["__call__"], return_value=False)
+ should_terminate = mock.Mock(spec=["__call__"], return_value=True)
+ bidi_rpc = bidi.ResumableBidiRpc(
+ start_rpc, should_recover, should_terminate=should_terminate
+ )
+
+ bidi_rpc.open()
+
+ bidi_rpc.send(mock.sentinel.request)
+
+ assert bidi_rpc.pending_requests == 1
+ assert bidi_rpc._request_queue.get() is None
+
+ should_recover.assert_not_called()
+ should_terminate.assert_called_once_with(cancellation)
+ assert bidi_rpc.call == call_1
+ assert bidi_rpc.is_active is False
+ assert call_1.cancelled is True
+
+ def test_send_recover(self):
+ error = ValueError()
+ call_1 = CallStub([error], active=False)
+ call_2 = CallStub([])
+ start_rpc = mock.create_autospec(
+ grpc.StreamStreamMultiCallable, instance=True, side_effect=[call_1, call_2]
+ )
+ should_recover = mock.Mock(spec=["__call__"], return_value=True)
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
+
+ bidi_rpc.open()
+
+ bidi_rpc.send(mock.sentinel.request)
+
+ assert bidi_rpc.pending_requests == 1
+ assert bidi_rpc._request_queue.get() is mock.sentinel.request
+
+ should_recover.assert_called_once_with(error)
+ assert bidi_rpc.call == call_2
+ assert bidi_rpc.is_active is True
+
+ def test_send_failure(self):
+ error = ValueError()
+ call = CallStub([error], active=False)
+ start_rpc = mock.create_autospec(
+ grpc.StreamStreamMultiCallable, instance=True, return_value=call
+ )
+ should_recover = mock.Mock(spec=["__call__"], return_value=False)
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
+
+ bidi_rpc.open()
+
+ with pytest.raises(ValueError) as exc_info:
+ bidi_rpc.send(mock.sentinel.request)
+
+ assert exc_info.value == error
+ should_recover.assert_called_once_with(error)
+ assert bidi_rpc.call == call
+ assert bidi_rpc.is_active is False
+ assert call.cancelled is True
+ assert bidi_rpc.pending_requests == 1
+ assert bidi_rpc._request_queue.get() is None
+
+ def test_recv_terminate(self):
+ cancellation = ValueError()
+ call = CallStub([cancellation])
+ start_rpc = mock.create_autospec(
+ grpc.StreamStreamMultiCallable, instance=True, return_value=call
+ )
+ should_recover = mock.Mock(spec=["__call__"], return_value=False)
+ should_terminate = mock.Mock(spec=["__call__"], return_value=True)
+ bidi_rpc = bidi.ResumableBidiRpc(
+ start_rpc, should_recover, should_terminate=should_terminate
+ )
+
+ bidi_rpc.open()
+
+ bidi_rpc.recv()
+
+ should_recover.assert_not_called()
+ should_terminate.assert_called_once_with(cancellation)
+ assert bidi_rpc.call == call
+ assert bidi_rpc.is_active is False
+ assert call.cancelled is True
+
+ def test_recv_recover(self):
+ error = ValueError()
+ call_1 = CallStub([1, error])
+ call_2 = CallStub([2, 3])
+ start_rpc = mock.create_autospec(
+ grpc.StreamStreamMultiCallable, instance=True, side_effect=[call_1, call_2]
+ )
+ should_recover = mock.Mock(spec=["__call__"], return_value=True)
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
+
+ bidi_rpc.open()
+
+ values = []
+ for n in range(3):
+ values.append(bidi_rpc.recv())
+
+ assert values == [1, 2, 3]
+ should_recover.assert_called_once_with(error)
+ assert bidi_rpc.call == call_2
+ assert bidi_rpc.is_active is True
+
+ def test_recv_recover_already_recovered(self):
+ call_1 = CallStub([])
+ call_2 = CallStub([])
+ start_rpc = mock.create_autospec(
+ grpc.StreamStreamMultiCallable, instance=True, side_effect=[call_1, call_2]
+ )
+ callback = mock.Mock()
+ callback.return_value = True
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, callback)
+
+ bidi_rpc.open()
+
+ bidi_rpc._reopen()
+
+ assert bidi_rpc.call is call_1
+ assert bidi_rpc.is_active is True
+
+ def test_recv_failure(self):
+ error = ValueError()
+ call = CallStub([error])
+ start_rpc = mock.create_autospec(
+ grpc.StreamStreamMultiCallable, instance=True, return_value=call
+ )
+ should_recover = mock.Mock(spec=["__call__"], return_value=False)
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
+
+ bidi_rpc.open()
+
+ with pytest.raises(ValueError) as exc_info:
+ bidi_rpc.recv()
+
+ assert exc_info.value == error
+ should_recover.assert_called_once_with(error)
+ assert bidi_rpc.call == call
+ assert bidi_rpc.is_active is False
+ assert call.cancelled is True
+
+ def test_close(self):
+ call = mock.create_autospec(_CallAndFuture, instance=True)
+
+ def cancel_side_effect():
+ call.is_active.return_value = False
+
+ call.cancel.side_effect = cancel_side_effect
+ start_rpc = mock.create_autospec(
+ grpc.StreamStreamMultiCallable, instance=True, return_value=call
+ )
+ should_recover = mock.Mock(spec=["__call__"], return_value=False)
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
+ bidi_rpc.open()
+
+ bidi_rpc.close()
+
+ should_recover.assert_not_called()
+ call.cancel.assert_called_once()
+ assert bidi_rpc.call == call
+ assert bidi_rpc.is_active is False
+ # ensure the request queue was signaled to stop.
+ assert bidi_rpc.pending_requests == 1
+ assert bidi_rpc._request_queue.get() is None
+ assert bidi_rpc._finalized
+
+ def test_reopen_failure_on_rpc_restart(self):
+ error1 = ValueError("1")
+ error2 = ValueError("2")
+ call = CallStub([error1])
+ # Invoking start RPC a second time will trigger an error.
+ start_rpc = mock.create_autospec(
+ grpc.StreamStreamMultiCallable, instance=True, side_effect=[call, error2]
+ )
+ should_recover = mock.Mock(spec=["__call__"], return_value=True)
+ callback = mock.Mock(spec=["__call__"])
+
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
+ bidi_rpc.add_done_callback(callback)
+
+ bidi_rpc.open()
+
+ with pytest.raises(ValueError) as exc_info:
+ bidi_rpc.recv()
+
+ assert exc_info.value == error2
+ should_recover.assert_called_once_with(error1)
+ assert bidi_rpc.call is None
+ assert bidi_rpc.is_active is False
+ callback.assert_called_once_with(error2)
+
+ def test_using_throttle_on_reopen_requests(self):
+ call = CallStub([])
+ start_rpc = mock.create_autospec(
+ grpc.StreamStreamMultiCallable, instance=True, return_value=call
+ )
+ should_recover = mock.Mock(spec=["__call__"], return_value=True)
+ bidi_rpc = bidi.ResumableBidiRpc(
+ start_rpc, should_recover, throttle_reopen=True
+ )
+
+ patcher = mock.patch.object(bidi_rpc._reopen_throttle.__class__, "__enter__")
+ with patcher as mock_enter:
+ bidi_rpc._reopen()
+
+ mock_enter.assert_called_once()
+
+ def test_send_not_open(self):
+ bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: False)
+
+ with pytest.raises(ValueError):
+ bidi_rpc.send(mock.sentinel.request)
+
+ def test_recv_not_open(self):
+ bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: False)
+
+ with pytest.raises(ValueError):
+ bidi_rpc.recv()
+
+ def test_finalize_idempotent(self):
+ error1 = ValueError("1")
+ error2 = ValueError("2")
+ callback = mock.Mock(spec=["__call__"])
+ should_recover = mock.Mock(spec=["__call__"], return_value=False)
+
+ bidi_rpc = bidi.ResumableBidiRpc(mock.sentinel.start_rpc, should_recover)
+
+ bidi_rpc.add_done_callback(callback)
+
+ bidi_rpc._on_call_done(error1)
+ bidi_rpc._on_call_done(error2)
+
+ callback.assert_called_once_with(error1)
+
+
+class TestBackgroundConsumer(object):
+ def test_consume_once_then_exit(self):
+ bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True)
+ bidi_rpc.is_active = True
+ bidi_rpc.recv.side_effect = [mock.sentinel.response_1]
+ recved = threading.Event()
+
+ def on_response(response):
+ assert response == mock.sentinel.response_1
+ bidi_rpc.is_active = False
+ recved.set()
+
+ consumer = bidi.BackgroundConsumer(bidi_rpc, on_response)
+
+ consumer.start()
+
+ recved.wait()
+
+ bidi_rpc.recv.assert_called_once()
+ assert bidi_rpc.is_active is False
+
+ consumer.stop()
+
+ bidi_rpc.close.assert_called_once()
+ assert consumer.is_active is False
+
+ def test_pause_resume_and_close(self):
+ # This test is relatively complex. It attempts to start the consumer,
+ # consume one item, pause the consumer, check the state of the world,
+ # then resume the consumer. Doing this in a deterministic fashion
+ # requires a bit more mocking and patching than usual.
+
+ bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True)
+ bidi_rpc.is_active = True
+
+ def close_side_effect():
+ bidi_rpc.is_active = False
+
+ bidi_rpc.close.side_effect = close_side_effect
+
+ # These are used to coordinate the two threads to ensure deterministic
+ # execution.
+ should_continue = threading.Event()
+ responses_and_events = {
+ mock.sentinel.response_1: threading.Event(),
+ mock.sentinel.response_2: threading.Event(),
+ }
+ bidi_rpc.recv.side_effect = [mock.sentinel.response_1, mock.sentinel.response_2]
+
+ recved_responses = []
+ consumer = None
+
+ def on_response(response):
+ if response == mock.sentinel.response_1:
+ consumer.pause()
+
+ recved_responses.append(response)
+ responses_and_events[response].set()
+ should_continue.wait()
+
+ consumer = bidi.BackgroundConsumer(bidi_rpc, on_response)
+
+ consumer.start()
+
+ # Wait for the first response to be recved.
+ responses_and_events[mock.sentinel.response_1].wait()
+
+ # Ensure only one item has been recved and that the consumer is paused.
+ assert recved_responses == [mock.sentinel.response_1]
+ assert consumer.is_paused is True
+ assert consumer.is_active is True
+
+ # Unpause the consumer, wait for the second item, then close the
+ # consumer.
+ should_continue.set()
+ consumer.resume()
+
+ responses_and_events[mock.sentinel.response_2].wait()
+
+ assert recved_responses == [mock.sentinel.response_1, mock.sentinel.response_2]
+
+ consumer.stop()
+
+ assert consumer.is_active is False
+
+ def test_wake_on_error(self):
+ should_continue = threading.Event()
+
+ bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True)
+ bidi_rpc.is_active = True
+ bidi_rpc.add_done_callback.side_effect = lambda _: should_continue.set()
+
+ consumer = bidi.BackgroundConsumer(bidi_rpc, mock.sentinel.on_response)
+
+ # Start the consumer paused, which should immediately put it into wait
+ # state.
+ consumer.pause()
+ consumer.start()
+
+ # Wait for add_done_callback to be called
+ should_continue.wait()
+ bidi_rpc.add_done_callback.assert_called_once_with(consumer._on_call_done)
+
+ # The consumer should now be blocked on waiting to be unpaused.
+ assert consumer.is_active
+ assert consumer.is_paused
+
+ # Trigger the done callback, it should unpause the consumer and cause
+ # it to exit.
+ bidi_rpc.is_active = False
+ consumer._on_call_done(bidi_rpc)
+
+ # It may take a few cycles for the thread to exit.
+ while consumer.is_active:
+ pass
+
+ def test_consumer_expected_error(self, caplog):
+ caplog.set_level(logging.DEBUG)
+
+ bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True)
+ bidi_rpc.is_active = True
+ bidi_rpc.recv.side_effect = exceptions.ServiceUnavailable("Gone away")
+
+ on_response = mock.Mock(spec=["__call__"])
+
+ consumer = bidi.BackgroundConsumer(bidi_rpc, on_response)
+
+ consumer.start()
+
+ # Wait for the consumer's thread to exit.
+ while consumer.is_active:
+ pass
+
+ on_response.assert_not_called()
+ bidi_rpc.recv.assert_called_once()
+ assert "caught error" in caplog.text
+
+ def test_consumer_unexpected_error(self, caplog):
+ caplog.set_level(logging.DEBUG)
+
+ bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True)
+ bidi_rpc.is_active = True
+ bidi_rpc.recv.side_effect = ValueError()
+
+ on_response = mock.Mock(spec=["__call__"])
+
+ consumer = bidi.BackgroundConsumer(bidi_rpc, on_response)
+
+ consumer.start()
+
+ # Wait for the consumer's thread to exit.
+ while consumer.is_active:
+ pass # pragma: NO COVER (race condition)
+
+ on_response.assert_not_called()
+ bidi_rpc.recv.assert_called_once()
+ assert "caught unexpected exception" in caplog.text
+
+ def test_double_stop(self, caplog):
+ caplog.set_level(logging.DEBUG)
+ bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True)
+ bidi_rpc.is_active = True
+ on_response = mock.Mock(spec=["__call__"])
+
+ def close_side_effect():
+ bidi_rpc.is_active = False
+
+ bidi_rpc.close.side_effect = close_side_effect
+
+ consumer = bidi.BackgroundConsumer(bidi_rpc, on_response)
+
+ consumer.start()
+ assert consumer.is_active is True
+
+ consumer.stop()
+ assert consumer.is_active is False
+
+ # calling stop twice should not result in an error.
+ consumer.stop()
diff --git a/tests/unit/test_client_info.py b/tests/unit/test_client_info.py
new file mode 100644
index 0000000..f5eebfb
--- /dev/null
+++ b/tests/unit/test_client_info.py
@@ -0,0 +1,98 @@
+# Copyright 2017 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+try:
+ import grpc
+except ImportError:
+ grpc = None
+
+from google.api_core import client_info
+
+
+def test_constructor_defaults():
+ info = client_info.ClientInfo()
+
+ assert info.python_version is not None
+
+ if grpc is not None:
+ assert info.grpc_version is not None
+ else:
+ assert info.grpc_version is None
+
+ assert info.api_core_version is not None
+ assert info.gapic_version is None
+ assert info.client_library_version is None
+ assert info.rest_version is None
+
+
+def test_constructor_options():
+ info = client_info.ClientInfo(
+ python_version="1",
+ grpc_version="2",
+ api_core_version="3",
+ gapic_version="4",
+ client_library_version="5",
+ user_agent="6",
+ rest_version="7",
+ )
+
+ assert info.python_version == "1"
+ assert info.grpc_version == "2"
+ assert info.api_core_version == "3"
+ assert info.gapic_version == "4"
+ assert info.client_library_version == "5"
+ assert info.user_agent == "6"
+ assert info.rest_version == "7"
+
+
+def test_to_user_agent_minimal():
+ info = client_info.ClientInfo(
+ python_version="1", api_core_version="2", grpc_version=None
+ )
+
+ user_agent = info.to_user_agent()
+
+ assert user_agent == "gl-python/1 gax/2"
+
+
+def test_to_user_agent_full():
+ info = client_info.ClientInfo(
+ python_version="1",
+ grpc_version="2",
+ api_core_version="3",
+ gapic_version="4",
+ client_library_version="5",
+ user_agent="app-name/1.0",
+ )
+
+ user_agent = info.to_user_agent()
+
+ assert user_agent == "app-name/1.0 gl-python/1 grpc/2 gax/3 gapic/4 gccl/5"
+
+
+def test_to_user_agent_rest():
+ info = client_info.ClientInfo(
+ python_version="1",
+ grpc_version=None,
+ rest_version="2",
+ api_core_version="3",
+ gapic_version="4",
+ client_library_version="5",
+ user_agent="app-name/1.0",
+ )
+
+ user_agent = info.to_user_agent()
+
+ assert user_agent == "app-name/1.0 gl-python/1 rest/2 gax/3 gapic/4 gccl/5"
diff --git a/tests/unit/test_client_options.py b/tests/unit/test_client_options.py
new file mode 100644
index 0000000..38b9ad0
--- /dev/null
+++ b/tests/unit/test_client_options.py
@@ -0,0 +1,117 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from google.api_core import client_options
+
+
+def get_client_cert():
+ return b"cert", b"key"
+
+
+def get_client_encrypted_cert():
+ return "cert_path", "key_path", b"passphrase"
+
+
+def test_constructor():
+
+ options = client_options.ClientOptions(
+ api_endpoint="foo.googleapis.com",
+ client_cert_source=get_client_cert,
+ quota_project_id="quote-proj",
+ credentials_file="path/to/credentials.json",
+ scopes=[
+ "https://www.googleapis.com/auth/cloud-platform",
+ "https://www.googleapis.com/auth/cloud-platform.read-only",
+ ],
+ )
+
+ assert options.api_endpoint == "foo.googleapis.com"
+ assert options.client_cert_source() == (b"cert", b"key")
+ assert options.quota_project_id == "quote-proj"
+ assert options.credentials_file == "path/to/credentials.json"
+ assert options.scopes == [
+ "https://www.googleapis.com/auth/cloud-platform",
+ "https://www.googleapis.com/auth/cloud-platform.read-only",
+ ]
+
+
+def test_constructor_with_encrypted_cert_source():
+
+ options = client_options.ClientOptions(
+ api_endpoint="foo.googleapis.com",
+ client_encrypted_cert_source=get_client_encrypted_cert,
+ )
+
+ assert options.api_endpoint == "foo.googleapis.com"
+ assert options.client_encrypted_cert_source() == (
+ "cert_path",
+ "key_path",
+ b"passphrase",
+ )
+
+
+def test_constructor_with_both_cert_sources():
+ with pytest.raises(ValueError):
+ client_options.ClientOptions(
+ api_endpoint="foo.googleapis.com",
+ client_cert_source=get_client_cert,
+ client_encrypted_cert_source=get_client_encrypted_cert,
+ )
+
+
+def test_from_dict():
+ options = client_options.from_dict(
+ {
+ "api_endpoint": "foo.googleapis.com",
+ "client_cert_source": get_client_cert,
+ "quota_project_id": "quote-proj",
+ "credentials_file": "path/to/credentials.json",
+ "scopes": [
+ "https://www.googleapis.com/auth/cloud-platform",
+ "https://www.googleapis.com/auth/cloud-platform.read-only",
+ ],
+ }
+ )
+
+ assert options.api_endpoint == "foo.googleapis.com"
+ assert options.client_cert_source() == (b"cert", b"key")
+ assert options.quota_project_id == "quote-proj"
+ assert options.credentials_file == "path/to/credentials.json"
+ assert options.scopes == [
+ "https://www.googleapis.com/auth/cloud-platform",
+ "https://www.googleapis.com/auth/cloud-platform.read-only",
+ ]
+
+
+def test_from_dict_bad_argument():
+ with pytest.raises(ValueError):
+ client_options.from_dict(
+ {
+ "api_endpoint": "foo.googleapis.com",
+ "bad_arg": "1234",
+ "client_cert_source": get_client_cert,
+ }
+ )
+
+
+def test_repr():
+ options = client_options.ClientOptions(api_endpoint="foo.googleapis.com")
+
+ assert (
+ repr(options)
+ == "ClientOptions: {'api_endpoint': 'foo.googleapis.com', 'client_cert_source': None, 'client_encrypted_cert_source': None}"
+ or "ClientOptions: {'client_encrypted_cert_source': None, 'client_cert_source': None, 'api_endpoint': 'foo.googleapis.com'}"
+ )
diff --git a/tests/unit/test_datetime_helpers.py b/tests/unit/test_datetime_helpers.py
new file mode 100644
index 0000000..5f5470a
--- /dev/null
+++ b/tests/unit/test_datetime_helpers.py
@@ -0,0 +1,396 @@
+# Copyright 2017, Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import calendar
+import datetime
+
+import pytest
+
+from google.api_core import datetime_helpers
+from google.protobuf import timestamp_pb2
+
+
+ONE_MINUTE_IN_MICROSECONDS = 60 * 1e6
+
+
+def test_utcnow():
+ result = datetime_helpers.utcnow()
+ assert isinstance(result, datetime.datetime)
+
+
+def test_to_milliseconds():
+ dt = datetime.datetime(1970, 1, 1, 0, 0, 1, tzinfo=datetime.timezone.utc)
+ assert datetime_helpers.to_milliseconds(dt) == 1000
+
+
+def test_to_microseconds():
+ microseconds = 314159
+ dt = datetime.datetime(1970, 1, 1, 0, 0, 0, microsecond=microseconds)
+ assert datetime_helpers.to_microseconds(dt) == microseconds
+
+
+def test_to_microseconds_non_utc():
+ zone = datetime.timezone(datetime.timedelta(minutes=-1))
+ dt = datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=zone)
+ assert datetime_helpers.to_microseconds(dt) == ONE_MINUTE_IN_MICROSECONDS
+
+
+def test_to_microseconds_naive():
+ microseconds = 314159
+ dt = datetime.datetime(1970, 1, 1, 0, 0, 0, microsecond=microseconds, tzinfo=None)
+ assert datetime_helpers.to_microseconds(dt) == microseconds
+
+
+def test_from_microseconds():
+ five_mins_from_epoch_in_microseconds = 5 * ONE_MINUTE_IN_MICROSECONDS
+ five_mins_from_epoch_datetime = datetime.datetime(
+ 1970, 1, 1, 0, 5, 0, tzinfo=datetime.timezone.utc
+ )
+
+ result = datetime_helpers.from_microseconds(five_mins_from_epoch_in_microseconds)
+
+ assert result == five_mins_from_epoch_datetime
+
+
+def test_from_iso8601_date():
+ today = datetime.date.today()
+ iso_8601_today = today.strftime("%Y-%m-%d")
+
+ assert datetime_helpers.from_iso8601_date(iso_8601_today) == today
+
+
+def test_from_iso8601_time():
+ assert datetime_helpers.from_iso8601_time("12:09:42") == datetime.time(12, 9, 42)
+
+
+def test_from_rfc3339():
+ value = "2009-12-17T12:44:32.123456Z"
+ assert datetime_helpers.from_rfc3339(value) == datetime.datetime(
+ 2009, 12, 17, 12, 44, 32, 123456, datetime.timezone.utc
+ )
+
+
+def test_from_rfc3339_nanos():
+ value = "2009-12-17T12:44:32.123456Z"
+ assert datetime_helpers.from_rfc3339_nanos(value) == datetime.datetime(
+ 2009, 12, 17, 12, 44, 32, 123456, datetime.timezone.utc
+ )
+
+
+def test_from_rfc3339_without_nanos():
+ value = "2009-12-17T12:44:32Z"
+ assert datetime_helpers.from_rfc3339(value) == datetime.datetime(
+ 2009, 12, 17, 12, 44, 32, 0, datetime.timezone.utc
+ )
+
+
+def test_from_rfc3339_nanos_without_nanos():
+ value = "2009-12-17T12:44:32Z"
+ assert datetime_helpers.from_rfc3339_nanos(value) == datetime.datetime(
+ 2009, 12, 17, 12, 44, 32, 0, datetime.timezone.utc
+ )
+
+
+@pytest.mark.parametrize(
+ "truncated, micros",
+ [
+ ("12345678", 123456),
+ ("1234567", 123456),
+ ("123456", 123456),
+ ("12345", 123450),
+ ("1234", 123400),
+ ("123", 123000),
+ ("12", 120000),
+ ("1", 100000),
+ ],
+)
+def test_from_rfc3339_with_truncated_nanos(truncated, micros):
+ value = "2009-12-17T12:44:32.{}Z".format(truncated)
+ assert datetime_helpers.from_rfc3339(value) == datetime.datetime(
+ 2009, 12, 17, 12, 44, 32, micros, datetime.timezone.utc
+ )
+
+
+def test_from_rfc3339_nanos_is_deprecated():
+ value = "2009-12-17T12:44:32.123456Z"
+
+ result = datetime_helpers.from_rfc3339(value)
+ result_nanos = datetime_helpers.from_rfc3339_nanos(value)
+
+ assert result == result_nanos
+
+
+@pytest.mark.parametrize(
+ "truncated, micros",
+ [
+ ("12345678", 123456),
+ ("1234567", 123456),
+ ("123456", 123456),
+ ("12345", 123450),
+ ("1234", 123400),
+ ("123", 123000),
+ ("12", 120000),
+ ("1", 100000),
+ ],
+)
+def test_from_rfc3339_nanos_with_truncated_nanos(truncated, micros):
+ value = "2009-12-17T12:44:32.{}Z".format(truncated)
+ assert datetime_helpers.from_rfc3339_nanos(value) == datetime.datetime(
+ 2009, 12, 17, 12, 44, 32, micros, datetime.timezone.utc
+ )
+
+
+def test_from_rfc3339_wo_nanos_raise_exception():
+ value = "2009-12-17T12:44:32"
+ with pytest.raises(ValueError):
+ datetime_helpers.from_rfc3339(value)
+
+
+def test_from_rfc3339_w_nanos_raise_exception():
+ value = "2009-12-17T12:44:32.123456"
+ with pytest.raises(ValueError):
+ datetime_helpers.from_rfc3339(value)
+
+
+def test_to_rfc3339():
+ value = datetime.datetime(2016, 4, 5, 13, 30, 0)
+ expected = "2016-04-05T13:30:00.000000Z"
+ assert datetime_helpers.to_rfc3339(value) == expected
+
+
+def test_to_rfc3339_with_utc():
+ value = datetime.datetime(2016, 4, 5, 13, 30, 0, tzinfo=datetime.timezone.utc)
+ expected = "2016-04-05T13:30:00.000000Z"
+ assert datetime_helpers.to_rfc3339(value, ignore_zone=False) == expected
+
+
+def test_to_rfc3339_with_non_utc():
+ zone = datetime.timezone(datetime.timedelta(minutes=-60))
+ value = datetime.datetime(2016, 4, 5, 13, 30, 0, tzinfo=zone)
+ expected = "2016-04-05T14:30:00.000000Z"
+ assert datetime_helpers.to_rfc3339(value, ignore_zone=False) == expected
+
+
+def test_to_rfc3339_with_non_utc_ignore_zone():
+ zone = datetime.timezone(datetime.timedelta(minutes=-60))
+ value = datetime.datetime(2016, 4, 5, 13, 30, 0, tzinfo=zone)
+ expected = "2016-04-05T13:30:00.000000Z"
+ assert datetime_helpers.to_rfc3339(value, ignore_zone=True) == expected
+
+
+class Test_DateTimeWithNanos(object):
+ @staticmethod
+ def test_ctor_wo_nanos():
+ stamp = datetime_helpers.DatetimeWithNanoseconds(
+ 2016, 12, 20, 21, 13, 47, 123456
+ )
+ assert stamp.year == 2016
+ assert stamp.month == 12
+ assert stamp.day == 20
+ assert stamp.hour == 21
+ assert stamp.minute == 13
+ assert stamp.second == 47
+ assert stamp.microsecond == 123456
+ assert stamp.nanosecond == 0
+
+ @staticmethod
+ def test_ctor_w_nanos():
+ stamp = datetime_helpers.DatetimeWithNanoseconds(
+ 2016, 12, 20, 21, 13, 47, nanosecond=123456789
+ )
+ assert stamp.year == 2016
+ assert stamp.month == 12
+ assert stamp.day == 20
+ assert stamp.hour == 21
+ assert stamp.minute == 13
+ assert stamp.second == 47
+ assert stamp.microsecond == 123456
+ assert stamp.nanosecond == 123456789
+
+ @staticmethod
+ def test_ctor_w_micros_positional_and_nanos():
+ with pytest.raises(TypeError):
+ datetime_helpers.DatetimeWithNanoseconds(
+ 2016, 12, 20, 21, 13, 47, 123456, nanosecond=123456789
+ )
+
+ @staticmethod
+ def test_ctor_w_micros_keyword_and_nanos():
+ with pytest.raises(TypeError):
+ datetime_helpers.DatetimeWithNanoseconds(
+ 2016, 12, 20, 21, 13, 47, microsecond=123456, nanosecond=123456789
+ )
+
+ @staticmethod
+ def test_rfc3339_wo_nanos():
+ stamp = datetime_helpers.DatetimeWithNanoseconds(
+ 2016, 12, 20, 21, 13, 47, 123456
+ )
+ assert stamp.rfc3339() == "2016-12-20T21:13:47.123456Z"
+
+ @staticmethod
+ def test_rfc3339_wo_nanos_w_leading_zero():
+ stamp = datetime_helpers.DatetimeWithNanoseconds(2016, 12, 20, 21, 13, 47, 1234)
+ assert stamp.rfc3339() == "2016-12-20T21:13:47.001234Z"
+
+ @staticmethod
+ def test_rfc3339_w_nanos():
+ stamp = datetime_helpers.DatetimeWithNanoseconds(
+ 2016, 12, 20, 21, 13, 47, nanosecond=123456789
+ )
+ assert stamp.rfc3339() == "2016-12-20T21:13:47.123456789Z"
+
+ @staticmethod
+ def test_rfc3339_w_nanos_w_leading_zero():
+ stamp = datetime_helpers.DatetimeWithNanoseconds(
+ 2016, 12, 20, 21, 13, 47, nanosecond=1234567
+ )
+ assert stamp.rfc3339() == "2016-12-20T21:13:47.001234567Z"
+
+ @staticmethod
+ def test_rfc3339_w_nanos_no_trailing_zeroes():
+ stamp = datetime_helpers.DatetimeWithNanoseconds(
+ 2016, 12, 20, 21, 13, 47, nanosecond=100000000
+ )
+ assert stamp.rfc3339() == "2016-12-20T21:13:47.1Z"
+
+ @staticmethod
+ def test_rfc3339_w_nanos_w_leading_zero_and_no_trailing_zeros():
+ stamp = datetime_helpers.DatetimeWithNanoseconds(
+ 2016, 12, 20, 21, 13, 47, nanosecond=1234500
+ )
+ assert stamp.rfc3339() == "2016-12-20T21:13:47.0012345Z"
+
+ @staticmethod
+ def test_from_rfc3339_w_invalid():
+ stamp = "2016-12-20T21:13:47"
+ with pytest.raises(ValueError):
+ datetime_helpers.DatetimeWithNanoseconds.from_rfc3339(stamp)
+
+ @staticmethod
+ def test_from_rfc3339_wo_fraction():
+ timestamp = "2016-12-20T21:13:47Z"
+ expected = datetime_helpers.DatetimeWithNanoseconds(
+ 2016, 12, 20, 21, 13, 47, tzinfo=datetime.timezone.utc
+ )
+ stamp = datetime_helpers.DatetimeWithNanoseconds.from_rfc3339(timestamp)
+ assert stamp == expected
+
+ @staticmethod
+ def test_from_rfc3339_w_partial_precision():
+ timestamp = "2016-12-20T21:13:47.1Z"
+ expected = datetime_helpers.DatetimeWithNanoseconds(
+ 2016, 12, 20, 21, 13, 47, microsecond=100000, tzinfo=datetime.timezone.utc
+ )
+ stamp = datetime_helpers.DatetimeWithNanoseconds.from_rfc3339(timestamp)
+ assert stamp == expected
+
+ @staticmethod
+ def test_from_rfc3339_w_full_precision():
+ timestamp = "2016-12-20T21:13:47.123456789Z"
+ expected = datetime_helpers.DatetimeWithNanoseconds(
+ 2016, 12, 20, 21, 13, 47, nanosecond=123456789, tzinfo=datetime.timezone.utc
+ )
+ stamp = datetime_helpers.DatetimeWithNanoseconds.from_rfc3339(timestamp)
+ assert stamp == expected
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "fractional, nanos",
+ [
+ ("12345678", 123456780),
+ ("1234567", 123456700),
+ ("123456", 123456000),
+ ("12345", 123450000),
+ ("1234", 123400000),
+ ("123", 123000000),
+ ("12", 120000000),
+ ("1", 100000000),
+ ],
+ )
+ def test_from_rfc3339_test_nanoseconds(fractional, nanos):
+ value = "2009-12-17T12:44:32.{}Z".format(fractional)
+ assert (
+ datetime_helpers.DatetimeWithNanoseconds.from_rfc3339(value).nanosecond
+ == nanos
+ )
+
+ @staticmethod
+ def test_timestamp_pb_wo_nanos_naive():
+ stamp = datetime_helpers.DatetimeWithNanoseconds(
+ 2016, 12, 20, 21, 13, 47, 123456
+ )
+ delta = (
+ stamp.replace(tzinfo=datetime.timezone.utc) - datetime_helpers._UTC_EPOCH
+ )
+ seconds = int(delta.total_seconds())
+ nanos = 123456000
+ timestamp = timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos)
+ assert stamp.timestamp_pb() == timestamp
+
+ @staticmethod
+ def test_timestamp_pb_w_nanos():
+ stamp = datetime_helpers.DatetimeWithNanoseconds(
+ 2016, 12, 20, 21, 13, 47, nanosecond=123456789, tzinfo=datetime.timezone.utc
+ )
+ delta = stamp - datetime_helpers._UTC_EPOCH
+ timestamp = timestamp_pb2.Timestamp(
+ seconds=int(delta.total_seconds()), nanos=123456789
+ )
+ assert stamp.timestamp_pb() == timestamp
+
+ @staticmethod
+ def test_from_timestamp_pb_wo_nanos():
+ when = datetime.datetime(
+ 2016, 12, 20, 21, 13, 47, 123456, tzinfo=datetime.timezone.utc
+ )
+ delta = when - datetime_helpers._UTC_EPOCH
+ seconds = int(delta.total_seconds())
+ timestamp = timestamp_pb2.Timestamp(seconds=seconds)
+
+ stamp = datetime_helpers.DatetimeWithNanoseconds.from_timestamp_pb(timestamp)
+
+ assert _to_seconds(when) == _to_seconds(stamp)
+ assert stamp.microsecond == 0
+ assert stamp.nanosecond == 0
+ assert stamp.tzinfo == datetime.timezone.utc
+
+ @staticmethod
+ def test_from_timestamp_pb_w_nanos():
+ when = datetime.datetime(
+ 2016, 12, 20, 21, 13, 47, 123456, tzinfo=datetime.timezone.utc
+ )
+ delta = when - datetime_helpers._UTC_EPOCH
+ seconds = int(delta.total_seconds())
+ timestamp = timestamp_pb2.Timestamp(seconds=seconds, nanos=123456789)
+
+ stamp = datetime_helpers.DatetimeWithNanoseconds.from_timestamp_pb(timestamp)
+
+ assert _to_seconds(when) == _to_seconds(stamp)
+ assert stamp.microsecond == 123456
+ assert stamp.nanosecond == 123456789
+ assert stamp.tzinfo == datetime.timezone.utc
+
+
+def _to_seconds(value):
+ """Convert a datetime to seconds since the unix epoch.
+
+ Args:
+ value (datetime.datetime): The datetime to covert.
+
+ Returns:
+ int: Microseconds since the unix epoch.
+ """
+ assert value.tzinfo is datetime.timezone.utc
+ return calendar.timegm(value.timetuple())
diff --git a/tests/unit/test_exceptions.py b/tests/unit/test_exceptions.py
new file mode 100644
index 0000000..622f58a
--- /dev/null
+++ b/tests/unit/test_exceptions.py
@@ -0,0 +1,353 @@
+# Copyright 2014 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import http.client
+import json
+
+import mock
+import pytest
+import requests
+
+try:
+ import grpc
+ from grpc_status import rpc_status
+except ImportError:
+ grpc = rpc_status = None
+
+from google.api_core import exceptions
+from google.protobuf import any_pb2, json_format
+from google.rpc import error_details_pb2, status_pb2
+
+
+def test_create_google_cloud_error():
+ exception = exceptions.GoogleAPICallError("Testing")
+ exception.code = 600
+ assert str(exception) == "600 Testing"
+ assert exception.message == "Testing"
+ assert exception.errors == []
+ assert exception.response is None
+
+
+def test_create_google_cloud_error_with_args():
+ error = {
+ "code": 600,
+ "message": "Testing",
+ }
+ response = mock.sentinel.response
+ exception = exceptions.GoogleAPICallError("Testing", [error], response=response)
+ exception.code = 600
+ assert str(exception) == "600 Testing"
+ assert exception.message == "Testing"
+ assert exception.errors == [error]
+ assert exception.response == response
+
+
+def test_from_http_status():
+ message = "message"
+ exception = exceptions.from_http_status(http.client.NOT_FOUND, message)
+ assert exception.code == http.client.NOT_FOUND
+ assert exception.message == message
+ assert exception.errors == []
+
+
+def test_from_http_status_with_errors_and_response():
+ message = "message"
+ errors = ["1", "2"]
+ response = mock.sentinel.response
+ exception = exceptions.from_http_status(
+ http.client.NOT_FOUND, message, errors=errors, response=response
+ )
+
+ assert isinstance(exception, exceptions.NotFound)
+ assert exception.code == http.client.NOT_FOUND
+ assert exception.message == message
+ assert exception.errors == errors
+ assert exception.response == response
+
+
+def test_from_http_status_unknown_code():
+ message = "message"
+ status_code = 156
+ exception = exceptions.from_http_status(status_code, message)
+ assert exception.code == status_code
+ assert exception.message == message
+
+
+def make_response(content):
+ response = requests.Response()
+ response._content = content
+ response.status_code = http.client.NOT_FOUND
+ response.request = requests.Request(
+ method="POST", url="https://example.com"
+ ).prepare()
+ return response
+
+
+def test_from_http_response_no_content():
+ response = make_response(None)
+
+ exception = exceptions.from_http_response(response)
+
+ assert isinstance(exception, exceptions.NotFound)
+ assert exception.code == http.client.NOT_FOUND
+ assert exception.message == "POST https://example.com/: unknown error"
+ assert exception.response == response
+
+
+def test_from_http_response_text_content():
+ response = make_response(b"message")
+ response.encoding = "UTF8" # suppress charset_normalizer warning
+
+ exception = exceptions.from_http_response(response)
+
+ assert isinstance(exception, exceptions.NotFound)
+ assert exception.code == http.client.NOT_FOUND
+ assert exception.message == "POST https://example.com/: message"
+
+
+def test_from_http_response_json_content():
+ response = make_response(
+ json.dumps({"error": {"message": "json message", "errors": ["1", "2"]}}).encode(
+ "utf-8"
+ )
+ )
+
+ exception = exceptions.from_http_response(response)
+
+ assert isinstance(exception, exceptions.NotFound)
+ assert exception.code == http.client.NOT_FOUND
+ assert exception.message == "POST https://example.com/: json message"
+ assert exception.errors == ["1", "2"]
+
+
+def test_from_http_response_bad_json_content():
+ response = make_response(json.dumps({"meep": "moop"}).encode("utf-8"))
+
+ exception = exceptions.from_http_response(response)
+
+ assert isinstance(exception, exceptions.NotFound)
+ assert exception.code == http.client.NOT_FOUND
+ assert exception.message == "POST https://example.com/: unknown error"
+
+
+def test_from_http_response_json_unicode_content():
+ response = make_response(
+ json.dumps(
+ {"error": {"message": "\u2019 message", "errors": ["1", "2"]}}
+ ).encode("utf-8")
+ )
+
+ exception = exceptions.from_http_response(response)
+
+ assert isinstance(exception, exceptions.NotFound)
+ assert exception.code == http.client.NOT_FOUND
+ assert exception.message == "POST https://example.com/: \u2019 message"
+ assert exception.errors == ["1", "2"]
+
+
+@pytest.mark.skipif(grpc is None, reason="No grpc")
+def test_from_grpc_status():
+ message = "message"
+ exception = exceptions.from_grpc_status(grpc.StatusCode.OUT_OF_RANGE, message)
+ assert isinstance(exception, exceptions.BadRequest)
+ assert isinstance(exception, exceptions.OutOfRange)
+ assert exception.code == http.client.BAD_REQUEST
+ assert exception.grpc_status_code == grpc.StatusCode.OUT_OF_RANGE
+ assert exception.message == message
+ assert exception.errors == []
+
+
+@pytest.mark.skipif(grpc is None, reason="No grpc")
+def test_from_grpc_status_as_int():
+ message = "message"
+ exception = exceptions.from_grpc_status(11, message)
+ assert isinstance(exception, exceptions.BadRequest)
+ assert isinstance(exception, exceptions.OutOfRange)
+ assert exception.code == http.client.BAD_REQUEST
+ assert exception.grpc_status_code == grpc.StatusCode.OUT_OF_RANGE
+ assert exception.message == message
+ assert exception.errors == []
+
+
+@pytest.mark.skipif(grpc is None, reason="No grpc")
+def test_from_grpc_status_with_errors_and_response():
+ message = "message"
+ response = mock.sentinel.response
+ errors = ["1", "2"]
+ exception = exceptions.from_grpc_status(
+ grpc.StatusCode.OUT_OF_RANGE, message, errors=errors, response=response
+ )
+
+ assert isinstance(exception, exceptions.OutOfRange)
+ assert exception.message == message
+ assert exception.errors == errors
+ assert exception.response == response
+
+
+@pytest.mark.skipif(grpc is None, reason="No grpc")
+def test_from_grpc_status_unknown_code():
+ message = "message"
+ exception = exceptions.from_grpc_status(grpc.StatusCode.OK, message)
+ assert exception.grpc_status_code == grpc.StatusCode.OK
+ assert exception.message == message
+
+
+@pytest.mark.skipif(grpc is None, reason="No grpc")
+def test_from_grpc_error():
+ message = "message"
+ error = mock.create_autospec(grpc.Call, instance=True)
+ error.code.return_value = grpc.StatusCode.INVALID_ARGUMENT
+ error.details.return_value = message
+
+ exception = exceptions.from_grpc_error(error)
+
+ assert isinstance(exception, exceptions.BadRequest)
+ assert isinstance(exception, exceptions.InvalidArgument)
+ assert exception.code == http.client.BAD_REQUEST
+ assert exception.grpc_status_code == grpc.StatusCode.INVALID_ARGUMENT
+ assert exception.message == message
+ assert exception.errors == [error]
+ assert exception.response == error
+
+
+@pytest.mark.skipif(grpc is None, reason="No grpc")
+def test_from_grpc_error_non_call():
+ message = "message"
+ error = mock.create_autospec(grpc.RpcError, instance=True)
+ error.__str__.return_value = message
+
+ exception = exceptions.from_grpc_error(error)
+
+ assert isinstance(exception, exceptions.GoogleAPICallError)
+ assert exception.code is None
+ assert exception.grpc_status_code is None
+ assert exception.message == message
+ assert exception.errors == [error]
+ assert exception.response == error
+
+
+@pytest.mark.skipif(grpc is None, reason="No grpc")
+def test_from_grpc_error_bare_call():
+ message = "Testing"
+
+ class TestingError(grpc.Call, grpc.RpcError):
+ def __init__(self, exception):
+ self.exception = exception
+
+ def code(self):
+ return self.exception.grpc_status_code
+
+ def details(self):
+ return message
+
+ nested_message = "message"
+ error = TestingError(exceptions.GoogleAPICallError(nested_message))
+
+ exception = exceptions.from_grpc_error(error)
+
+ assert isinstance(exception, exceptions.GoogleAPICallError)
+ assert exception.code is None
+ assert exception.grpc_status_code is None
+ assert exception.message == message
+ assert exception.errors == [error]
+ assert exception.response == error
+ assert exception.details == []
+
+
+def create_bad_request_details():
+ bad_request_details = error_details_pb2.BadRequest()
+ field_violation = bad_request_details.field_violations.add()
+ field_violation.field = "document.content"
+ field_violation.description = "Must have some text content to annotate."
+ status_detail = any_pb2.Any()
+ status_detail.Pack(bad_request_details)
+ return status_detail
+
+
+def test_error_details_from_rest_response():
+ bad_request_detail = create_bad_request_details()
+ status = status_pb2.Status()
+ status.code = 3
+ status.message = (
+ "3 INVALID_ARGUMENT: One of content, or gcs_content_uri must be set."
+ )
+ status.details.append(bad_request_detail)
+
+ # See JSON schema in https://cloud.google.com/apis/design/errors#http_mapping
+ http_response = make_response(
+ json.dumps({"error": json.loads(json_format.MessageToJson(status))}).encode(
+ "utf-8"
+ )
+ )
+ exception = exceptions.from_http_response(http_response)
+ want_error_details = [json.loads(json_format.MessageToJson(bad_request_detail))]
+ assert want_error_details == exception.details
+ # 404 POST comes from make_response.
+ assert str(exception) == (
+ "404 POST https://example.com/: 3 INVALID_ARGUMENT:"
+ " One of content, or gcs_content_uri must be set."
+ " [{'@type': 'type.googleapis.com/google.rpc.BadRequest',"
+ " 'fieldViolations': [{'field': 'document.content',"
+ " 'description': 'Must have some text content to annotate.'}]}]"
+ )
+
+
+def test_error_details_from_v1_rest_response():
+ response = make_response(
+ json.dumps(
+ {"error": {"message": "\u2019 message", "errors": ["1", "2"]}}
+ ).encode("utf-8")
+ )
+ exception = exceptions.from_http_response(response)
+ assert exception.details == []
+
+
+@pytest.mark.skipif(grpc is None, reason="gRPC not importable")
+def test_error_details_from_grpc_response():
+ status = rpc_status.status_pb2.Status()
+ status.code = 3
+ status.message = (
+ "3 INVALID_ARGUMENT: One of content, or gcs_content_uri must be set."
+ )
+ status_detail = create_bad_request_details()
+ status.details.append(status_detail)
+
+ # Actualy error doesn't matter as long as its grpc.Call,
+ # because from_call is mocked.
+ error = mock.create_autospec(grpc.Call, instance=True)
+ with mock.patch("grpc_status.rpc_status.from_call") as m:
+ m.return_value = status
+ exception = exceptions.from_grpc_error(error)
+
+ bad_request_detail = error_details_pb2.BadRequest()
+ status_detail.Unpack(bad_request_detail)
+ assert exception.details == [bad_request_detail]
+
+
+@pytest.mark.skipif(grpc is None, reason="gRPC not importable")
+def test_error_details_from_grpc_response_unknown_error():
+ status_detail = any_pb2.Any()
+
+ status = rpc_status.status_pb2.Status()
+ status.code = 3
+ status.message = (
+ "3 INVALID_ARGUMENT: One of content, or gcs_content_uri must be set."
+ )
+ status.details.append(status_detail)
+
+ error = mock.create_autospec(grpc.Call, instance=True)
+ with mock.patch("grpc_status.rpc_status.from_call") as m:
+ m.return_value = status
+ exception = exceptions.from_grpc_error(error)
+ assert exception.details == [status_detail]
diff --git a/tests/unit/test_grpc_helpers.py b/tests/unit/test_grpc_helpers.py
new file mode 100644
index 0000000..ca969e4
--- /dev/null
+++ b/tests/unit/test_grpc_helpers.py
@@ -0,0 +1,860 @@
+# Copyright 2017 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import mock
+import pytest
+
+try:
+ import grpc
+except ImportError:
+ pytest.skip("No GRPC", allow_module_level=True)
+
+from google.api_core import exceptions
+from google.api_core import grpc_helpers
+import google.auth.credentials
+from google.longrunning import operations_pb2
+
+
+def test__patch_callable_name():
+ callable = mock.Mock(spec=["__class__"])
+ callable.__class__ = mock.Mock(spec=["__name__"])
+ callable.__class__.__name__ = "TestCallable"
+
+ grpc_helpers._patch_callable_name(callable)
+
+ assert callable.__name__ == "TestCallable"
+
+
+def test__patch_callable_name_no_op():
+ callable = mock.Mock(spec=["__name__"])
+ callable.__name__ = "test_callable"
+
+ grpc_helpers._patch_callable_name(callable)
+
+ assert callable.__name__ == "test_callable"
+
+
+class RpcErrorImpl(grpc.RpcError, grpc.Call):
+ def __init__(self, code):
+ super(RpcErrorImpl, self).__init__()
+ self._code = code
+
+ def code(self):
+ return self._code
+
+ def details(self):
+ return None
+
+ def trailing_metadata(self):
+ return None
+
+
+def test_wrap_unary_errors():
+ grpc_error = RpcErrorImpl(grpc.StatusCode.INVALID_ARGUMENT)
+ callable_ = mock.Mock(spec=["__call__"], side_effect=grpc_error)
+
+ wrapped_callable = grpc_helpers._wrap_unary_errors(callable_)
+
+ with pytest.raises(exceptions.InvalidArgument) as exc_info:
+ wrapped_callable(1, 2, three="four")
+
+ callable_.assert_called_once_with(1, 2, three="four")
+ assert exc_info.value.response == grpc_error
+
+
+class Test_StreamingResponseIterator:
+ @staticmethod
+ def _make_wrapped(*items):
+ return iter(items)
+
+ @staticmethod
+ def _make_one(wrapped, **kw):
+ return grpc_helpers._StreamingResponseIterator(wrapped, **kw)
+
+ def test_ctor_defaults(self):
+ wrapped = self._make_wrapped("a", "b", "c")
+ iterator = self._make_one(wrapped)
+ assert iterator._stored_first_result == "a"
+ assert list(wrapped) == ["b", "c"]
+
+ def test_ctor_explicit(self):
+ wrapped = self._make_wrapped("a", "b", "c")
+ iterator = self._make_one(wrapped, prefetch_first_result=False)
+ assert getattr(iterator, "_stored_first_result", self) is self
+ assert list(wrapped) == ["a", "b", "c"]
+
+ def test_ctor_w_rpc_error_on_prefetch(self):
+ wrapped = mock.MagicMock()
+ wrapped.__next__.side_effect = grpc.RpcError()
+
+ with pytest.raises(grpc.RpcError):
+ self._make_one(wrapped)
+
+ def test___iter__(self):
+ wrapped = self._make_wrapped("a", "b", "c")
+ iterator = self._make_one(wrapped)
+ assert iter(iterator) is iterator
+
+ def test___next___w_cached_first_result(self):
+ wrapped = self._make_wrapped("a", "b", "c")
+ iterator = self._make_one(wrapped)
+ assert next(iterator) == "a"
+ iterator = self._make_one(wrapped, prefetch_first_result=False)
+ assert next(iterator) == "b"
+ assert next(iterator) == "c"
+
+ def test___next___wo_cached_first_result(self):
+ wrapped = self._make_wrapped("a", "b", "c")
+ iterator = self._make_one(wrapped, prefetch_first_result=False)
+ assert next(iterator) == "a"
+ assert next(iterator) == "b"
+ assert next(iterator) == "c"
+
+ def test___next___w_rpc_error(self):
+ wrapped = mock.MagicMock()
+ wrapped.__next__.side_effect = grpc.RpcError()
+ iterator = self._make_one(wrapped, prefetch_first_result=False)
+
+ with pytest.raises(exceptions.GoogleAPICallError):
+ next(iterator)
+
+ def test_add_callback(self):
+ wrapped = mock.MagicMock()
+ callback = mock.Mock(spec={})
+ iterator = self._make_one(wrapped, prefetch_first_result=False)
+
+ assert iterator.add_callback(callback) is wrapped.add_callback.return_value
+
+ wrapped.add_callback.assert_called_once_with(callback)
+
+ def test_cancel(self):
+ wrapped = mock.MagicMock()
+ iterator = self._make_one(wrapped, prefetch_first_result=False)
+
+ assert iterator.cancel() is wrapped.cancel.return_value
+
+ wrapped.cancel.assert_called_once_with()
+
+ def test_code(self):
+ wrapped = mock.MagicMock()
+ iterator = self._make_one(wrapped, prefetch_first_result=False)
+
+ assert iterator.code() is wrapped.code.return_value
+
+ wrapped.code.assert_called_once_with()
+
+ def test_details(self):
+ wrapped = mock.MagicMock()
+ iterator = self._make_one(wrapped, prefetch_first_result=False)
+
+ assert iterator.details() is wrapped.details.return_value
+
+ wrapped.details.assert_called_once_with()
+
+ def test_initial_metadata(self):
+ wrapped = mock.MagicMock()
+ iterator = self._make_one(wrapped, prefetch_first_result=False)
+
+ assert iterator.initial_metadata() is wrapped.initial_metadata.return_value
+
+ wrapped.initial_metadata.assert_called_once_with()
+
+ def test_is_active(self):
+ wrapped = mock.MagicMock()
+ iterator = self._make_one(wrapped, prefetch_first_result=False)
+
+ assert iterator.is_active() is wrapped.is_active.return_value
+
+ wrapped.is_active.assert_called_once_with()
+
+ def test_time_remaining(self):
+ wrapped = mock.MagicMock()
+ iterator = self._make_one(wrapped, prefetch_first_result=False)
+
+ assert iterator.time_remaining() is wrapped.time_remaining.return_value
+
+ wrapped.time_remaining.assert_called_once_with()
+
+ def test_trailing_metadata(self):
+ wrapped = mock.MagicMock()
+ iterator = self._make_one(wrapped, prefetch_first_result=False)
+
+ assert iterator.trailing_metadata() is wrapped.trailing_metadata.return_value
+
+ wrapped.trailing_metadata.assert_called_once_with()
+
+
+def test_wrap_stream_okay():
+ expected_responses = [1, 2, 3]
+ callable_ = mock.Mock(spec=["__call__"], return_value=iter(expected_responses))
+
+ wrapped_callable = grpc_helpers._wrap_stream_errors(callable_)
+
+ got_iterator = wrapped_callable(1, 2, three="four")
+
+ responses = list(got_iterator)
+
+ callable_.assert_called_once_with(1, 2, three="four")
+ assert responses == expected_responses
+
+
+def test_wrap_stream_prefetch_disabled():
+ responses = [1, 2, 3]
+ iter_responses = iter(responses)
+ callable_ = mock.Mock(spec=["__call__"], return_value=iter_responses)
+ callable_._prefetch_first_result_ = False
+
+ wrapped_callable = grpc_helpers._wrap_stream_errors(callable_)
+ wrapped_callable(1, 2, three="four")
+
+ assert list(iter_responses) == responses # no items should have been pre-fetched
+ callable_.assert_called_once_with(1, 2, three="four")
+
+
+def test_wrap_stream_iterable_iterface():
+ response_iter = mock.create_autospec(grpc.Call, instance=True)
+ callable_ = mock.Mock(spec=["__call__"], return_value=response_iter)
+
+ wrapped_callable = grpc_helpers._wrap_stream_errors(callable_)
+
+ got_iterator = wrapped_callable()
+
+ callable_.assert_called_once_with()
+
+ # Check each aliased method in the grpc.Call interface
+ got_iterator.add_callback(mock.sentinel.callback)
+ response_iter.add_callback.assert_called_once_with(mock.sentinel.callback)
+
+ got_iterator.cancel()
+ response_iter.cancel.assert_called_once_with()
+
+ got_iterator.code()
+ response_iter.code.assert_called_once_with()
+
+ got_iterator.details()
+ response_iter.details.assert_called_once_with()
+
+ got_iterator.initial_metadata()
+ response_iter.initial_metadata.assert_called_once_with()
+
+ got_iterator.is_active()
+ response_iter.is_active.assert_called_once_with()
+
+ got_iterator.time_remaining()
+ response_iter.time_remaining.assert_called_once_with()
+
+ got_iterator.trailing_metadata()
+ response_iter.trailing_metadata.assert_called_once_with()
+
+
+def test_wrap_stream_errors_invocation():
+ grpc_error = RpcErrorImpl(grpc.StatusCode.INVALID_ARGUMENT)
+ callable_ = mock.Mock(spec=["__call__"], side_effect=grpc_error)
+
+ wrapped_callable = grpc_helpers._wrap_stream_errors(callable_)
+
+ with pytest.raises(exceptions.InvalidArgument) as exc_info:
+ wrapped_callable(1, 2, three="four")
+
+ callable_.assert_called_once_with(1, 2, three="four")
+ assert exc_info.value.response == grpc_error
+
+
+def test_wrap_stream_empty_iterator():
+ expected_responses = []
+ callable_ = mock.Mock(spec=["__call__"], return_value=iter(expected_responses))
+
+ wrapped_callable = grpc_helpers._wrap_stream_errors(callable_)
+
+ got_iterator = wrapped_callable()
+
+ responses = list(got_iterator)
+
+ callable_.assert_called_once_with()
+ assert responses == expected_responses
+
+
+class RpcResponseIteratorImpl(object):
+ def __init__(self, iterable):
+ self._iterable = iter(iterable)
+
+ def next(self):
+ next_item = next(self._iterable)
+ if isinstance(next_item, RpcErrorImpl):
+ raise next_item
+ return next_item
+
+ __next__ = next
+
+
+def test_wrap_stream_errors_iterator_initialization():
+ grpc_error = RpcErrorImpl(grpc.StatusCode.UNAVAILABLE)
+ response_iter = RpcResponseIteratorImpl([grpc_error])
+ callable_ = mock.Mock(spec=["__call__"], return_value=response_iter)
+
+ wrapped_callable = grpc_helpers._wrap_stream_errors(callable_)
+
+ with pytest.raises(exceptions.ServiceUnavailable) as exc_info:
+ wrapped_callable(1, 2, three="four")
+
+ callable_.assert_called_once_with(1, 2, three="four")
+ assert exc_info.value.response == grpc_error
+
+
+def test_wrap_stream_errors_during_iteration():
+ grpc_error = RpcErrorImpl(grpc.StatusCode.UNAVAILABLE)
+ response_iter = RpcResponseIteratorImpl([1, grpc_error])
+ callable_ = mock.Mock(spec=["__call__"], return_value=response_iter)
+
+ wrapped_callable = grpc_helpers._wrap_stream_errors(callable_)
+ got_iterator = wrapped_callable(1, 2, three="four")
+ next(got_iterator)
+
+ with pytest.raises(exceptions.ServiceUnavailable) as exc_info:
+ next(got_iterator)
+
+ callable_.assert_called_once_with(1, 2, three="four")
+ assert exc_info.value.response == grpc_error
+
+
+@mock.patch("google.api_core.grpc_helpers._wrap_unary_errors")
+def test_wrap_errors_non_streaming(wrap_unary_errors):
+ callable_ = mock.create_autospec(grpc.UnaryUnaryMultiCallable)
+
+ result = grpc_helpers.wrap_errors(callable_)
+
+ assert result == wrap_unary_errors.return_value
+ wrap_unary_errors.assert_called_once_with(callable_)
+
+
+@mock.patch("google.api_core.grpc_helpers._wrap_stream_errors")
+def test_wrap_errors_streaming(wrap_stream_errors):
+ callable_ = mock.create_autospec(grpc.UnaryStreamMultiCallable)
+
+ result = grpc_helpers.wrap_errors(callable_)
+
+ assert result == wrap_stream_errors.return_value
+ wrap_stream_errors.assert_called_once_with(callable_)
+
+
+@mock.patch("grpc.composite_channel_credentials")
+@mock.patch(
+ "google.auth.default",
+ autospec=True,
+ return_value=(mock.sentinel.credentials, mock.sentinel.projet),
+)
+@mock.patch("grpc.secure_channel")
+def test_create_channel_implicit(grpc_secure_channel, default, composite_creds_call):
+ target = "example.com:443"
+ composite_creds = composite_creds_call.return_value
+
+ channel = grpc_helpers.create_channel(target)
+
+ assert channel is grpc_secure_channel.return_value
+
+ default.assert_called_once_with(scopes=None, default_scopes=None)
+
+ if grpc_helpers.HAS_GRPC_GCP:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds, None)
+ else:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds)
+
+
+@mock.patch("google.auth.transport.grpc.AuthMetadataPlugin", autospec=True)
+@mock.patch(
+ "google.auth.transport.requests.Request",
+ autospec=True,
+ return_value=mock.sentinel.Request,
+)
+@mock.patch("grpc.composite_channel_credentials")
+@mock.patch(
+ "google.auth.default",
+ autospec=True,
+ return_value=(mock.sentinel.credentials, mock.sentinel.project),
+)
+@mock.patch("grpc.secure_channel")
+def test_create_channel_implicit_with_default_host(
+ grpc_secure_channel, default, composite_creds_call, request, auth_metadata_plugin
+):
+ target = "example.com:443"
+ default_host = "example.com"
+ composite_creds = composite_creds_call.return_value
+
+ channel = grpc_helpers.create_channel(target, default_host=default_host)
+
+ assert channel is grpc_secure_channel.return_value
+
+ default.assert_called_once_with(scopes=None, default_scopes=None)
+ auth_metadata_plugin.assert_called_once_with(
+ mock.sentinel.credentials, mock.sentinel.Request, default_host=default_host
+ )
+
+ if grpc_helpers.HAS_GRPC_GCP:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds, None)
+ else:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds)
+
+
+@mock.patch("grpc.composite_channel_credentials")
+@mock.patch(
+ "google.auth.default",
+ autospec=True,
+ return_value=(mock.sentinel.credentials, mock.sentinel.projet),
+)
+@mock.patch("grpc.secure_channel")
+def test_create_channel_implicit_with_ssl_creds(
+ grpc_secure_channel, default, composite_creds_call
+):
+ target = "example.com:443"
+
+ ssl_creds = grpc.ssl_channel_credentials()
+
+ grpc_helpers.create_channel(target, ssl_credentials=ssl_creds)
+
+ default.assert_called_once_with(scopes=None, default_scopes=None)
+
+ composite_creds_call.assert_called_once_with(ssl_creds, mock.ANY)
+ composite_creds = composite_creds_call.return_value
+ if grpc_helpers.HAS_GRPC_GCP:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds, None)
+ else:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds)
+
+
+@mock.patch("grpc.composite_channel_credentials")
+@mock.patch(
+ "google.auth.default",
+ autospec=True,
+ return_value=(mock.sentinel.credentials, mock.sentinel.projet),
+)
+@mock.patch("grpc.secure_channel")
+def test_create_channel_implicit_with_scopes(
+ grpc_secure_channel, default, composite_creds_call
+):
+ target = "example.com:443"
+ composite_creds = composite_creds_call.return_value
+
+ channel = grpc_helpers.create_channel(target, scopes=["one", "two"])
+
+ assert channel is grpc_secure_channel.return_value
+
+ default.assert_called_once_with(scopes=["one", "two"], default_scopes=None)
+
+ if grpc_helpers.HAS_GRPC_GCP:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds, None)
+ else:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds)
+
+
+@mock.patch("grpc.composite_channel_credentials")
+@mock.patch(
+ "google.auth.default",
+ autospec=True,
+ return_value=(mock.sentinel.credentials, mock.sentinel.projet),
+)
+@mock.patch("grpc.secure_channel")
+def test_create_channel_implicit_with_default_scopes(
+ grpc_secure_channel, default, composite_creds_call
+):
+ target = "example.com:443"
+ composite_creds = composite_creds_call.return_value
+
+ channel = grpc_helpers.create_channel(target, default_scopes=["three", "four"])
+
+ assert channel is grpc_secure_channel.return_value
+
+ default.assert_called_once_with(scopes=None, default_scopes=["three", "four"])
+
+ if grpc_helpers.HAS_GRPC_GCP:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds, None)
+ else:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds)
+
+
+def test_create_channel_explicit_with_duplicate_credentials():
+ target = "example.com:443"
+
+ with pytest.raises(exceptions.DuplicateCredentialArgs):
+ grpc_helpers.create_channel(
+ target,
+ credentials_file="credentials.json",
+ credentials=mock.sentinel.credentials,
+ )
+
+
+@mock.patch("grpc.composite_channel_credentials")
+@mock.patch("google.auth.credentials.with_scopes_if_required", autospec=True)
+@mock.patch("grpc.secure_channel")
+def test_create_channel_explicit(grpc_secure_channel, auth_creds, composite_creds_call):
+ target = "example.com:443"
+ composite_creds = composite_creds_call.return_value
+
+ channel = grpc_helpers.create_channel(target, credentials=mock.sentinel.credentials)
+
+ auth_creds.assert_called_once_with(
+ mock.sentinel.credentials, scopes=None, default_scopes=None
+ )
+
+ assert channel is grpc_secure_channel.return_value
+ if grpc_helpers.HAS_GRPC_GCP:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds, None)
+ else:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds)
+
+
+@mock.patch("grpc.composite_channel_credentials")
+@mock.patch("grpc.secure_channel")
+def test_create_channel_explicit_scoped(grpc_secure_channel, composite_creds_call):
+ target = "example.com:443"
+ scopes = ["1", "2"]
+ composite_creds = composite_creds_call.return_value
+
+ credentials = mock.create_autospec(google.auth.credentials.Scoped, instance=True)
+ credentials.requires_scopes = True
+
+ channel = grpc_helpers.create_channel(
+ target, credentials=credentials, scopes=scopes
+ )
+
+ credentials.with_scopes.assert_called_once_with(scopes, default_scopes=None)
+
+ assert channel is grpc_secure_channel.return_value
+ if grpc_helpers.HAS_GRPC_GCP:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds, None)
+ else:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds)
+
+
+@mock.patch("grpc.composite_channel_credentials")
+@mock.patch("grpc.secure_channel")
+def test_create_channel_explicit_default_scopes(
+ grpc_secure_channel, composite_creds_call
+):
+ target = "example.com:443"
+ default_scopes = ["3", "4"]
+ composite_creds = composite_creds_call.return_value
+
+ credentials = mock.create_autospec(google.auth.credentials.Scoped, instance=True)
+ credentials.requires_scopes = True
+
+ channel = grpc_helpers.create_channel(
+ target, credentials=credentials, default_scopes=default_scopes
+ )
+
+ credentials.with_scopes.assert_called_once_with(
+ scopes=None, default_scopes=default_scopes
+ )
+
+ assert channel is grpc_secure_channel.return_value
+ if grpc_helpers.HAS_GRPC_GCP:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds, None)
+ else:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds)
+
+
+@mock.patch("grpc.composite_channel_credentials")
+@mock.patch("grpc.secure_channel")
+def test_create_channel_explicit_with_quota_project(
+ grpc_secure_channel, composite_creds_call
+):
+ target = "example.com:443"
+ composite_creds = composite_creds_call.return_value
+
+ credentials = mock.create_autospec(
+ google.auth.credentials.CredentialsWithQuotaProject, instance=True
+ )
+
+ channel = grpc_helpers.create_channel(
+ target, credentials=credentials, quota_project_id="project-foo"
+ )
+
+ credentials.with_quota_project.assert_called_once_with("project-foo")
+
+ assert channel is grpc_secure_channel.return_value
+ if grpc_helpers.HAS_GRPC_GCP:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds, None)
+ else:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds)
+
+
+@mock.patch("grpc.composite_channel_credentials")
+@mock.patch("grpc.secure_channel")
+@mock.patch(
+ "google.auth.load_credentials_from_file",
+ autospec=True,
+ return_value=(mock.sentinel.credentials, mock.sentinel.project),
+)
+def test_create_channel_with_credentials_file(
+ load_credentials_from_file, grpc_secure_channel, composite_creds_call
+):
+ target = "example.com:443"
+
+ credentials_file = "/path/to/credentials/file.json"
+ composite_creds = composite_creds_call.return_value
+
+ channel = grpc_helpers.create_channel(target, credentials_file=credentials_file)
+
+ google.auth.load_credentials_from_file.assert_called_once_with(
+ credentials_file, scopes=None, default_scopes=None
+ )
+
+ assert channel is grpc_secure_channel.return_value
+ if grpc_helpers.HAS_GRPC_GCP:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds, None)
+ else:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds)
+
+
+@mock.patch("grpc.composite_channel_credentials")
+@mock.patch("grpc.secure_channel")
+@mock.patch(
+ "google.auth.load_credentials_from_file",
+ autospec=True,
+ return_value=(mock.sentinel.credentials, mock.sentinel.project),
+)
+def test_create_channel_with_credentials_file_and_scopes(
+ load_credentials_from_file, grpc_secure_channel, composite_creds_call
+):
+ target = "example.com:443"
+ scopes = ["1", "2"]
+
+ credentials_file = "/path/to/credentials/file.json"
+ composite_creds = composite_creds_call.return_value
+
+ channel = grpc_helpers.create_channel(
+ target, credentials_file=credentials_file, scopes=scopes
+ )
+
+ google.auth.load_credentials_from_file.assert_called_once_with(
+ credentials_file, scopes=scopes, default_scopes=None
+ )
+
+ assert channel is grpc_secure_channel.return_value
+ if grpc_helpers.HAS_GRPC_GCP:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds, None)
+ else:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds)
+
+
+@mock.patch("grpc.composite_channel_credentials")
+@mock.patch("grpc.secure_channel")
+@mock.patch(
+ "google.auth.load_credentials_from_file",
+ autospec=True,
+ return_value=(mock.sentinel.credentials, mock.sentinel.project),
+)
+def test_create_channel_with_credentials_file_and_default_scopes(
+ load_credentials_from_file, grpc_secure_channel, composite_creds_call
+):
+ target = "example.com:443"
+ default_scopes = ["3", "4"]
+
+ credentials_file = "/path/to/credentials/file.json"
+ composite_creds = composite_creds_call.return_value
+
+ channel = grpc_helpers.create_channel(
+ target, credentials_file=credentials_file, default_scopes=default_scopes
+ )
+
+ load_credentials_from_file.assert_called_once_with(
+ credentials_file, scopes=None, default_scopes=default_scopes
+ )
+
+ assert channel is grpc_secure_channel.return_value
+ if grpc_helpers.HAS_GRPC_GCP:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds, None)
+ else:
+ grpc_secure_channel.assert_called_once_with(target, composite_creds)
+
+
+@pytest.mark.skipif(
+ not grpc_helpers.HAS_GRPC_GCP, reason="grpc_gcp module not available"
+)
+@mock.patch("grpc_gcp.secure_channel")
+def test_create_channel_with_grpc_gcp(grpc_gcp_secure_channel):
+ target = "example.com:443"
+ scopes = ["test_scope"]
+
+ credentials = mock.create_autospec(google.auth.credentials.Scoped, instance=True)
+ credentials.requires_scopes = True
+
+ grpc_helpers.create_channel(target, credentials=credentials, scopes=scopes)
+ grpc_gcp_secure_channel.assert_called()
+
+ credentials.with_scopes.assert_called_once_with(scopes, default_scopes=None)
+
+
+@pytest.mark.skipif(grpc_helpers.HAS_GRPC_GCP, reason="grpc_gcp module not available")
+@mock.patch("grpc.secure_channel")
+def test_create_channel_without_grpc_gcp(grpc_secure_channel):
+ target = "example.com:443"
+ scopes = ["test_scope"]
+
+ credentials = mock.create_autospec(google.auth.credentials.Scoped, instance=True)
+ credentials.requires_scopes = True
+
+ grpc_helpers.create_channel(target, credentials=credentials, scopes=scopes)
+ grpc_secure_channel.assert_called()
+
+ credentials.with_scopes.assert_called_once_with(scopes, default_scopes=None)
+
+
+class TestChannelStub(object):
+ def test_single_response(self):
+ channel = grpc_helpers.ChannelStub()
+ stub = operations_pb2.OperationsStub(channel)
+ expected_request = operations_pb2.GetOperationRequest(name="meep")
+ expected_response = operations_pb2.Operation(name="moop")
+
+ channel.GetOperation.response = expected_response
+
+ response = stub.GetOperation(expected_request)
+
+ assert response == expected_response
+ assert channel.requests == [("GetOperation", expected_request)]
+ assert channel.GetOperation.requests == [expected_request]
+
+ def test_no_response(self):
+ channel = grpc_helpers.ChannelStub()
+ stub = operations_pb2.OperationsStub(channel)
+ expected_request = operations_pb2.GetOperationRequest(name="meep")
+
+ with pytest.raises(ValueError) as exc_info:
+ stub.GetOperation(expected_request)
+
+ assert exc_info.match("GetOperation")
+
+ def test_missing_method(self):
+ channel = grpc_helpers.ChannelStub()
+
+ with pytest.raises(AttributeError):
+ channel.DoesNotExist.response
+
+ def test_exception_response(self):
+ channel = grpc_helpers.ChannelStub()
+ stub = operations_pb2.OperationsStub(channel)
+ expected_request = operations_pb2.GetOperationRequest(name="meep")
+
+ channel.GetOperation.response = RuntimeError()
+
+ with pytest.raises(RuntimeError):
+ stub.GetOperation(expected_request)
+
+ def test_callable_response(self):
+ channel = grpc_helpers.ChannelStub()
+ stub = operations_pb2.OperationsStub(channel)
+ expected_request = operations_pb2.GetOperationRequest(name="meep")
+ expected_response = operations_pb2.Operation(name="moop")
+
+ on_get_operation = mock.Mock(spec=("__call__",), return_value=expected_response)
+
+ channel.GetOperation.response = on_get_operation
+
+ response = stub.GetOperation(expected_request)
+
+ assert response == expected_response
+ on_get_operation.assert_called_once_with(expected_request)
+
+ def test_multiple_responses(self):
+ channel = grpc_helpers.ChannelStub()
+ stub = operations_pb2.OperationsStub(channel)
+ expected_request = operations_pb2.GetOperationRequest(name="meep")
+ expected_responses = [
+ operations_pb2.Operation(name="foo"),
+ operations_pb2.Operation(name="bar"),
+ operations_pb2.Operation(name="baz"),
+ ]
+
+ channel.GetOperation.responses = iter(expected_responses)
+
+ response1 = stub.GetOperation(expected_request)
+ response2 = stub.GetOperation(expected_request)
+ response3 = stub.GetOperation(expected_request)
+
+ assert response1 == expected_responses[0]
+ assert response2 == expected_responses[1]
+ assert response3 == expected_responses[2]
+ assert channel.requests == [("GetOperation", expected_request)] * 3
+ assert channel.GetOperation.requests == [expected_request] * 3
+
+ with pytest.raises(StopIteration):
+ stub.GetOperation(expected_request)
+
+ def test_multiple_responses_and_single_response_error(self):
+ channel = grpc_helpers.ChannelStub()
+ stub = operations_pb2.OperationsStub(channel)
+ channel.GetOperation.responses = []
+ channel.GetOperation.response = mock.sentinel.response
+
+ with pytest.raises(ValueError):
+ stub.GetOperation(operations_pb2.GetOperationRequest())
+
+ def test_call_info(self):
+ channel = grpc_helpers.ChannelStub()
+ stub = operations_pb2.OperationsStub(channel)
+ expected_request = operations_pb2.GetOperationRequest(name="meep")
+ expected_response = operations_pb2.Operation(name="moop")
+ expected_metadata = [("red", "blue"), ("two", "shoe")]
+ expected_credentials = mock.sentinel.credentials
+ channel.GetOperation.response = expected_response
+
+ response = stub.GetOperation(
+ expected_request,
+ timeout=42,
+ metadata=expected_metadata,
+ credentials=expected_credentials,
+ )
+
+ assert response == expected_response
+ assert channel.requests == [("GetOperation", expected_request)]
+ assert channel.GetOperation.calls == [
+ (expected_request, 42, expected_metadata, expected_credentials)
+ ]
+
+ def test_unary_unary(self):
+ channel = grpc_helpers.ChannelStub()
+ method_name = "GetOperation"
+ callable_stub = channel.unary_unary(method_name)
+ assert callable_stub._method == method_name
+ assert callable_stub._channel == channel
+
+ def test_unary_stream(self):
+ channel = grpc_helpers.ChannelStub()
+ method_name = "GetOperation"
+ callable_stub = channel.unary_stream(method_name)
+ assert callable_stub._method == method_name
+ assert callable_stub._channel == channel
+
+ def test_stream_unary(self):
+ channel = grpc_helpers.ChannelStub()
+ method_name = "GetOperation"
+ callable_stub = channel.stream_unary(method_name)
+ assert callable_stub._method == method_name
+ assert callable_stub._channel == channel
+
+ def test_stream_stream(self):
+ channel = grpc_helpers.ChannelStub()
+ method_name = "GetOperation"
+ callable_stub = channel.stream_stream(method_name)
+ assert callable_stub._method == method_name
+ assert callable_stub._channel == channel
+
+ def test_subscribe_unsubscribe(self):
+ channel = grpc_helpers.ChannelStub()
+ assert channel.subscribe(None) is None
+ assert channel.unsubscribe(None) is None
+
+ def test_close(self):
+ channel = grpc_helpers.ChannelStub()
+ assert channel.close() is None
diff --git a/tests/unit/test_iam.py b/tests/unit/test_iam.py
new file mode 100644
index 0000000..fbd242e
--- /dev/null
+++ b/tests/unit/test_iam.py
@@ -0,0 +1,382 @@
+# Copyright 2017 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from google.api_core.iam import _DICT_ACCESS_MSG, InvalidOperationException
+
+
+class TestPolicy:
+ @staticmethod
+ def _get_target_class():
+ from google.api_core.iam import Policy
+
+ return Policy
+
+ def _make_one(self, *args, **kw):
+ return self._get_target_class()(*args, **kw)
+
+ def test_ctor_defaults(self):
+ empty = frozenset()
+ policy = self._make_one()
+ assert policy.etag is None
+ assert policy.version is None
+ assert policy.owners == empty
+ assert policy.editors == empty
+ assert policy.viewers == empty
+ assert len(policy) == 0
+ assert dict(policy) == {}
+
+ def test_ctor_explicit(self):
+ VERSION = 1
+ ETAG = "ETAG"
+ empty = frozenset()
+ policy = self._make_one(ETAG, VERSION)
+ assert policy.etag == ETAG
+ assert policy.version == VERSION
+ assert policy.owners == empty
+ assert policy.editors == empty
+ assert policy.viewers == empty
+ assert len(policy) == 0
+ assert dict(policy) == {}
+
+ def test___getitem___miss(self):
+ policy = self._make_one()
+ assert policy["nonesuch"] == set()
+
+ def test__getitem___and_set(self):
+ from google.api_core.iam import OWNER_ROLE
+
+ policy = self._make_one()
+
+ # get the policy using the getter and then modify it
+ policy[OWNER_ROLE].add("user:phred@example.com")
+ assert dict(policy) == {OWNER_ROLE: {"user:phred@example.com"}}
+
+ def test___getitem___version3(self):
+ policy = self._make_one("DEADBEEF", 3)
+ with pytest.raises(InvalidOperationException, match=_DICT_ACCESS_MSG):
+ policy["role"]
+
+ def test___getitem___with_conditions(self):
+ USER = "user:phred@example.com"
+ CONDITION = {"expression": "2 > 1"}
+ policy = self._make_one("DEADBEEF", 1)
+ policy.bindings = [
+ {"role": "role/reader", "members": [USER], "condition": CONDITION}
+ ]
+ with pytest.raises(InvalidOperationException, match=_DICT_ACCESS_MSG):
+ policy["role/reader"]
+
+ def test___setitem__(self):
+ USER = "user:phred@example.com"
+ PRINCIPALS = set([USER])
+ policy = self._make_one()
+ policy["rolename"] = [USER]
+ assert policy["rolename"] == PRINCIPALS
+ assert len(policy) == 1
+ assert dict(policy) == {"rolename": PRINCIPALS}
+
+ def test__set_item__overwrite(self):
+ GROUP = "group:test@group.com"
+ USER = "user:phred@example.com"
+ ALL_USERS = "allUsers"
+ MEMBERS = set([ALL_USERS])
+ GROUPS = set([GROUP])
+ policy = self._make_one()
+ policy["first"] = [GROUP]
+ policy["second"] = [USER]
+ policy["second"] = [ALL_USERS]
+ assert policy["second"] == MEMBERS
+ assert len(policy) == 2
+ assert dict(policy) == {"first": GROUPS, "second": MEMBERS}
+
+ def test___setitem___version3(self):
+ policy = self._make_one("DEADBEEF", 3)
+ with pytest.raises(InvalidOperationException, match=_DICT_ACCESS_MSG):
+ policy["role/reader"] = ["user:phred@example.com"]
+
+ def test___setitem___with_conditions(self):
+ USER = "user:phred@example.com"
+ CONDITION = {"expression": "2 > 1"}
+ policy = self._make_one("DEADBEEF", 1)
+ policy.bindings = [
+ {"role": "role/reader", "members": set([USER]), "condition": CONDITION}
+ ]
+ with pytest.raises(InvalidOperationException, match=_DICT_ACCESS_MSG):
+ policy["role/reader"] = ["user:phred@example.com"]
+
+ def test___delitem___hit(self):
+ policy = self._make_one()
+ policy.bindings = [
+ {"role": "to/keep", "members": set(["phred@example.com"])},
+ {"role": "to/remove", "members": set(["phred@example.com"])},
+ ]
+ del policy["to/remove"]
+ assert len(policy) == 1
+ assert dict(policy) == {"to/keep": set(["phred@example.com"])}
+
+ def test___delitem___miss(self):
+ policy = self._make_one()
+ with pytest.raises(KeyError):
+ del policy["nonesuch"]
+
+ def test___delitem___version3(self):
+ policy = self._make_one("DEADBEEF", 3)
+ with pytest.raises(InvalidOperationException, match=_DICT_ACCESS_MSG):
+ del policy["role/reader"]
+
+ def test___delitem___with_conditions(self):
+ USER = "user:phred@example.com"
+ CONDITION = {"expression": "2 > 1"}
+ policy = self._make_one("DEADBEEF", 1)
+ policy.bindings = [
+ {"role": "role/reader", "members": set([USER]), "condition": CONDITION}
+ ]
+ with pytest.raises(InvalidOperationException, match=_DICT_ACCESS_MSG):
+ del policy["role/reader"]
+
+ def test_bindings_property(self):
+ USER = "user:phred@example.com"
+ CONDITION = {"expression": "2 > 1"}
+ policy = self._make_one()
+ BINDINGS = [
+ {"role": "role/reader", "members": set([USER]), "condition": CONDITION}
+ ]
+ policy.bindings = BINDINGS
+ assert policy.bindings == BINDINGS
+
+ def test_owners_getter(self):
+ from google.api_core.iam import OWNER_ROLE
+
+ MEMBER = "user:phred@example.com"
+ expected = frozenset([MEMBER])
+ policy = self._make_one()
+ policy[OWNER_ROLE] = [MEMBER]
+ assert policy.owners == expected
+
+ def test_owners_setter(self):
+ import warnings
+ from google.api_core.iam import OWNER_ROLE
+
+ MEMBER = "user:phred@example.com"
+ expected = set([MEMBER])
+ policy = self._make_one()
+
+ with warnings.catch_warnings(record=True) as warned:
+ policy.owners = [MEMBER]
+
+ (warning,) = warned
+ assert warning.category is DeprecationWarning
+ assert policy[OWNER_ROLE] == expected
+
+ def test_editors_getter(self):
+ from google.api_core.iam import EDITOR_ROLE
+
+ MEMBER = "user:phred@example.com"
+ expected = frozenset([MEMBER])
+ policy = self._make_one()
+ policy[EDITOR_ROLE] = [MEMBER]
+ assert policy.editors == expected
+
+ def test_editors_setter(self):
+ import warnings
+ from google.api_core.iam import EDITOR_ROLE
+
+ MEMBER = "user:phred@example.com"
+ expected = set([MEMBER])
+ policy = self._make_one()
+
+ with warnings.catch_warnings(record=True) as warned:
+ policy.editors = [MEMBER]
+
+ (warning,) = warned
+ assert warning.category is DeprecationWarning
+ assert policy[EDITOR_ROLE] == expected
+
+ def test_viewers_getter(self):
+ from google.api_core.iam import VIEWER_ROLE
+
+ MEMBER = "user:phred@example.com"
+ expected = frozenset([MEMBER])
+ policy = self._make_one()
+ policy[VIEWER_ROLE] = [MEMBER]
+ assert policy.viewers == expected
+
+ def test_viewers_setter(self):
+ import warnings
+ from google.api_core.iam import VIEWER_ROLE
+
+ MEMBER = "user:phred@example.com"
+ expected = set([MEMBER])
+ policy = self._make_one()
+
+ with warnings.catch_warnings(record=True) as warned:
+ policy.viewers = [MEMBER]
+
+ (warning,) = warned
+ assert warning.category is DeprecationWarning
+ assert policy[VIEWER_ROLE] == expected
+
+ def test_user(self):
+ EMAIL = "phred@example.com"
+ MEMBER = "user:%s" % (EMAIL,)
+ policy = self._make_one()
+ assert policy.user(EMAIL) == MEMBER
+
+ def test_service_account(self):
+ EMAIL = "phred@example.com"
+ MEMBER = "serviceAccount:%s" % (EMAIL,)
+ policy = self._make_one()
+ assert policy.service_account(EMAIL) == MEMBER
+
+ def test_group(self):
+ EMAIL = "phred@example.com"
+ MEMBER = "group:%s" % (EMAIL,)
+ policy = self._make_one()
+ assert policy.group(EMAIL) == MEMBER
+
+ def test_domain(self):
+ DOMAIN = "example.com"
+ MEMBER = "domain:%s" % (DOMAIN,)
+ policy = self._make_one()
+ assert policy.domain(DOMAIN) == MEMBER
+
+ def test_all_users(self):
+ policy = self._make_one()
+ assert policy.all_users() == "allUsers"
+
+ def test_authenticated_users(self):
+ policy = self._make_one()
+ assert policy.authenticated_users() == "allAuthenticatedUsers"
+
+ def test_from_api_repr_only_etag(self):
+ empty = frozenset()
+ RESOURCE = {"etag": "ACAB"}
+ klass = self._get_target_class()
+ policy = klass.from_api_repr(RESOURCE)
+ assert policy.etag == "ACAB"
+ assert policy.version is None
+ assert policy.owners == empty
+ assert policy.editors == empty
+ assert policy.viewers == empty
+ assert dict(policy) == {}
+
+ def test_from_api_repr_complete(self):
+ from google.api_core.iam import OWNER_ROLE, EDITOR_ROLE, VIEWER_ROLE
+
+ OWNER1 = "group:cloud-logs@google.com"
+ OWNER2 = "user:phred@example.com"
+ EDITOR1 = "domain:google.com"
+ EDITOR2 = "user:phred@example.com"
+ VIEWER1 = "serviceAccount:1234-abcdef@service.example.com"
+ VIEWER2 = "user:phred@example.com"
+ RESOURCE = {
+ "etag": "DEADBEEF",
+ "version": 1,
+ "bindings": [
+ {"role": OWNER_ROLE, "members": [OWNER1, OWNER2]},
+ {"role": EDITOR_ROLE, "members": [EDITOR1, EDITOR2]},
+ {"role": VIEWER_ROLE, "members": [VIEWER1, VIEWER2]},
+ ],
+ }
+ klass = self._get_target_class()
+ policy = klass.from_api_repr(RESOURCE)
+ assert policy.etag == "DEADBEEF"
+ assert policy.version == 1
+ assert policy.owners, frozenset([OWNER1 == OWNER2])
+ assert policy.editors, frozenset([EDITOR1 == EDITOR2])
+ assert policy.viewers, frozenset([VIEWER1 == VIEWER2])
+ assert dict(policy) == {
+ OWNER_ROLE: set([OWNER1, OWNER2]),
+ EDITOR_ROLE: set([EDITOR1, EDITOR2]),
+ VIEWER_ROLE: set([VIEWER1, VIEWER2]),
+ }
+ assert policy.bindings == [
+ {"role": OWNER_ROLE, "members": set([OWNER1, OWNER2])},
+ {"role": EDITOR_ROLE, "members": set([EDITOR1, EDITOR2])},
+ {"role": VIEWER_ROLE, "members": set([VIEWER1, VIEWER2])},
+ ]
+
+ def test_from_api_repr_unknown_role(self):
+ USER = "user:phred@example.com"
+ GROUP = "group:cloud-logs@google.com"
+ RESOURCE = {
+ "etag": "DEADBEEF",
+ "version": 1,
+ "bindings": [{"role": "unknown", "members": [USER, GROUP]}],
+ }
+ klass = self._get_target_class()
+ policy = klass.from_api_repr(RESOURCE)
+ assert policy.etag == "DEADBEEF"
+ assert policy.version == 1
+ assert dict(policy), {"unknown": set([GROUP == USER])}
+
+ def test_to_api_repr_defaults(self):
+ policy = self._make_one()
+ assert policy.to_api_repr() == {}
+
+ def test_to_api_repr_only_etag(self):
+ policy = self._make_one("DEADBEEF")
+ assert policy.to_api_repr() == {"etag": "DEADBEEF"}
+
+ def test_to_api_repr_binding_wo_members(self):
+ policy = self._make_one()
+ policy["empty"] = []
+ assert policy.to_api_repr() == {}
+
+ def test_to_api_repr_binding_w_duplicates(self):
+ import warnings
+ from google.api_core.iam import OWNER_ROLE
+
+ OWNER = "group:cloud-logs@google.com"
+ policy = self._make_one()
+ with warnings.catch_warnings(record=True):
+ policy.owners = [OWNER, OWNER]
+ assert policy.to_api_repr() == {
+ "bindings": [{"role": OWNER_ROLE, "members": [OWNER]}]
+ }
+
+ def test_to_api_repr_full(self):
+ import operator
+ from google.api_core.iam import OWNER_ROLE, EDITOR_ROLE, VIEWER_ROLE
+
+ OWNER1 = "group:cloud-logs@google.com"
+ OWNER2 = "user:phred@example.com"
+ EDITOR1 = "domain:google.com"
+ EDITOR2 = "user:phred@example.com"
+ VIEWER1 = "serviceAccount:1234-abcdef@service.example.com"
+ VIEWER2 = "user:phred@example.com"
+ CONDITION = {
+ "title": "title",
+ "description": "description",
+ "expression": "true",
+ }
+ BINDINGS = [
+ {"role": OWNER_ROLE, "members": [OWNER1, OWNER2]},
+ {"role": EDITOR_ROLE, "members": [EDITOR1, EDITOR2]},
+ {"role": VIEWER_ROLE, "members": [VIEWER1, VIEWER2]},
+ {
+ "role": VIEWER_ROLE,
+ "members": [VIEWER1, VIEWER2],
+ "condition": CONDITION,
+ },
+ ]
+ policy = self._make_one("DEADBEEF", 1)
+ policy.bindings = BINDINGS
+ resource = policy.to_api_repr()
+ assert resource["etag"] == "DEADBEEF"
+ assert resource["version"] == 1
+ key = operator.itemgetter("role")
+ assert sorted(resource["bindings"], key=key) == sorted(BINDINGS, key=key)
diff --git a/tests/unit/test_operation.py b/tests/unit/test_operation.py
new file mode 100644
index 0000000..22e23bc
--- /dev/null
+++ b/tests/unit/test_operation.py
@@ -0,0 +1,326 @@
+# Copyright 2017, Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import mock
+import pytest
+
+try:
+ import grpc # noqa: F401
+except ImportError:
+ pytest.skip("No GRPC", allow_module_level=True)
+
+from google.api_core import exceptions
+from google.api_core import operation
+from google.api_core import operations_v1
+from google.api_core import retry
+from google.longrunning import operations_pb2
+from google.protobuf import struct_pb2
+from google.rpc import code_pb2
+from google.rpc import status_pb2
+
+TEST_OPERATION_NAME = "test/operation"
+
+
+def make_operation_proto(
+ name=TEST_OPERATION_NAME, metadata=None, response=None, error=None, **kwargs
+):
+ operation_proto = operations_pb2.Operation(name=name, **kwargs)
+
+ if metadata is not None:
+ operation_proto.metadata.Pack(metadata)
+
+ if response is not None:
+ operation_proto.response.Pack(response)
+
+ if error is not None:
+ operation_proto.error.CopyFrom(error)
+
+ return operation_proto
+
+
+def make_operation_future(client_operations_responses=None):
+ if client_operations_responses is None:
+ client_operations_responses = [make_operation_proto()]
+
+ refresh = mock.Mock(spec=["__call__"], side_effect=client_operations_responses)
+ refresh.responses = client_operations_responses
+ cancel = mock.Mock(spec=["__call__"])
+ operation_future = operation.Operation(
+ client_operations_responses[0],
+ refresh,
+ cancel,
+ result_type=struct_pb2.Struct,
+ metadata_type=struct_pb2.Struct,
+ )
+
+ return operation_future, refresh, cancel
+
+
+def test_constructor():
+ future, refresh, _ = make_operation_future()
+
+ assert future.operation == refresh.responses[0]
+ assert future.operation.done is False
+ assert future.operation.name == TEST_OPERATION_NAME
+ assert future.metadata is None
+ assert future.running()
+
+
+def test_metadata():
+ expected_metadata = struct_pb2.Struct()
+ future, _, _ = make_operation_future(
+ [make_operation_proto(metadata=expected_metadata)]
+ )
+
+ assert future.metadata == expected_metadata
+
+
+def test_cancellation():
+ responses = [
+ make_operation_proto(),
+ # Second response indicates that the operation was cancelled.
+ make_operation_proto(
+ done=True, error=status_pb2.Status(code=code_pb2.CANCELLED)
+ ),
+ ]
+ future, _, cancel = make_operation_future(responses)
+
+ assert future.cancel()
+ assert future.cancelled()
+ cancel.assert_called_once_with()
+
+ # Cancelling twice should have no effect.
+ assert not future.cancel()
+ cancel.assert_called_once_with()
+
+
+def test_result():
+ expected_result = struct_pb2.Struct()
+ responses = [
+ make_operation_proto(),
+ # Second operation response includes the result.
+ make_operation_proto(done=True, response=expected_result),
+ ]
+ future, _, _ = make_operation_future(responses)
+
+ result = future.result()
+
+ assert result == expected_result
+ assert future.done()
+
+
+def test_done_w_retry():
+ RETRY_PREDICATE = retry.if_exception_type(exceptions.TooManyRequests)
+ test_retry = retry.Retry(predicate=RETRY_PREDICATE)
+
+ expected_result = struct_pb2.Struct()
+ responses = [
+ make_operation_proto(),
+ # Second operation response includes the result.
+ make_operation_proto(done=True, response=expected_result),
+ ]
+ future, _, _ = make_operation_future(responses)
+ future._refresh = mock.Mock()
+
+ future.done(retry=test_retry)
+ future._refresh.assert_called_once_with(retry=test_retry)
+
+
+def test_exception():
+ expected_exception = status_pb2.Status(message="meep")
+ responses = [
+ make_operation_proto(),
+ # Second operation response includes the error.
+ make_operation_proto(done=True, error=expected_exception),
+ ]
+ future, _, _ = make_operation_future(responses)
+
+ exception = future.exception()
+
+ assert expected_exception.message in "{!r}".format(exception)
+
+
+def test_exception_with_error_code():
+ expected_exception = status_pb2.Status(message="meep", code=5)
+ responses = [
+ make_operation_proto(),
+ # Second operation response includes the error.
+ make_operation_proto(done=True, error=expected_exception),
+ ]
+ future, _, _ = make_operation_future(responses)
+
+ exception = future.exception()
+
+ assert expected_exception.message in "{!r}".format(exception)
+ # Status Code 5 maps to Not Found
+ # https://developers.google.com/maps-booking/reference/grpc-api/status_codes
+ assert isinstance(exception, exceptions.NotFound)
+
+
+def test_unexpected_result():
+ responses = [
+ make_operation_proto(),
+ # Second operation response is done, but has not error or response.
+ make_operation_proto(done=True),
+ ]
+ future, _, _ = make_operation_future(responses)
+
+ exception = future.exception()
+
+ assert "Unexpected state" in "{!r}".format(exception)
+
+
+def test__refresh_http():
+ json_response = {"name": TEST_OPERATION_NAME, "done": True}
+ api_request = mock.Mock(return_value=json_response)
+
+ result = operation._refresh_http(api_request, TEST_OPERATION_NAME)
+
+ assert isinstance(result, operations_pb2.Operation)
+ assert result.name == TEST_OPERATION_NAME
+ assert result.done is True
+
+ api_request.assert_called_once_with(
+ method="GET", path="operations/{}".format(TEST_OPERATION_NAME)
+ )
+
+
+def test__refresh_http_w_retry():
+ json_response = {"name": TEST_OPERATION_NAME, "done": True}
+ api_request = mock.Mock()
+ retry = mock.Mock()
+ retry.return_value.return_value = json_response
+
+ result = operation._refresh_http(api_request, TEST_OPERATION_NAME, retry=retry)
+
+ assert isinstance(result, operations_pb2.Operation)
+ assert result.name == TEST_OPERATION_NAME
+ assert result.done is True
+
+ api_request.assert_not_called()
+ retry.assert_called_once_with(api_request)
+ retry.return_value.assert_called_once_with(
+ method="GET", path="operations/{}".format(TEST_OPERATION_NAME)
+ )
+
+
+def test__cancel_http():
+ api_request = mock.Mock()
+
+ operation._cancel_http(api_request, TEST_OPERATION_NAME)
+
+ api_request.assert_called_once_with(
+ method="POST", path="operations/{}:cancel".format(TEST_OPERATION_NAME)
+ )
+
+
+def test_from_http_json():
+ operation_json = {"name": TEST_OPERATION_NAME, "done": True}
+ api_request = mock.sentinel.api_request
+
+ future = operation.from_http_json(
+ operation_json, api_request, struct_pb2.Struct, metadata_type=struct_pb2.Struct
+ )
+
+ assert future._result_type == struct_pb2.Struct
+ assert future._metadata_type == struct_pb2.Struct
+ assert future.operation.name == TEST_OPERATION_NAME
+ assert future.done
+
+
+def test__refresh_grpc():
+ operations_stub = mock.Mock(spec=["GetOperation"])
+ expected_result = make_operation_proto(done=True)
+ operations_stub.GetOperation.return_value = expected_result
+
+ result = operation._refresh_grpc(operations_stub, TEST_OPERATION_NAME)
+
+ assert result == expected_result
+ expected_request = operations_pb2.GetOperationRequest(name=TEST_OPERATION_NAME)
+ operations_stub.GetOperation.assert_called_once_with(expected_request)
+
+
+def test__refresh_grpc_w_retry():
+ operations_stub = mock.Mock(spec=["GetOperation"])
+ expected_result = make_operation_proto(done=True)
+ retry = mock.Mock()
+ retry.return_value.return_value = expected_result
+
+ result = operation._refresh_grpc(operations_stub, TEST_OPERATION_NAME, retry=retry)
+
+ assert result == expected_result
+ expected_request = operations_pb2.GetOperationRequest(name=TEST_OPERATION_NAME)
+ operations_stub.GetOperation.assert_not_called()
+ retry.assert_called_once_with(operations_stub.GetOperation)
+ retry.return_value.assert_called_once_with(expected_request)
+
+
+def test__cancel_grpc():
+ operations_stub = mock.Mock(spec=["CancelOperation"])
+
+ operation._cancel_grpc(operations_stub, TEST_OPERATION_NAME)
+
+ expected_request = operations_pb2.CancelOperationRequest(name=TEST_OPERATION_NAME)
+ operations_stub.CancelOperation.assert_called_once_with(expected_request)
+
+
+def test_from_grpc():
+ operation_proto = make_operation_proto(done=True)
+ operations_stub = mock.sentinel.operations_stub
+
+ future = operation.from_grpc(
+ operation_proto,
+ operations_stub,
+ struct_pb2.Struct,
+ metadata_type=struct_pb2.Struct,
+ grpc_metadata=[("x-goog-request-params", "foo")],
+ )
+
+ assert future._result_type == struct_pb2.Struct
+ assert future._metadata_type == struct_pb2.Struct
+ assert future.operation.name == TEST_OPERATION_NAME
+ assert future.done
+ assert future._refresh.keywords["metadata"] == [("x-goog-request-params", "foo")]
+ assert future._cancel.keywords["metadata"] == [("x-goog-request-params", "foo")]
+
+
+def test_from_gapic():
+ operation_proto = make_operation_proto(done=True)
+ operations_client = mock.create_autospec(
+ operations_v1.OperationsClient, instance=True
+ )
+
+ future = operation.from_gapic(
+ operation_proto,
+ operations_client,
+ struct_pb2.Struct,
+ metadata_type=struct_pb2.Struct,
+ grpc_metadata=[("x-goog-request-params", "foo")],
+ )
+
+ assert future._result_type == struct_pb2.Struct
+ assert future._metadata_type == struct_pb2.Struct
+ assert future.operation.name == TEST_OPERATION_NAME
+ assert future.done
+ assert future._refresh.keywords["metadata"] == [("x-goog-request-params", "foo")]
+ assert future._cancel.keywords["metadata"] == [("x-goog-request-params", "foo")]
+
+
+def test_deserialize():
+ op = make_operation_proto(name="foobarbaz")
+ serialized = op.SerializeToString()
+ deserialized_op = operation.Operation.deserialize(serialized)
+ assert op.name == deserialized_op.name
+ assert type(op) is type(deserialized_op)
diff --git a/tests/unit/test_page_iterator.py b/tests/unit/test_page_iterator.py
new file mode 100644
index 0000000..a44e998
--- /dev/null
+++ b/tests/unit/test_page_iterator.py
@@ -0,0 +1,665 @@
+# Copyright 2015 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import types
+
+import mock
+import pytest
+
+from google.api_core import page_iterator
+
+
+def test__do_nothing_page_start():
+ assert page_iterator._do_nothing_page_start(None, None, None) is None
+
+
+class TestPage(object):
+ def test_constructor(self):
+ parent = mock.sentinel.parent
+ item_to_value = mock.sentinel.item_to_value
+
+ page = page_iterator.Page(parent, (1, 2, 3), item_to_value)
+
+ assert page.num_items == 3
+ assert page.remaining == 3
+ assert page._parent is parent
+ assert page._item_to_value is item_to_value
+ assert page.raw_page is None
+
+ def test___iter__(self):
+ page = page_iterator.Page(None, (), None, None)
+ assert iter(page) is page
+
+ def test_iterator_calls_parent_item_to_value(self):
+ parent = mock.sentinel.parent
+
+ item_to_value = mock.Mock(
+ side_effect=lambda iterator, value: value, spec=["__call__"]
+ )
+
+ page = page_iterator.Page(parent, (10, 11, 12), item_to_value)
+ page._remaining = 100
+
+ assert item_to_value.call_count == 0
+ assert page.remaining == 100
+
+ assert next(page) == 10
+ assert item_to_value.call_count == 1
+ item_to_value.assert_called_with(parent, 10)
+ assert page.remaining == 99
+
+ assert next(page) == 11
+ assert item_to_value.call_count == 2
+ item_to_value.assert_called_with(parent, 11)
+ assert page.remaining == 98
+
+ assert next(page) == 12
+ assert item_to_value.call_count == 3
+ item_to_value.assert_called_with(parent, 12)
+ assert page.remaining == 97
+
+ def test_raw_page(self):
+ parent = mock.sentinel.parent
+ item_to_value = mock.sentinel.item_to_value
+
+ raw_page = mock.sentinel.raw_page
+
+ page = page_iterator.Page(parent, (1, 2, 3), item_to_value, raw_page=raw_page)
+ assert page.raw_page is raw_page
+
+ with pytest.raises(AttributeError):
+ page.raw_page = None
+
+
+class PageIteratorImpl(page_iterator.Iterator):
+ def _next_page(self):
+ return mock.create_autospec(page_iterator.Page, instance=True)
+
+
+class TestIterator(object):
+ def test_constructor(self):
+ client = mock.sentinel.client
+ item_to_value = mock.sentinel.item_to_value
+ token = "ab13nceor03"
+ max_results = 1337
+
+ iterator = PageIteratorImpl(
+ client, item_to_value, page_token=token, max_results=max_results
+ )
+
+ assert not iterator._started
+ assert iterator.client is client
+ assert iterator.item_to_value == item_to_value
+ assert iterator.max_results == max_results
+ # Changing attributes.
+ assert iterator.page_number == 0
+ assert iterator.next_page_token == token
+ assert iterator.num_results == 0
+
+ def test_next(self):
+ iterator = PageIteratorImpl(None, None)
+ page_1 = page_iterator.Page(
+ iterator, ("item 1.1", "item 1.2"), page_iterator._item_to_value_identity
+ )
+ page_2 = page_iterator.Page(
+ iterator, ("item 2.1",), page_iterator._item_to_value_identity
+ )
+ iterator._next_page = mock.Mock(side_effect=[page_1, page_2, None])
+
+ result = next(iterator)
+ assert result == "item 1.1"
+ result = next(iterator)
+ assert result == "item 1.2"
+ result = next(iterator)
+ assert result == "item 2.1"
+
+ with pytest.raises(StopIteration):
+ next(iterator)
+
+ def test_pages_property_starts(self):
+ iterator = PageIteratorImpl(None, None)
+
+ assert not iterator._started
+
+ assert isinstance(iterator.pages, types.GeneratorType)
+
+ assert iterator._started
+
+ def test_pages_property_restart(self):
+ iterator = PageIteratorImpl(None, None)
+
+ assert iterator.pages
+
+ # Make sure we cannot restart.
+ with pytest.raises(ValueError):
+ assert iterator.pages
+
+ def test__page_iter_increment(self):
+ iterator = PageIteratorImpl(None, None)
+ page = page_iterator.Page(
+ iterator, ("item",), page_iterator._item_to_value_identity
+ )
+ iterator._next_page = mock.Mock(side_effect=[page, None])
+
+ assert iterator.num_results == 0
+
+ page_iter = iterator._page_iter(increment=True)
+ next(page_iter)
+
+ assert iterator.num_results == 1
+
+ def test__page_iter_no_increment(self):
+ iterator = PageIteratorImpl(None, None)
+
+ assert iterator.num_results == 0
+
+ page_iter = iterator._page_iter(increment=False)
+ next(page_iter)
+
+ # results should still be 0 after fetching a page.
+ assert iterator.num_results == 0
+
+ def test__items_iter(self):
+ # Items to be returned.
+ item1 = 17
+ item2 = 100
+ item3 = 211
+
+ # Make pages from mock responses
+ parent = mock.sentinel.parent
+ page1 = page_iterator.Page(
+ parent, (item1, item2), page_iterator._item_to_value_identity
+ )
+ page2 = page_iterator.Page(
+ parent, (item3,), page_iterator._item_to_value_identity
+ )
+
+ iterator = PageIteratorImpl(None, None)
+ iterator._next_page = mock.Mock(side_effect=[page1, page2, None])
+
+ items_iter = iterator._items_iter()
+
+ assert isinstance(items_iter, types.GeneratorType)
+
+ # Consume items and check the state of the iterator.
+ assert iterator.num_results == 0
+
+ assert next(items_iter) == item1
+ assert iterator.num_results == 1
+
+ assert next(items_iter) == item2
+ assert iterator.num_results == 2
+
+ assert next(items_iter) == item3
+ assert iterator.num_results == 3
+
+ with pytest.raises(StopIteration):
+ next(items_iter)
+
+ def test___iter__(self):
+ iterator = PageIteratorImpl(None, None)
+ iterator._next_page = mock.Mock(side_effect=[(1, 2), (3,), None])
+
+ assert not iterator._started
+
+ result = list(iterator)
+
+ assert result == [1, 2, 3]
+ assert iterator._started
+
+ def test___iter__restart(self):
+ iterator = PageIteratorImpl(None, None)
+
+ iter(iterator)
+
+ # Make sure we cannot restart.
+ with pytest.raises(ValueError):
+ iter(iterator)
+
+ def test___iter___restart_after_page(self):
+ iterator = PageIteratorImpl(None, None)
+
+ assert iterator.pages
+
+ # Make sure we cannot restart after starting the page iterator
+ with pytest.raises(ValueError):
+ iter(iterator)
+
+
+class TestHTTPIterator(object):
+ def test_constructor(self):
+ client = mock.sentinel.client
+ path = "/foo"
+ iterator = page_iterator.HTTPIterator(
+ client, mock.sentinel.api_request, path, mock.sentinel.item_to_value
+ )
+
+ assert not iterator._started
+ assert iterator.client is client
+ assert iterator.path == path
+ assert iterator.item_to_value is mock.sentinel.item_to_value
+ assert iterator._items_key == "items"
+ assert iterator.max_results is None
+ assert iterator.extra_params == {}
+ assert iterator._page_start == page_iterator._do_nothing_page_start
+ # Changing attributes.
+ assert iterator.page_number == 0
+ assert iterator.next_page_token is None
+ assert iterator.num_results == 0
+ assert iterator._page_size is None
+
+ def test_constructor_w_extra_param_collision(self):
+ extra_params = {"pageToken": "val"}
+
+ with pytest.raises(ValueError):
+ page_iterator.HTTPIterator(
+ mock.sentinel.client,
+ mock.sentinel.api_request,
+ mock.sentinel.path,
+ mock.sentinel.item_to_value,
+ extra_params=extra_params,
+ )
+
+ def test_iterate(self):
+ path = "/foo"
+ item1 = {"name": "1"}
+ item2 = {"name": "2"}
+ api_request = mock.Mock(return_value={"items": [item1, item2]})
+ iterator = page_iterator.HTTPIterator(
+ mock.sentinel.client,
+ api_request,
+ path=path,
+ item_to_value=page_iterator._item_to_value_identity,
+ )
+
+ assert iterator.num_results == 0
+
+ items_iter = iter(iterator)
+
+ val1 = next(items_iter)
+ assert val1 == item1
+ assert iterator.num_results == 1
+
+ val2 = next(items_iter)
+ assert val2 == item2
+ assert iterator.num_results == 2
+
+ with pytest.raises(StopIteration):
+ next(items_iter)
+
+ api_request.assert_called_once_with(method="GET", path=path, query_params={})
+
+ def test__has_next_page_new(self):
+ iterator = page_iterator.HTTPIterator(
+ mock.sentinel.client,
+ mock.sentinel.api_request,
+ mock.sentinel.path,
+ mock.sentinel.item_to_value,
+ )
+
+ # The iterator should *always* indicate that it has a next page
+ # when created so that it can fetch the initial page.
+ assert iterator._has_next_page()
+
+ def test__has_next_page_without_token(self):
+ iterator = page_iterator.HTTPIterator(
+ mock.sentinel.client,
+ mock.sentinel.api_request,
+ mock.sentinel.path,
+ mock.sentinel.item_to_value,
+ )
+
+ iterator.page_number = 1
+
+ # The iterator should not indicate that it has a new page if the
+ # initial page has been requested and there's no page token.
+ assert not iterator._has_next_page()
+
+ def test__has_next_page_w_number_w_token(self):
+ iterator = page_iterator.HTTPIterator(
+ mock.sentinel.client,
+ mock.sentinel.api_request,
+ mock.sentinel.path,
+ mock.sentinel.item_to_value,
+ )
+
+ iterator.page_number = 1
+ iterator.next_page_token = mock.sentinel.token
+
+ # The iterator should indicate that it has a new page if the
+ # initial page has been requested and there's is a page token.
+ assert iterator._has_next_page()
+
+ def test__has_next_page_w_max_results_not_done(self):
+ iterator = page_iterator.HTTPIterator(
+ mock.sentinel.client,
+ mock.sentinel.api_request,
+ mock.sentinel.path,
+ mock.sentinel.item_to_value,
+ max_results=3,
+ page_token=mock.sentinel.token,
+ )
+
+ iterator.page_number = 1
+
+ # The iterator should indicate that it has a new page if there
+ # is a page token and it has not consumed more than max_results.
+ assert iterator.num_results < iterator.max_results
+ assert iterator._has_next_page()
+
+ def test__has_next_page_w_max_results_done(self):
+
+ iterator = page_iterator.HTTPIterator(
+ mock.sentinel.client,
+ mock.sentinel.api_request,
+ mock.sentinel.path,
+ mock.sentinel.item_to_value,
+ max_results=3,
+ page_token=mock.sentinel.token,
+ )
+
+ iterator.page_number = 1
+ iterator.num_results = 3
+
+ # The iterator should not indicate that it has a new page if there
+ # if it has consumed more than max_results.
+ assert iterator.num_results == iterator.max_results
+ assert not iterator._has_next_page()
+
+ def test__get_query_params_no_token(self):
+ iterator = page_iterator.HTTPIterator(
+ mock.sentinel.client,
+ mock.sentinel.api_request,
+ mock.sentinel.path,
+ mock.sentinel.item_to_value,
+ )
+
+ assert iterator._get_query_params() == {}
+
+ def test__get_query_params_w_token(self):
+ iterator = page_iterator.HTTPIterator(
+ mock.sentinel.client,
+ mock.sentinel.api_request,
+ mock.sentinel.path,
+ mock.sentinel.item_to_value,
+ )
+ iterator.next_page_token = "token"
+
+ assert iterator._get_query_params() == {"pageToken": iterator.next_page_token}
+
+ def test__get_query_params_w_max_results(self):
+ max_results = 3
+ iterator = page_iterator.HTTPIterator(
+ mock.sentinel.client,
+ mock.sentinel.api_request,
+ mock.sentinel.path,
+ mock.sentinel.item_to_value,
+ max_results=max_results,
+ )
+
+ iterator.num_results = 1
+ local_max = max_results - iterator.num_results
+
+ assert iterator._get_query_params() == {"maxResults": local_max}
+
+ def test__get_query_params_extra_params(self):
+ extra_params = {"key": "val"}
+ iterator = page_iterator.HTTPIterator(
+ mock.sentinel.client,
+ mock.sentinel.api_request,
+ mock.sentinel.path,
+ mock.sentinel.item_to_value,
+ extra_params=extra_params,
+ )
+
+ assert iterator._get_query_params() == extra_params
+
+ def test__get_next_page_response_with_post(self):
+ path = "/foo"
+ page_response = {"items": ["one", "two"]}
+ api_request = mock.Mock(return_value=page_response)
+ iterator = page_iterator.HTTPIterator(
+ mock.sentinel.client,
+ api_request,
+ path=path,
+ item_to_value=page_iterator._item_to_value_identity,
+ )
+ iterator._HTTP_METHOD = "POST"
+
+ response = iterator._get_next_page_response()
+
+ assert response == page_response
+
+ api_request.assert_called_once_with(method="POST", path=path, data={})
+
+ def test__get_next_page_bad_http_method(self):
+ iterator = page_iterator.HTTPIterator(
+ mock.sentinel.client,
+ mock.sentinel.api_request,
+ mock.sentinel.path,
+ mock.sentinel.item_to_value,
+ )
+ iterator._HTTP_METHOD = "NOT-A-VERB"
+
+ with pytest.raises(ValueError):
+ iterator._get_next_page_response()
+
+ @pytest.mark.parametrize(
+ "page_size,max_results,pages",
+ [(3, None, False), (3, 8, False), (3, None, True), (3, 8, True)],
+ )
+ def test_page_size_items(self, page_size, max_results, pages):
+ path = "/foo"
+ NITEMS = 10
+
+ n = [0] # blast you python 2!
+
+ def api_request(*args, **kw):
+ assert not args
+ query_params = dict(
+ maxResults=(
+ page_size
+ if max_results is None
+ else min(page_size, max_results - n[0])
+ )
+ )
+ if n[0]:
+ query_params.update(pageToken="test")
+ assert kw == {"method": "GET", "path": "/foo", "query_params": query_params}
+ n_items = min(kw["query_params"]["maxResults"], NITEMS - n[0])
+ items = [dict(name=str(i + n[0])) for i in range(n_items)]
+ n[0] += n_items
+ result = dict(items=items)
+ if n[0] < NITEMS:
+ result.update(nextPageToken="test")
+ return result
+
+ iterator = page_iterator.HTTPIterator(
+ mock.sentinel.client,
+ api_request,
+ path=path,
+ item_to_value=page_iterator._item_to_value_identity,
+ page_size=page_size,
+ max_results=max_results,
+ )
+
+ assert iterator.num_results == 0
+
+ n_results = max_results if max_results is not None else NITEMS
+ if pages:
+ items_iter = iter(iterator.pages)
+ npages = int(math.ceil(float(n_results) / page_size))
+ for ipage in range(npages):
+ assert list(next(items_iter)) == [
+ dict(name=str(i))
+ for i in range(
+ ipage * page_size, min((ipage + 1) * page_size, n_results),
+ )
+ ]
+ else:
+ items_iter = iter(iterator)
+ for i in range(n_results):
+ assert next(items_iter) == dict(name=str(i))
+ assert iterator.num_results == i + 1
+
+ with pytest.raises(StopIteration):
+ next(items_iter)
+
+
+class TestGRPCIterator(object):
+ def test_constructor(self):
+ client = mock.sentinel.client
+ items_field = "items"
+ iterator = page_iterator.GRPCIterator(
+ client, mock.sentinel.method, mock.sentinel.request, items_field
+ )
+
+ assert not iterator._started
+ assert iterator.client is client
+ assert iterator.max_results is None
+ assert iterator.item_to_value is page_iterator._item_to_value_identity
+ assert iterator._method == mock.sentinel.method
+ assert iterator._request == mock.sentinel.request
+ assert iterator._items_field == items_field
+ assert (
+ iterator._request_token_field
+ == page_iterator.GRPCIterator._DEFAULT_REQUEST_TOKEN_FIELD
+ )
+ assert (
+ iterator._response_token_field
+ == page_iterator.GRPCIterator._DEFAULT_RESPONSE_TOKEN_FIELD
+ )
+ # Changing attributes.
+ assert iterator.page_number == 0
+ assert iterator.next_page_token is None
+ assert iterator.num_results == 0
+
+ def test_constructor_options(self):
+ client = mock.sentinel.client
+ items_field = "items"
+ request_field = "request"
+ response_field = "response"
+ iterator = page_iterator.GRPCIterator(
+ client,
+ mock.sentinel.method,
+ mock.sentinel.request,
+ items_field,
+ item_to_value=mock.sentinel.item_to_value,
+ request_token_field=request_field,
+ response_token_field=response_field,
+ max_results=42,
+ )
+
+ assert iterator.client is client
+ assert iterator.max_results == 42
+ assert iterator.item_to_value is mock.sentinel.item_to_value
+ assert iterator._method == mock.sentinel.method
+ assert iterator._request == mock.sentinel.request
+ assert iterator._items_field == items_field
+ assert iterator._request_token_field == request_field
+ assert iterator._response_token_field == response_field
+
+ def test_iterate(self):
+ request = mock.Mock(spec=["page_token"], page_token=None)
+ response1 = mock.Mock(items=["a", "b"], next_page_token="1")
+ response2 = mock.Mock(items=["c"], next_page_token="2")
+ response3 = mock.Mock(items=["d"], next_page_token="")
+ method = mock.Mock(side_effect=[response1, response2, response3])
+ iterator = page_iterator.GRPCIterator(
+ mock.sentinel.client, method, request, "items"
+ )
+
+ assert iterator.num_results == 0
+
+ items = list(iterator)
+ assert items == ["a", "b", "c", "d"]
+
+ method.assert_called_with(request)
+ assert method.call_count == 3
+ assert request.page_token == "2"
+
+ def test_iterate_with_max_results(self):
+ request = mock.Mock(spec=["page_token"], page_token=None)
+ response1 = mock.Mock(items=["a", "b"], next_page_token="1")
+ response2 = mock.Mock(items=["c"], next_page_token="2")
+ response3 = mock.Mock(items=["d"], next_page_token="")
+ method = mock.Mock(side_effect=[response1, response2, response3])
+ iterator = page_iterator.GRPCIterator(
+ mock.sentinel.client, method, request, "items", max_results=3
+ )
+
+ assert iterator.num_results == 0
+
+ items = list(iterator)
+
+ assert items == ["a", "b", "c"]
+ assert iterator.num_results == 3
+
+ method.assert_called_with(request)
+ assert method.call_count == 2
+ assert request.page_token == "1"
+
+
+class GAXPageIterator(object):
+ """Fake object that matches gax.PageIterator"""
+
+ def __init__(self, pages, page_token=None):
+ self._pages = iter(pages)
+ self.page_token = page_token
+
+ def next(self):
+ return next(self._pages)
+
+ __next__ = next
+
+
+class TestGAXIterator(object):
+ def test_constructor(self):
+ client = mock.sentinel.client
+ token = "zzzyy78kl"
+ page_iter = GAXPageIterator((), page_token=token)
+ item_to_value = page_iterator._item_to_value_identity
+ max_results = 1337
+ iterator = page_iterator._GAXIterator(
+ client, page_iter, item_to_value, max_results=max_results
+ )
+
+ assert not iterator._started
+ assert iterator.client is client
+ assert iterator.item_to_value is item_to_value
+ assert iterator.max_results == max_results
+ assert iterator._gax_page_iter is page_iter
+ # Changing attributes.
+ assert iterator.page_number == 0
+ assert iterator.next_page_token == token
+ assert iterator.num_results == 0
+
+ def test__next_page(self):
+ page_items = (29, 31)
+ page_token = "2sde98ds2s0hh"
+ page_iter = GAXPageIterator([page_items], page_token=page_token)
+ iterator = page_iterator._GAXIterator(
+ mock.sentinel.client, page_iter, page_iterator._item_to_value_identity
+ )
+
+ page = iterator._next_page()
+
+ assert iterator.next_page_token == page_token
+ assert isinstance(page, page_iterator.Page)
+ assert list(page) == list(page_items)
+
+ next_page = iterator._next_page()
+
+ assert next_page is None
diff --git a/tests/unit/test_path_template.py b/tests/unit/test_path_template.py
new file mode 100644
index 0000000..2c5216e
--- /dev/null
+++ b/tests/unit/test_path_template.py
@@ -0,0 +1,389 @@
+# Copyright 2017 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import unicode_literals
+
+import mock
+import pytest
+
+from google.api_core import path_template
+
+
+@pytest.mark.parametrize(
+ "tmpl, args, kwargs, expected_result",
+ [
+ # Basic positional params
+ ["/v1/*", ["a"], {}, "/v1/a"],
+ ["/v1/**", ["a/b"], {}, "/v1/a/b"],
+ ["/v1/*/*", ["a", "b"], {}, "/v1/a/b"],
+ ["/v1/*/*/**", ["a", "b", "c/d"], {}, "/v1/a/b/c/d"],
+ # Basic named params
+ ["/v1/{name}", [], {"name": "parent"}, "/v1/parent"],
+ ["/v1/{name=**}", [], {"name": "parent/child"}, "/v1/parent/child"],
+ # Named params with a sub-template
+ ["/v1/{name=parent/*}", [], {"name": "parent/child"}, "/v1/parent/child"],
+ [
+ "/v1/{name=parent/**}",
+ [],
+ {"name": "parent/child/object"},
+ "/v1/parent/child/object",
+ ],
+ # Combining positional and named params
+ ["/v1/*/{name}", ["a"], {"name": "parent"}, "/v1/a/parent"],
+ ["/v1/{name}/*", ["a"], {"name": "parent"}, "/v1/parent/a"],
+ [
+ "/v1/{parent}/*/{child}/*",
+ ["a", "b"],
+ {"parent": "thor", "child": "thorson"},
+ "/v1/thor/a/thorson/b",
+ ],
+ ["/v1/{name}/**", ["a/b"], {"name": "parent"}, "/v1/parent/a/b"],
+ # Combining positional and named params with sub-templates.
+ [
+ "/v1/{name=parent/*}/*",
+ ["a"],
+ {"name": "parent/child"},
+ "/v1/parent/child/a",
+ ],
+ [
+ "/v1/*/{name=parent/**}",
+ ["a"],
+ {"name": "parent/child/object"},
+ "/v1/a/parent/child/object",
+ ],
+ ],
+)
+def test_expand_success(tmpl, args, kwargs, expected_result):
+ result = path_template.expand(tmpl, *args, **kwargs)
+ assert result == expected_result
+ assert path_template.validate(tmpl, result)
+
+
+@pytest.mark.parametrize(
+ "tmpl, args, kwargs, exc_match",
+ [
+ # Missing positional arg.
+ ["v1/*", [], {}, "Positional"],
+ # Missing named arg.
+ ["v1/{name}", [], {}, "Named"],
+ ],
+)
+def test_expanded_failure(tmpl, args, kwargs, exc_match):
+ with pytest.raises(ValueError, match=exc_match):
+ path_template.expand(tmpl, *args, **kwargs)
+
+
+@pytest.mark.parametrize(
+ "request_obj, field, expected_result",
+ [
+ [{"field": "stringValue"}, "field", "stringValue"],
+ [{"field": "stringValue"}, "nosuchfield", None],
+ [{"field": "stringValue"}, "field.subfield", None],
+ [{"field": {"subfield": "stringValue"}}, "field", None],
+ [{"field": {"subfield": "stringValue"}}, "field.subfield", "stringValue"],
+ [{"field": {"subfield": [1, 2, 3]}}, "field.subfield", [1, 2, 3]],
+ [{"field": {"subfield": "stringValue"}}, "field", None],
+ [{"field": {"subfield": "stringValue"}}, "field.nosuchfield", None],
+ [
+ {"field": {"subfield": {"subsubfield": "stringValue"}}},
+ "field.subfield.subsubfield",
+ "stringValue",
+ ],
+ ["string", "field", None],
+ ],
+)
+def test_get_field(request_obj, field, expected_result):
+ result = path_template.get_field(request_obj, field)
+ assert result == expected_result
+
+
+@pytest.mark.parametrize(
+ "request_obj, field, expected_result",
+ [
+ [{"field": "stringValue"}, "field", {}],
+ [{"field": "stringValue"}, "nosuchfield", {"field": "stringValue"}],
+ [{"field": "stringValue"}, "field.subfield", {"field": "stringValue"}],
+ [{"field": {"subfield": "stringValue"}}, "field.subfield", {"field": {}}],
+ [
+ {"field": {"subfield": "stringValue", "q": "w"}, "e": "f"},
+ "field.subfield",
+ {"field": {"q": "w"}, "e": "f"},
+ ],
+ [
+ {"field": {"subfield": "stringValue"}},
+ "field.nosuchfield",
+ {"field": {"subfield": "stringValue"}},
+ ],
+ [
+ {"field": {"subfield": {"subsubfield": "stringValue", "q": "w"}}},
+ "field.subfield.subsubfield",
+ {"field": {"subfield": {"q": "w"}}},
+ ],
+ ["string", "field", "string"],
+ ["string", "field.subfield", "string"],
+ ],
+)
+def test_delete_field(request_obj, field, expected_result):
+ path_template.delete_field(request_obj, field)
+ assert request_obj == expected_result
+
+
+@pytest.mark.parametrize(
+ "tmpl, path",
+ [
+ # Single segment template, but multi segment value
+ ["v1/*", "v1/a/b"],
+ ["v1/*/*", "v1/a/b/c"],
+ # Single segement named template, but multi segment value
+ ["v1/{name}", "v1/a/b"],
+ ["v1/{name}/{value}", "v1/a/b/c"],
+ # Named value with a sub-template but invalid value
+ ["v1/{name=parent/*}", "v1/grandparent/child"],
+ ],
+)
+def test_validate_failure(tmpl, path):
+ assert not path_template.validate(tmpl, path)
+
+
+def test__expand_variable_match_unexpected():
+ match = mock.Mock(spec=["group"])
+ match.group.return_value = None
+ with pytest.raises(ValueError, match="Unknown"):
+ path_template._expand_variable_match([], {}, match)
+
+
+def test__replace_variable_with_pattern():
+ match = mock.Mock(spec=["group"])
+ match.group.return_value = None
+ with pytest.raises(ValueError, match="Unknown"):
+ path_template._replace_variable_with_pattern(match)
+
+
+@pytest.mark.parametrize(
+ "http_options, request_kwargs, expected_result",
+ [
+ [
+ [["get", "/v1/no/template", ""]],
+ {"foo": "bar"},
+ ["get", "/v1/no/template", {}, {"foo": "bar"}],
+ ],
+ # Single templates
+ [
+ [["get", "/v1/{field}", ""]],
+ {"field": "parent"},
+ ["get", "/v1/parent", {}, {}],
+ ],
+ [
+ [["get", "/v1/{field.sub}", ""]],
+ {"field": {"sub": "parent"}, "foo": "bar"},
+ ["get", "/v1/parent", {}, {"field": {}, "foo": "bar"}],
+ ],
+ ],
+)
+def test_transcode_base_case(http_options, request_kwargs, expected_result):
+ http_options, expected_result = helper_test_transcode(http_options, expected_result)
+ result = path_template.transcode(http_options, **request_kwargs)
+ assert result == expected_result
+
+
+@pytest.mark.parametrize(
+ "http_options, request_kwargs, expected_result",
+ [
+ [
+ [["get", "/v1/{field.subfield}", ""]],
+ {"field": {"subfield": "parent"}, "foo": "bar"},
+ ["get", "/v1/parent", {}, {"field": {}, "foo": "bar"}],
+ ],
+ [
+ [["get", "/v1/{field.subfield.subsubfield}", ""]],
+ {"field": {"subfield": {"subsubfield": "parent"}}, "foo": "bar"},
+ ["get", "/v1/parent", {}, {"field": {"subfield": {}}, "foo": "bar"}],
+ ],
+ [
+ [["get", "/v1/{field.subfield1}/{field.subfield2}", ""]],
+ {"field": {"subfield1": "parent", "subfield2": "child"}, "foo": "bar"},
+ ["get", "/v1/parent/child", {}, {"field": {}, "foo": "bar"}],
+ ],
+ ],
+)
+def test_transcode_subfields(http_options, request_kwargs, expected_result):
+ http_options, expected_result = helper_test_transcode(http_options, expected_result)
+ result = path_template.transcode(http_options, **request_kwargs)
+ assert result == expected_result
+
+
+@pytest.mark.parametrize(
+ "http_options, request_kwargs, expected_result",
+ [
+ # Single segment wildcard
+ [
+ [["get", "/v1/{field=*}", ""]],
+ {"field": "parent"},
+ ["get", "/v1/parent", {}, {}],
+ ],
+ [
+ [["get", "/v1/{field=a/*/b/*}", ""]],
+ {"field": "a/parent/b/child", "foo": "bar"},
+ ["get", "/v1/a/parent/b/child", {}, {"foo": "bar"}],
+ ],
+ # Double segment wildcard
+ [
+ [["get", "/v1/{field=**}", ""]],
+ {"field": "parent/p1"},
+ ["get", "/v1/parent/p1", {}, {}],
+ ],
+ [
+ [["get", "/v1/{field=a/**/b/**}", ""]],
+ {"field": "a/parent/p1/b/child/c1", "foo": "bar"},
+ ["get", "/v1/a/parent/p1/b/child/c1", {}, {"foo": "bar"}],
+ ],
+ # Combined single and double segment wildcard
+ [
+ [["get", "/v1/{field=a/*/b/**}", ""]],
+ {"field": "a/parent/b/child/c1"},
+ ["get", "/v1/a/parent/b/child/c1", {}, {}],
+ ],
+ [
+ [["get", "/v1/{field=a/**/b/*}/v2/{name}", ""]],
+ {"field": "a/parent/p1/b/child", "name": "first", "foo": "bar"},
+ ["get", "/v1/a/parent/p1/b/child/v2/first", {}, {"foo": "bar"}],
+ ],
+ ],
+)
+def test_transcode_with_wildcard(http_options, request_kwargs, expected_result):
+ http_options, expected_result = helper_test_transcode(http_options, expected_result)
+ result = path_template.transcode(http_options, **request_kwargs)
+ assert result == expected_result
+
+
+@pytest.mark.parametrize(
+ "http_options, request_kwargs, expected_result",
+ [
+ # Single field body
+ [
+ [["post", "/v1/no/template", "data"]],
+ {"data": {"id": 1, "info": "some info"}, "foo": "bar"},
+ ["post", "/v1/no/template", {"id": 1, "info": "some info"}, {"foo": "bar"}],
+ ],
+ [
+ [["post", "/v1/{field=a/*}/b/{name=**}", "data"]],
+ {
+ "field": "a/parent",
+ "name": "first/last",
+ "data": {"id": 1, "info": "some info"},
+ "foo": "bar",
+ },
+ [
+ "post",
+ "/v1/a/parent/b/first/last",
+ {"id": 1, "info": "some info"},
+ {"foo": "bar"},
+ ],
+ ],
+ # Wildcard body
+ [
+ [["post", "/v1/{field=a/*}/b/{name=**}", "*"]],
+ {
+ "field": "a/parent",
+ "name": "first/last",
+ "data": {"id": 1, "info": "some info"},
+ "foo": "bar",
+ },
+ [
+ "post",
+ "/v1/a/parent/b/first/last",
+ {"data": {"id": 1, "info": "some info"}, "foo": "bar"},
+ {},
+ ],
+ ],
+ ],
+)
+def test_transcode_with_body(http_options, request_kwargs, expected_result):
+ http_options, expected_result = helper_test_transcode(http_options, expected_result)
+ result = path_template.transcode(http_options, **request_kwargs)
+ assert result == expected_result
+
+
+@pytest.mark.parametrize(
+ "http_options, request_kwargs, expected_result",
+ [
+ # Additional bindings
+ [
+ [
+ ["post", "/v1/{field=a/*}/b/{name=**}", "extra_data"],
+ ["post", "/v1/{field=a/*}/b/{name=**}", "*"],
+ ],
+ {
+ "field": "a/parent",
+ "name": "first/last",
+ "data": {"id": 1, "info": "some info"},
+ "foo": "bar",
+ },
+ [
+ "post",
+ "/v1/a/parent/b/first/last",
+ {"data": {"id": 1, "info": "some info"}, "foo": "bar"},
+ {},
+ ],
+ ],
+ [
+ [
+ ["get", "/v1/{field=a/*}/b/{name=**}", ""],
+ ["get", "/v1/{field=a/*}/b/first/last", ""],
+ ],
+ {"field": "a/parent", "foo": "bar"},
+ ["get", "/v1/a/parent/b/first/last", {}, {"foo": "bar"}],
+ ],
+ ],
+)
+def test_transcode_with_additional_bindings(
+ http_options, request_kwargs, expected_result
+):
+ http_options, expected_result = helper_test_transcode(http_options, expected_result)
+ result = path_template.transcode(http_options, **request_kwargs)
+ assert result == expected_result
+
+
+@pytest.mark.parametrize(
+ "http_options, request_kwargs",
+ [
+ [[["get", "/v1/{name}", ""]], {"foo": "bar"}],
+ [[["get", "/v1/{name}", ""]], {"name": "first/last"}],
+ [[["get", "/v1/{name=mr/*/*}", ""]], {"name": "first/last"}],
+ [[["post", "/v1/{name}", "data"]], {"name": "first/last"}],
+ ],
+)
+def test_transcode_fails(http_options, request_kwargs):
+ http_options, _ = helper_test_transcode(http_options, range(4))
+ with pytest.raises(ValueError):
+ path_template.transcode(http_options, **request_kwargs)
+
+
+def helper_test_transcode(http_options_list, expected_result_list):
+ http_options = []
+ for opt_list in http_options_list:
+ http_option = {"method": opt_list[0], "uri": opt_list[1]}
+ if opt_list[2]:
+ http_option["body"] = opt_list[2]
+ http_options.append(http_option)
+
+ expected_result = {
+ "method": expected_result_list[0],
+ "uri": expected_result_list[1],
+ "query_params": expected_result_list[3],
+ }
+ if expected_result_list[2]:
+ expected_result["body"] = expected_result_list[2]
+
+ return (http_options, expected_result)
diff --git a/tests/unit/test_protobuf_helpers.py b/tests/unit/test_protobuf_helpers.py
new file mode 100644
index 0000000..3df45df
--- /dev/null
+++ b/tests/unit/test_protobuf_helpers.py
@@ -0,0 +1,518 @@
+# Copyright 2017 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+
+import pytest
+
+from google.api import http_pb2
+from google.api_core import protobuf_helpers
+from google.longrunning import operations_pb2
+from google.protobuf import any_pb2
+from google.protobuf import message
+from google.protobuf import source_context_pb2
+from google.protobuf import struct_pb2
+from google.protobuf import timestamp_pb2
+from google.protobuf import type_pb2
+from google.protobuf import wrappers_pb2
+from google.type import color_pb2
+from google.type import date_pb2
+from google.type import timeofday_pb2
+
+
+def test_from_any_pb_success():
+ in_message = date_pb2.Date(year=1990)
+ in_message_any = any_pb2.Any()
+ in_message_any.Pack(in_message)
+ out_message = protobuf_helpers.from_any_pb(date_pb2.Date, in_message_any)
+
+ assert in_message == out_message
+
+
+def test_from_any_pb_wrapped_success():
+ # Declare a message class conforming to wrapped messages.
+ class WrappedDate(object):
+ def __init__(self, **kwargs):
+ self._pb = date_pb2.Date(**kwargs)
+
+ def __eq__(self, other):
+ return self._pb == other
+
+ @classmethod
+ def pb(cls, msg):
+ return msg._pb
+
+ # Run the same test as `test_from_any_pb_success`, but using the
+ # wrapped class.
+ in_message = date_pb2.Date(year=1990)
+ in_message_any = any_pb2.Any()
+ in_message_any.Pack(in_message)
+ out_message = protobuf_helpers.from_any_pb(WrappedDate, in_message_any)
+
+ assert out_message == in_message
+
+
+def test_from_any_pb_failure():
+ in_message = any_pb2.Any()
+ in_message.Pack(date_pb2.Date(year=1990))
+
+ with pytest.raises(TypeError):
+ protobuf_helpers.from_any_pb(timeofday_pb2.TimeOfDay, in_message)
+
+
+def test_check_protobuf_helpers_ok():
+ assert protobuf_helpers.check_oneof() is None
+ assert protobuf_helpers.check_oneof(foo="bar") is None
+ assert protobuf_helpers.check_oneof(foo="bar", baz=None) is None
+ assert protobuf_helpers.check_oneof(foo=None, baz="bacon") is None
+ assert protobuf_helpers.check_oneof(foo="bar", spam=None, eggs=None) is None
+
+
+def test_check_protobuf_helpers_failures():
+ with pytest.raises(ValueError):
+ protobuf_helpers.check_oneof(foo="bar", spam="eggs")
+ with pytest.raises(ValueError):
+ protobuf_helpers.check_oneof(foo="bar", baz="bacon", spam="eggs")
+ with pytest.raises(ValueError):
+ protobuf_helpers.check_oneof(foo="bar", spam=0, eggs=None)
+
+
+def test_get_messages():
+ answer = protobuf_helpers.get_messages(date_pb2)
+
+ # Ensure that Date was exported properly.
+ assert answer["Date"] is date_pb2.Date
+
+ # Ensure that no non-Message objects were exported.
+ for value in answer.values():
+ assert issubclass(value, message.Message)
+
+
+def test_get_dict_absent():
+ with pytest.raises(KeyError):
+ assert protobuf_helpers.get({}, "foo")
+
+
+def test_get_dict_present():
+ assert protobuf_helpers.get({"foo": "bar"}, "foo") == "bar"
+
+
+def test_get_dict_default():
+ assert protobuf_helpers.get({}, "foo", default="bar") == "bar"
+
+
+def test_get_dict_nested():
+ assert protobuf_helpers.get({"foo": {"bar": "baz"}}, "foo.bar") == "baz"
+
+
+def test_get_dict_nested_default():
+ assert protobuf_helpers.get({}, "foo.baz", default="bacon") == "bacon"
+ assert protobuf_helpers.get({"foo": {}}, "foo.baz", default="bacon") == "bacon"
+
+
+def test_get_msg_sentinel():
+ msg = timestamp_pb2.Timestamp()
+ with pytest.raises(KeyError):
+ assert protobuf_helpers.get(msg, "foo")
+
+
+def test_get_msg_present():
+ msg = timestamp_pb2.Timestamp(seconds=42)
+ assert protobuf_helpers.get(msg, "seconds") == 42
+
+
+def test_get_msg_default():
+ msg = timestamp_pb2.Timestamp()
+ assert protobuf_helpers.get(msg, "foo", default="bar") == "bar"
+
+
+def test_invalid_object():
+ with pytest.raises(TypeError):
+ protobuf_helpers.get(object(), "foo", "bar")
+
+
+def test_set_dict():
+ mapping = {}
+ protobuf_helpers.set(mapping, "foo", "bar")
+ assert mapping == {"foo": "bar"}
+
+
+def test_set_msg():
+ msg = timestamp_pb2.Timestamp()
+ protobuf_helpers.set(msg, "seconds", 42)
+ assert msg.seconds == 42
+
+
+def test_set_dict_nested():
+ mapping = {}
+ protobuf_helpers.set(mapping, "foo.bar", "baz")
+ assert mapping == {"foo": {"bar": "baz"}}
+
+
+def test_set_invalid_object():
+ with pytest.raises(TypeError):
+ protobuf_helpers.set(object(), "foo", "bar")
+
+
+def test_set_list():
+ list_ops_response = operations_pb2.ListOperationsResponse()
+
+ protobuf_helpers.set(
+ list_ops_response,
+ "operations",
+ [{"name": "foo"}, operations_pb2.Operation(name="bar")],
+ )
+
+ assert len(list_ops_response.operations) == 2
+
+ for operation in list_ops_response.operations:
+ assert isinstance(operation, operations_pb2.Operation)
+
+ assert list_ops_response.operations[0].name == "foo"
+ assert list_ops_response.operations[1].name == "bar"
+
+
+def test_set_list_clear_existing():
+ list_ops_response = operations_pb2.ListOperationsResponse(
+ operations=[{"name": "baz"}]
+ )
+
+ protobuf_helpers.set(
+ list_ops_response,
+ "operations",
+ [{"name": "foo"}, operations_pb2.Operation(name="bar")],
+ )
+
+ assert len(list_ops_response.operations) == 2
+ for operation in list_ops_response.operations:
+ assert isinstance(operation, operations_pb2.Operation)
+ assert list_ops_response.operations[0].name == "foo"
+ assert list_ops_response.operations[1].name == "bar"
+
+
+def test_set_msg_with_msg_field():
+ rule = http_pb2.HttpRule()
+ pattern = http_pb2.CustomHttpPattern(kind="foo", path="bar")
+
+ protobuf_helpers.set(rule, "custom", pattern)
+
+ assert rule.custom.kind == "foo"
+ assert rule.custom.path == "bar"
+
+
+def test_set_msg_with_dict_field():
+ rule = http_pb2.HttpRule()
+ pattern = {"kind": "foo", "path": "bar"}
+
+ protobuf_helpers.set(rule, "custom", pattern)
+
+ assert rule.custom.kind == "foo"
+ assert rule.custom.path == "bar"
+
+
+def test_set_msg_nested_key():
+ rule = http_pb2.HttpRule(custom=http_pb2.CustomHttpPattern(kind="foo", path="bar"))
+
+ protobuf_helpers.set(rule, "custom.kind", "baz")
+
+ assert rule.custom.kind == "baz"
+ assert rule.custom.path == "bar"
+
+
+def test_setdefault_dict_unset():
+ mapping = {}
+ protobuf_helpers.setdefault(mapping, "foo", "bar")
+ assert mapping == {"foo": "bar"}
+
+
+def test_setdefault_dict_falsy():
+ mapping = {"foo": None}
+ protobuf_helpers.setdefault(mapping, "foo", "bar")
+ assert mapping == {"foo": "bar"}
+
+
+def test_setdefault_dict_truthy():
+ mapping = {"foo": "bar"}
+ protobuf_helpers.setdefault(mapping, "foo", "baz")
+ assert mapping == {"foo": "bar"}
+
+
+def test_setdefault_pb2_falsy():
+ operation = operations_pb2.Operation()
+ protobuf_helpers.setdefault(operation, "name", "foo")
+ assert operation.name == "foo"
+
+
+def test_setdefault_pb2_truthy():
+ operation = operations_pb2.Operation(name="bar")
+ protobuf_helpers.setdefault(operation, "name", "foo")
+ assert operation.name == "bar"
+
+
+def test_field_mask_invalid_args():
+ with pytest.raises(ValueError):
+ protobuf_helpers.field_mask("foo", any_pb2.Any())
+ with pytest.raises(ValueError):
+ protobuf_helpers.field_mask(any_pb2.Any(), "bar")
+ with pytest.raises(ValueError):
+ protobuf_helpers.field_mask(any_pb2.Any(), operations_pb2.Operation())
+
+
+def test_field_mask_equal_values():
+ assert protobuf_helpers.field_mask(None, None).paths == []
+
+ original = struct_pb2.Value(number_value=1.0)
+ modified = struct_pb2.Value(number_value=1.0)
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+ original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+ modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+ original = struct_pb2.ListValue(values=[struct_pb2.Value(number_value=1.0)])
+ modified = struct_pb2.ListValue(values=[struct_pb2.Value(number_value=1.0)])
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+ original = struct_pb2.Struct(fields={"bar": struct_pb2.Value(number_value=1.0)})
+ modified = struct_pb2.Struct(fields={"bar": struct_pb2.Value(number_value=1.0)})
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+
+def test_field_mask_zero_values():
+ # Singular Values
+ original = color_pb2.Color(red=0.0)
+ modified = None
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+ original = None
+ modified = color_pb2.Color(red=0.0)
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+ # Repeated Values
+ original = struct_pb2.ListValue(values=[])
+ modified = None
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+ original = None
+ modified = struct_pb2.ListValue(values=[])
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+ # Maps
+ original = struct_pb2.Struct(fields={})
+ modified = None
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+ original = None
+ modified = struct_pb2.Struct(fields={})
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+ # Oneofs
+ original = struct_pb2.Value(number_value=0.0)
+ modified = None
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+ original = None
+ modified = struct_pb2.Value(number_value=0.0)
+ assert protobuf_helpers.field_mask(original, modified).paths == []
+
+
+def test_field_mask_singular_field_diffs():
+ original = type_pb2.Type(name="name")
+ modified = type_pb2.Type()
+ assert protobuf_helpers.field_mask(original, modified).paths == ["name"]
+
+ original = type_pb2.Type(name="name")
+ modified = type_pb2.Type()
+ assert protobuf_helpers.field_mask(original, modified).paths == ["name"]
+
+ original = None
+ modified = type_pb2.Type(name="name")
+ assert protobuf_helpers.field_mask(original, modified).paths == ["name"]
+
+ original = type_pb2.Type(name="name")
+ modified = None
+ assert protobuf_helpers.field_mask(original, modified).paths == ["name"]
+
+
+def test_field_mask_message_diffs():
+ original = type_pb2.Type()
+ modified = type_pb2.Type(
+ source_context=source_context_pb2.SourceContext(file_name="name")
+ )
+ assert protobuf_helpers.field_mask(original, modified).paths == [
+ "source_context.file_name"
+ ]
+
+ original = type_pb2.Type(
+ source_context=source_context_pb2.SourceContext(file_name="name")
+ )
+ modified = type_pb2.Type()
+ assert protobuf_helpers.field_mask(original, modified).paths == ["source_context"]
+
+ original = type_pb2.Type(
+ source_context=source_context_pb2.SourceContext(file_name="name")
+ )
+ modified = type_pb2.Type(
+ source_context=source_context_pb2.SourceContext(file_name="other_name")
+ )
+ assert protobuf_helpers.field_mask(original, modified).paths == [
+ "source_context.file_name"
+ ]
+
+ original = None
+ modified = type_pb2.Type(
+ source_context=source_context_pb2.SourceContext(file_name="name")
+ )
+ assert protobuf_helpers.field_mask(original, modified).paths == [
+ "source_context.file_name"
+ ]
+
+ original = type_pb2.Type(
+ source_context=source_context_pb2.SourceContext(file_name="name")
+ )
+ modified = None
+ assert protobuf_helpers.field_mask(original, modified).paths == ["source_context"]
+
+
+def test_field_mask_wrapper_type_diffs():
+ original = color_pb2.Color()
+ modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+ assert protobuf_helpers.field_mask(original, modified).paths == ["alpha"]
+
+ original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+ modified = color_pb2.Color()
+ assert protobuf_helpers.field_mask(original, modified).paths == ["alpha"]
+
+ original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+ modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=2.0))
+ assert protobuf_helpers.field_mask(original, modified).paths == ["alpha"]
+
+ original = None
+ modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=2.0))
+ assert protobuf_helpers.field_mask(original, modified).paths == ["alpha"]
+
+ original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+ modified = None
+ assert protobuf_helpers.field_mask(original, modified).paths == ["alpha"]
+
+
+def test_field_mask_repeated_diffs():
+ original = struct_pb2.ListValue()
+ modified = struct_pb2.ListValue(
+ values=[struct_pb2.Value(number_value=1.0), struct_pb2.Value(number_value=2.0)]
+ )
+ assert protobuf_helpers.field_mask(original, modified).paths == ["values"]
+
+ original = struct_pb2.ListValue(
+ values=[struct_pb2.Value(number_value=1.0), struct_pb2.Value(number_value=2.0)]
+ )
+ modified = struct_pb2.ListValue()
+ assert protobuf_helpers.field_mask(original, modified).paths == ["values"]
+
+ original = None
+ modified = struct_pb2.ListValue(
+ values=[struct_pb2.Value(number_value=1.0), struct_pb2.Value(number_value=2.0)]
+ )
+ assert protobuf_helpers.field_mask(original, modified).paths == ["values"]
+
+ original = struct_pb2.ListValue(
+ values=[struct_pb2.Value(number_value=1.0), struct_pb2.Value(number_value=2.0)]
+ )
+ modified = None
+ assert protobuf_helpers.field_mask(original, modified).paths == ["values"]
+
+ original = struct_pb2.ListValue(
+ values=[struct_pb2.Value(number_value=1.0), struct_pb2.Value(number_value=2.0)]
+ )
+ modified = struct_pb2.ListValue(
+ values=[struct_pb2.Value(number_value=2.0), struct_pb2.Value(number_value=1.0)]
+ )
+ assert protobuf_helpers.field_mask(original, modified).paths == ["values"]
+
+
+def test_field_mask_map_diffs():
+ original = struct_pb2.Struct()
+ modified = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)})
+ assert protobuf_helpers.field_mask(original, modified).paths == ["fields"]
+
+ original = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)})
+ modified = struct_pb2.Struct()
+ assert protobuf_helpers.field_mask(original, modified).paths == ["fields"]
+
+ original = None
+ modified = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)})
+ assert protobuf_helpers.field_mask(original, modified).paths == ["fields"]
+
+ original = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)})
+ modified = None
+ assert protobuf_helpers.field_mask(original, modified).paths == ["fields"]
+
+ original = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)})
+ modified = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=2.0)})
+ assert protobuf_helpers.field_mask(original, modified).paths == ["fields"]
+
+ original = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)})
+ modified = struct_pb2.Struct(fields={"bar": struct_pb2.Value(number_value=1.0)})
+ assert protobuf_helpers.field_mask(original, modified).paths == ["fields"]
+
+
+def test_field_mask_different_level_diffs():
+ original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
+ modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=2.0), red=1.0)
+ assert sorted(protobuf_helpers.field_mask(original, modified).paths) == [
+ "alpha",
+ "red",
+ ]
+
+
+@pytest.mark.skipif(
+ sys.version_info.major == 2,
+ reason="Field names with trailing underscores can only be created"
+ "through proto-plus, which is Python 3 only.",
+)
+def test_field_mask_ignore_trailing_underscore():
+ import proto
+
+ class Foo(proto.Message):
+ type_ = proto.Field(proto.STRING, number=1)
+ input_config = proto.Field(proto.STRING, number=2)
+
+ modified = Foo(type_="bar", input_config="baz")
+
+ assert sorted(protobuf_helpers.field_mask(None, Foo.pb(modified)).paths) == [
+ "input_config",
+ "type",
+ ]
+
+
+@pytest.mark.skipif(
+ sys.version_info.major == 2,
+ reason="Field names with trailing underscores can only be created"
+ "through proto-plus, which is Python 3 only.",
+)
+def test_field_mask_ignore_trailing_underscore_with_nesting():
+ import proto
+
+ class Bar(proto.Message):
+ class Baz(proto.Message):
+ input_config = proto.Field(proto.STRING, number=1)
+
+ type_ = proto.Field(Baz, number=1)
+
+ modified = Bar()
+ modified.type_.input_config = "foo"
+
+ assert sorted(protobuf_helpers.field_mask(None, Bar.pb(modified)).paths) == [
+ "type.input_config",
+ ]
diff --git a/tests/unit/test_rest_helpers.py b/tests/unit/test_rest_helpers.py
new file mode 100644
index 0000000..5932fa5
--- /dev/null
+++ b/tests/unit/test_rest_helpers.py
@@ -0,0 +1,77 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from google.api_core import rest_helpers
+
+
+def test_flatten_simple_value():
+ with pytest.raises(TypeError):
+ rest_helpers.flatten_query_params("abc")
+
+
+def test_flatten_list():
+ with pytest.raises(TypeError):
+ rest_helpers.flatten_query_params(["abc", "def"])
+
+
+def test_flatten_none():
+ assert rest_helpers.flatten_query_params(None) == []
+
+
+def test_flatten_empty_dict():
+ assert rest_helpers.flatten_query_params({}) == []
+
+
+def test_flatten_simple_dict():
+ assert rest_helpers.flatten_query_params({"a": "abc", "b": "def"}) == [
+ ("a", "abc"),
+ ("b", "def"),
+ ]
+
+
+def test_flatten_repeated_field():
+ assert rest_helpers.flatten_query_params({"a": ["x", "y", "z", None]}) == [
+ ("a", "x"),
+ ("a", "y"),
+ ("a", "z"),
+ ]
+
+
+def test_flatten_nested_dict():
+ obj = {"a": {"b": {"c": ["x", "y", "z"]}}, "d": {"e": "uvw"}}
+ expected_result = [("a.b.c", "x"), ("a.b.c", "y"), ("a.b.c", "z"), ("d.e", "uvw")]
+
+ assert rest_helpers.flatten_query_params(obj) == expected_result
+
+
+def test_flatten_repeated_dict():
+ obj = {
+ "a": {"b": {"c": [{"v": 1}, {"v": 2}]}},
+ "d": "uvw",
+ }
+
+ with pytest.raises(ValueError):
+ rest_helpers.flatten_query_params(obj)
+
+
+def test_flatten_repeated_list():
+ obj = {
+ "a": {"b": {"c": [["e", "f"], ["g", "h"]]}},
+ "d": "uvw",
+ }
+
+ with pytest.raises(ValueError):
+ rest_helpers.flatten_query_params(obj)
diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py
new file mode 100644
index 0000000..199ca55
--- /dev/null
+++ b/tests/unit/test_retry.py
@@ -0,0 +1,458 @@
+# Copyright 2017 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import itertools
+import re
+
+import mock
+import pytest
+import requests.exceptions
+
+from google.api_core import exceptions
+from google.api_core import retry
+from google.auth import exceptions as auth_exceptions
+
+
+def test_if_exception_type():
+ predicate = retry.if_exception_type(ValueError)
+
+ assert predicate(ValueError())
+ assert not predicate(TypeError())
+
+
+def test_if_exception_type_multiple():
+ predicate = retry.if_exception_type(ValueError, TypeError)
+
+ assert predicate(ValueError())
+ assert predicate(TypeError())
+ assert not predicate(RuntimeError())
+
+
+def test_if_transient_error():
+ assert retry.if_transient_error(exceptions.InternalServerError(""))
+ assert retry.if_transient_error(exceptions.TooManyRequests(""))
+ assert retry.if_transient_error(exceptions.ServiceUnavailable(""))
+ assert retry.if_transient_error(requests.exceptions.ConnectionError(""))
+ assert retry.if_transient_error(requests.exceptions.ChunkedEncodingError(""))
+ assert retry.if_transient_error(auth_exceptions.TransportError(""))
+ assert not retry.if_transient_error(exceptions.InvalidArgument(""))
+
+
+# Make uniform return half of its maximum, which will be the calculated
+# sleep time.
+@mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n / 2.0)
+def test_exponential_sleep_generator_base_2(uniform):
+ gen = retry.exponential_sleep_generator(1, 60, multiplier=2)
+
+ result = list(itertools.islice(gen, 8))
+ assert result == [1, 2, 4, 8, 16, 32, 60, 60]
+
+
+@mock.patch("time.sleep", autospec=True)
+@mock.patch(
+ "google.api_core.datetime_helpers.utcnow",
+ return_value=datetime.datetime.min,
+ autospec=True,
+)
+def test_retry_target_success(utcnow, sleep):
+ predicate = retry.if_exception_type(ValueError)
+ call_count = [0]
+
+ def target():
+ call_count[0] += 1
+ if call_count[0] < 3:
+ raise ValueError()
+ return 42
+
+ result = retry.retry_target(target, predicate, range(10), None)
+
+ assert result == 42
+ assert call_count[0] == 3
+ sleep.assert_has_calls([mock.call(0), mock.call(1)])
+
+
+@mock.patch("time.sleep", autospec=True)
+@mock.patch(
+ "google.api_core.datetime_helpers.utcnow",
+ return_value=datetime.datetime.min,
+ autospec=True,
+)
+def test_retry_target_w_on_error(utcnow, sleep):
+ predicate = retry.if_exception_type(ValueError)
+ call_count = {"target": 0}
+ to_raise = ValueError()
+
+ def target():
+ call_count["target"] += 1
+ if call_count["target"] < 3:
+ raise to_raise
+ return 42
+
+ on_error = mock.Mock()
+
+ result = retry.retry_target(target, predicate, range(10), None, on_error=on_error)
+
+ assert result == 42
+ assert call_count["target"] == 3
+
+ on_error.assert_has_calls([mock.call(to_raise), mock.call(to_raise)])
+ sleep.assert_has_calls([mock.call(0), mock.call(1)])
+
+
+@mock.patch("time.sleep", autospec=True)
+@mock.patch(
+ "google.api_core.datetime_helpers.utcnow",
+ return_value=datetime.datetime.min,
+ autospec=True,
+)
+def test_retry_target_non_retryable_error(utcnow, sleep):
+ predicate = retry.if_exception_type(ValueError)
+ exception = TypeError()
+ target = mock.Mock(side_effect=exception)
+
+ with pytest.raises(TypeError) as exc_info:
+ retry.retry_target(target, predicate, range(10), None)
+
+ assert exc_info.value == exception
+ sleep.assert_not_called()
+
+
+@mock.patch("time.sleep", autospec=True)
+@mock.patch("google.api_core.datetime_helpers.utcnow", autospec=True)
+def test_retry_target_deadline_exceeded(utcnow, sleep):
+ predicate = retry.if_exception_type(ValueError)
+ exception = ValueError("meep")
+ target = mock.Mock(side_effect=exception)
+ # Setup the timeline so that the first call takes 5 seconds but the second
+ # call takes 6, which puts the retry over the deadline.
+ utcnow.side_effect = [
+ # The first call to utcnow establishes the start of the timeline.
+ datetime.datetime.min,
+ datetime.datetime.min + datetime.timedelta(seconds=5),
+ datetime.datetime.min + datetime.timedelta(seconds=11),
+ ]
+
+ with pytest.raises(exceptions.RetryError) as exc_info:
+ retry.retry_target(target, predicate, range(10), deadline=10)
+
+ assert exc_info.value.cause == exception
+ assert exc_info.match("Deadline of 10.0s exceeded")
+ assert exc_info.match("last exception: meep")
+ assert target.call_count == 2
+
+
+def test_retry_target_bad_sleep_generator():
+ with pytest.raises(ValueError, match="Sleep generator"):
+ retry.retry_target(mock.sentinel.target, mock.sentinel.predicate, [], None)
+
+
+class TestRetry(object):
+ def test_constructor_defaults(self):
+ retry_ = retry.Retry()
+ assert retry_._predicate == retry.if_transient_error
+ assert retry_._initial == 1
+ assert retry_._maximum == 60
+ assert retry_._multiplier == 2
+ assert retry_._deadline == 120
+ assert retry_._on_error is None
+ assert retry_.deadline == 120
+
+ def test_constructor_options(self):
+ _some_function = mock.Mock()
+
+ retry_ = retry.Retry(
+ predicate=mock.sentinel.predicate,
+ initial=1,
+ maximum=2,
+ multiplier=3,
+ deadline=4,
+ on_error=_some_function,
+ )
+ assert retry_._predicate == mock.sentinel.predicate
+ assert retry_._initial == 1
+ assert retry_._maximum == 2
+ assert retry_._multiplier == 3
+ assert retry_._deadline == 4
+ assert retry_._on_error is _some_function
+
+ def test_with_deadline(self):
+ retry_ = retry.Retry(
+ predicate=mock.sentinel.predicate,
+ initial=1,
+ maximum=2,
+ multiplier=3,
+ deadline=4,
+ on_error=mock.sentinel.on_error,
+ )
+ new_retry = retry_.with_deadline(42)
+ assert retry_ is not new_retry
+ assert new_retry._deadline == 42
+
+ # the rest of the attributes should remain the same
+ assert new_retry._predicate is retry_._predicate
+ assert new_retry._initial == retry_._initial
+ assert new_retry._maximum == retry_._maximum
+ assert new_retry._multiplier == retry_._multiplier
+ assert new_retry._on_error is retry_._on_error
+
+ def test_with_predicate(self):
+ retry_ = retry.Retry(
+ predicate=mock.sentinel.predicate,
+ initial=1,
+ maximum=2,
+ multiplier=3,
+ deadline=4,
+ on_error=mock.sentinel.on_error,
+ )
+ new_retry = retry_.with_predicate(mock.sentinel.predicate)
+ assert retry_ is not new_retry
+ assert new_retry._predicate == mock.sentinel.predicate
+
+ # the rest of the attributes should remain the same
+ assert new_retry._deadline == retry_._deadline
+ assert new_retry._initial == retry_._initial
+ assert new_retry._maximum == retry_._maximum
+ assert new_retry._multiplier == retry_._multiplier
+ assert new_retry._on_error is retry_._on_error
+
+ def test_with_delay_noop(self):
+ retry_ = retry.Retry(
+ predicate=mock.sentinel.predicate,
+ initial=1,
+ maximum=2,
+ multiplier=3,
+ deadline=4,
+ on_error=mock.sentinel.on_error,
+ )
+ new_retry = retry_.with_delay()
+ assert retry_ is not new_retry
+ assert new_retry._initial == retry_._initial
+ assert new_retry._maximum == retry_._maximum
+ assert new_retry._multiplier == retry_._multiplier
+
+ def test_with_delay(self):
+ retry_ = retry.Retry(
+ predicate=mock.sentinel.predicate,
+ initial=1,
+ maximum=2,
+ multiplier=3,
+ deadline=4,
+ on_error=mock.sentinel.on_error,
+ )
+ new_retry = retry_.with_delay(initial=5, maximum=6, multiplier=7)
+ assert retry_ is not new_retry
+ assert new_retry._initial == 5
+ assert new_retry._maximum == 6
+ assert new_retry._multiplier == 7
+
+ # the rest of the attributes should remain the same
+ assert new_retry._deadline == retry_._deadline
+ assert new_retry._predicate is retry_._predicate
+ assert new_retry._on_error is retry_._on_error
+
+ def test_with_delay_partial_options(self):
+ retry_ = retry.Retry(
+ predicate=mock.sentinel.predicate,
+ initial=1,
+ maximum=2,
+ multiplier=3,
+ deadline=4,
+ on_error=mock.sentinel.on_error,
+ )
+ new_retry = retry_.with_delay(initial=4)
+ assert retry_ is not new_retry
+ assert new_retry._initial == 4
+ assert new_retry._maximum == 2
+ assert new_retry._multiplier == 3
+
+ new_retry = retry_.with_delay(maximum=4)
+ assert retry_ is not new_retry
+ assert new_retry._initial == 1
+ assert new_retry._maximum == 4
+ assert new_retry._multiplier == 3
+
+ new_retry = retry_.with_delay(multiplier=4)
+ assert retry_ is not new_retry
+ assert new_retry._initial == 1
+ assert new_retry._maximum == 2
+ assert new_retry._multiplier == 4
+
+ # the rest of the attributes should remain the same
+ assert new_retry._deadline == retry_._deadline
+ assert new_retry._predicate is retry_._predicate
+ assert new_retry._on_error is retry_._on_error
+
+ def test___str__(self):
+ def if_exception_type(exc):
+ return bool(exc) # pragma: NO COVER
+
+ # Explicitly set all attributes as changed Retry defaults should not
+ # cause this test to start failing.
+ retry_ = retry.Retry(
+ predicate=if_exception_type,
+ initial=1.0,
+ maximum=60.0,
+ multiplier=2.0,
+ deadline=120.0,
+ on_error=None,
+ )
+ assert re.match(
+ (
+ r"<Retry predicate=<function.*?if_exception_type.*?>, "
+ r"initial=1.0, maximum=60.0, multiplier=2.0, deadline=120.0, "
+ r"on_error=None>"
+ ),
+ str(retry_),
+ )
+
+ @mock.patch("time.sleep", autospec=True)
+ def test___call___and_execute_success(self, sleep):
+ retry_ = retry.Retry()
+ target = mock.Mock(spec=["__call__"], return_value=42)
+ # __name__ is needed by functools.partial.
+ target.__name__ = "target"
+
+ decorated = retry_(target)
+ target.assert_not_called()
+
+ result = decorated("meep")
+
+ assert result == 42
+ target.assert_called_once_with("meep")
+ sleep.assert_not_called()
+
+ # Make uniform return half of its maximum, which is the calculated sleep time.
+ @mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n / 2.0)
+ @mock.patch("time.sleep", autospec=True)
+ def test___call___and_execute_retry(self, sleep, uniform):
+
+ on_error = mock.Mock(spec=["__call__"], side_effect=[None])
+ retry_ = retry.Retry(predicate=retry.if_exception_type(ValueError))
+
+ target = mock.Mock(spec=["__call__"], side_effect=[ValueError(), 42])
+ # __name__ is needed by functools.partial.
+ target.__name__ = "target"
+
+ decorated = retry_(target, on_error=on_error)
+ target.assert_not_called()
+
+ result = decorated("meep")
+
+ assert result == 42
+ assert target.call_count == 2
+ target.assert_has_calls([mock.call("meep"), mock.call("meep")])
+ sleep.assert_called_once_with(retry_._initial)
+ assert on_error.call_count == 1
+
+ # Make uniform return half of its maximum, which is the calculated sleep time.
+ @mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n / 2.0)
+ @mock.patch("time.sleep", autospec=True)
+ def test___call___and_execute_retry_hitting_deadline(self, sleep, uniform):
+
+ on_error = mock.Mock(spec=["__call__"], side_effect=[None] * 10)
+ retry_ = retry.Retry(
+ predicate=retry.if_exception_type(ValueError),
+ initial=1.0,
+ maximum=1024.0,
+ multiplier=2.0,
+ deadline=9.9,
+ )
+
+ utcnow = datetime.datetime.utcnow()
+ utcnow_patcher = mock.patch(
+ "google.api_core.datetime_helpers.utcnow", return_value=utcnow
+ )
+
+ target = mock.Mock(spec=["__call__"], side_effect=[ValueError()] * 10)
+ # __name__ is needed by functools.partial.
+ target.__name__ = "target"
+
+ decorated = retry_(target, on_error=on_error)
+ target.assert_not_called()
+
+ with utcnow_patcher as patched_utcnow:
+ # Make sure that calls to fake time.sleep() also advance the mocked
+ # time clock.
+ def increase_time(sleep_delay):
+ patched_utcnow.return_value += datetime.timedelta(seconds=sleep_delay)
+
+ sleep.side_effect = increase_time
+
+ with pytest.raises(exceptions.RetryError):
+ decorated("meep")
+
+ assert target.call_count == 5
+ target.assert_has_calls([mock.call("meep")] * 5)
+ assert on_error.call_count == 5
+
+ # check the delays
+ assert sleep.call_count == 4 # once between each successive target calls
+ last_wait = sleep.call_args.args[0]
+ total_wait = sum(call_args.args[0] for call_args in sleep.call_args_list)
+
+ assert last_wait == 2.9 # and not 8.0, because the last delay was shortened
+ assert total_wait == 9.9 # the same as the deadline
+
+ @mock.patch("time.sleep", autospec=True)
+ def test___init___without_retry_executed(self, sleep):
+ _some_function = mock.Mock()
+
+ retry_ = retry.Retry(
+ predicate=retry.if_exception_type(ValueError), on_error=_some_function
+ )
+ # check the proper creation of the class
+ assert retry_._on_error is _some_function
+
+ target = mock.Mock(spec=["__call__"], side_effect=[42])
+ # __name__ is needed by functools.partial.
+ target.__name__ = "target"
+
+ wrapped = retry_(target)
+
+ result = wrapped("meep")
+
+ assert result == 42
+ target.assert_called_once_with("meep")
+ sleep.assert_not_called()
+ _some_function.assert_not_called()
+
+ # Make uniform return half of its maximum, which is the calculated sleep time.
+ @mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n / 2.0)
+ @mock.patch("time.sleep", autospec=True)
+ def test___init___when_retry_is_executed(self, sleep, uniform):
+ _some_function = mock.Mock()
+
+ retry_ = retry.Retry(
+ predicate=retry.if_exception_type(ValueError), on_error=_some_function
+ )
+ # check the proper creation of the class
+ assert retry_._on_error is _some_function
+
+ target = mock.Mock(
+ spec=["__call__"], side_effect=[ValueError(), ValueError(), 42]
+ )
+ # __name__ is needed by functools.partial.
+ target.__name__ = "target"
+
+ wrapped = retry_(target)
+ target.assert_not_called()
+
+ result = wrapped("meep")
+
+ assert result == 42
+ assert target.call_count == 3
+ assert _some_function.call_count == 2
+ target.assert_has_calls([mock.call("meep"), mock.call("meep")])
+ sleep.assert_any_call(retry_._initial)
diff --git a/tests/unit/test_timeout.py b/tests/unit/test_timeout.py
new file mode 100644
index 0000000..30d624e
--- /dev/null
+++ b/tests/unit/test_timeout.py
@@ -0,0 +1,129 @@
+# Copyright 2017 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import itertools
+
+import mock
+
+from google.api_core import timeout
+
+
+def test__exponential_timeout_generator_base_2():
+ gen = timeout._exponential_timeout_generator(1.0, 60.0, 2.0, deadline=None)
+
+ result = list(itertools.islice(gen, 8))
+ assert result == [1, 2, 4, 8, 16, 32, 60, 60]
+
+
+@mock.patch("google.api_core.datetime_helpers.utcnow", autospec=True)
+def test__exponential_timeout_generator_base_deadline(utcnow):
+ # Make each successive call to utcnow() advance one second.
+ utcnow.side_effect = [
+ datetime.datetime.min + datetime.timedelta(seconds=n) for n in range(15)
+ ]
+
+ gen = timeout._exponential_timeout_generator(1.0, 60.0, 2.0, deadline=30.0)
+
+ result = list(itertools.islice(gen, 14))
+ # Should grow until the cumulative time is > 30s, then start decreasing as
+ # the cumulative time approaches 60s.
+ assert result == [1, 2, 4, 8, 16, 24, 23, 22, 21, 20, 19, 18, 17, 16]
+
+
+class TestConstantTimeout(object):
+ def test_constructor(self):
+ timeout_ = timeout.ConstantTimeout()
+ assert timeout_._timeout is None
+
+ def test_constructor_args(self):
+ timeout_ = timeout.ConstantTimeout(42.0)
+ assert timeout_._timeout == 42.0
+
+ def test___str__(self):
+ timeout_ = timeout.ConstantTimeout(1)
+ assert str(timeout_) == "<ConstantTimeout timeout=1.0>"
+
+ def test_apply(self):
+ target = mock.Mock(spec=["__call__", "__name__"], __name__="target")
+ timeout_ = timeout.ConstantTimeout(42.0)
+ wrapped = timeout_(target)
+
+ wrapped()
+
+ target.assert_called_once_with(timeout=42.0)
+
+ def test_apply_passthrough(self):
+ target = mock.Mock(spec=["__call__", "__name__"], __name__="target")
+ timeout_ = timeout.ConstantTimeout(42.0)
+ wrapped = timeout_(target)
+
+ wrapped(1, 2, meep="moop")
+
+ target.assert_called_once_with(1, 2, meep="moop", timeout=42.0)
+
+
+class TestExponentialTimeout(object):
+ def test_constructor(self):
+ timeout_ = timeout.ExponentialTimeout()
+ assert timeout_._initial == timeout._DEFAULT_INITIAL_TIMEOUT
+ assert timeout_._maximum == timeout._DEFAULT_MAXIMUM_TIMEOUT
+ assert timeout_._multiplier == timeout._DEFAULT_TIMEOUT_MULTIPLIER
+ assert timeout_._deadline == timeout._DEFAULT_DEADLINE
+
+ def test_constructor_args(self):
+ timeout_ = timeout.ExponentialTimeout(1, 2, 3, 4)
+ assert timeout_._initial == 1
+ assert timeout_._maximum == 2
+ assert timeout_._multiplier == 3
+ assert timeout_._deadline == 4
+
+ def test_with_timeout(self):
+ original_timeout = timeout.ExponentialTimeout()
+ timeout_ = original_timeout.with_deadline(42)
+ assert original_timeout is not timeout_
+ assert timeout_._initial == timeout._DEFAULT_INITIAL_TIMEOUT
+ assert timeout_._maximum == timeout._DEFAULT_MAXIMUM_TIMEOUT
+ assert timeout_._multiplier == timeout._DEFAULT_TIMEOUT_MULTIPLIER
+ assert timeout_._deadline == 42
+
+ def test___str__(self):
+ timeout_ = timeout.ExponentialTimeout(1, 2, 3, 4)
+ assert str(timeout_) == (
+ "<ExponentialTimeout initial=1.0, maximum=2.0, multiplier=3.0, "
+ "deadline=4.0>"
+ )
+
+ def test_apply(self):
+ target = mock.Mock(spec=["__call__", "__name__"], __name__="target")
+ timeout_ = timeout.ExponentialTimeout(1, 10, 2)
+ wrapped = timeout_(target)
+
+ wrapped()
+ target.assert_called_with(timeout=1)
+
+ wrapped()
+ target.assert_called_with(timeout=2)
+
+ wrapped()
+ target.assert_called_with(timeout=4)
+
+ def test_apply_passthrough(self):
+ target = mock.Mock(spec=["__call__", "__name__"], __name__="target")
+ timeout_ = timeout.ExponentialTimeout(42.0, 100, 2)
+ wrapped = timeout_(target)
+
+ wrapped(1, 2, meep="moop")
+
+ target.assert_called_once_with(1, 2, meep="moop", timeout=42.0)