diff options
Diffstat (limited to 'tests/unit')
25 files changed, 7368 insertions, 0 deletions
diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/unit/__init__.py diff --git a/tests/unit/future/__init__.py b/tests/unit/future/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/unit/future/__init__.py diff --git a/tests/unit/future/test__helpers.py b/tests/unit/future/test__helpers.py new file mode 100644 index 0000000..98afc59 --- /dev/null +++ b/tests/unit/future/test__helpers.py @@ -0,0 +1,37 @@ +# Copyright 2017, Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock + +from google.api_core.future import _helpers + + +@mock.patch("threading.Thread", autospec=True) +def test_start_deamon_thread(unused_thread): + deamon_thread = _helpers.start_daemon_thread(target=mock.sentinel.target) + assert deamon_thread.daemon is True + + +def test_safe_invoke_callback(): + callback = mock.Mock(spec=["__call__"], return_value=42) + result = _helpers.safe_invoke_callback(callback, "a", b="c") + assert result == 42 + callback.assert_called_once_with("a", b="c") + + +def test_safe_invoke_callback_exception(): + callback = mock.Mock(spec=["__call__"], side_effect=ValueError()) + result = _helpers.safe_invoke_callback(callback, "a", b="c") + assert result is None + callback.assert_called_once_with("a", b="c") diff --git a/tests/unit/future/test_polling.py b/tests/unit/future/test_polling.py new file mode 100644 index 0000000..2381d03 --- /dev/null +++ b/tests/unit/future/test_polling.py @@ -0,0 +1,242 @@ +# Copyright 2017, Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import concurrent.futures +import threading +import time + +import mock +import pytest + +from google.api_core import exceptions, retry +from google.api_core.future import polling + + +class PollingFutureImpl(polling.PollingFuture): + def done(self): + return False + + def cancel(self): + return True + + def cancelled(self): + return False + + def running(self): + return True + + +def test_polling_future_constructor(): + future = PollingFutureImpl() + assert not future.done() + assert not future.cancelled() + assert future.running() + assert future.cancel() + with mock.patch.object(future, "done", return_value=True): + future.result() + + +def test_set_result(): + future = PollingFutureImpl() + callback = mock.Mock() + + future.set_result(1) + + assert future.result() == 1 + future.add_done_callback(callback) + callback.assert_called_once_with(future) + + +def test_set_exception(): + future = PollingFutureImpl() + exception = ValueError("meep") + + future.set_exception(exception) + + assert future.exception() == exception + with pytest.raises(ValueError): + future.result() + + callback = mock.Mock() + future.add_done_callback(callback) + callback.assert_called_once_with(future) + + +def test_invoke_callback_exception(): + future = PollingFutureImplWithPoll() + future.set_result(42) + + # This should not raise, despite the callback causing an exception. + callback = mock.Mock(side_effect=ValueError) + future.add_done_callback(callback) + callback.assert_called_once_with(future) + + +class PollingFutureImplWithPoll(PollingFutureImpl): + def __init__(self): + super(PollingFutureImplWithPoll, self).__init__() + self.poll_count = 0 + self.event = threading.Event() + + def done(self, retry=polling.DEFAULT_RETRY): + self.poll_count += 1 + self.event.wait() + self.set_result(42) + return True + + +def test_result_with_polling(): + future = PollingFutureImplWithPoll() + + future.event.set() + result = future.result() + + assert result == 42 + assert future.poll_count == 1 + # Repeated calls should not cause additional polling + assert future.result() == result + assert future.poll_count == 1 + + +class PollingFutureImplTimeout(PollingFutureImplWithPoll): + def done(self, retry=polling.DEFAULT_RETRY): + time.sleep(1) + return False + + +def test_result_timeout(): + future = PollingFutureImplTimeout() + with pytest.raises(concurrent.futures.TimeoutError): + future.result(timeout=1) + + +def test_exception_timeout(): + future = PollingFutureImplTimeout() + with pytest.raises(concurrent.futures.TimeoutError): + future.exception(timeout=1) + + +class PollingFutureImplTransient(PollingFutureImplWithPoll): + def __init__(self, errors): + super(PollingFutureImplTransient, self).__init__() + self._errors = errors + + def done(self, retry=polling.DEFAULT_RETRY): + if self._errors: + error, self._errors = self._errors[0], self._errors[1:] + raise error("testing") + self.poll_count += 1 + self.set_result(42) + return True + + +def test_result_transient_error(): + future = PollingFutureImplTransient( + ( + exceptions.TooManyRequests, + exceptions.InternalServerError, + exceptions.BadGateway, + ) + ) + result = future.result() + assert result == 42 + assert future.poll_count == 1 + # Repeated calls should not cause additional polling + assert future.result() == result + assert future.poll_count == 1 + + +def test_callback_background_thread(): + future = PollingFutureImplWithPoll() + callback = mock.Mock() + + future.add_done_callback(callback) + + assert future._polling_thread is not None + + # Give the thread a second to poll + time.sleep(1) + assert future.poll_count == 1 + + future.event.set() + future._polling_thread.join() + + callback.assert_called_once_with(future) + + +def test_double_callback_background_thread(): + future = PollingFutureImplWithPoll() + callback = mock.Mock() + callback2 = mock.Mock() + + future.add_done_callback(callback) + current_thread = future._polling_thread + assert current_thread is not None + + # only one polling thread should be created. + future.add_done_callback(callback2) + assert future._polling_thread is current_thread + + future.event.set() + future._polling_thread.join() + + assert future.poll_count == 1 + callback.assert_called_once_with(future) + callback2.assert_called_once_with(future) + + +class PollingFutureImplWithoutRetry(PollingFutureImpl): + def done(self): + return True + + def result(self): + return super(PollingFutureImplWithoutRetry, self).result() + + def _blocking_poll(self, timeout): + return super(PollingFutureImplWithoutRetry, self)._blocking_poll( + timeout=timeout + ) + + +class PollingFutureImplWith_done_or_raise(PollingFutureImpl): + def done(self): + return True + + def _done_or_raise(self): + return super(PollingFutureImplWith_done_or_raise, self)._done_or_raise() + + +def test_polling_future_without_retry(): + custom_retry = retry.Retry( + predicate=retry.if_exception_type(exceptions.TooManyRequests) + ) + future = PollingFutureImplWithoutRetry() + assert future.done() + assert future.running() + assert future.result() is None + + with mock.patch.object(future, "done") as done_mock: + future._done_or_raise() + done_mock.assert_called_once_with() + + with mock.patch.object(future, "done") as done_mock: + future._done_or_raise(retry=custom_retry) + done_mock.assert_called_once_with(retry=custom_retry) + + +def test_polling_future_with__done_or_raise(): + future = PollingFutureImplWith_done_or_raise() + assert future.done() + assert future.running() + assert future.result() is None diff --git a/tests/unit/gapic/test_client_info.py b/tests/unit/gapic/test_client_info.py new file mode 100644 index 0000000..2ca5c40 --- /dev/null +++ b/tests/unit/gapic/test_client_info.py @@ -0,0 +1,31 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +try: + import grpc # noqa: F401 +except ImportError: + pytest.skip("No GRPC", allow_module_level=True) + + +from google.api_core.gapic_v1 import client_info + + +def test_to_grpc_metadata(): + info = client_info.ClientInfo() + + metadata = info.to_grpc_metadata() + + assert metadata == (client_info.METRICS_METADATA_KEY, info.to_user_agent()) diff --git a/tests/unit/gapic/test_config.py b/tests/unit/gapic/test_config.py new file mode 100644 index 0000000..5e42fde --- /dev/null +++ b/tests/unit/gapic/test_config.py @@ -0,0 +1,94 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +try: + import grpc # noqa: F401 +except ImportError: + pytest.skip("No GRPC", allow_module_level=True) + +from google.api_core import exceptions +from google.api_core.gapic_v1 import config + + +INTERFACE_CONFIG = { + "retry_codes": { + "idempotent": ["DEADLINE_EXCEEDED", "UNAVAILABLE"], + "other": ["FAILED_PRECONDITION"], + "non_idempotent": [], + }, + "retry_params": { + "default": { + "initial_retry_delay_millis": 1000, + "retry_delay_multiplier": 2.5, + "max_retry_delay_millis": 120000, + "initial_rpc_timeout_millis": 120000, + "rpc_timeout_multiplier": 1.0, + "max_rpc_timeout_millis": 120000, + "total_timeout_millis": 600000, + }, + "other": { + "initial_retry_delay_millis": 1000, + "retry_delay_multiplier": 1, + "max_retry_delay_millis": 1000, + "initial_rpc_timeout_millis": 1000, + "rpc_timeout_multiplier": 1, + "max_rpc_timeout_millis": 1000, + "total_timeout_millis": 1000, + }, + }, + "methods": { + "AnnotateVideo": { + "timeout_millis": 60000, + "retry_codes_name": "idempotent", + "retry_params_name": "default", + }, + "Other": { + "timeout_millis": 60000, + "retry_codes_name": "other", + "retry_params_name": "other", + }, + "Plain": {"timeout_millis": 30000}, + }, +} + + +def test_create_method_configs(): + method_configs = config.parse_method_configs(INTERFACE_CONFIG) + + retry, timeout = method_configs["AnnotateVideo"] + assert retry._predicate(exceptions.DeadlineExceeded(None)) + assert retry._predicate(exceptions.ServiceUnavailable(None)) + assert retry._initial == 1.0 + assert retry._multiplier == 2.5 + assert retry._maximum == 120.0 + assert retry._deadline == 600.0 + assert timeout._initial == 120.0 + assert timeout._multiplier == 1.0 + assert timeout._maximum == 120.0 + + retry, timeout = method_configs["Other"] + assert retry._predicate(exceptions.FailedPrecondition(None)) + assert retry._initial == 1.0 + assert retry._multiplier == 1.0 + assert retry._maximum == 1.0 + assert retry._deadline == 1.0 + assert timeout._initial == 1.0 + assert timeout._multiplier == 1.0 + assert timeout._maximum == 1.0 + + retry, timeout = method_configs["Plain"] + assert retry is None + assert timeout._timeout == 30.0 diff --git a/tests/unit/gapic/test_method.py b/tests/unit/gapic/test_method.py new file mode 100644 index 0000000..9778d23 --- /dev/null +++ b/tests/unit/gapic/test_method.py @@ -0,0 +1,244 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +import mock +import pytest + +try: + import grpc # noqa: F401 +except ImportError: + pytest.skip("No GRPC", allow_module_level=True) + + +from google.api_core import exceptions +from google.api_core import retry +from google.api_core import timeout +import google.api_core.gapic_v1.client_info +import google.api_core.gapic_v1.method +import google.api_core.page_iterator + + +def _utcnow_monotonic(): + curr_value = datetime.datetime.min + delta = datetime.timedelta(seconds=0.5) + while True: + yield curr_value + curr_value += delta + + +def test__determine_timeout(): + # Check _determine_timeout always returns a Timeout object. + timeout_type_timeout = timeout.ConstantTimeout(600.0) + returned_timeout = google.api_core.gapic_v1.method._determine_timeout( + 600.0, 600.0, None + ) + assert isinstance(returned_timeout, timeout.ConstantTimeout) + returned_timeout = google.api_core.gapic_v1.method._determine_timeout( + 600.0, timeout_type_timeout, None + ) + assert isinstance(returned_timeout, timeout.ConstantTimeout) + returned_timeout = google.api_core.gapic_v1.method._determine_timeout( + timeout_type_timeout, 600.0, None + ) + assert isinstance(returned_timeout, timeout.ConstantTimeout) + returned_timeout = google.api_core.gapic_v1.method._determine_timeout( + timeout_type_timeout, timeout_type_timeout, None + ) + assert isinstance(returned_timeout, timeout.ConstantTimeout) + + +def test_wrap_method_basic(): + method = mock.Mock(spec=["__call__"], return_value=42) + + wrapped_method = google.api_core.gapic_v1.method.wrap_method(method) + + result = wrapped_method(1, 2, meep="moop") + + assert result == 42 + method.assert_called_once_with(1, 2, meep="moop", metadata=mock.ANY) + + # Check that the default client info was specified in the metadata. + metadata = method.call_args[1]["metadata"] + assert len(metadata) == 1 + client_info = google.api_core.gapic_v1.client_info.DEFAULT_CLIENT_INFO + user_agent_metadata = client_info.to_grpc_metadata() + assert user_agent_metadata in metadata + + +def test_wrap_method_with_no_client_info(): + method = mock.Mock(spec=["__call__"]) + + wrapped_method = google.api_core.gapic_v1.method.wrap_method( + method, client_info=None + ) + + wrapped_method(1, 2, meep="moop") + + method.assert_called_once_with(1, 2, meep="moop") + + +def test_wrap_method_with_custom_client_info(): + client_info = google.api_core.gapic_v1.client_info.ClientInfo( + python_version=1, + grpc_version=2, + api_core_version=3, + gapic_version=4, + client_library_version=5, + ) + method = mock.Mock(spec=["__call__"]) + + wrapped_method = google.api_core.gapic_v1.method.wrap_method( + method, client_info=client_info + ) + + wrapped_method(1, 2, meep="moop") + + method.assert_called_once_with(1, 2, meep="moop", metadata=mock.ANY) + + # Check that the custom client info was specified in the metadata. + metadata = method.call_args[1]["metadata"] + assert client_info.to_grpc_metadata() in metadata + + +def test_invoke_wrapped_method_with_metadata(): + method = mock.Mock(spec=["__call__"]) + + wrapped_method = google.api_core.gapic_v1.method.wrap_method(method) + + wrapped_method(mock.sentinel.request, metadata=[("a", "b")]) + + method.assert_called_once_with(mock.sentinel.request, metadata=mock.ANY) + metadata = method.call_args[1]["metadata"] + # Metadata should have two items: the client info metadata and our custom + # metadata. + assert len(metadata) == 2 + assert ("a", "b") in metadata + + +def test_invoke_wrapped_method_with_metadata_as_none(): + method = mock.Mock(spec=["__call__"]) + + wrapped_method = google.api_core.gapic_v1.method.wrap_method(method) + + wrapped_method(mock.sentinel.request, metadata=None) + + method.assert_called_once_with(mock.sentinel.request, metadata=mock.ANY) + metadata = method.call_args[1]["metadata"] + # Metadata should have just one items: the client info metadata. + assert len(metadata) == 1 + + +@mock.patch("time.sleep") +def test_wrap_method_with_default_retry_and_timeout(unusued_sleep): + method = mock.Mock( + spec=["__call__"], side_effect=[exceptions.InternalServerError(None), 42] + ) + default_retry = retry.Retry() + default_timeout = timeout.ConstantTimeout(60) + wrapped_method = google.api_core.gapic_v1.method.wrap_method( + method, default_retry, default_timeout + ) + + result = wrapped_method() + + assert result == 42 + assert method.call_count == 2 + method.assert_called_with(timeout=60, metadata=mock.ANY) + + +@mock.patch("time.sleep") +def test_wrap_method_with_default_retry_and_timeout_using_sentinel(unusued_sleep): + method = mock.Mock( + spec=["__call__"], side_effect=[exceptions.InternalServerError(None), 42] + ) + default_retry = retry.Retry() + default_timeout = timeout.ConstantTimeout(60) + wrapped_method = google.api_core.gapic_v1.method.wrap_method( + method, default_retry, default_timeout + ) + + result = wrapped_method( + retry=google.api_core.gapic_v1.method.DEFAULT, + timeout=google.api_core.gapic_v1.method.DEFAULT, + ) + + assert result == 42 + assert method.call_count == 2 + method.assert_called_with(timeout=60, metadata=mock.ANY) + + +@mock.patch("time.sleep") +def test_wrap_method_with_overriding_retry_and_timeout(unusued_sleep): + method = mock.Mock(spec=["__call__"], side_effect=[exceptions.NotFound(None), 42]) + default_retry = retry.Retry() + default_timeout = timeout.ConstantTimeout(60) + wrapped_method = google.api_core.gapic_v1.method.wrap_method( + method, default_retry, default_timeout + ) + + result = wrapped_method( + retry=retry.Retry(retry.if_exception_type(exceptions.NotFound)), + timeout=timeout.ConstantTimeout(22), + ) + + assert result == 42 + assert method.call_count == 2 + method.assert_called_with(timeout=22, metadata=mock.ANY) + + +@mock.patch("time.sleep") +@mock.patch( + "google.api_core.datetime_helpers.utcnow", + side_effect=_utcnow_monotonic(), + autospec=True, +) +def test_wrap_method_with_overriding_retry_deadline(utcnow, unused_sleep): + method = mock.Mock( + spec=["__call__"], + side_effect=([exceptions.InternalServerError(None)] * 4) + [42], + ) + default_retry = retry.Retry() + default_timeout = timeout.ExponentialTimeout(deadline=60) + wrapped_method = google.api_core.gapic_v1.method.wrap_method( + method, default_retry, default_timeout + ) + + # Overriding only the retry's deadline should also override the timeout's + # deadline. + result = wrapped_method(retry=default_retry.with_deadline(30)) + + assert result == 42 + timeout_args = [call[1]["timeout"] for call in method.call_args_list] + assert timeout_args == [5.0, 10.0, 20.0, 26.0, 25.0] + assert utcnow.call_count == ( + 1 + + 5 # First to set the deadline. + + 5 # One for each min(timeout, maximum, (DEADLINE - NOW).seconds) + ) + + +def test_wrap_method_with_overriding_timeout_as_a_number(): + method = mock.Mock(spec=["__call__"], return_value=42) + default_retry = retry.Retry() + default_timeout = timeout.ConstantTimeout(60) + wrapped_method = google.api_core.gapic_v1.method.wrap_method( + method, default_retry, default_timeout + ) + + result = wrapped_method(timeout=22) + + assert result == 42 + method.assert_called_once_with(timeout=22, metadata=mock.ANY) diff --git a/tests/unit/gapic/test_routing_header.py b/tests/unit/gapic/test_routing_header.py new file mode 100644 index 0000000..3037867 --- /dev/null +++ b/tests/unit/gapic/test_routing_header.py @@ -0,0 +1,41 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +try: + import grpc # noqa: F401 +except ImportError: + pytest.skip("No GRPC", allow_module_level=True) + + +from google.api_core.gapic_v1 import routing_header + + +def test_to_routing_header(): + params = [("name", "meep"), ("book.read", "1")] + value = routing_header.to_routing_header(params) + assert value == "name=meep&book.read=1" + + +def test_to_routing_header_with_slashes(): + params = [("name", "me/ep"), ("book.read", "1&2")] + value = routing_header.to_routing_header(params) + assert value == "name=me/ep&book.read=1%262" + + +def test_to_grpc_metadata(): + params = [("name", "meep"), ("book.read", "1")] + metadata = routing_header.to_grpc_metadata(params) + assert metadata == (routing_header.ROUTING_METADATA_KEY, "name=meep&book.read=1") diff --git a/tests/unit/operations_v1/__init__.py b/tests/unit/operations_v1/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/unit/operations_v1/__init__.py diff --git a/tests/unit/operations_v1/test_operations_client.py b/tests/unit/operations_v1/test_operations_client.py new file mode 100644 index 0000000..187f0be --- /dev/null +++ b/tests/unit/operations_v1/test_operations_client.py @@ -0,0 +1,98 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +try: + import grpc # noqa: F401 +except ImportError: + pytest.skip("No GRPC", allow_module_level=True) + +from google.api_core import grpc_helpers +from google.api_core import operations_v1 +from google.api_core import page_iterator +from google.longrunning import operations_pb2 +from google.protobuf import empty_pb2 + + +def test_get_operation(): + channel = grpc_helpers.ChannelStub() + client = operations_v1.OperationsClient(channel) + channel.GetOperation.response = operations_pb2.Operation(name="meep") + + response = client.get_operation("name", metadata=[("header", "foo")]) + + assert ("header", "foo") in channel.GetOperation.calls[0].metadata + assert ("x-goog-request-params", "name=name") in channel.GetOperation.calls[ + 0 + ].metadata + assert len(channel.GetOperation.requests) == 1 + assert channel.GetOperation.requests[0].name == "name" + assert response == channel.GetOperation.response + + +def test_list_operations(): + channel = grpc_helpers.ChannelStub() + client = operations_v1.OperationsClient(channel) + operations = [ + operations_pb2.Operation(name="1"), + operations_pb2.Operation(name="2"), + ] + list_response = operations_pb2.ListOperationsResponse(operations=operations) + channel.ListOperations.response = list_response + + response = client.list_operations("name", "filter", metadata=[("header", "foo")]) + + assert isinstance(response, page_iterator.Iterator) + assert list(response) == operations + + assert ("header", "foo") in channel.ListOperations.calls[0].metadata + assert ("x-goog-request-params", "name=name") in channel.ListOperations.calls[ + 0 + ].metadata + assert len(channel.ListOperations.requests) == 1 + request = channel.ListOperations.requests[0] + assert isinstance(request, operations_pb2.ListOperationsRequest) + assert request.name == "name" + assert request.filter == "filter" + + +def test_delete_operation(): + channel = grpc_helpers.ChannelStub() + client = operations_v1.OperationsClient(channel) + channel.DeleteOperation.response = empty_pb2.Empty() + + client.delete_operation("name", metadata=[("header", "foo")]) + + assert ("header", "foo") in channel.DeleteOperation.calls[0].metadata + assert ("x-goog-request-params", "name=name") in channel.DeleteOperation.calls[ + 0 + ].metadata + assert len(channel.DeleteOperation.requests) == 1 + assert channel.DeleteOperation.requests[0].name == "name" + + +def test_cancel_operation(): + channel = grpc_helpers.ChannelStub() + client = operations_v1.OperationsClient(channel) + channel.CancelOperation.response = empty_pb2.Empty() + + client.cancel_operation("name", metadata=[("header", "foo")]) + + assert ("header", "foo") in channel.CancelOperation.calls[0].metadata + assert ("x-goog-request-params", "name=name") in channel.CancelOperation.calls[ + 0 + ].metadata + assert len(channel.CancelOperation.requests) == 1 + assert channel.CancelOperation.requests[0].name == "name" diff --git a/tests/unit/operations_v1/test_operations_rest_client.py b/tests/unit/operations_v1/test_operations_rest_client.py new file mode 100644 index 0000000..dddf6b7 --- /dev/null +++ b/tests/unit/operations_v1/test_operations_rest_client.py @@ -0,0 +1,944 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +import mock +import pytest + +try: + import grpc # noqa: F401 +except ImportError: + pytest.skip("No GRPC", allow_module_level=True) +from requests import Response # noqa I201 +from requests.sessions import Session + +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core.operations_v1 import AbstractOperationsClient +from google.api_core.operations_v1 import pagers +from google.api_core.operations_v1 import transports +import google.auth +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import json_format # type: ignore +from google.rpc import status_pb2 # type: ignore + + +HTTP_OPTIONS = { + "google.longrunning.Operations.CancelOperation": [ + {"method": "post", "uri": "/v3/{name=operations/*}:cancel", "body": "*"}, + ], + "google.longrunning.Operations.DeleteOperation": [ + {"method": "delete", "uri": "/v3/{name=operations/*}"}, + ], + "google.longrunning.Operations.GetOperation": [ + {"method": "get", "uri": "/v3/{name=operations/*}"}, + ], + "google.longrunning.Operations.ListOperations": [ + {"method": "get", "uri": "/v3/{name=operations}"}, + ], +} + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +def _get_operations_client(http_options=HTTP_OPTIONS): + transport = transports.rest.OperationsRestTransport( + credentials=ga_credentials.AnonymousCredentials(), http_options=http_options + ) + + return AbstractOperationsClient(transport=transport) + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert AbstractOperationsClient._get_default_mtls_endpoint(None) is None + assert ( + AbstractOperationsClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + AbstractOperationsClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + AbstractOperationsClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + AbstractOperationsClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + AbstractOperationsClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) + + +@pytest.mark.parametrize("client_class", [AbstractOperationsClient]) +def test_operations_client_from_service_account_info(client_class): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == "longrunning.googleapis.com:443" + + +@pytest.mark.parametrize( + "transport_class,transport_name", [(transports.OperationsRestTransport, "rest")] +) +def test_operations_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize("client_class", [AbstractOperationsClient]) +def test_operations_client_from_service_account_file(client_class): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == "longrunning.googleapis.com:443" + + +def test_operations_client_get_transport_class(): + transport = AbstractOperationsClient.get_transport_class() + available_transports = [ + transports.OperationsRestTransport, + ] + assert transport in available_transports + + transport = AbstractOperationsClient.get_transport_class("rest") + assert transport == transports.OperationsRestTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [(AbstractOperationsClient, transports.OperationsRestTransport, "rest")], +) +@mock.patch.object( + AbstractOperationsClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(AbstractOperationsClient), +) +def test_operations_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(AbstractOperationsClient, "get_transport_class") as gtc: + transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(AbstractOperationsClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (AbstractOperationsClient, transports.OperationsRestTransport, "rest", "true"), + (AbstractOperationsClient, transports.OperationsRestTransport, "rest", "false"), + ], +) +@mock.patch.object( + AbstractOperationsClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(AbstractOperationsClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_operations_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + + def fake_init(client_cert_source_for_mtls=None, **kwargs): + """Invoke client_cert source if provided.""" + + if client_cert_source_for_mtls: + client_cert_source_for_mtls() + return None + + with mock.patch.object(transport_class, "__init__") as patched: + patched.side_effect = fake_init + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [(AbstractOperationsClient, transports.OperationsRestTransport, "rest")], +) +def test_operations_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [(AbstractOperationsClient, transports.OperationsRestTransport, "rest")], +) +def test_operations_client_client_options_credentials_file( + client_class, transport_class, transport_name +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + +def test_list_operations_rest( + transport: str = "rest", request_type=operations_pb2.ListOperationsRequest +): + client = _get_operations_client() + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.ListOperationsResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.list_operations( + name="operations", filter_="my_filter", page_size=10, page_token="abc" + ) + + actual_args = req.call_args + assert actual_args.args[0] == "GET" + assert ( + actual_args.args[1] + == "https://longrunning.googleapis.com:443/v3/operations" + ) + assert actual_args.kwargs["params"] == [ + ("filter", "my_filter"), + ("pageSize", 10), + ("pageToken", "abc"), + ] + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListOperationsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_operations_rest_failure(): + client = _get_operations_client(http_options=None) + + with mock.patch.object(Session, "request") as req: + response_value = Response() + response_value.status_code = 400 + mock_request = mock.MagicMock() + mock_request.method = "GET" + mock_request.url = "https://longrunning.googleapis.com:443/v1/operations" + response_value.request = mock_request + req.return_value = response_value + with pytest.raises(core_exceptions.GoogleAPIError): + client.list_operations(name="operations") + + +def test_list_operations_rest_pager(): + client = AbstractOperationsClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + operations_pb2.ListOperationsResponse( + operations=[ + operations_pb2.Operation(), + operations_pb2.Operation(), + operations_pb2.Operation(), + ], + next_page_token="abc", + ), + operations_pb2.ListOperationsResponse( + operations=[], next_page_token="def", + ), + operations_pb2.ListOperationsResponse( + operations=[operations_pb2.Operation()], next_page_token="ghi", + ), + operations_pb2.ListOperationsResponse( + operations=[operations_pb2.Operation(), operations_pb2.Operation()], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(json_format.MessageToJson(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + pager = client.list_operations(name="operations") + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, operations_pb2.Operation) for i in results) + + pages = list(client.list_operations(name="operations").pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.next_page_token == token + + +def test_get_operation_rest( + transport: str = "rest", request_type=operations_pb2.GetOperationRequest +): + client = _get_operations_client() + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation( + name="operations/sample1", done=True, error=status_pb2.Status(code=411), + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.get_operation("operations/sample1") + + actual_args = req.call_args + assert actual_args.args[0] == "GET" + assert ( + actual_args.args[1] + == "https://longrunning.googleapis.com:443/v3/operations/sample1" + ) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + assert response.name == "operations/sample1" + assert response.done is True + + +def test_get_operation_rest_failure(): + client = _get_operations_client(http_options=None) + + with mock.patch.object(Session, "request") as req: + response_value = Response() + response_value.status_code = 400 + mock_request = mock.MagicMock() + mock_request.method = "GET" + mock_request.url = ( + "https://longrunning.googleapis.com:443/v1/operations/sample1" + ) + response_value.request = mock_request + req.return_value = response_value + with pytest.raises(core_exceptions.GoogleAPIError): + client.get_operation("operations/sample1") + + +def test_delete_operation_rest( + transport: str = "rest", request_type=operations_pb2.DeleteOperationRequest +): + client = _get_operations_client() + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + client.delete_operation(name="operations/sample1") + assert req.call_count == 1 + actual_args = req.call_args + assert actual_args.args[0] == "DELETE" + assert ( + actual_args.args[1] + == "https://longrunning.googleapis.com:443/v3/operations/sample1" + ) + + +def test_delete_operation_rest_failure(): + client = _get_operations_client(http_options=None) + + with mock.patch.object(Session, "request") as req: + response_value = Response() + response_value.status_code = 400 + mock_request = mock.MagicMock() + mock_request.method = "DELETE" + mock_request.url = ( + "https://longrunning.googleapis.com:443/v1/operations/sample1" + ) + response_value.request = mock_request + req.return_value = response_value + with pytest.raises(core_exceptions.GoogleAPIError): + client.delete_operation(name="operations/sample1") + + +def test_cancel_operation_rest(transport: str = "rest"): + client = _get_operations_client() + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + client.cancel_operation(name="operations/sample1") + assert req.call_count == 1 + actual_args = req.call_args + assert actual_args.args[0] == "POST" + assert ( + actual_args.args[1] + == "https://longrunning.googleapis.com:443/v3/operations/sample1:cancel" + ) + + +def test_cancel_operation_rest_failure(): + client = _get_operations_client(http_options=None) + + with mock.patch.object(Session, "request") as req: + response_value = Response() + response_value.status_code = 400 + mock_request = mock.MagicMock() + mock_request.method = "POST" + mock_request.url = ( + "https://longrunning.googleapis.com:443/v1/operations/sample1:cancel" + ) + response_value.request = mock_request + req.return_value = response_value + with pytest.raises(core_exceptions.GoogleAPIError): + client.cancel_operation(name="operations/sample1") + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.OperationsRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + AbstractOperationsClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.OperationsRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + AbstractOperationsClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.OperationsRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + AbstractOperationsClient( + client_options={"scopes": ["1", "2"]}, transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.OperationsRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = AbstractOperationsClient(transport=transport) + assert client.transport is transport + + +@pytest.mark.parametrize("transport_class", [transports.OperationsRestTransport]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_operations_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transports.OperationsTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_operations_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.api_core.operations_v1.transports.OperationsTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.OperationsTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "list_operations", + "get_operation", + "delete_operation", + "cancel_operation", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + +def test_operations_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.api_core.operations_v1.transports.OperationsTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transports.OperationsTransport( + credentials_file="credentials.json", quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=None, + default_scopes=(), + quota_project_id="octopus", + ) + + +def test_operations_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( + "google.api_core.operations_v1.transports.OperationsTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transports.OperationsTransport() + adc.assert_called_once() + + +def test_operations_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + AbstractOperationsClient() + adc.assert_called_once_with( + scopes=None, default_scopes=(), quota_project_id=None, + ) + + +def test_operations_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.OperationsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +def test_operations_host_no_port(): + client = AbstractOperationsClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="longrunning.googleapis.com" + ), + ) + assert client.transport._host == "longrunning.googleapis.com:443" + + +def test_operations_host_with_port(): + client = AbstractOperationsClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="longrunning.googleapis.com:8000" + ), + ) + assert client.transport._host == "longrunning.googleapis.com:8000" + + +def test_common_billing_account_path(): + billing_account = "squid" + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = AbstractOperationsClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "clam", + } + path = AbstractOperationsClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = AbstractOperationsClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "whelk" + expected = "folders/{folder}".format(folder=folder,) + actual = AbstractOperationsClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "octopus", + } + path = AbstractOperationsClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = AbstractOperationsClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "oyster" + expected = "organizations/{organization}".format(organization=organization,) + actual = AbstractOperationsClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nudibranch", + } + path = AbstractOperationsClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = AbstractOperationsClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "cuttlefish" + expected = "projects/{project}".format(project=project,) + actual = AbstractOperationsClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "mussel", + } + path = AbstractOperationsClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = AbstractOperationsClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "winkle" + location = "nautilus" + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = AbstractOperationsClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "scallop", + "location": "abalone", + } + path = AbstractOperationsClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = AbstractOperationsClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.OperationsTransport, "_prep_wrapped_messages" + ) as prep: + AbstractOperationsClient( + credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.OperationsTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = AbstractOperationsClient.get_transport_class() + transport_class( + credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/test_bidi.py b/tests/unit/test_bidi.py new file mode 100644 index 0000000..7fb1620 --- /dev/null +++ b/tests/unit/test_bidi.py @@ -0,0 +1,869 @@ +# Copyright 2018, Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import logging +import queue +import threading + +import mock +import pytest + +try: + import grpc +except ImportError: + pytest.skip("No GRPC", allow_module_level=True) + +from google.api_core import bidi +from google.api_core import exceptions + + +class Test_RequestQueueGenerator(object): + def test_bounded_consume(self): + call = mock.create_autospec(grpc.Call, instance=True) + call.is_active.return_value = True + + def queue_generator(rpc): + yield mock.sentinel.A + yield queue.Empty() + yield mock.sentinel.B + rpc.is_active.return_value = False + yield mock.sentinel.C + + q = mock.create_autospec(queue.Queue, instance=True) + q.get.side_effect = queue_generator(call) + + generator = bidi._RequestQueueGenerator(q) + generator.call = call + + items = list(generator) + + assert items == [mock.sentinel.A, mock.sentinel.B] + + def test_yield_initial_and_exit(self): + q = mock.create_autospec(queue.Queue, instance=True) + q.get.side_effect = queue.Empty() + call = mock.create_autospec(grpc.Call, instance=True) + call.is_active.return_value = False + + generator = bidi._RequestQueueGenerator(q, initial_request=mock.sentinel.A) + generator.call = call + + items = list(generator) + + assert items == [mock.sentinel.A] + + def test_yield_initial_callable_and_exit(self): + q = mock.create_autospec(queue.Queue, instance=True) + q.get.side_effect = queue.Empty() + call = mock.create_autospec(grpc.Call, instance=True) + call.is_active.return_value = False + + generator = bidi._RequestQueueGenerator( + q, initial_request=lambda: mock.sentinel.A + ) + generator.call = call + + items = list(generator) + + assert items == [mock.sentinel.A] + + def test_exit_when_inactive_with_item(self): + q = mock.create_autospec(queue.Queue, instance=True) + q.get.side_effect = [mock.sentinel.A, queue.Empty()] + call = mock.create_autospec(grpc.Call, instance=True) + call.is_active.return_value = False + + generator = bidi._RequestQueueGenerator(q) + generator.call = call + + items = list(generator) + + assert items == [] + # Make sure it put the item back. + q.put.assert_called_once_with(mock.sentinel.A) + + def test_exit_when_inactive_empty(self): + q = mock.create_autospec(queue.Queue, instance=True) + q.get.side_effect = queue.Empty() + call = mock.create_autospec(grpc.Call, instance=True) + call.is_active.return_value = False + + generator = bidi._RequestQueueGenerator(q) + generator.call = call + + items = list(generator) + + assert items == [] + + def test_exit_with_stop(self): + q = mock.create_autospec(queue.Queue, instance=True) + q.get.side_effect = [None, queue.Empty()] + call = mock.create_autospec(grpc.Call, instance=True) + call.is_active.return_value = True + + generator = bidi._RequestQueueGenerator(q) + generator.call = call + + items = list(generator) + + assert items == [] + + +class Test_Throttle(object): + def test_repr(self): + delta = datetime.timedelta(seconds=4.5) + instance = bidi._Throttle(access_limit=42, time_window=delta) + assert repr(instance) == "_Throttle(access_limit=42, time_window={})".format( + repr(delta) + ) + + def test_raises_error_on_invalid_init_arguments(self): + with pytest.raises(ValueError) as exc_info: + bidi._Throttle(access_limit=10, time_window=datetime.timedelta(seconds=0.0)) + assert "time_window" in str(exc_info.value) + assert "must be a positive timedelta" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + bidi._Throttle(access_limit=0, time_window=datetime.timedelta(seconds=10)) + assert "access_limit" in str(exc_info.value) + assert "must be positive" in str(exc_info.value) + + def test_does_not_delay_entry_attempts_under_threshold(self): + throttle = bidi._Throttle( + access_limit=3, time_window=datetime.timedelta(seconds=1) + ) + entries = [] + + for _ in range(3): + with throttle as time_waited: + entry_info = { + "entered_at": datetime.datetime.now(), + "reported_wait": time_waited, + } + entries.append(entry_info) + + # check the reported wait times ... + assert all(entry["reported_wait"] == 0.0 for entry in entries) + + # .. and the actual wait times + delta = entries[1]["entered_at"] - entries[0]["entered_at"] + assert delta.total_seconds() < 0.1 + delta = entries[2]["entered_at"] - entries[1]["entered_at"] + assert delta.total_seconds() < 0.1 + + def test_delays_entry_attempts_above_threshold(self): + throttle = bidi._Throttle( + access_limit=3, time_window=datetime.timedelta(seconds=1) + ) + entries = [] + + for _ in range(6): + with throttle as time_waited: + entry_info = { + "entered_at": datetime.datetime.now(), + "reported_wait": time_waited, + } + entries.append(entry_info) + + # For each group of 4 consecutive entries the time difference between + # the first and the last entry must have been greater than time_window, + # because a maximum of 3 are allowed in each time_window. + for i, entry in enumerate(entries[3:], start=3): + first_entry = entries[i - 3] + delta = entry["entered_at"] - first_entry["entered_at"] + assert delta.total_seconds() > 1.0 + + # check the reported wait times + # (NOTE: not using assert all(...), b/c the coverage check would complain) + for i, entry in enumerate(entries): + if i != 3: + assert entry["reported_wait"] == 0.0 + + # The delayed entry is expected to have been delayed for a significant + # chunk of the full second, and the actual and reported delay times + # should reflect that. + assert entries[3]["reported_wait"] > 0.7 + delta = entries[3]["entered_at"] - entries[2]["entered_at"] + assert delta.total_seconds() > 0.7 + + +class _CallAndFuture(grpc.Call, grpc.Future): + pass + + +def make_rpc(): + """Makes a mock RPC used to test Bidi classes.""" + call = mock.create_autospec(_CallAndFuture, instance=True) + rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True) + + def rpc_side_effect(request, metadata=None): + call.is_active.return_value = True + call.request = request + call.metadata = metadata + return call + + rpc.side_effect = rpc_side_effect + + def cancel_side_effect(): + call.is_active.return_value = False + + call.cancel.side_effect = cancel_side_effect + + return rpc, call + + +class ClosedCall(object): + def __init__(self, exception): + self.exception = exception + + def __next__(self): + raise self.exception + + def is_active(self): + return False + + +class TestBidiRpc(object): + def test_initial_state(self): + bidi_rpc = bidi.BidiRpc(None) + + assert bidi_rpc.is_active is False + + def test_done_callbacks(self): + bidi_rpc = bidi.BidiRpc(None) + callback = mock.Mock(spec=["__call__"]) + + bidi_rpc.add_done_callback(callback) + bidi_rpc._on_call_done(mock.sentinel.future) + + callback.assert_called_once_with(mock.sentinel.future) + + def test_metadata(self): + rpc, call = make_rpc() + bidi_rpc = bidi.BidiRpc(rpc, metadata=mock.sentinel.A) + assert bidi_rpc._rpc_metadata == mock.sentinel.A + + bidi_rpc.open() + assert bidi_rpc.call == call + assert bidi_rpc.call.metadata == mock.sentinel.A + + def test_open(self): + rpc, call = make_rpc() + bidi_rpc = bidi.BidiRpc(rpc) + + bidi_rpc.open() + + assert bidi_rpc.call == call + assert bidi_rpc.is_active + call.add_done_callback.assert_called_once_with(bidi_rpc._on_call_done) + + def test_open_error_already_open(self): + rpc, _ = make_rpc() + bidi_rpc = bidi.BidiRpc(rpc) + + bidi_rpc.open() + + with pytest.raises(ValueError): + bidi_rpc.open() + + def test_close(self): + rpc, call = make_rpc() + bidi_rpc = bidi.BidiRpc(rpc) + bidi_rpc.open() + + bidi_rpc.close() + + call.cancel.assert_called_once() + assert bidi_rpc.call == call + assert bidi_rpc.is_active is False + # ensure the request queue was signaled to stop. + assert bidi_rpc.pending_requests == 1 + assert bidi_rpc._request_queue.get() is None + + def test_close_no_rpc(self): + bidi_rpc = bidi.BidiRpc(None) + bidi_rpc.close() + + def test_send(self): + rpc, call = make_rpc() + bidi_rpc = bidi.BidiRpc(rpc) + bidi_rpc.open() + + bidi_rpc.send(mock.sentinel.request) + + assert bidi_rpc.pending_requests == 1 + assert bidi_rpc._request_queue.get() is mock.sentinel.request + + def test_send_not_open(self): + rpc, call = make_rpc() + bidi_rpc = bidi.BidiRpc(rpc) + + with pytest.raises(ValueError): + bidi_rpc.send(mock.sentinel.request) + + def test_send_dead_rpc(self): + error = ValueError() + bidi_rpc = bidi.BidiRpc(None) + bidi_rpc.call = ClosedCall(error) + + with pytest.raises(ValueError) as exc_info: + bidi_rpc.send(mock.sentinel.request) + + assert exc_info.value == error + + def test_recv(self): + bidi_rpc = bidi.BidiRpc(None) + bidi_rpc.call = iter([mock.sentinel.response]) + + response = bidi_rpc.recv() + + assert response == mock.sentinel.response + + def test_recv_not_open(self): + rpc, call = make_rpc() + bidi_rpc = bidi.BidiRpc(rpc) + + with pytest.raises(ValueError): + bidi_rpc.recv() + + +class CallStub(object): + def __init__(self, values, active=True): + self.values = iter(values) + self._is_active = active + self.cancelled = False + + def __next__(self): + item = next(self.values) + if isinstance(item, Exception): + self._is_active = False + raise item + return item + + def is_active(self): + return self._is_active + + def add_done_callback(self, callback): + pass + + def cancel(self): + self.cancelled = True + + +class TestResumableBidiRpc(object): + def test_ctor_defaults(self): + start_rpc = mock.Mock() + should_recover = mock.Mock() + bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover) + + assert bidi_rpc.is_active is False + assert bidi_rpc._finalized is False + assert bidi_rpc._start_rpc is start_rpc + assert bidi_rpc._should_recover is should_recover + assert bidi_rpc._should_terminate is bidi._never_terminate + assert bidi_rpc._initial_request is None + assert bidi_rpc._rpc_metadata is None + assert bidi_rpc._reopen_throttle is None + + def test_ctor_explicit(self): + start_rpc = mock.Mock() + should_recover = mock.Mock() + should_terminate = mock.Mock() + initial_request = mock.Mock() + metadata = {"x-foo": "bar"} + bidi_rpc = bidi.ResumableBidiRpc( + start_rpc, + should_recover, + should_terminate=should_terminate, + initial_request=initial_request, + metadata=metadata, + throttle_reopen=True, + ) + + assert bidi_rpc.is_active is False + assert bidi_rpc._finalized is False + assert bidi_rpc._should_recover is should_recover + assert bidi_rpc._should_terminate is should_terminate + assert bidi_rpc._initial_request is initial_request + assert bidi_rpc._rpc_metadata == metadata + assert isinstance(bidi_rpc._reopen_throttle, bidi._Throttle) + + def test_done_callbacks_terminate(self): + cancellation = mock.Mock() + start_rpc = mock.Mock() + should_recover = mock.Mock(spec=["__call__"], return_value=True) + should_terminate = mock.Mock(spec=["__call__"], return_value=True) + bidi_rpc = bidi.ResumableBidiRpc( + start_rpc, should_recover, should_terminate=should_terminate + ) + callback = mock.Mock(spec=["__call__"]) + + bidi_rpc.add_done_callback(callback) + bidi_rpc._on_call_done(cancellation) + + should_terminate.assert_called_once_with(cancellation) + should_recover.assert_not_called() + callback.assert_called_once_with(cancellation) + assert not bidi_rpc.is_active + + def test_done_callbacks_recoverable(self): + start_rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True) + should_recover = mock.Mock(spec=["__call__"], return_value=True) + bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover) + callback = mock.Mock(spec=["__call__"]) + + bidi_rpc.add_done_callback(callback) + bidi_rpc._on_call_done(mock.sentinel.future) + + callback.assert_not_called() + start_rpc.assert_called_once() + should_recover.assert_called_once_with(mock.sentinel.future) + assert bidi_rpc.is_active + + def test_done_callbacks_non_recoverable(self): + start_rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True) + should_recover = mock.Mock(spec=["__call__"], return_value=False) + bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover) + callback = mock.Mock(spec=["__call__"]) + + bidi_rpc.add_done_callback(callback) + bidi_rpc._on_call_done(mock.sentinel.future) + + callback.assert_called_once_with(mock.sentinel.future) + should_recover.assert_called_once_with(mock.sentinel.future) + assert not bidi_rpc.is_active + + def test_send_terminate(self): + cancellation = ValueError() + call_1 = CallStub([cancellation], active=False) + call_2 = CallStub([]) + start_rpc = mock.create_autospec( + grpc.StreamStreamMultiCallable, instance=True, side_effect=[call_1, call_2] + ) + should_recover = mock.Mock(spec=["__call__"], return_value=False) + should_terminate = mock.Mock(spec=["__call__"], return_value=True) + bidi_rpc = bidi.ResumableBidiRpc( + start_rpc, should_recover, should_terminate=should_terminate + ) + + bidi_rpc.open() + + bidi_rpc.send(mock.sentinel.request) + + assert bidi_rpc.pending_requests == 1 + assert bidi_rpc._request_queue.get() is None + + should_recover.assert_not_called() + should_terminate.assert_called_once_with(cancellation) + assert bidi_rpc.call == call_1 + assert bidi_rpc.is_active is False + assert call_1.cancelled is True + + def test_send_recover(self): + error = ValueError() + call_1 = CallStub([error], active=False) + call_2 = CallStub([]) + start_rpc = mock.create_autospec( + grpc.StreamStreamMultiCallable, instance=True, side_effect=[call_1, call_2] + ) + should_recover = mock.Mock(spec=["__call__"], return_value=True) + bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover) + + bidi_rpc.open() + + bidi_rpc.send(mock.sentinel.request) + + assert bidi_rpc.pending_requests == 1 + assert bidi_rpc._request_queue.get() is mock.sentinel.request + + should_recover.assert_called_once_with(error) + assert bidi_rpc.call == call_2 + assert bidi_rpc.is_active is True + + def test_send_failure(self): + error = ValueError() + call = CallStub([error], active=False) + start_rpc = mock.create_autospec( + grpc.StreamStreamMultiCallable, instance=True, return_value=call + ) + should_recover = mock.Mock(spec=["__call__"], return_value=False) + bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover) + + bidi_rpc.open() + + with pytest.raises(ValueError) as exc_info: + bidi_rpc.send(mock.sentinel.request) + + assert exc_info.value == error + should_recover.assert_called_once_with(error) + assert bidi_rpc.call == call + assert bidi_rpc.is_active is False + assert call.cancelled is True + assert bidi_rpc.pending_requests == 1 + assert bidi_rpc._request_queue.get() is None + + def test_recv_terminate(self): + cancellation = ValueError() + call = CallStub([cancellation]) + start_rpc = mock.create_autospec( + grpc.StreamStreamMultiCallable, instance=True, return_value=call + ) + should_recover = mock.Mock(spec=["__call__"], return_value=False) + should_terminate = mock.Mock(spec=["__call__"], return_value=True) + bidi_rpc = bidi.ResumableBidiRpc( + start_rpc, should_recover, should_terminate=should_terminate + ) + + bidi_rpc.open() + + bidi_rpc.recv() + + should_recover.assert_not_called() + should_terminate.assert_called_once_with(cancellation) + assert bidi_rpc.call == call + assert bidi_rpc.is_active is False + assert call.cancelled is True + + def test_recv_recover(self): + error = ValueError() + call_1 = CallStub([1, error]) + call_2 = CallStub([2, 3]) + start_rpc = mock.create_autospec( + grpc.StreamStreamMultiCallable, instance=True, side_effect=[call_1, call_2] + ) + should_recover = mock.Mock(spec=["__call__"], return_value=True) + bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover) + + bidi_rpc.open() + + values = [] + for n in range(3): + values.append(bidi_rpc.recv()) + + assert values == [1, 2, 3] + should_recover.assert_called_once_with(error) + assert bidi_rpc.call == call_2 + assert bidi_rpc.is_active is True + + def test_recv_recover_already_recovered(self): + call_1 = CallStub([]) + call_2 = CallStub([]) + start_rpc = mock.create_autospec( + grpc.StreamStreamMultiCallable, instance=True, side_effect=[call_1, call_2] + ) + callback = mock.Mock() + callback.return_value = True + bidi_rpc = bidi.ResumableBidiRpc(start_rpc, callback) + + bidi_rpc.open() + + bidi_rpc._reopen() + + assert bidi_rpc.call is call_1 + assert bidi_rpc.is_active is True + + def test_recv_failure(self): + error = ValueError() + call = CallStub([error]) + start_rpc = mock.create_autospec( + grpc.StreamStreamMultiCallable, instance=True, return_value=call + ) + should_recover = mock.Mock(spec=["__call__"], return_value=False) + bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover) + + bidi_rpc.open() + + with pytest.raises(ValueError) as exc_info: + bidi_rpc.recv() + + assert exc_info.value == error + should_recover.assert_called_once_with(error) + assert bidi_rpc.call == call + assert bidi_rpc.is_active is False + assert call.cancelled is True + + def test_close(self): + call = mock.create_autospec(_CallAndFuture, instance=True) + + def cancel_side_effect(): + call.is_active.return_value = False + + call.cancel.side_effect = cancel_side_effect + start_rpc = mock.create_autospec( + grpc.StreamStreamMultiCallable, instance=True, return_value=call + ) + should_recover = mock.Mock(spec=["__call__"], return_value=False) + bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover) + bidi_rpc.open() + + bidi_rpc.close() + + should_recover.assert_not_called() + call.cancel.assert_called_once() + assert bidi_rpc.call == call + assert bidi_rpc.is_active is False + # ensure the request queue was signaled to stop. + assert bidi_rpc.pending_requests == 1 + assert bidi_rpc._request_queue.get() is None + assert bidi_rpc._finalized + + def test_reopen_failure_on_rpc_restart(self): + error1 = ValueError("1") + error2 = ValueError("2") + call = CallStub([error1]) + # Invoking start RPC a second time will trigger an error. + start_rpc = mock.create_autospec( + grpc.StreamStreamMultiCallable, instance=True, side_effect=[call, error2] + ) + should_recover = mock.Mock(spec=["__call__"], return_value=True) + callback = mock.Mock(spec=["__call__"]) + + bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover) + bidi_rpc.add_done_callback(callback) + + bidi_rpc.open() + + with pytest.raises(ValueError) as exc_info: + bidi_rpc.recv() + + assert exc_info.value == error2 + should_recover.assert_called_once_with(error1) + assert bidi_rpc.call is None + assert bidi_rpc.is_active is False + callback.assert_called_once_with(error2) + + def test_using_throttle_on_reopen_requests(self): + call = CallStub([]) + start_rpc = mock.create_autospec( + grpc.StreamStreamMultiCallable, instance=True, return_value=call + ) + should_recover = mock.Mock(spec=["__call__"], return_value=True) + bidi_rpc = bidi.ResumableBidiRpc( + start_rpc, should_recover, throttle_reopen=True + ) + + patcher = mock.patch.object(bidi_rpc._reopen_throttle.__class__, "__enter__") + with patcher as mock_enter: + bidi_rpc._reopen() + + mock_enter.assert_called_once() + + def test_send_not_open(self): + bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: False) + + with pytest.raises(ValueError): + bidi_rpc.send(mock.sentinel.request) + + def test_recv_not_open(self): + bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: False) + + with pytest.raises(ValueError): + bidi_rpc.recv() + + def test_finalize_idempotent(self): + error1 = ValueError("1") + error2 = ValueError("2") + callback = mock.Mock(spec=["__call__"]) + should_recover = mock.Mock(spec=["__call__"], return_value=False) + + bidi_rpc = bidi.ResumableBidiRpc(mock.sentinel.start_rpc, should_recover) + + bidi_rpc.add_done_callback(callback) + + bidi_rpc._on_call_done(error1) + bidi_rpc._on_call_done(error2) + + callback.assert_called_once_with(error1) + + +class TestBackgroundConsumer(object): + def test_consume_once_then_exit(self): + bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True) + bidi_rpc.is_active = True + bidi_rpc.recv.side_effect = [mock.sentinel.response_1] + recved = threading.Event() + + def on_response(response): + assert response == mock.sentinel.response_1 + bidi_rpc.is_active = False + recved.set() + + consumer = bidi.BackgroundConsumer(bidi_rpc, on_response) + + consumer.start() + + recved.wait() + + bidi_rpc.recv.assert_called_once() + assert bidi_rpc.is_active is False + + consumer.stop() + + bidi_rpc.close.assert_called_once() + assert consumer.is_active is False + + def test_pause_resume_and_close(self): + # This test is relatively complex. It attempts to start the consumer, + # consume one item, pause the consumer, check the state of the world, + # then resume the consumer. Doing this in a deterministic fashion + # requires a bit more mocking and patching than usual. + + bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True) + bidi_rpc.is_active = True + + def close_side_effect(): + bidi_rpc.is_active = False + + bidi_rpc.close.side_effect = close_side_effect + + # These are used to coordinate the two threads to ensure deterministic + # execution. + should_continue = threading.Event() + responses_and_events = { + mock.sentinel.response_1: threading.Event(), + mock.sentinel.response_2: threading.Event(), + } + bidi_rpc.recv.side_effect = [mock.sentinel.response_1, mock.sentinel.response_2] + + recved_responses = [] + consumer = None + + def on_response(response): + if response == mock.sentinel.response_1: + consumer.pause() + + recved_responses.append(response) + responses_and_events[response].set() + should_continue.wait() + + consumer = bidi.BackgroundConsumer(bidi_rpc, on_response) + + consumer.start() + + # Wait for the first response to be recved. + responses_and_events[mock.sentinel.response_1].wait() + + # Ensure only one item has been recved and that the consumer is paused. + assert recved_responses == [mock.sentinel.response_1] + assert consumer.is_paused is True + assert consumer.is_active is True + + # Unpause the consumer, wait for the second item, then close the + # consumer. + should_continue.set() + consumer.resume() + + responses_and_events[mock.sentinel.response_2].wait() + + assert recved_responses == [mock.sentinel.response_1, mock.sentinel.response_2] + + consumer.stop() + + assert consumer.is_active is False + + def test_wake_on_error(self): + should_continue = threading.Event() + + bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True) + bidi_rpc.is_active = True + bidi_rpc.add_done_callback.side_effect = lambda _: should_continue.set() + + consumer = bidi.BackgroundConsumer(bidi_rpc, mock.sentinel.on_response) + + # Start the consumer paused, which should immediately put it into wait + # state. + consumer.pause() + consumer.start() + + # Wait for add_done_callback to be called + should_continue.wait() + bidi_rpc.add_done_callback.assert_called_once_with(consumer._on_call_done) + + # The consumer should now be blocked on waiting to be unpaused. + assert consumer.is_active + assert consumer.is_paused + + # Trigger the done callback, it should unpause the consumer and cause + # it to exit. + bidi_rpc.is_active = False + consumer._on_call_done(bidi_rpc) + + # It may take a few cycles for the thread to exit. + while consumer.is_active: + pass + + def test_consumer_expected_error(self, caplog): + caplog.set_level(logging.DEBUG) + + bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True) + bidi_rpc.is_active = True + bidi_rpc.recv.side_effect = exceptions.ServiceUnavailable("Gone away") + + on_response = mock.Mock(spec=["__call__"]) + + consumer = bidi.BackgroundConsumer(bidi_rpc, on_response) + + consumer.start() + + # Wait for the consumer's thread to exit. + while consumer.is_active: + pass + + on_response.assert_not_called() + bidi_rpc.recv.assert_called_once() + assert "caught error" in caplog.text + + def test_consumer_unexpected_error(self, caplog): + caplog.set_level(logging.DEBUG) + + bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True) + bidi_rpc.is_active = True + bidi_rpc.recv.side_effect = ValueError() + + on_response = mock.Mock(spec=["__call__"]) + + consumer = bidi.BackgroundConsumer(bidi_rpc, on_response) + + consumer.start() + + # Wait for the consumer's thread to exit. + while consumer.is_active: + pass # pragma: NO COVER (race condition) + + on_response.assert_not_called() + bidi_rpc.recv.assert_called_once() + assert "caught unexpected exception" in caplog.text + + def test_double_stop(self, caplog): + caplog.set_level(logging.DEBUG) + bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True) + bidi_rpc.is_active = True + on_response = mock.Mock(spec=["__call__"]) + + def close_side_effect(): + bidi_rpc.is_active = False + + bidi_rpc.close.side_effect = close_side_effect + + consumer = bidi.BackgroundConsumer(bidi_rpc, on_response) + + consumer.start() + assert consumer.is_active is True + + consumer.stop() + assert consumer.is_active is False + + # calling stop twice should not result in an error. + consumer.stop() diff --git a/tests/unit/test_client_info.py b/tests/unit/test_client_info.py new file mode 100644 index 0000000..f5eebfb --- /dev/null +++ b/tests/unit/test_client_info.py @@ -0,0 +1,98 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +try: + import grpc +except ImportError: + grpc = None + +from google.api_core import client_info + + +def test_constructor_defaults(): + info = client_info.ClientInfo() + + assert info.python_version is not None + + if grpc is not None: + assert info.grpc_version is not None + else: + assert info.grpc_version is None + + assert info.api_core_version is not None + assert info.gapic_version is None + assert info.client_library_version is None + assert info.rest_version is None + + +def test_constructor_options(): + info = client_info.ClientInfo( + python_version="1", + grpc_version="2", + api_core_version="3", + gapic_version="4", + client_library_version="5", + user_agent="6", + rest_version="7", + ) + + assert info.python_version == "1" + assert info.grpc_version == "2" + assert info.api_core_version == "3" + assert info.gapic_version == "4" + assert info.client_library_version == "5" + assert info.user_agent == "6" + assert info.rest_version == "7" + + +def test_to_user_agent_minimal(): + info = client_info.ClientInfo( + python_version="1", api_core_version="2", grpc_version=None + ) + + user_agent = info.to_user_agent() + + assert user_agent == "gl-python/1 gax/2" + + +def test_to_user_agent_full(): + info = client_info.ClientInfo( + python_version="1", + grpc_version="2", + api_core_version="3", + gapic_version="4", + client_library_version="5", + user_agent="app-name/1.0", + ) + + user_agent = info.to_user_agent() + + assert user_agent == "app-name/1.0 gl-python/1 grpc/2 gax/3 gapic/4 gccl/5" + + +def test_to_user_agent_rest(): + info = client_info.ClientInfo( + python_version="1", + grpc_version=None, + rest_version="2", + api_core_version="3", + gapic_version="4", + client_library_version="5", + user_agent="app-name/1.0", + ) + + user_agent = info.to_user_agent() + + assert user_agent == "app-name/1.0 gl-python/1 rest/2 gax/3 gapic/4 gccl/5" diff --git a/tests/unit/test_client_options.py b/tests/unit/test_client_options.py new file mode 100644 index 0000000..38b9ad0 --- /dev/null +++ b/tests/unit/test_client_options.py @@ -0,0 +1,117 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from google.api_core import client_options + + +def get_client_cert(): + return b"cert", b"key" + + +def get_client_encrypted_cert(): + return "cert_path", "key_path", b"passphrase" + + +def test_constructor(): + + options = client_options.ClientOptions( + api_endpoint="foo.googleapis.com", + client_cert_source=get_client_cert, + quota_project_id="quote-proj", + credentials_file="path/to/credentials.json", + scopes=[ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/cloud-platform.read-only", + ], + ) + + assert options.api_endpoint == "foo.googleapis.com" + assert options.client_cert_source() == (b"cert", b"key") + assert options.quota_project_id == "quote-proj" + assert options.credentials_file == "path/to/credentials.json" + assert options.scopes == [ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/cloud-platform.read-only", + ] + + +def test_constructor_with_encrypted_cert_source(): + + options = client_options.ClientOptions( + api_endpoint="foo.googleapis.com", + client_encrypted_cert_source=get_client_encrypted_cert, + ) + + assert options.api_endpoint == "foo.googleapis.com" + assert options.client_encrypted_cert_source() == ( + "cert_path", + "key_path", + b"passphrase", + ) + + +def test_constructor_with_both_cert_sources(): + with pytest.raises(ValueError): + client_options.ClientOptions( + api_endpoint="foo.googleapis.com", + client_cert_source=get_client_cert, + client_encrypted_cert_source=get_client_encrypted_cert, + ) + + +def test_from_dict(): + options = client_options.from_dict( + { + "api_endpoint": "foo.googleapis.com", + "client_cert_source": get_client_cert, + "quota_project_id": "quote-proj", + "credentials_file": "path/to/credentials.json", + "scopes": [ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/cloud-platform.read-only", + ], + } + ) + + assert options.api_endpoint == "foo.googleapis.com" + assert options.client_cert_source() == (b"cert", b"key") + assert options.quota_project_id == "quote-proj" + assert options.credentials_file == "path/to/credentials.json" + assert options.scopes == [ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/cloud-platform.read-only", + ] + + +def test_from_dict_bad_argument(): + with pytest.raises(ValueError): + client_options.from_dict( + { + "api_endpoint": "foo.googleapis.com", + "bad_arg": "1234", + "client_cert_source": get_client_cert, + } + ) + + +def test_repr(): + options = client_options.ClientOptions(api_endpoint="foo.googleapis.com") + + assert ( + repr(options) + == "ClientOptions: {'api_endpoint': 'foo.googleapis.com', 'client_cert_source': None, 'client_encrypted_cert_source': None}" + or "ClientOptions: {'client_encrypted_cert_source': None, 'client_cert_source': None, 'api_endpoint': 'foo.googleapis.com'}" + ) diff --git a/tests/unit/test_datetime_helpers.py b/tests/unit/test_datetime_helpers.py new file mode 100644 index 0000000..5f5470a --- /dev/null +++ b/tests/unit/test_datetime_helpers.py @@ -0,0 +1,396 @@ +# Copyright 2017, Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import calendar +import datetime + +import pytest + +from google.api_core import datetime_helpers +from google.protobuf import timestamp_pb2 + + +ONE_MINUTE_IN_MICROSECONDS = 60 * 1e6 + + +def test_utcnow(): + result = datetime_helpers.utcnow() + assert isinstance(result, datetime.datetime) + + +def test_to_milliseconds(): + dt = datetime.datetime(1970, 1, 1, 0, 0, 1, tzinfo=datetime.timezone.utc) + assert datetime_helpers.to_milliseconds(dt) == 1000 + + +def test_to_microseconds(): + microseconds = 314159 + dt = datetime.datetime(1970, 1, 1, 0, 0, 0, microsecond=microseconds) + assert datetime_helpers.to_microseconds(dt) == microseconds + + +def test_to_microseconds_non_utc(): + zone = datetime.timezone(datetime.timedelta(minutes=-1)) + dt = datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=zone) + assert datetime_helpers.to_microseconds(dt) == ONE_MINUTE_IN_MICROSECONDS + + +def test_to_microseconds_naive(): + microseconds = 314159 + dt = datetime.datetime(1970, 1, 1, 0, 0, 0, microsecond=microseconds, tzinfo=None) + assert datetime_helpers.to_microseconds(dt) == microseconds + + +def test_from_microseconds(): + five_mins_from_epoch_in_microseconds = 5 * ONE_MINUTE_IN_MICROSECONDS + five_mins_from_epoch_datetime = datetime.datetime( + 1970, 1, 1, 0, 5, 0, tzinfo=datetime.timezone.utc + ) + + result = datetime_helpers.from_microseconds(five_mins_from_epoch_in_microseconds) + + assert result == five_mins_from_epoch_datetime + + +def test_from_iso8601_date(): + today = datetime.date.today() + iso_8601_today = today.strftime("%Y-%m-%d") + + assert datetime_helpers.from_iso8601_date(iso_8601_today) == today + + +def test_from_iso8601_time(): + assert datetime_helpers.from_iso8601_time("12:09:42") == datetime.time(12, 9, 42) + + +def test_from_rfc3339(): + value = "2009-12-17T12:44:32.123456Z" + assert datetime_helpers.from_rfc3339(value) == datetime.datetime( + 2009, 12, 17, 12, 44, 32, 123456, datetime.timezone.utc + ) + + +def test_from_rfc3339_nanos(): + value = "2009-12-17T12:44:32.123456Z" + assert datetime_helpers.from_rfc3339_nanos(value) == datetime.datetime( + 2009, 12, 17, 12, 44, 32, 123456, datetime.timezone.utc + ) + + +def test_from_rfc3339_without_nanos(): + value = "2009-12-17T12:44:32Z" + assert datetime_helpers.from_rfc3339(value) == datetime.datetime( + 2009, 12, 17, 12, 44, 32, 0, datetime.timezone.utc + ) + + +def test_from_rfc3339_nanos_without_nanos(): + value = "2009-12-17T12:44:32Z" + assert datetime_helpers.from_rfc3339_nanos(value) == datetime.datetime( + 2009, 12, 17, 12, 44, 32, 0, datetime.timezone.utc + ) + + +@pytest.mark.parametrize( + "truncated, micros", + [ + ("12345678", 123456), + ("1234567", 123456), + ("123456", 123456), + ("12345", 123450), + ("1234", 123400), + ("123", 123000), + ("12", 120000), + ("1", 100000), + ], +) +def test_from_rfc3339_with_truncated_nanos(truncated, micros): + value = "2009-12-17T12:44:32.{}Z".format(truncated) + assert datetime_helpers.from_rfc3339(value) == datetime.datetime( + 2009, 12, 17, 12, 44, 32, micros, datetime.timezone.utc + ) + + +def test_from_rfc3339_nanos_is_deprecated(): + value = "2009-12-17T12:44:32.123456Z" + + result = datetime_helpers.from_rfc3339(value) + result_nanos = datetime_helpers.from_rfc3339_nanos(value) + + assert result == result_nanos + + +@pytest.mark.parametrize( + "truncated, micros", + [ + ("12345678", 123456), + ("1234567", 123456), + ("123456", 123456), + ("12345", 123450), + ("1234", 123400), + ("123", 123000), + ("12", 120000), + ("1", 100000), + ], +) +def test_from_rfc3339_nanos_with_truncated_nanos(truncated, micros): + value = "2009-12-17T12:44:32.{}Z".format(truncated) + assert datetime_helpers.from_rfc3339_nanos(value) == datetime.datetime( + 2009, 12, 17, 12, 44, 32, micros, datetime.timezone.utc + ) + + +def test_from_rfc3339_wo_nanos_raise_exception(): + value = "2009-12-17T12:44:32" + with pytest.raises(ValueError): + datetime_helpers.from_rfc3339(value) + + +def test_from_rfc3339_w_nanos_raise_exception(): + value = "2009-12-17T12:44:32.123456" + with pytest.raises(ValueError): + datetime_helpers.from_rfc3339(value) + + +def test_to_rfc3339(): + value = datetime.datetime(2016, 4, 5, 13, 30, 0) + expected = "2016-04-05T13:30:00.000000Z" + assert datetime_helpers.to_rfc3339(value) == expected + + +def test_to_rfc3339_with_utc(): + value = datetime.datetime(2016, 4, 5, 13, 30, 0, tzinfo=datetime.timezone.utc) + expected = "2016-04-05T13:30:00.000000Z" + assert datetime_helpers.to_rfc3339(value, ignore_zone=False) == expected + + +def test_to_rfc3339_with_non_utc(): + zone = datetime.timezone(datetime.timedelta(minutes=-60)) + value = datetime.datetime(2016, 4, 5, 13, 30, 0, tzinfo=zone) + expected = "2016-04-05T14:30:00.000000Z" + assert datetime_helpers.to_rfc3339(value, ignore_zone=False) == expected + + +def test_to_rfc3339_with_non_utc_ignore_zone(): + zone = datetime.timezone(datetime.timedelta(minutes=-60)) + value = datetime.datetime(2016, 4, 5, 13, 30, 0, tzinfo=zone) + expected = "2016-04-05T13:30:00.000000Z" + assert datetime_helpers.to_rfc3339(value, ignore_zone=True) == expected + + +class Test_DateTimeWithNanos(object): + @staticmethod + def test_ctor_wo_nanos(): + stamp = datetime_helpers.DatetimeWithNanoseconds( + 2016, 12, 20, 21, 13, 47, 123456 + ) + assert stamp.year == 2016 + assert stamp.month == 12 + assert stamp.day == 20 + assert stamp.hour == 21 + assert stamp.minute == 13 + assert stamp.second == 47 + assert stamp.microsecond == 123456 + assert stamp.nanosecond == 0 + + @staticmethod + def test_ctor_w_nanos(): + stamp = datetime_helpers.DatetimeWithNanoseconds( + 2016, 12, 20, 21, 13, 47, nanosecond=123456789 + ) + assert stamp.year == 2016 + assert stamp.month == 12 + assert stamp.day == 20 + assert stamp.hour == 21 + assert stamp.minute == 13 + assert stamp.second == 47 + assert stamp.microsecond == 123456 + assert stamp.nanosecond == 123456789 + + @staticmethod + def test_ctor_w_micros_positional_and_nanos(): + with pytest.raises(TypeError): + datetime_helpers.DatetimeWithNanoseconds( + 2016, 12, 20, 21, 13, 47, 123456, nanosecond=123456789 + ) + + @staticmethod + def test_ctor_w_micros_keyword_and_nanos(): + with pytest.raises(TypeError): + datetime_helpers.DatetimeWithNanoseconds( + 2016, 12, 20, 21, 13, 47, microsecond=123456, nanosecond=123456789 + ) + + @staticmethod + def test_rfc3339_wo_nanos(): + stamp = datetime_helpers.DatetimeWithNanoseconds( + 2016, 12, 20, 21, 13, 47, 123456 + ) + assert stamp.rfc3339() == "2016-12-20T21:13:47.123456Z" + + @staticmethod + def test_rfc3339_wo_nanos_w_leading_zero(): + stamp = datetime_helpers.DatetimeWithNanoseconds(2016, 12, 20, 21, 13, 47, 1234) + assert stamp.rfc3339() == "2016-12-20T21:13:47.001234Z" + + @staticmethod + def test_rfc3339_w_nanos(): + stamp = datetime_helpers.DatetimeWithNanoseconds( + 2016, 12, 20, 21, 13, 47, nanosecond=123456789 + ) + assert stamp.rfc3339() == "2016-12-20T21:13:47.123456789Z" + + @staticmethod + def test_rfc3339_w_nanos_w_leading_zero(): + stamp = datetime_helpers.DatetimeWithNanoseconds( + 2016, 12, 20, 21, 13, 47, nanosecond=1234567 + ) + assert stamp.rfc3339() == "2016-12-20T21:13:47.001234567Z" + + @staticmethod + def test_rfc3339_w_nanos_no_trailing_zeroes(): + stamp = datetime_helpers.DatetimeWithNanoseconds( + 2016, 12, 20, 21, 13, 47, nanosecond=100000000 + ) + assert stamp.rfc3339() == "2016-12-20T21:13:47.1Z" + + @staticmethod + def test_rfc3339_w_nanos_w_leading_zero_and_no_trailing_zeros(): + stamp = datetime_helpers.DatetimeWithNanoseconds( + 2016, 12, 20, 21, 13, 47, nanosecond=1234500 + ) + assert stamp.rfc3339() == "2016-12-20T21:13:47.0012345Z" + + @staticmethod + def test_from_rfc3339_w_invalid(): + stamp = "2016-12-20T21:13:47" + with pytest.raises(ValueError): + datetime_helpers.DatetimeWithNanoseconds.from_rfc3339(stamp) + + @staticmethod + def test_from_rfc3339_wo_fraction(): + timestamp = "2016-12-20T21:13:47Z" + expected = datetime_helpers.DatetimeWithNanoseconds( + 2016, 12, 20, 21, 13, 47, tzinfo=datetime.timezone.utc + ) + stamp = datetime_helpers.DatetimeWithNanoseconds.from_rfc3339(timestamp) + assert stamp == expected + + @staticmethod + def test_from_rfc3339_w_partial_precision(): + timestamp = "2016-12-20T21:13:47.1Z" + expected = datetime_helpers.DatetimeWithNanoseconds( + 2016, 12, 20, 21, 13, 47, microsecond=100000, tzinfo=datetime.timezone.utc + ) + stamp = datetime_helpers.DatetimeWithNanoseconds.from_rfc3339(timestamp) + assert stamp == expected + + @staticmethod + def test_from_rfc3339_w_full_precision(): + timestamp = "2016-12-20T21:13:47.123456789Z" + expected = datetime_helpers.DatetimeWithNanoseconds( + 2016, 12, 20, 21, 13, 47, nanosecond=123456789, tzinfo=datetime.timezone.utc + ) + stamp = datetime_helpers.DatetimeWithNanoseconds.from_rfc3339(timestamp) + assert stamp == expected + + @staticmethod + @pytest.mark.parametrize( + "fractional, nanos", + [ + ("12345678", 123456780), + ("1234567", 123456700), + ("123456", 123456000), + ("12345", 123450000), + ("1234", 123400000), + ("123", 123000000), + ("12", 120000000), + ("1", 100000000), + ], + ) + def test_from_rfc3339_test_nanoseconds(fractional, nanos): + value = "2009-12-17T12:44:32.{}Z".format(fractional) + assert ( + datetime_helpers.DatetimeWithNanoseconds.from_rfc3339(value).nanosecond + == nanos + ) + + @staticmethod + def test_timestamp_pb_wo_nanos_naive(): + stamp = datetime_helpers.DatetimeWithNanoseconds( + 2016, 12, 20, 21, 13, 47, 123456 + ) + delta = ( + stamp.replace(tzinfo=datetime.timezone.utc) - datetime_helpers._UTC_EPOCH + ) + seconds = int(delta.total_seconds()) + nanos = 123456000 + timestamp = timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) + assert stamp.timestamp_pb() == timestamp + + @staticmethod + def test_timestamp_pb_w_nanos(): + stamp = datetime_helpers.DatetimeWithNanoseconds( + 2016, 12, 20, 21, 13, 47, nanosecond=123456789, tzinfo=datetime.timezone.utc + ) + delta = stamp - datetime_helpers._UTC_EPOCH + timestamp = timestamp_pb2.Timestamp( + seconds=int(delta.total_seconds()), nanos=123456789 + ) + assert stamp.timestamp_pb() == timestamp + + @staticmethod + def test_from_timestamp_pb_wo_nanos(): + when = datetime.datetime( + 2016, 12, 20, 21, 13, 47, 123456, tzinfo=datetime.timezone.utc + ) + delta = when - datetime_helpers._UTC_EPOCH + seconds = int(delta.total_seconds()) + timestamp = timestamp_pb2.Timestamp(seconds=seconds) + + stamp = datetime_helpers.DatetimeWithNanoseconds.from_timestamp_pb(timestamp) + + assert _to_seconds(when) == _to_seconds(stamp) + assert stamp.microsecond == 0 + assert stamp.nanosecond == 0 + assert stamp.tzinfo == datetime.timezone.utc + + @staticmethod + def test_from_timestamp_pb_w_nanos(): + when = datetime.datetime( + 2016, 12, 20, 21, 13, 47, 123456, tzinfo=datetime.timezone.utc + ) + delta = when - datetime_helpers._UTC_EPOCH + seconds = int(delta.total_seconds()) + timestamp = timestamp_pb2.Timestamp(seconds=seconds, nanos=123456789) + + stamp = datetime_helpers.DatetimeWithNanoseconds.from_timestamp_pb(timestamp) + + assert _to_seconds(when) == _to_seconds(stamp) + assert stamp.microsecond == 123456 + assert stamp.nanosecond == 123456789 + assert stamp.tzinfo == datetime.timezone.utc + + +def _to_seconds(value): + """Convert a datetime to seconds since the unix epoch. + + Args: + value (datetime.datetime): The datetime to covert. + + Returns: + int: Microseconds since the unix epoch. + """ + assert value.tzinfo is datetime.timezone.utc + return calendar.timegm(value.timetuple()) diff --git a/tests/unit/test_exceptions.py b/tests/unit/test_exceptions.py new file mode 100644 index 0000000..622f58a --- /dev/null +++ b/tests/unit/test_exceptions.py @@ -0,0 +1,353 @@ +# Copyright 2014 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import http.client +import json + +import mock +import pytest +import requests + +try: + import grpc + from grpc_status import rpc_status +except ImportError: + grpc = rpc_status = None + +from google.api_core import exceptions +from google.protobuf import any_pb2, json_format +from google.rpc import error_details_pb2, status_pb2 + + +def test_create_google_cloud_error(): + exception = exceptions.GoogleAPICallError("Testing") + exception.code = 600 + assert str(exception) == "600 Testing" + assert exception.message == "Testing" + assert exception.errors == [] + assert exception.response is None + + +def test_create_google_cloud_error_with_args(): + error = { + "code": 600, + "message": "Testing", + } + response = mock.sentinel.response + exception = exceptions.GoogleAPICallError("Testing", [error], response=response) + exception.code = 600 + assert str(exception) == "600 Testing" + assert exception.message == "Testing" + assert exception.errors == [error] + assert exception.response == response + + +def test_from_http_status(): + message = "message" + exception = exceptions.from_http_status(http.client.NOT_FOUND, message) + assert exception.code == http.client.NOT_FOUND + assert exception.message == message + assert exception.errors == [] + + +def test_from_http_status_with_errors_and_response(): + message = "message" + errors = ["1", "2"] + response = mock.sentinel.response + exception = exceptions.from_http_status( + http.client.NOT_FOUND, message, errors=errors, response=response + ) + + assert isinstance(exception, exceptions.NotFound) + assert exception.code == http.client.NOT_FOUND + assert exception.message == message + assert exception.errors == errors + assert exception.response == response + + +def test_from_http_status_unknown_code(): + message = "message" + status_code = 156 + exception = exceptions.from_http_status(status_code, message) + assert exception.code == status_code + assert exception.message == message + + +def make_response(content): + response = requests.Response() + response._content = content + response.status_code = http.client.NOT_FOUND + response.request = requests.Request( + method="POST", url="https://example.com" + ).prepare() + return response + + +def test_from_http_response_no_content(): + response = make_response(None) + + exception = exceptions.from_http_response(response) + + assert isinstance(exception, exceptions.NotFound) + assert exception.code == http.client.NOT_FOUND + assert exception.message == "POST https://example.com/: unknown error" + assert exception.response == response + + +def test_from_http_response_text_content(): + response = make_response(b"message") + response.encoding = "UTF8" # suppress charset_normalizer warning + + exception = exceptions.from_http_response(response) + + assert isinstance(exception, exceptions.NotFound) + assert exception.code == http.client.NOT_FOUND + assert exception.message == "POST https://example.com/: message" + + +def test_from_http_response_json_content(): + response = make_response( + json.dumps({"error": {"message": "json message", "errors": ["1", "2"]}}).encode( + "utf-8" + ) + ) + + exception = exceptions.from_http_response(response) + + assert isinstance(exception, exceptions.NotFound) + assert exception.code == http.client.NOT_FOUND + assert exception.message == "POST https://example.com/: json message" + assert exception.errors == ["1", "2"] + + +def test_from_http_response_bad_json_content(): + response = make_response(json.dumps({"meep": "moop"}).encode("utf-8")) + + exception = exceptions.from_http_response(response) + + assert isinstance(exception, exceptions.NotFound) + assert exception.code == http.client.NOT_FOUND + assert exception.message == "POST https://example.com/: unknown error" + + +def test_from_http_response_json_unicode_content(): + response = make_response( + json.dumps( + {"error": {"message": "\u2019 message", "errors": ["1", "2"]}} + ).encode("utf-8") + ) + + exception = exceptions.from_http_response(response) + + assert isinstance(exception, exceptions.NotFound) + assert exception.code == http.client.NOT_FOUND + assert exception.message == "POST https://example.com/: \u2019 message" + assert exception.errors == ["1", "2"] + + +@pytest.mark.skipif(grpc is None, reason="No grpc") +def test_from_grpc_status(): + message = "message" + exception = exceptions.from_grpc_status(grpc.StatusCode.OUT_OF_RANGE, message) + assert isinstance(exception, exceptions.BadRequest) + assert isinstance(exception, exceptions.OutOfRange) + assert exception.code == http.client.BAD_REQUEST + assert exception.grpc_status_code == grpc.StatusCode.OUT_OF_RANGE + assert exception.message == message + assert exception.errors == [] + + +@pytest.mark.skipif(grpc is None, reason="No grpc") +def test_from_grpc_status_as_int(): + message = "message" + exception = exceptions.from_grpc_status(11, message) + assert isinstance(exception, exceptions.BadRequest) + assert isinstance(exception, exceptions.OutOfRange) + assert exception.code == http.client.BAD_REQUEST + assert exception.grpc_status_code == grpc.StatusCode.OUT_OF_RANGE + assert exception.message == message + assert exception.errors == [] + + +@pytest.mark.skipif(grpc is None, reason="No grpc") +def test_from_grpc_status_with_errors_and_response(): + message = "message" + response = mock.sentinel.response + errors = ["1", "2"] + exception = exceptions.from_grpc_status( + grpc.StatusCode.OUT_OF_RANGE, message, errors=errors, response=response + ) + + assert isinstance(exception, exceptions.OutOfRange) + assert exception.message == message + assert exception.errors == errors + assert exception.response == response + + +@pytest.mark.skipif(grpc is None, reason="No grpc") +def test_from_grpc_status_unknown_code(): + message = "message" + exception = exceptions.from_grpc_status(grpc.StatusCode.OK, message) + assert exception.grpc_status_code == grpc.StatusCode.OK + assert exception.message == message + + +@pytest.mark.skipif(grpc is None, reason="No grpc") +def test_from_grpc_error(): + message = "message" + error = mock.create_autospec(grpc.Call, instance=True) + error.code.return_value = grpc.StatusCode.INVALID_ARGUMENT + error.details.return_value = message + + exception = exceptions.from_grpc_error(error) + + assert isinstance(exception, exceptions.BadRequest) + assert isinstance(exception, exceptions.InvalidArgument) + assert exception.code == http.client.BAD_REQUEST + assert exception.grpc_status_code == grpc.StatusCode.INVALID_ARGUMENT + assert exception.message == message + assert exception.errors == [error] + assert exception.response == error + + +@pytest.mark.skipif(grpc is None, reason="No grpc") +def test_from_grpc_error_non_call(): + message = "message" + error = mock.create_autospec(grpc.RpcError, instance=True) + error.__str__.return_value = message + + exception = exceptions.from_grpc_error(error) + + assert isinstance(exception, exceptions.GoogleAPICallError) + assert exception.code is None + assert exception.grpc_status_code is None + assert exception.message == message + assert exception.errors == [error] + assert exception.response == error + + +@pytest.mark.skipif(grpc is None, reason="No grpc") +def test_from_grpc_error_bare_call(): + message = "Testing" + + class TestingError(grpc.Call, grpc.RpcError): + def __init__(self, exception): + self.exception = exception + + def code(self): + return self.exception.grpc_status_code + + def details(self): + return message + + nested_message = "message" + error = TestingError(exceptions.GoogleAPICallError(nested_message)) + + exception = exceptions.from_grpc_error(error) + + assert isinstance(exception, exceptions.GoogleAPICallError) + assert exception.code is None + assert exception.grpc_status_code is None + assert exception.message == message + assert exception.errors == [error] + assert exception.response == error + assert exception.details == [] + + +def create_bad_request_details(): + bad_request_details = error_details_pb2.BadRequest() + field_violation = bad_request_details.field_violations.add() + field_violation.field = "document.content" + field_violation.description = "Must have some text content to annotate." + status_detail = any_pb2.Any() + status_detail.Pack(bad_request_details) + return status_detail + + +def test_error_details_from_rest_response(): + bad_request_detail = create_bad_request_details() + status = status_pb2.Status() + status.code = 3 + status.message = ( + "3 INVALID_ARGUMENT: One of content, or gcs_content_uri must be set." + ) + status.details.append(bad_request_detail) + + # See JSON schema in https://cloud.google.com/apis/design/errors#http_mapping + http_response = make_response( + json.dumps({"error": json.loads(json_format.MessageToJson(status))}).encode( + "utf-8" + ) + ) + exception = exceptions.from_http_response(http_response) + want_error_details = [json.loads(json_format.MessageToJson(bad_request_detail))] + assert want_error_details == exception.details + # 404 POST comes from make_response. + assert str(exception) == ( + "404 POST https://example.com/: 3 INVALID_ARGUMENT:" + " One of content, or gcs_content_uri must be set." + " [{'@type': 'type.googleapis.com/google.rpc.BadRequest'," + " 'fieldViolations': [{'field': 'document.content'," + " 'description': 'Must have some text content to annotate.'}]}]" + ) + + +def test_error_details_from_v1_rest_response(): + response = make_response( + json.dumps( + {"error": {"message": "\u2019 message", "errors": ["1", "2"]}} + ).encode("utf-8") + ) + exception = exceptions.from_http_response(response) + assert exception.details == [] + + +@pytest.mark.skipif(grpc is None, reason="gRPC not importable") +def test_error_details_from_grpc_response(): + status = rpc_status.status_pb2.Status() + status.code = 3 + status.message = ( + "3 INVALID_ARGUMENT: One of content, or gcs_content_uri must be set." + ) + status_detail = create_bad_request_details() + status.details.append(status_detail) + + # Actualy error doesn't matter as long as its grpc.Call, + # because from_call is mocked. + error = mock.create_autospec(grpc.Call, instance=True) + with mock.patch("grpc_status.rpc_status.from_call") as m: + m.return_value = status + exception = exceptions.from_grpc_error(error) + + bad_request_detail = error_details_pb2.BadRequest() + status_detail.Unpack(bad_request_detail) + assert exception.details == [bad_request_detail] + + +@pytest.mark.skipif(grpc is None, reason="gRPC not importable") +def test_error_details_from_grpc_response_unknown_error(): + status_detail = any_pb2.Any() + + status = rpc_status.status_pb2.Status() + status.code = 3 + status.message = ( + "3 INVALID_ARGUMENT: One of content, or gcs_content_uri must be set." + ) + status.details.append(status_detail) + + error = mock.create_autospec(grpc.Call, instance=True) + with mock.patch("grpc_status.rpc_status.from_call") as m: + m.return_value = status + exception = exceptions.from_grpc_error(error) + assert exception.details == [status_detail] diff --git a/tests/unit/test_grpc_helpers.py b/tests/unit/test_grpc_helpers.py new file mode 100644 index 0000000..ca969e4 --- /dev/null +++ b/tests/unit/test_grpc_helpers.py @@ -0,0 +1,860 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +import pytest + +try: + import grpc +except ImportError: + pytest.skip("No GRPC", allow_module_level=True) + +from google.api_core import exceptions +from google.api_core import grpc_helpers +import google.auth.credentials +from google.longrunning import operations_pb2 + + +def test__patch_callable_name(): + callable = mock.Mock(spec=["__class__"]) + callable.__class__ = mock.Mock(spec=["__name__"]) + callable.__class__.__name__ = "TestCallable" + + grpc_helpers._patch_callable_name(callable) + + assert callable.__name__ == "TestCallable" + + +def test__patch_callable_name_no_op(): + callable = mock.Mock(spec=["__name__"]) + callable.__name__ = "test_callable" + + grpc_helpers._patch_callable_name(callable) + + assert callable.__name__ == "test_callable" + + +class RpcErrorImpl(grpc.RpcError, grpc.Call): + def __init__(self, code): + super(RpcErrorImpl, self).__init__() + self._code = code + + def code(self): + return self._code + + def details(self): + return None + + def trailing_metadata(self): + return None + + +def test_wrap_unary_errors(): + grpc_error = RpcErrorImpl(grpc.StatusCode.INVALID_ARGUMENT) + callable_ = mock.Mock(spec=["__call__"], side_effect=grpc_error) + + wrapped_callable = grpc_helpers._wrap_unary_errors(callable_) + + with pytest.raises(exceptions.InvalidArgument) as exc_info: + wrapped_callable(1, 2, three="four") + + callable_.assert_called_once_with(1, 2, three="four") + assert exc_info.value.response == grpc_error + + +class Test_StreamingResponseIterator: + @staticmethod + def _make_wrapped(*items): + return iter(items) + + @staticmethod + def _make_one(wrapped, **kw): + return grpc_helpers._StreamingResponseIterator(wrapped, **kw) + + def test_ctor_defaults(self): + wrapped = self._make_wrapped("a", "b", "c") + iterator = self._make_one(wrapped) + assert iterator._stored_first_result == "a" + assert list(wrapped) == ["b", "c"] + + def test_ctor_explicit(self): + wrapped = self._make_wrapped("a", "b", "c") + iterator = self._make_one(wrapped, prefetch_first_result=False) + assert getattr(iterator, "_stored_first_result", self) is self + assert list(wrapped) == ["a", "b", "c"] + + def test_ctor_w_rpc_error_on_prefetch(self): + wrapped = mock.MagicMock() + wrapped.__next__.side_effect = grpc.RpcError() + + with pytest.raises(grpc.RpcError): + self._make_one(wrapped) + + def test___iter__(self): + wrapped = self._make_wrapped("a", "b", "c") + iterator = self._make_one(wrapped) + assert iter(iterator) is iterator + + def test___next___w_cached_first_result(self): + wrapped = self._make_wrapped("a", "b", "c") + iterator = self._make_one(wrapped) + assert next(iterator) == "a" + iterator = self._make_one(wrapped, prefetch_first_result=False) + assert next(iterator) == "b" + assert next(iterator) == "c" + + def test___next___wo_cached_first_result(self): + wrapped = self._make_wrapped("a", "b", "c") + iterator = self._make_one(wrapped, prefetch_first_result=False) + assert next(iterator) == "a" + assert next(iterator) == "b" + assert next(iterator) == "c" + + def test___next___w_rpc_error(self): + wrapped = mock.MagicMock() + wrapped.__next__.side_effect = grpc.RpcError() + iterator = self._make_one(wrapped, prefetch_first_result=False) + + with pytest.raises(exceptions.GoogleAPICallError): + next(iterator) + + def test_add_callback(self): + wrapped = mock.MagicMock() + callback = mock.Mock(spec={}) + iterator = self._make_one(wrapped, prefetch_first_result=False) + + assert iterator.add_callback(callback) is wrapped.add_callback.return_value + + wrapped.add_callback.assert_called_once_with(callback) + + def test_cancel(self): + wrapped = mock.MagicMock() + iterator = self._make_one(wrapped, prefetch_first_result=False) + + assert iterator.cancel() is wrapped.cancel.return_value + + wrapped.cancel.assert_called_once_with() + + def test_code(self): + wrapped = mock.MagicMock() + iterator = self._make_one(wrapped, prefetch_first_result=False) + + assert iterator.code() is wrapped.code.return_value + + wrapped.code.assert_called_once_with() + + def test_details(self): + wrapped = mock.MagicMock() + iterator = self._make_one(wrapped, prefetch_first_result=False) + + assert iterator.details() is wrapped.details.return_value + + wrapped.details.assert_called_once_with() + + def test_initial_metadata(self): + wrapped = mock.MagicMock() + iterator = self._make_one(wrapped, prefetch_first_result=False) + + assert iterator.initial_metadata() is wrapped.initial_metadata.return_value + + wrapped.initial_metadata.assert_called_once_with() + + def test_is_active(self): + wrapped = mock.MagicMock() + iterator = self._make_one(wrapped, prefetch_first_result=False) + + assert iterator.is_active() is wrapped.is_active.return_value + + wrapped.is_active.assert_called_once_with() + + def test_time_remaining(self): + wrapped = mock.MagicMock() + iterator = self._make_one(wrapped, prefetch_first_result=False) + + assert iterator.time_remaining() is wrapped.time_remaining.return_value + + wrapped.time_remaining.assert_called_once_with() + + def test_trailing_metadata(self): + wrapped = mock.MagicMock() + iterator = self._make_one(wrapped, prefetch_first_result=False) + + assert iterator.trailing_metadata() is wrapped.trailing_metadata.return_value + + wrapped.trailing_metadata.assert_called_once_with() + + +def test_wrap_stream_okay(): + expected_responses = [1, 2, 3] + callable_ = mock.Mock(spec=["__call__"], return_value=iter(expected_responses)) + + wrapped_callable = grpc_helpers._wrap_stream_errors(callable_) + + got_iterator = wrapped_callable(1, 2, three="four") + + responses = list(got_iterator) + + callable_.assert_called_once_with(1, 2, three="four") + assert responses == expected_responses + + +def test_wrap_stream_prefetch_disabled(): + responses = [1, 2, 3] + iter_responses = iter(responses) + callable_ = mock.Mock(spec=["__call__"], return_value=iter_responses) + callable_._prefetch_first_result_ = False + + wrapped_callable = grpc_helpers._wrap_stream_errors(callable_) + wrapped_callable(1, 2, three="four") + + assert list(iter_responses) == responses # no items should have been pre-fetched + callable_.assert_called_once_with(1, 2, three="four") + + +def test_wrap_stream_iterable_iterface(): + response_iter = mock.create_autospec(grpc.Call, instance=True) + callable_ = mock.Mock(spec=["__call__"], return_value=response_iter) + + wrapped_callable = grpc_helpers._wrap_stream_errors(callable_) + + got_iterator = wrapped_callable() + + callable_.assert_called_once_with() + + # Check each aliased method in the grpc.Call interface + got_iterator.add_callback(mock.sentinel.callback) + response_iter.add_callback.assert_called_once_with(mock.sentinel.callback) + + got_iterator.cancel() + response_iter.cancel.assert_called_once_with() + + got_iterator.code() + response_iter.code.assert_called_once_with() + + got_iterator.details() + response_iter.details.assert_called_once_with() + + got_iterator.initial_metadata() + response_iter.initial_metadata.assert_called_once_with() + + got_iterator.is_active() + response_iter.is_active.assert_called_once_with() + + got_iterator.time_remaining() + response_iter.time_remaining.assert_called_once_with() + + got_iterator.trailing_metadata() + response_iter.trailing_metadata.assert_called_once_with() + + +def test_wrap_stream_errors_invocation(): + grpc_error = RpcErrorImpl(grpc.StatusCode.INVALID_ARGUMENT) + callable_ = mock.Mock(spec=["__call__"], side_effect=grpc_error) + + wrapped_callable = grpc_helpers._wrap_stream_errors(callable_) + + with pytest.raises(exceptions.InvalidArgument) as exc_info: + wrapped_callable(1, 2, three="four") + + callable_.assert_called_once_with(1, 2, three="four") + assert exc_info.value.response == grpc_error + + +def test_wrap_stream_empty_iterator(): + expected_responses = [] + callable_ = mock.Mock(spec=["__call__"], return_value=iter(expected_responses)) + + wrapped_callable = grpc_helpers._wrap_stream_errors(callable_) + + got_iterator = wrapped_callable() + + responses = list(got_iterator) + + callable_.assert_called_once_with() + assert responses == expected_responses + + +class RpcResponseIteratorImpl(object): + def __init__(self, iterable): + self._iterable = iter(iterable) + + def next(self): + next_item = next(self._iterable) + if isinstance(next_item, RpcErrorImpl): + raise next_item + return next_item + + __next__ = next + + +def test_wrap_stream_errors_iterator_initialization(): + grpc_error = RpcErrorImpl(grpc.StatusCode.UNAVAILABLE) + response_iter = RpcResponseIteratorImpl([grpc_error]) + callable_ = mock.Mock(spec=["__call__"], return_value=response_iter) + + wrapped_callable = grpc_helpers._wrap_stream_errors(callable_) + + with pytest.raises(exceptions.ServiceUnavailable) as exc_info: + wrapped_callable(1, 2, three="four") + + callable_.assert_called_once_with(1, 2, three="four") + assert exc_info.value.response == grpc_error + + +def test_wrap_stream_errors_during_iteration(): + grpc_error = RpcErrorImpl(grpc.StatusCode.UNAVAILABLE) + response_iter = RpcResponseIteratorImpl([1, grpc_error]) + callable_ = mock.Mock(spec=["__call__"], return_value=response_iter) + + wrapped_callable = grpc_helpers._wrap_stream_errors(callable_) + got_iterator = wrapped_callable(1, 2, three="four") + next(got_iterator) + + with pytest.raises(exceptions.ServiceUnavailable) as exc_info: + next(got_iterator) + + callable_.assert_called_once_with(1, 2, three="four") + assert exc_info.value.response == grpc_error + + +@mock.patch("google.api_core.grpc_helpers._wrap_unary_errors") +def test_wrap_errors_non_streaming(wrap_unary_errors): + callable_ = mock.create_autospec(grpc.UnaryUnaryMultiCallable) + + result = grpc_helpers.wrap_errors(callable_) + + assert result == wrap_unary_errors.return_value + wrap_unary_errors.assert_called_once_with(callable_) + + +@mock.patch("google.api_core.grpc_helpers._wrap_stream_errors") +def test_wrap_errors_streaming(wrap_stream_errors): + callable_ = mock.create_autospec(grpc.UnaryStreamMultiCallable) + + result = grpc_helpers.wrap_errors(callable_) + + assert result == wrap_stream_errors.return_value + wrap_stream_errors.assert_called_once_with(callable_) + + +@mock.patch("grpc.composite_channel_credentials") +@mock.patch( + "google.auth.default", + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.projet), +) +@mock.patch("grpc.secure_channel") +def test_create_channel_implicit(grpc_secure_channel, default, composite_creds_call): + target = "example.com:443" + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers.create_channel(target) + + assert channel is grpc_secure_channel.return_value + + default.assert_called_once_with(scopes=None, default_scopes=None) + + if grpc_helpers.HAS_GRPC_GCP: + grpc_secure_channel.assert_called_once_with(target, composite_creds, None) + else: + grpc_secure_channel.assert_called_once_with(target, composite_creds) + + +@mock.patch("google.auth.transport.grpc.AuthMetadataPlugin", autospec=True) +@mock.patch( + "google.auth.transport.requests.Request", + autospec=True, + return_value=mock.sentinel.Request, +) +@mock.patch("grpc.composite_channel_credentials") +@mock.patch( + "google.auth.default", + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.project), +) +@mock.patch("grpc.secure_channel") +def test_create_channel_implicit_with_default_host( + grpc_secure_channel, default, composite_creds_call, request, auth_metadata_plugin +): + target = "example.com:443" + default_host = "example.com" + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers.create_channel(target, default_host=default_host) + + assert channel is grpc_secure_channel.return_value + + default.assert_called_once_with(scopes=None, default_scopes=None) + auth_metadata_plugin.assert_called_once_with( + mock.sentinel.credentials, mock.sentinel.Request, default_host=default_host + ) + + if grpc_helpers.HAS_GRPC_GCP: + grpc_secure_channel.assert_called_once_with(target, composite_creds, None) + else: + grpc_secure_channel.assert_called_once_with(target, composite_creds) + + +@mock.patch("grpc.composite_channel_credentials") +@mock.patch( + "google.auth.default", + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.projet), +) +@mock.patch("grpc.secure_channel") +def test_create_channel_implicit_with_ssl_creds( + grpc_secure_channel, default, composite_creds_call +): + target = "example.com:443" + + ssl_creds = grpc.ssl_channel_credentials() + + grpc_helpers.create_channel(target, ssl_credentials=ssl_creds) + + default.assert_called_once_with(scopes=None, default_scopes=None) + + composite_creds_call.assert_called_once_with(ssl_creds, mock.ANY) + composite_creds = composite_creds_call.return_value + if grpc_helpers.HAS_GRPC_GCP: + grpc_secure_channel.assert_called_once_with(target, composite_creds, None) + else: + grpc_secure_channel.assert_called_once_with(target, composite_creds) + + +@mock.patch("grpc.composite_channel_credentials") +@mock.patch( + "google.auth.default", + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.projet), +) +@mock.patch("grpc.secure_channel") +def test_create_channel_implicit_with_scopes( + grpc_secure_channel, default, composite_creds_call +): + target = "example.com:443" + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers.create_channel(target, scopes=["one", "two"]) + + assert channel is grpc_secure_channel.return_value + + default.assert_called_once_with(scopes=["one", "two"], default_scopes=None) + + if grpc_helpers.HAS_GRPC_GCP: + grpc_secure_channel.assert_called_once_with(target, composite_creds, None) + else: + grpc_secure_channel.assert_called_once_with(target, composite_creds) + + +@mock.patch("grpc.composite_channel_credentials") +@mock.patch( + "google.auth.default", + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.projet), +) +@mock.patch("grpc.secure_channel") +def test_create_channel_implicit_with_default_scopes( + grpc_secure_channel, default, composite_creds_call +): + target = "example.com:443" + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers.create_channel(target, default_scopes=["three", "four"]) + + assert channel is grpc_secure_channel.return_value + + default.assert_called_once_with(scopes=None, default_scopes=["three", "four"]) + + if grpc_helpers.HAS_GRPC_GCP: + grpc_secure_channel.assert_called_once_with(target, composite_creds, None) + else: + grpc_secure_channel.assert_called_once_with(target, composite_creds) + + +def test_create_channel_explicit_with_duplicate_credentials(): + target = "example.com:443" + + with pytest.raises(exceptions.DuplicateCredentialArgs): + grpc_helpers.create_channel( + target, + credentials_file="credentials.json", + credentials=mock.sentinel.credentials, + ) + + +@mock.patch("grpc.composite_channel_credentials") +@mock.patch("google.auth.credentials.with_scopes_if_required", autospec=True) +@mock.patch("grpc.secure_channel") +def test_create_channel_explicit(grpc_secure_channel, auth_creds, composite_creds_call): + target = "example.com:443" + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers.create_channel(target, credentials=mock.sentinel.credentials) + + auth_creds.assert_called_once_with( + mock.sentinel.credentials, scopes=None, default_scopes=None + ) + + assert channel is grpc_secure_channel.return_value + if grpc_helpers.HAS_GRPC_GCP: + grpc_secure_channel.assert_called_once_with(target, composite_creds, None) + else: + grpc_secure_channel.assert_called_once_with(target, composite_creds) + + +@mock.patch("grpc.composite_channel_credentials") +@mock.patch("grpc.secure_channel") +def test_create_channel_explicit_scoped(grpc_secure_channel, composite_creds_call): + target = "example.com:443" + scopes = ["1", "2"] + composite_creds = composite_creds_call.return_value + + credentials = mock.create_autospec(google.auth.credentials.Scoped, instance=True) + credentials.requires_scopes = True + + channel = grpc_helpers.create_channel( + target, credentials=credentials, scopes=scopes + ) + + credentials.with_scopes.assert_called_once_with(scopes, default_scopes=None) + + assert channel is grpc_secure_channel.return_value + if grpc_helpers.HAS_GRPC_GCP: + grpc_secure_channel.assert_called_once_with(target, composite_creds, None) + else: + grpc_secure_channel.assert_called_once_with(target, composite_creds) + + +@mock.patch("grpc.composite_channel_credentials") +@mock.patch("grpc.secure_channel") +def test_create_channel_explicit_default_scopes( + grpc_secure_channel, composite_creds_call +): + target = "example.com:443" + default_scopes = ["3", "4"] + composite_creds = composite_creds_call.return_value + + credentials = mock.create_autospec(google.auth.credentials.Scoped, instance=True) + credentials.requires_scopes = True + + channel = grpc_helpers.create_channel( + target, credentials=credentials, default_scopes=default_scopes + ) + + credentials.with_scopes.assert_called_once_with( + scopes=None, default_scopes=default_scopes + ) + + assert channel is grpc_secure_channel.return_value + if grpc_helpers.HAS_GRPC_GCP: + grpc_secure_channel.assert_called_once_with(target, composite_creds, None) + else: + grpc_secure_channel.assert_called_once_with(target, composite_creds) + + +@mock.patch("grpc.composite_channel_credentials") +@mock.patch("grpc.secure_channel") +def test_create_channel_explicit_with_quota_project( + grpc_secure_channel, composite_creds_call +): + target = "example.com:443" + composite_creds = composite_creds_call.return_value + + credentials = mock.create_autospec( + google.auth.credentials.CredentialsWithQuotaProject, instance=True + ) + + channel = grpc_helpers.create_channel( + target, credentials=credentials, quota_project_id="project-foo" + ) + + credentials.with_quota_project.assert_called_once_with("project-foo") + + assert channel is grpc_secure_channel.return_value + if grpc_helpers.HAS_GRPC_GCP: + grpc_secure_channel.assert_called_once_with(target, composite_creds, None) + else: + grpc_secure_channel.assert_called_once_with(target, composite_creds) + + +@mock.patch("grpc.composite_channel_credentials") +@mock.patch("grpc.secure_channel") +@mock.patch( + "google.auth.load_credentials_from_file", + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.project), +) +def test_create_channel_with_credentials_file( + load_credentials_from_file, grpc_secure_channel, composite_creds_call +): + target = "example.com:443" + + credentials_file = "/path/to/credentials/file.json" + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers.create_channel(target, credentials_file=credentials_file) + + google.auth.load_credentials_from_file.assert_called_once_with( + credentials_file, scopes=None, default_scopes=None + ) + + assert channel is grpc_secure_channel.return_value + if grpc_helpers.HAS_GRPC_GCP: + grpc_secure_channel.assert_called_once_with(target, composite_creds, None) + else: + grpc_secure_channel.assert_called_once_with(target, composite_creds) + + +@mock.patch("grpc.composite_channel_credentials") +@mock.patch("grpc.secure_channel") +@mock.patch( + "google.auth.load_credentials_from_file", + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.project), +) +def test_create_channel_with_credentials_file_and_scopes( + load_credentials_from_file, grpc_secure_channel, composite_creds_call +): + target = "example.com:443" + scopes = ["1", "2"] + + credentials_file = "/path/to/credentials/file.json" + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers.create_channel( + target, credentials_file=credentials_file, scopes=scopes + ) + + google.auth.load_credentials_from_file.assert_called_once_with( + credentials_file, scopes=scopes, default_scopes=None + ) + + assert channel is grpc_secure_channel.return_value + if grpc_helpers.HAS_GRPC_GCP: + grpc_secure_channel.assert_called_once_with(target, composite_creds, None) + else: + grpc_secure_channel.assert_called_once_with(target, composite_creds) + + +@mock.patch("grpc.composite_channel_credentials") +@mock.patch("grpc.secure_channel") +@mock.patch( + "google.auth.load_credentials_from_file", + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.project), +) +def test_create_channel_with_credentials_file_and_default_scopes( + load_credentials_from_file, grpc_secure_channel, composite_creds_call +): + target = "example.com:443" + default_scopes = ["3", "4"] + + credentials_file = "/path/to/credentials/file.json" + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers.create_channel( + target, credentials_file=credentials_file, default_scopes=default_scopes + ) + + load_credentials_from_file.assert_called_once_with( + credentials_file, scopes=None, default_scopes=default_scopes + ) + + assert channel is grpc_secure_channel.return_value + if grpc_helpers.HAS_GRPC_GCP: + grpc_secure_channel.assert_called_once_with(target, composite_creds, None) + else: + grpc_secure_channel.assert_called_once_with(target, composite_creds) + + +@pytest.mark.skipif( + not grpc_helpers.HAS_GRPC_GCP, reason="grpc_gcp module not available" +) +@mock.patch("grpc_gcp.secure_channel") +def test_create_channel_with_grpc_gcp(grpc_gcp_secure_channel): + target = "example.com:443" + scopes = ["test_scope"] + + credentials = mock.create_autospec(google.auth.credentials.Scoped, instance=True) + credentials.requires_scopes = True + + grpc_helpers.create_channel(target, credentials=credentials, scopes=scopes) + grpc_gcp_secure_channel.assert_called() + + credentials.with_scopes.assert_called_once_with(scopes, default_scopes=None) + + +@pytest.mark.skipif(grpc_helpers.HAS_GRPC_GCP, reason="grpc_gcp module not available") +@mock.patch("grpc.secure_channel") +def test_create_channel_without_grpc_gcp(grpc_secure_channel): + target = "example.com:443" + scopes = ["test_scope"] + + credentials = mock.create_autospec(google.auth.credentials.Scoped, instance=True) + credentials.requires_scopes = True + + grpc_helpers.create_channel(target, credentials=credentials, scopes=scopes) + grpc_secure_channel.assert_called() + + credentials.with_scopes.assert_called_once_with(scopes, default_scopes=None) + + +class TestChannelStub(object): + def test_single_response(self): + channel = grpc_helpers.ChannelStub() + stub = operations_pb2.OperationsStub(channel) + expected_request = operations_pb2.GetOperationRequest(name="meep") + expected_response = operations_pb2.Operation(name="moop") + + channel.GetOperation.response = expected_response + + response = stub.GetOperation(expected_request) + + assert response == expected_response + assert channel.requests == [("GetOperation", expected_request)] + assert channel.GetOperation.requests == [expected_request] + + def test_no_response(self): + channel = grpc_helpers.ChannelStub() + stub = operations_pb2.OperationsStub(channel) + expected_request = operations_pb2.GetOperationRequest(name="meep") + + with pytest.raises(ValueError) as exc_info: + stub.GetOperation(expected_request) + + assert exc_info.match("GetOperation") + + def test_missing_method(self): + channel = grpc_helpers.ChannelStub() + + with pytest.raises(AttributeError): + channel.DoesNotExist.response + + def test_exception_response(self): + channel = grpc_helpers.ChannelStub() + stub = operations_pb2.OperationsStub(channel) + expected_request = operations_pb2.GetOperationRequest(name="meep") + + channel.GetOperation.response = RuntimeError() + + with pytest.raises(RuntimeError): + stub.GetOperation(expected_request) + + def test_callable_response(self): + channel = grpc_helpers.ChannelStub() + stub = operations_pb2.OperationsStub(channel) + expected_request = operations_pb2.GetOperationRequest(name="meep") + expected_response = operations_pb2.Operation(name="moop") + + on_get_operation = mock.Mock(spec=("__call__",), return_value=expected_response) + + channel.GetOperation.response = on_get_operation + + response = stub.GetOperation(expected_request) + + assert response == expected_response + on_get_operation.assert_called_once_with(expected_request) + + def test_multiple_responses(self): + channel = grpc_helpers.ChannelStub() + stub = operations_pb2.OperationsStub(channel) + expected_request = operations_pb2.GetOperationRequest(name="meep") + expected_responses = [ + operations_pb2.Operation(name="foo"), + operations_pb2.Operation(name="bar"), + operations_pb2.Operation(name="baz"), + ] + + channel.GetOperation.responses = iter(expected_responses) + + response1 = stub.GetOperation(expected_request) + response2 = stub.GetOperation(expected_request) + response3 = stub.GetOperation(expected_request) + + assert response1 == expected_responses[0] + assert response2 == expected_responses[1] + assert response3 == expected_responses[2] + assert channel.requests == [("GetOperation", expected_request)] * 3 + assert channel.GetOperation.requests == [expected_request] * 3 + + with pytest.raises(StopIteration): + stub.GetOperation(expected_request) + + def test_multiple_responses_and_single_response_error(self): + channel = grpc_helpers.ChannelStub() + stub = operations_pb2.OperationsStub(channel) + channel.GetOperation.responses = [] + channel.GetOperation.response = mock.sentinel.response + + with pytest.raises(ValueError): + stub.GetOperation(operations_pb2.GetOperationRequest()) + + def test_call_info(self): + channel = grpc_helpers.ChannelStub() + stub = operations_pb2.OperationsStub(channel) + expected_request = operations_pb2.GetOperationRequest(name="meep") + expected_response = operations_pb2.Operation(name="moop") + expected_metadata = [("red", "blue"), ("two", "shoe")] + expected_credentials = mock.sentinel.credentials + channel.GetOperation.response = expected_response + + response = stub.GetOperation( + expected_request, + timeout=42, + metadata=expected_metadata, + credentials=expected_credentials, + ) + + assert response == expected_response + assert channel.requests == [("GetOperation", expected_request)] + assert channel.GetOperation.calls == [ + (expected_request, 42, expected_metadata, expected_credentials) + ] + + def test_unary_unary(self): + channel = grpc_helpers.ChannelStub() + method_name = "GetOperation" + callable_stub = channel.unary_unary(method_name) + assert callable_stub._method == method_name + assert callable_stub._channel == channel + + def test_unary_stream(self): + channel = grpc_helpers.ChannelStub() + method_name = "GetOperation" + callable_stub = channel.unary_stream(method_name) + assert callable_stub._method == method_name + assert callable_stub._channel == channel + + def test_stream_unary(self): + channel = grpc_helpers.ChannelStub() + method_name = "GetOperation" + callable_stub = channel.stream_unary(method_name) + assert callable_stub._method == method_name + assert callable_stub._channel == channel + + def test_stream_stream(self): + channel = grpc_helpers.ChannelStub() + method_name = "GetOperation" + callable_stub = channel.stream_stream(method_name) + assert callable_stub._method == method_name + assert callable_stub._channel == channel + + def test_subscribe_unsubscribe(self): + channel = grpc_helpers.ChannelStub() + assert channel.subscribe(None) is None + assert channel.unsubscribe(None) is None + + def test_close(self): + channel = grpc_helpers.ChannelStub() + assert channel.close() is None diff --git a/tests/unit/test_iam.py b/tests/unit/test_iam.py new file mode 100644 index 0000000..fbd242e --- /dev/null +++ b/tests/unit/test_iam.py @@ -0,0 +1,382 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from google.api_core.iam import _DICT_ACCESS_MSG, InvalidOperationException + + +class TestPolicy: + @staticmethod + def _get_target_class(): + from google.api_core.iam import Policy + + return Policy + + def _make_one(self, *args, **kw): + return self._get_target_class()(*args, **kw) + + def test_ctor_defaults(self): + empty = frozenset() + policy = self._make_one() + assert policy.etag is None + assert policy.version is None + assert policy.owners == empty + assert policy.editors == empty + assert policy.viewers == empty + assert len(policy) == 0 + assert dict(policy) == {} + + def test_ctor_explicit(self): + VERSION = 1 + ETAG = "ETAG" + empty = frozenset() + policy = self._make_one(ETAG, VERSION) + assert policy.etag == ETAG + assert policy.version == VERSION + assert policy.owners == empty + assert policy.editors == empty + assert policy.viewers == empty + assert len(policy) == 0 + assert dict(policy) == {} + + def test___getitem___miss(self): + policy = self._make_one() + assert policy["nonesuch"] == set() + + def test__getitem___and_set(self): + from google.api_core.iam import OWNER_ROLE + + policy = self._make_one() + + # get the policy using the getter and then modify it + policy[OWNER_ROLE].add("user:phred@example.com") + assert dict(policy) == {OWNER_ROLE: {"user:phred@example.com"}} + + def test___getitem___version3(self): + policy = self._make_one("DEADBEEF", 3) + with pytest.raises(InvalidOperationException, match=_DICT_ACCESS_MSG): + policy["role"] + + def test___getitem___with_conditions(self): + USER = "user:phred@example.com" + CONDITION = {"expression": "2 > 1"} + policy = self._make_one("DEADBEEF", 1) + policy.bindings = [ + {"role": "role/reader", "members": [USER], "condition": CONDITION} + ] + with pytest.raises(InvalidOperationException, match=_DICT_ACCESS_MSG): + policy["role/reader"] + + def test___setitem__(self): + USER = "user:phred@example.com" + PRINCIPALS = set([USER]) + policy = self._make_one() + policy["rolename"] = [USER] + assert policy["rolename"] == PRINCIPALS + assert len(policy) == 1 + assert dict(policy) == {"rolename": PRINCIPALS} + + def test__set_item__overwrite(self): + GROUP = "group:test@group.com" + USER = "user:phred@example.com" + ALL_USERS = "allUsers" + MEMBERS = set([ALL_USERS]) + GROUPS = set([GROUP]) + policy = self._make_one() + policy["first"] = [GROUP] + policy["second"] = [USER] + policy["second"] = [ALL_USERS] + assert policy["second"] == MEMBERS + assert len(policy) == 2 + assert dict(policy) == {"first": GROUPS, "second": MEMBERS} + + def test___setitem___version3(self): + policy = self._make_one("DEADBEEF", 3) + with pytest.raises(InvalidOperationException, match=_DICT_ACCESS_MSG): + policy["role/reader"] = ["user:phred@example.com"] + + def test___setitem___with_conditions(self): + USER = "user:phred@example.com" + CONDITION = {"expression": "2 > 1"} + policy = self._make_one("DEADBEEF", 1) + policy.bindings = [ + {"role": "role/reader", "members": set([USER]), "condition": CONDITION} + ] + with pytest.raises(InvalidOperationException, match=_DICT_ACCESS_MSG): + policy["role/reader"] = ["user:phred@example.com"] + + def test___delitem___hit(self): + policy = self._make_one() + policy.bindings = [ + {"role": "to/keep", "members": set(["phred@example.com"])}, + {"role": "to/remove", "members": set(["phred@example.com"])}, + ] + del policy["to/remove"] + assert len(policy) == 1 + assert dict(policy) == {"to/keep": set(["phred@example.com"])} + + def test___delitem___miss(self): + policy = self._make_one() + with pytest.raises(KeyError): + del policy["nonesuch"] + + def test___delitem___version3(self): + policy = self._make_one("DEADBEEF", 3) + with pytest.raises(InvalidOperationException, match=_DICT_ACCESS_MSG): + del policy["role/reader"] + + def test___delitem___with_conditions(self): + USER = "user:phred@example.com" + CONDITION = {"expression": "2 > 1"} + policy = self._make_one("DEADBEEF", 1) + policy.bindings = [ + {"role": "role/reader", "members": set([USER]), "condition": CONDITION} + ] + with pytest.raises(InvalidOperationException, match=_DICT_ACCESS_MSG): + del policy["role/reader"] + + def test_bindings_property(self): + USER = "user:phred@example.com" + CONDITION = {"expression": "2 > 1"} + policy = self._make_one() + BINDINGS = [ + {"role": "role/reader", "members": set([USER]), "condition": CONDITION} + ] + policy.bindings = BINDINGS + assert policy.bindings == BINDINGS + + def test_owners_getter(self): + from google.api_core.iam import OWNER_ROLE + + MEMBER = "user:phred@example.com" + expected = frozenset([MEMBER]) + policy = self._make_one() + policy[OWNER_ROLE] = [MEMBER] + assert policy.owners == expected + + def test_owners_setter(self): + import warnings + from google.api_core.iam import OWNER_ROLE + + MEMBER = "user:phred@example.com" + expected = set([MEMBER]) + policy = self._make_one() + + with warnings.catch_warnings(record=True) as warned: + policy.owners = [MEMBER] + + (warning,) = warned + assert warning.category is DeprecationWarning + assert policy[OWNER_ROLE] == expected + + def test_editors_getter(self): + from google.api_core.iam import EDITOR_ROLE + + MEMBER = "user:phred@example.com" + expected = frozenset([MEMBER]) + policy = self._make_one() + policy[EDITOR_ROLE] = [MEMBER] + assert policy.editors == expected + + def test_editors_setter(self): + import warnings + from google.api_core.iam import EDITOR_ROLE + + MEMBER = "user:phred@example.com" + expected = set([MEMBER]) + policy = self._make_one() + + with warnings.catch_warnings(record=True) as warned: + policy.editors = [MEMBER] + + (warning,) = warned + assert warning.category is DeprecationWarning + assert policy[EDITOR_ROLE] == expected + + def test_viewers_getter(self): + from google.api_core.iam import VIEWER_ROLE + + MEMBER = "user:phred@example.com" + expected = frozenset([MEMBER]) + policy = self._make_one() + policy[VIEWER_ROLE] = [MEMBER] + assert policy.viewers == expected + + def test_viewers_setter(self): + import warnings + from google.api_core.iam import VIEWER_ROLE + + MEMBER = "user:phred@example.com" + expected = set([MEMBER]) + policy = self._make_one() + + with warnings.catch_warnings(record=True) as warned: + policy.viewers = [MEMBER] + + (warning,) = warned + assert warning.category is DeprecationWarning + assert policy[VIEWER_ROLE] == expected + + def test_user(self): + EMAIL = "phred@example.com" + MEMBER = "user:%s" % (EMAIL,) + policy = self._make_one() + assert policy.user(EMAIL) == MEMBER + + def test_service_account(self): + EMAIL = "phred@example.com" + MEMBER = "serviceAccount:%s" % (EMAIL,) + policy = self._make_one() + assert policy.service_account(EMAIL) == MEMBER + + def test_group(self): + EMAIL = "phred@example.com" + MEMBER = "group:%s" % (EMAIL,) + policy = self._make_one() + assert policy.group(EMAIL) == MEMBER + + def test_domain(self): + DOMAIN = "example.com" + MEMBER = "domain:%s" % (DOMAIN,) + policy = self._make_one() + assert policy.domain(DOMAIN) == MEMBER + + def test_all_users(self): + policy = self._make_one() + assert policy.all_users() == "allUsers" + + def test_authenticated_users(self): + policy = self._make_one() + assert policy.authenticated_users() == "allAuthenticatedUsers" + + def test_from_api_repr_only_etag(self): + empty = frozenset() + RESOURCE = {"etag": "ACAB"} + klass = self._get_target_class() + policy = klass.from_api_repr(RESOURCE) + assert policy.etag == "ACAB" + assert policy.version is None + assert policy.owners == empty + assert policy.editors == empty + assert policy.viewers == empty + assert dict(policy) == {} + + def test_from_api_repr_complete(self): + from google.api_core.iam import OWNER_ROLE, EDITOR_ROLE, VIEWER_ROLE + + OWNER1 = "group:cloud-logs@google.com" + OWNER2 = "user:phred@example.com" + EDITOR1 = "domain:google.com" + EDITOR2 = "user:phred@example.com" + VIEWER1 = "serviceAccount:1234-abcdef@service.example.com" + VIEWER2 = "user:phred@example.com" + RESOURCE = { + "etag": "DEADBEEF", + "version": 1, + "bindings": [ + {"role": OWNER_ROLE, "members": [OWNER1, OWNER2]}, + {"role": EDITOR_ROLE, "members": [EDITOR1, EDITOR2]}, + {"role": VIEWER_ROLE, "members": [VIEWER1, VIEWER2]}, + ], + } + klass = self._get_target_class() + policy = klass.from_api_repr(RESOURCE) + assert policy.etag == "DEADBEEF" + assert policy.version == 1 + assert policy.owners, frozenset([OWNER1 == OWNER2]) + assert policy.editors, frozenset([EDITOR1 == EDITOR2]) + assert policy.viewers, frozenset([VIEWER1 == VIEWER2]) + assert dict(policy) == { + OWNER_ROLE: set([OWNER1, OWNER2]), + EDITOR_ROLE: set([EDITOR1, EDITOR2]), + VIEWER_ROLE: set([VIEWER1, VIEWER2]), + } + assert policy.bindings == [ + {"role": OWNER_ROLE, "members": set([OWNER1, OWNER2])}, + {"role": EDITOR_ROLE, "members": set([EDITOR1, EDITOR2])}, + {"role": VIEWER_ROLE, "members": set([VIEWER1, VIEWER2])}, + ] + + def test_from_api_repr_unknown_role(self): + USER = "user:phred@example.com" + GROUP = "group:cloud-logs@google.com" + RESOURCE = { + "etag": "DEADBEEF", + "version": 1, + "bindings": [{"role": "unknown", "members": [USER, GROUP]}], + } + klass = self._get_target_class() + policy = klass.from_api_repr(RESOURCE) + assert policy.etag == "DEADBEEF" + assert policy.version == 1 + assert dict(policy), {"unknown": set([GROUP == USER])} + + def test_to_api_repr_defaults(self): + policy = self._make_one() + assert policy.to_api_repr() == {} + + def test_to_api_repr_only_etag(self): + policy = self._make_one("DEADBEEF") + assert policy.to_api_repr() == {"etag": "DEADBEEF"} + + def test_to_api_repr_binding_wo_members(self): + policy = self._make_one() + policy["empty"] = [] + assert policy.to_api_repr() == {} + + def test_to_api_repr_binding_w_duplicates(self): + import warnings + from google.api_core.iam import OWNER_ROLE + + OWNER = "group:cloud-logs@google.com" + policy = self._make_one() + with warnings.catch_warnings(record=True): + policy.owners = [OWNER, OWNER] + assert policy.to_api_repr() == { + "bindings": [{"role": OWNER_ROLE, "members": [OWNER]}] + } + + def test_to_api_repr_full(self): + import operator + from google.api_core.iam import OWNER_ROLE, EDITOR_ROLE, VIEWER_ROLE + + OWNER1 = "group:cloud-logs@google.com" + OWNER2 = "user:phred@example.com" + EDITOR1 = "domain:google.com" + EDITOR2 = "user:phred@example.com" + VIEWER1 = "serviceAccount:1234-abcdef@service.example.com" + VIEWER2 = "user:phred@example.com" + CONDITION = { + "title": "title", + "description": "description", + "expression": "true", + } + BINDINGS = [ + {"role": OWNER_ROLE, "members": [OWNER1, OWNER2]}, + {"role": EDITOR_ROLE, "members": [EDITOR1, EDITOR2]}, + {"role": VIEWER_ROLE, "members": [VIEWER1, VIEWER2]}, + { + "role": VIEWER_ROLE, + "members": [VIEWER1, VIEWER2], + "condition": CONDITION, + }, + ] + policy = self._make_one("DEADBEEF", 1) + policy.bindings = BINDINGS + resource = policy.to_api_repr() + assert resource["etag"] == "DEADBEEF" + assert resource["version"] == 1 + key = operator.itemgetter("role") + assert sorted(resource["bindings"], key=key) == sorted(BINDINGS, key=key) diff --git a/tests/unit/test_operation.py b/tests/unit/test_operation.py new file mode 100644 index 0000000..22e23bc --- /dev/null +++ b/tests/unit/test_operation.py @@ -0,0 +1,326 @@ +# Copyright 2017, Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import mock +import pytest + +try: + import grpc # noqa: F401 +except ImportError: + pytest.skip("No GRPC", allow_module_level=True) + +from google.api_core import exceptions +from google.api_core import operation +from google.api_core import operations_v1 +from google.api_core import retry +from google.longrunning import operations_pb2 +from google.protobuf import struct_pb2 +from google.rpc import code_pb2 +from google.rpc import status_pb2 + +TEST_OPERATION_NAME = "test/operation" + + +def make_operation_proto( + name=TEST_OPERATION_NAME, metadata=None, response=None, error=None, **kwargs +): + operation_proto = operations_pb2.Operation(name=name, **kwargs) + + if metadata is not None: + operation_proto.metadata.Pack(metadata) + + if response is not None: + operation_proto.response.Pack(response) + + if error is not None: + operation_proto.error.CopyFrom(error) + + return operation_proto + + +def make_operation_future(client_operations_responses=None): + if client_operations_responses is None: + client_operations_responses = [make_operation_proto()] + + refresh = mock.Mock(spec=["__call__"], side_effect=client_operations_responses) + refresh.responses = client_operations_responses + cancel = mock.Mock(spec=["__call__"]) + operation_future = operation.Operation( + client_operations_responses[0], + refresh, + cancel, + result_type=struct_pb2.Struct, + metadata_type=struct_pb2.Struct, + ) + + return operation_future, refresh, cancel + + +def test_constructor(): + future, refresh, _ = make_operation_future() + + assert future.operation == refresh.responses[0] + assert future.operation.done is False + assert future.operation.name == TEST_OPERATION_NAME + assert future.metadata is None + assert future.running() + + +def test_metadata(): + expected_metadata = struct_pb2.Struct() + future, _, _ = make_operation_future( + [make_operation_proto(metadata=expected_metadata)] + ) + + assert future.metadata == expected_metadata + + +def test_cancellation(): + responses = [ + make_operation_proto(), + # Second response indicates that the operation was cancelled. + make_operation_proto( + done=True, error=status_pb2.Status(code=code_pb2.CANCELLED) + ), + ] + future, _, cancel = make_operation_future(responses) + + assert future.cancel() + assert future.cancelled() + cancel.assert_called_once_with() + + # Cancelling twice should have no effect. + assert not future.cancel() + cancel.assert_called_once_with() + + +def test_result(): + expected_result = struct_pb2.Struct() + responses = [ + make_operation_proto(), + # Second operation response includes the result. + make_operation_proto(done=True, response=expected_result), + ] + future, _, _ = make_operation_future(responses) + + result = future.result() + + assert result == expected_result + assert future.done() + + +def test_done_w_retry(): + RETRY_PREDICATE = retry.if_exception_type(exceptions.TooManyRequests) + test_retry = retry.Retry(predicate=RETRY_PREDICATE) + + expected_result = struct_pb2.Struct() + responses = [ + make_operation_proto(), + # Second operation response includes the result. + make_operation_proto(done=True, response=expected_result), + ] + future, _, _ = make_operation_future(responses) + future._refresh = mock.Mock() + + future.done(retry=test_retry) + future._refresh.assert_called_once_with(retry=test_retry) + + +def test_exception(): + expected_exception = status_pb2.Status(message="meep") + responses = [ + make_operation_proto(), + # Second operation response includes the error. + make_operation_proto(done=True, error=expected_exception), + ] + future, _, _ = make_operation_future(responses) + + exception = future.exception() + + assert expected_exception.message in "{!r}".format(exception) + + +def test_exception_with_error_code(): + expected_exception = status_pb2.Status(message="meep", code=5) + responses = [ + make_operation_proto(), + # Second operation response includes the error. + make_operation_proto(done=True, error=expected_exception), + ] + future, _, _ = make_operation_future(responses) + + exception = future.exception() + + assert expected_exception.message in "{!r}".format(exception) + # Status Code 5 maps to Not Found + # https://developers.google.com/maps-booking/reference/grpc-api/status_codes + assert isinstance(exception, exceptions.NotFound) + + +def test_unexpected_result(): + responses = [ + make_operation_proto(), + # Second operation response is done, but has not error or response. + make_operation_proto(done=True), + ] + future, _, _ = make_operation_future(responses) + + exception = future.exception() + + assert "Unexpected state" in "{!r}".format(exception) + + +def test__refresh_http(): + json_response = {"name": TEST_OPERATION_NAME, "done": True} + api_request = mock.Mock(return_value=json_response) + + result = operation._refresh_http(api_request, TEST_OPERATION_NAME) + + assert isinstance(result, operations_pb2.Operation) + assert result.name == TEST_OPERATION_NAME + assert result.done is True + + api_request.assert_called_once_with( + method="GET", path="operations/{}".format(TEST_OPERATION_NAME) + ) + + +def test__refresh_http_w_retry(): + json_response = {"name": TEST_OPERATION_NAME, "done": True} + api_request = mock.Mock() + retry = mock.Mock() + retry.return_value.return_value = json_response + + result = operation._refresh_http(api_request, TEST_OPERATION_NAME, retry=retry) + + assert isinstance(result, operations_pb2.Operation) + assert result.name == TEST_OPERATION_NAME + assert result.done is True + + api_request.assert_not_called() + retry.assert_called_once_with(api_request) + retry.return_value.assert_called_once_with( + method="GET", path="operations/{}".format(TEST_OPERATION_NAME) + ) + + +def test__cancel_http(): + api_request = mock.Mock() + + operation._cancel_http(api_request, TEST_OPERATION_NAME) + + api_request.assert_called_once_with( + method="POST", path="operations/{}:cancel".format(TEST_OPERATION_NAME) + ) + + +def test_from_http_json(): + operation_json = {"name": TEST_OPERATION_NAME, "done": True} + api_request = mock.sentinel.api_request + + future = operation.from_http_json( + operation_json, api_request, struct_pb2.Struct, metadata_type=struct_pb2.Struct + ) + + assert future._result_type == struct_pb2.Struct + assert future._metadata_type == struct_pb2.Struct + assert future.operation.name == TEST_OPERATION_NAME + assert future.done + + +def test__refresh_grpc(): + operations_stub = mock.Mock(spec=["GetOperation"]) + expected_result = make_operation_proto(done=True) + operations_stub.GetOperation.return_value = expected_result + + result = operation._refresh_grpc(operations_stub, TEST_OPERATION_NAME) + + assert result == expected_result + expected_request = operations_pb2.GetOperationRequest(name=TEST_OPERATION_NAME) + operations_stub.GetOperation.assert_called_once_with(expected_request) + + +def test__refresh_grpc_w_retry(): + operations_stub = mock.Mock(spec=["GetOperation"]) + expected_result = make_operation_proto(done=True) + retry = mock.Mock() + retry.return_value.return_value = expected_result + + result = operation._refresh_grpc(operations_stub, TEST_OPERATION_NAME, retry=retry) + + assert result == expected_result + expected_request = operations_pb2.GetOperationRequest(name=TEST_OPERATION_NAME) + operations_stub.GetOperation.assert_not_called() + retry.assert_called_once_with(operations_stub.GetOperation) + retry.return_value.assert_called_once_with(expected_request) + + +def test__cancel_grpc(): + operations_stub = mock.Mock(spec=["CancelOperation"]) + + operation._cancel_grpc(operations_stub, TEST_OPERATION_NAME) + + expected_request = operations_pb2.CancelOperationRequest(name=TEST_OPERATION_NAME) + operations_stub.CancelOperation.assert_called_once_with(expected_request) + + +def test_from_grpc(): + operation_proto = make_operation_proto(done=True) + operations_stub = mock.sentinel.operations_stub + + future = operation.from_grpc( + operation_proto, + operations_stub, + struct_pb2.Struct, + metadata_type=struct_pb2.Struct, + grpc_metadata=[("x-goog-request-params", "foo")], + ) + + assert future._result_type == struct_pb2.Struct + assert future._metadata_type == struct_pb2.Struct + assert future.operation.name == TEST_OPERATION_NAME + assert future.done + assert future._refresh.keywords["metadata"] == [("x-goog-request-params", "foo")] + assert future._cancel.keywords["metadata"] == [("x-goog-request-params", "foo")] + + +def test_from_gapic(): + operation_proto = make_operation_proto(done=True) + operations_client = mock.create_autospec( + operations_v1.OperationsClient, instance=True + ) + + future = operation.from_gapic( + operation_proto, + operations_client, + struct_pb2.Struct, + metadata_type=struct_pb2.Struct, + grpc_metadata=[("x-goog-request-params", "foo")], + ) + + assert future._result_type == struct_pb2.Struct + assert future._metadata_type == struct_pb2.Struct + assert future.operation.name == TEST_OPERATION_NAME + assert future.done + assert future._refresh.keywords["metadata"] == [("x-goog-request-params", "foo")] + assert future._cancel.keywords["metadata"] == [("x-goog-request-params", "foo")] + + +def test_deserialize(): + op = make_operation_proto(name="foobarbaz") + serialized = op.SerializeToString() + deserialized_op = operation.Operation.deserialize(serialized) + assert op.name == deserialized_op.name + assert type(op) is type(deserialized_op) diff --git a/tests/unit/test_page_iterator.py b/tests/unit/test_page_iterator.py new file mode 100644 index 0000000..a44e998 --- /dev/null +++ b/tests/unit/test_page_iterator.py @@ -0,0 +1,665 @@ +# Copyright 2015 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import types + +import mock +import pytest + +from google.api_core import page_iterator + + +def test__do_nothing_page_start(): + assert page_iterator._do_nothing_page_start(None, None, None) is None + + +class TestPage(object): + def test_constructor(self): + parent = mock.sentinel.parent + item_to_value = mock.sentinel.item_to_value + + page = page_iterator.Page(parent, (1, 2, 3), item_to_value) + + assert page.num_items == 3 + assert page.remaining == 3 + assert page._parent is parent + assert page._item_to_value is item_to_value + assert page.raw_page is None + + def test___iter__(self): + page = page_iterator.Page(None, (), None, None) + assert iter(page) is page + + def test_iterator_calls_parent_item_to_value(self): + parent = mock.sentinel.parent + + item_to_value = mock.Mock( + side_effect=lambda iterator, value: value, spec=["__call__"] + ) + + page = page_iterator.Page(parent, (10, 11, 12), item_to_value) + page._remaining = 100 + + assert item_to_value.call_count == 0 + assert page.remaining == 100 + + assert next(page) == 10 + assert item_to_value.call_count == 1 + item_to_value.assert_called_with(parent, 10) + assert page.remaining == 99 + + assert next(page) == 11 + assert item_to_value.call_count == 2 + item_to_value.assert_called_with(parent, 11) + assert page.remaining == 98 + + assert next(page) == 12 + assert item_to_value.call_count == 3 + item_to_value.assert_called_with(parent, 12) + assert page.remaining == 97 + + def test_raw_page(self): + parent = mock.sentinel.parent + item_to_value = mock.sentinel.item_to_value + + raw_page = mock.sentinel.raw_page + + page = page_iterator.Page(parent, (1, 2, 3), item_to_value, raw_page=raw_page) + assert page.raw_page is raw_page + + with pytest.raises(AttributeError): + page.raw_page = None + + +class PageIteratorImpl(page_iterator.Iterator): + def _next_page(self): + return mock.create_autospec(page_iterator.Page, instance=True) + + +class TestIterator(object): + def test_constructor(self): + client = mock.sentinel.client + item_to_value = mock.sentinel.item_to_value + token = "ab13nceor03" + max_results = 1337 + + iterator = PageIteratorImpl( + client, item_to_value, page_token=token, max_results=max_results + ) + + assert not iterator._started + assert iterator.client is client + assert iterator.item_to_value == item_to_value + assert iterator.max_results == max_results + # Changing attributes. + assert iterator.page_number == 0 + assert iterator.next_page_token == token + assert iterator.num_results == 0 + + def test_next(self): + iterator = PageIteratorImpl(None, None) + page_1 = page_iterator.Page( + iterator, ("item 1.1", "item 1.2"), page_iterator._item_to_value_identity + ) + page_2 = page_iterator.Page( + iterator, ("item 2.1",), page_iterator._item_to_value_identity + ) + iterator._next_page = mock.Mock(side_effect=[page_1, page_2, None]) + + result = next(iterator) + assert result == "item 1.1" + result = next(iterator) + assert result == "item 1.2" + result = next(iterator) + assert result == "item 2.1" + + with pytest.raises(StopIteration): + next(iterator) + + def test_pages_property_starts(self): + iterator = PageIteratorImpl(None, None) + + assert not iterator._started + + assert isinstance(iterator.pages, types.GeneratorType) + + assert iterator._started + + def test_pages_property_restart(self): + iterator = PageIteratorImpl(None, None) + + assert iterator.pages + + # Make sure we cannot restart. + with pytest.raises(ValueError): + assert iterator.pages + + def test__page_iter_increment(self): + iterator = PageIteratorImpl(None, None) + page = page_iterator.Page( + iterator, ("item",), page_iterator._item_to_value_identity + ) + iterator._next_page = mock.Mock(side_effect=[page, None]) + + assert iterator.num_results == 0 + + page_iter = iterator._page_iter(increment=True) + next(page_iter) + + assert iterator.num_results == 1 + + def test__page_iter_no_increment(self): + iterator = PageIteratorImpl(None, None) + + assert iterator.num_results == 0 + + page_iter = iterator._page_iter(increment=False) + next(page_iter) + + # results should still be 0 after fetching a page. + assert iterator.num_results == 0 + + def test__items_iter(self): + # Items to be returned. + item1 = 17 + item2 = 100 + item3 = 211 + + # Make pages from mock responses + parent = mock.sentinel.parent + page1 = page_iterator.Page( + parent, (item1, item2), page_iterator._item_to_value_identity + ) + page2 = page_iterator.Page( + parent, (item3,), page_iterator._item_to_value_identity + ) + + iterator = PageIteratorImpl(None, None) + iterator._next_page = mock.Mock(side_effect=[page1, page2, None]) + + items_iter = iterator._items_iter() + + assert isinstance(items_iter, types.GeneratorType) + + # Consume items and check the state of the iterator. + assert iterator.num_results == 0 + + assert next(items_iter) == item1 + assert iterator.num_results == 1 + + assert next(items_iter) == item2 + assert iterator.num_results == 2 + + assert next(items_iter) == item3 + assert iterator.num_results == 3 + + with pytest.raises(StopIteration): + next(items_iter) + + def test___iter__(self): + iterator = PageIteratorImpl(None, None) + iterator._next_page = mock.Mock(side_effect=[(1, 2), (3,), None]) + + assert not iterator._started + + result = list(iterator) + + assert result == [1, 2, 3] + assert iterator._started + + def test___iter__restart(self): + iterator = PageIteratorImpl(None, None) + + iter(iterator) + + # Make sure we cannot restart. + with pytest.raises(ValueError): + iter(iterator) + + def test___iter___restart_after_page(self): + iterator = PageIteratorImpl(None, None) + + assert iterator.pages + + # Make sure we cannot restart after starting the page iterator + with pytest.raises(ValueError): + iter(iterator) + + +class TestHTTPIterator(object): + def test_constructor(self): + client = mock.sentinel.client + path = "/foo" + iterator = page_iterator.HTTPIterator( + client, mock.sentinel.api_request, path, mock.sentinel.item_to_value + ) + + assert not iterator._started + assert iterator.client is client + assert iterator.path == path + assert iterator.item_to_value is mock.sentinel.item_to_value + assert iterator._items_key == "items" + assert iterator.max_results is None + assert iterator.extra_params == {} + assert iterator._page_start == page_iterator._do_nothing_page_start + # Changing attributes. + assert iterator.page_number == 0 + assert iterator.next_page_token is None + assert iterator.num_results == 0 + assert iterator._page_size is None + + def test_constructor_w_extra_param_collision(self): + extra_params = {"pageToken": "val"} + + with pytest.raises(ValueError): + page_iterator.HTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + extra_params=extra_params, + ) + + def test_iterate(self): + path = "/foo" + item1 = {"name": "1"} + item2 = {"name": "2"} + api_request = mock.Mock(return_value={"items": [item1, item2]}) + iterator = page_iterator.HTTPIterator( + mock.sentinel.client, + api_request, + path=path, + item_to_value=page_iterator._item_to_value_identity, + ) + + assert iterator.num_results == 0 + + items_iter = iter(iterator) + + val1 = next(items_iter) + assert val1 == item1 + assert iterator.num_results == 1 + + val2 = next(items_iter) + assert val2 == item2 + assert iterator.num_results == 2 + + with pytest.raises(StopIteration): + next(items_iter) + + api_request.assert_called_once_with(method="GET", path=path, query_params={}) + + def test__has_next_page_new(self): + iterator = page_iterator.HTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + ) + + # The iterator should *always* indicate that it has a next page + # when created so that it can fetch the initial page. + assert iterator._has_next_page() + + def test__has_next_page_without_token(self): + iterator = page_iterator.HTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + ) + + iterator.page_number = 1 + + # The iterator should not indicate that it has a new page if the + # initial page has been requested and there's no page token. + assert not iterator._has_next_page() + + def test__has_next_page_w_number_w_token(self): + iterator = page_iterator.HTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + ) + + iterator.page_number = 1 + iterator.next_page_token = mock.sentinel.token + + # The iterator should indicate that it has a new page if the + # initial page has been requested and there's is a page token. + assert iterator._has_next_page() + + def test__has_next_page_w_max_results_not_done(self): + iterator = page_iterator.HTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + max_results=3, + page_token=mock.sentinel.token, + ) + + iterator.page_number = 1 + + # The iterator should indicate that it has a new page if there + # is a page token and it has not consumed more than max_results. + assert iterator.num_results < iterator.max_results + assert iterator._has_next_page() + + def test__has_next_page_w_max_results_done(self): + + iterator = page_iterator.HTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + max_results=3, + page_token=mock.sentinel.token, + ) + + iterator.page_number = 1 + iterator.num_results = 3 + + # The iterator should not indicate that it has a new page if there + # if it has consumed more than max_results. + assert iterator.num_results == iterator.max_results + assert not iterator._has_next_page() + + def test__get_query_params_no_token(self): + iterator = page_iterator.HTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + ) + + assert iterator._get_query_params() == {} + + def test__get_query_params_w_token(self): + iterator = page_iterator.HTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + ) + iterator.next_page_token = "token" + + assert iterator._get_query_params() == {"pageToken": iterator.next_page_token} + + def test__get_query_params_w_max_results(self): + max_results = 3 + iterator = page_iterator.HTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + max_results=max_results, + ) + + iterator.num_results = 1 + local_max = max_results - iterator.num_results + + assert iterator._get_query_params() == {"maxResults": local_max} + + def test__get_query_params_extra_params(self): + extra_params = {"key": "val"} + iterator = page_iterator.HTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + extra_params=extra_params, + ) + + assert iterator._get_query_params() == extra_params + + def test__get_next_page_response_with_post(self): + path = "/foo" + page_response = {"items": ["one", "two"]} + api_request = mock.Mock(return_value=page_response) + iterator = page_iterator.HTTPIterator( + mock.sentinel.client, + api_request, + path=path, + item_to_value=page_iterator._item_to_value_identity, + ) + iterator._HTTP_METHOD = "POST" + + response = iterator._get_next_page_response() + + assert response == page_response + + api_request.assert_called_once_with(method="POST", path=path, data={}) + + def test__get_next_page_bad_http_method(self): + iterator = page_iterator.HTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + ) + iterator._HTTP_METHOD = "NOT-A-VERB" + + with pytest.raises(ValueError): + iterator._get_next_page_response() + + @pytest.mark.parametrize( + "page_size,max_results,pages", + [(3, None, False), (3, 8, False), (3, None, True), (3, 8, True)], + ) + def test_page_size_items(self, page_size, max_results, pages): + path = "/foo" + NITEMS = 10 + + n = [0] # blast you python 2! + + def api_request(*args, **kw): + assert not args + query_params = dict( + maxResults=( + page_size + if max_results is None + else min(page_size, max_results - n[0]) + ) + ) + if n[0]: + query_params.update(pageToken="test") + assert kw == {"method": "GET", "path": "/foo", "query_params": query_params} + n_items = min(kw["query_params"]["maxResults"], NITEMS - n[0]) + items = [dict(name=str(i + n[0])) for i in range(n_items)] + n[0] += n_items + result = dict(items=items) + if n[0] < NITEMS: + result.update(nextPageToken="test") + return result + + iterator = page_iterator.HTTPIterator( + mock.sentinel.client, + api_request, + path=path, + item_to_value=page_iterator._item_to_value_identity, + page_size=page_size, + max_results=max_results, + ) + + assert iterator.num_results == 0 + + n_results = max_results if max_results is not None else NITEMS + if pages: + items_iter = iter(iterator.pages) + npages = int(math.ceil(float(n_results) / page_size)) + for ipage in range(npages): + assert list(next(items_iter)) == [ + dict(name=str(i)) + for i in range( + ipage * page_size, min((ipage + 1) * page_size, n_results), + ) + ] + else: + items_iter = iter(iterator) + for i in range(n_results): + assert next(items_iter) == dict(name=str(i)) + assert iterator.num_results == i + 1 + + with pytest.raises(StopIteration): + next(items_iter) + + +class TestGRPCIterator(object): + def test_constructor(self): + client = mock.sentinel.client + items_field = "items" + iterator = page_iterator.GRPCIterator( + client, mock.sentinel.method, mock.sentinel.request, items_field + ) + + assert not iterator._started + assert iterator.client is client + assert iterator.max_results is None + assert iterator.item_to_value is page_iterator._item_to_value_identity + assert iterator._method == mock.sentinel.method + assert iterator._request == mock.sentinel.request + assert iterator._items_field == items_field + assert ( + iterator._request_token_field + == page_iterator.GRPCIterator._DEFAULT_REQUEST_TOKEN_FIELD + ) + assert ( + iterator._response_token_field + == page_iterator.GRPCIterator._DEFAULT_RESPONSE_TOKEN_FIELD + ) + # Changing attributes. + assert iterator.page_number == 0 + assert iterator.next_page_token is None + assert iterator.num_results == 0 + + def test_constructor_options(self): + client = mock.sentinel.client + items_field = "items" + request_field = "request" + response_field = "response" + iterator = page_iterator.GRPCIterator( + client, + mock.sentinel.method, + mock.sentinel.request, + items_field, + item_to_value=mock.sentinel.item_to_value, + request_token_field=request_field, + response_token_field=response_field, + max_results=42, + ) + + assert iterator.client is client + assert iterator.max_results == 42 + assert iterator.item_to_value is mock.sentinel.item_to_value + assert iterator._method == mock.sentinel.method + assert iterator._request == mock.sentinel.request + assert iterator._items_field == items_field + assert iterator._request_token_field == request_field + assert iterator._response_token_field == response_field + + def test_iterate(self): + request = mock.Mock(spec=["page_token"], page_token=None) + response1 = mock.Mock(items=["a", "b"], next_page_token="1") + response2 = mock.Mock(items=["c"], next_page_token="2") + response3 = mock.Mock(items=["d"], next_page_token="") + method = mock.Mock(side_effect=[response1, response2, response3]) + iterator = page_iterator.GRPCIterator( + mock.sentinel.client, method, request, "items" + ) + + assert iterator.num_results == 0 + + items = list(iterator) + assert items == ["a", "b", "c", "d"] + + method.assert_called_with(request) + assert method.call_count == 3 + assert request.page_token == "2" + + def test_iterate_with_max_results(self): + request = mock.Mock(spec=["page_token"], page_token=None) + response1 = mock.Mock(items=["a", "b"], next_page_token="1") + response2 = mock.Mock(items=["c"], next_page_token="2") + response3 = mock.Mock(items=["d"], next_page_token="") + method = mock.Mock(side_effect=[response1, response2, response3]) + iterator = page_iterator.GRPCIterator( + mock.sentinel.client, method, request, "items", max_results=3 + ) + + assert iterator.num_results == 0 + + items = list(iterator) + + assert items == ["a", "b", "c"] + assert iterator.num_results == 3 + + method.assert_called_with(request) + assert method.call_count == 2 + assert request.page_token == "1" + + +class GAXPageIterator(object): + """Fake object that matches gax.PageIterator""" + + def __init__(self, pages, page_token=None): + self._pages = iter(pages) + self.page_token = page_token + + def next(self): + return next(self._pages) + + __next__ = next + + +class TestGAXIterator(object): + def test_constructor(self): + client = mock.sentinel.client + token = "zzzyy78kl" + page_iter = GAXPageIterator((), page_token=token) + item_to_value = page_iterator._item_to_value_identity + max_results = 1337 + iterator = page_iterator._GAXIterator( + client, page_iter, item_to_value, max_results=max_results + ) + + assert not iterator._started + assert iterator.client is client + assert iterator.item_to_value is item_to_value + assert iterator.max_results == max_results + assert iterator._gax_page_iter is page_iter + # Changing attributes. + assert iterator.page_number == 0 + assert iterator.next_page_token == token + assert iterator.num_results == 0 + + def test__next_page(self): + page_items = (29, 31) + page_token = "2sde98ds2s0hh" + page_iter = GAXPageIterator([page_items], page_token=page_token) + iterator = page_iterator._GAXIterator( + mock.sentinel.client, page_iter, page_iterator._item_to_value_identity + ) + + page = iterator._next_page() + + assert iterator.next_page_token == page_token + assert isinstance(page, page_iterator.Page) + assert list(page) == list(page_items) + + next_page = iterator._next_page() + + assert next_page is None diff --git a/tests/unit/test_path_template.py b/tests/unit/test_path_template.py new file mode 100644 index 0000000..2c5216e --- /dev/null +++ b/tests/unit/test_path_template.py @@ -0,0 +1,389 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import unicode_literals + +import mock +import pytest + +from google.api_core import path_template + + +@pytest.mark.parametrize( + "tmpl, args, kwargs, expected_result", + [ + # Basic positional params + ["/v1/*", ["a"], {}, "/v1/a"], + ["/v1/**", ["a/b"], {}, "/v1/a/b"], + ["/v1/*/*", ["a", "b"], {}, "/v1/a/b"], + ["/v1/*/*/**", ["a", "b", "c/d"], {}, "/v1/a/b/c/d"], + # Basic named params + ["/v1/{name}", [], {"name": "parent"}, "/v1/parent"], + ["/v1/{name=**}", [], {"name": "parent/child"}, "/v1/parent/child"], + # Named params with a sub-template + ["/v1/{name=parent/*}", [], {"name": "parent/child"}, "/v1/parent/child"], + [ + "/v1/{name=parent/**}", + [], + {"name": "parent/child/object"}, + "/v1/parent/child/object", + ], + # Combining positional and named params + ["/v1/*/{name}", ["a"], {"name": "parent"}, "/v1/a/parent"], + ["/v1/{name}/*", ["a"], {"name": "parent"}, "/v1/parent/a"], + [ + "/v1/{parent}/*/{child}/*", + ["a", "b"], + {"parent": "thor", "child": "thorson"}, + "/v1/thor/a/thorson/b", + ], + ["/v1/{name}/**", ["a/b"], {"name": "parent"}, "/v1/parent/a/b"], + # Combining positional and named params with sub-templates. + [ + "/v1/{name=parent/*}/*", + ["a"], + {"name": "parent/child"}, + "/v1/parent/child/a", + ], + [ + "/v1/*/{name=parent/**}", + ["a"], + {"name": "parent/child/object"}, + "/v1/a/parent/child/object", + ], + ], +) +def test_expand_success(tmpl, args, kwargs, expected_result): + result = path_template.expand(tmpl, *args, **kwargs) + assert result == expected_result + assert path_template.validate(tmpl, result) + + +@pytest.mark.parametrize( + "tmpl, args, kwargs, exc_match", + [ + # Missing positional arg. + ["v1/*", [], {}, "Positional"], + # Missing named arg. + ["v1/{name}", [], {}, "Named"], + ], +) +def test_expanded_failure(tmpl, args, kwargs, exc_match): + with pytest.raises(ValueError, match=exc_match): + path_template.expand(tmpl, *args, **kwargs) + + +@pytest.mark.parametrize( + "request_obj, field, expected_result", + [ + [{"field": "stringValue"}, "field", "stringValue"], + [{"field": "stringValue"}, "nosuchfield", None], + [{"field": "stringValue"}, "field.subfield", None], + [{"field": {"subfield": "stringValue"}}, "field", None], + [{"field": {"subfield": "stringValue"}}, "field.subfield", "stringValue"], + [{"field": {"subfield": [1, 2, 3]}}, "field.subfield", [1, 2, 3]], + [{"field": {"subfield": "stringValue"}}, "field", None], + [{"field": {"subfield": "stringValue"}}, "field.nosuchfield", None], + [ + {"field": {"subfield": {"subsubfield": "stringValue"}}}, + "field.subfield.subsubfield", + "stringValue", + ], + ["string", "field", None], + ], +) +def test_get_field(request_obj, field, expected_result): + result = path_template.get_field(request_obj, field) + assert result == expected_result + + +@pytest.mark.parametrize( + "request_obj, field, expected_result", + [ + [{"field": "stringValue"}, "field", {}], + [{"field": "stringValue"}, "nosuchfield", {"field": "stringValue"}], + [{"field": "stringValue"}, "field.subfield", {"field": "stringValue"}], + [{"field": {"subfield": "stringValue"}}, "field.subfield", {"field": {}}], + [ + {"field": {"subfield": "stringValue", "q": "w"}, "e": "f"}, + "field.subfield", + {"field": {"q": "w"}, "e": "f"}, + ], + [ + {"field": {"subfield": "stringValue"}}, + "field.nosuchfield", + {"field": {"subfield": "stringValue"}}, + ], + [ + {"field": {"subfield": {"subsubfield": "stringValue", "q": "w"}}}, + "field.subfield.subsubfield", + {"field": {"subfield": {"q": "w"}}}, + ], + ["string", "field", "string"], + ["string", "field.subfield", "string"], + ], +) +def test_delete_field(request_obj, field, expected_result): + path_template.delete_field(request_obj, field) + assert request_obj == expected_result + + +@pytest.mark.parametrize( + "tmpl, path", + [ + # Single segment template, but multi segment value + ["v1/*", "v1/a/b"], + ["v1/*/*", "v1/a/b/c"], + # Single segement named template, but multi segment value + ["v1/{name}", "v1/a/b"], + ["v1/{name}/{value}", "v1/a/b/c"], + # Named value with a sub-template but invalid value + ["v1/{name=parent/*}", "v1/grandparent/child"], + ], +) +def test_validate_failure(tmpl, path): + assert not path_template.validate(tmpl, path) + + +def test__expand_variable_match_unexpected(): + match = mock.Mock(spec=["group"]) + match.group.return_value = None + with pytest.raises(ValueError, match="Unknown"): + path_template._expand_variable_match([], {}, match) + + +def test__replace_variable_with_pattern(): + match = mock.Mock(spec=["group"]) + match.group.return_value = None + with pytest.raises(ValueError, match="Unknown"): + path_template._replace_variable_with_pattern(match) + + +@pytest.mark.parametrize( + "http_options, request_kwargs, expected_result", + [ + [ + [["get", "/v1/no/template", ""]], + {"foo": "bar"}, + ["get", "/v1/no/template", {}, {"foo": "bar"}], + ], + # Single templates + [ + [["get", "/v1/{field}", ""]], + {"field": "parent"}, + ["get", "/v1/parent", {}, {}], + ], + [ + [["get", "/v1/{field.sub}", ""]], + {"field": {"sub": "parent"}, "foo": "bar"}, + ["get", "/v1/parent", {}, {"field": {}, "foo": "bar"}], + ], + ], +) +def test_transcode_base_case(http_options, request_kwargs, expected_result): + http_options, expected_result = helper_test_transcode(http_options, expected_result) + result = path_template.transcode(http_options, **request_kwargs) + assert result == expected_result + + +@pytest.mark.parametrize( + "http_options, request_kwargs, expected_result", + [ + [ + [["get", "/v1/{field.subfield}", ""]], + {"field": {"subfield": "parent"}, "foo": "bar"}, + ["get", "/v1/parent", {}, {"field": {}, "foo": "bar"}], + ], + [ + [["get", "/v1/{field.subfield.subsubfield}", ""]], + {"field": {"subfield": {"subsubfield": "parent"}}, "foo": "bar"}, + ["get", "/v1/parent", {}, {"field": {"subfield": {}}, "foo": "bar"}], + ], + [ + [["get", "/v1/{field.subfield1}/{field.subfield2}", ""]], + {"field": {"subfield1": "parent", "subfield2": "child"}, "foo": "bar"}, + ["get", "/v1/parent/child", {}, {"field": {}, "foo": "bar"}], + ], + ], +) +def test_transcode_subfields(http_options, request_kwargs, expected_result): + http_options, expected_result = helper_test_transcode(http_options, expected_result) + result = path_template.transcode(http_options, **request_kwargs) + assert result == expected_result + + +@pytest.mark.parametrize( + "http_options, request_kwargs, expected_result", + [ + # Single segment wildcard + [ + [["get", "/v1/{field=*}", ""]], + {"field": "parent"}, + ["get", "/v1/parent", {}, {}], + ], + [ + [["get", "/v1/{field=a/*/b/*}", ""]], + {"field": "a/parent/b/child", "foo": "bar"}, + ["get", "/v1/a/parent/b/child", {}, {"foo": "bar"}], + ], + # Double segment wildcard + [ + [["get", "/v1/{field=**}", ""]], + {"field": "parent/p1"}, + ["get", "/v1/parent/p1", {}, {}], + ], + [ + [["get", "/v1/{field=a/**/b/**}", ""]], + {"field": "a/parent/p1/b/child/c1", "foo": "bar"}, + ["get", "/v1/a/parent/p1/b/child/c1", {}, {"foo": "bar"}], + ], + # Combined single and double segment wildcard + [ + [["get", "/v1/{field=a/*/b/**}", ""]], + {"field": "a/parent/b/child/c1"}, + ["get", "/v1/a/parent/b/child/c1", {}, {}], + ], + [ + [["get", "/v1/{field=a/**/b/*}/v2/{name}", ""]], + {"field": "a/parent/p1/b/child", "name": "first", "foo": "bar"}, + ["get", "/v1/a/parent/p1/b/child/v2/first", {}, {"foo": "bar"}], + ], + ], +) +def test_transcode_with_wildcard(http_options, request_kwargs, expected_result): + http_options, expected_result = helper_test_transcode(http_options, expected_result) + result = path_template.transcode(http_options, **request_kwargs) + assert result == expected_result + + +@pytest.mark.parametrize( + "http_options, request_kwargs, expected_result", + [ + # Single field body + [ + [["post", "/v1/no/template", "data"]], + {"data": {"id": 1, "info": "some info"}, "foo": "bar"}, + ["post", "/v1/no/template", {"id": 1, "info": "some info"}, {"foo": "bar"}], + ], + [ + [["post", "/v1/{field=a/*}/b/{name=**}", "data"]], + { + "field": "a/parent", + "name": "first/last", + "data": {"id": 1, "info": "some info"}, + "foo": "bar", + }, + [ + "post", + "/v1/a/parent/b/first/last", + {"id": 1, "info": "some info"}, + {"foo": "bar"}, + ], + ], + # Wildcard body + [ + [["post", "/v1/{field=a/*}/b/{name=**}", "*"]], + { + "field": "a/parent", + "name": "first/last", + "data": {"id": 1, "info": "some info"}, + "foo": "bar", + }, + [ + "post", + "/v1/a/parent/b/first/last", + {"data": {"id": 1, "info": "some info"}, "foo": "bar"}, + {}, + ], + ], + ], +) +def test_transcode_with_body(http_options, request_kwargs, expected_result): + http_options, expected_result = helper_test_transcode(http_options, expected_result) + result = path_template.transcode(http_options, **request_kwargs) + assert result == expected_result + + +@pytest.mark.parametrize( + "http_options, request_kwargs, expected_result", + [ + # Additional bindings + [ + [ + ["post", "/v1/{field=a/*}/b/{name=**}", "extra_data"], + ["post", "/v1/{field=a/*}/b/{name=**}", "*"], + ], + { + "field": "a/parent", + "name": "first/last", + "data": {"id": 1, "info": "some info"}, + "foo": "bar", + }, + [ + "post", + "/v1/a/parent/b/first/last", + {"data": {"id": 1, "info": "some info"}, "foo": "bar"}, + {}, + ], + ], + [ + [ + ["get", "/v1/{field=a/*}/b/{name=**}", ""], + ["get", "/v1/{field=a/*}/b/first/last", ""], + ], + {"field": "a/parent", "foo": "bar"}, + ["get", "/v1/a/parent/b/first/last", {}, {"foo": "bar"}], + ], + ], +) +def test_transcode_with_additional_bindings( + http_options, request_kwargs, expected_result +): + http_options, expected_result = helper_test_transcode(http_options, expected_result) + result = path_template.transcode(http_options, **request_kwargs) + assert result == expected_result + + +@pytest.mark.parametrize( + "http_options, request_kwargs", + [ + [[["get", "/v1/{name}", ""]], {"foo": "bar"}], + [[["get", "/v1/{name}", ""]], {"name": "first/last"}], + [[["get", "/v1/{name=mr/*/*}", ""]], {"name": "first/last"}], + [[["post", "/v1/{name}", "data"]], {"name": "first/last"}], + ], +) +def test_transcode_fails(http_options, request_kwargs): + http_options, _ = helper_test_transcode(http_options, range(4)) + with pytest.raises(ValueError): + path_template.transcode(http_options, **request_kwargs) + + +def helper_test_transcode(http_options_list, expected_result_list): + http_options = [] + for opt_list in http_options_list: + http_option = {"method": opt_list[0], "uri": opt_list[1]} + if opt_list[2]: + http_option["body"] = opt_list[2] + http_options.append(http_option) + + expected_result = { + "method": expected_result_list[0], + "uri": expected_result_list[1], + "query_params": expected_result_list[3], + } + if expected_result_list[2]: + expected_result["body"] = expected_result_list[2] + + return (http_options, expected_result) diff --git a/tests/unit/test_protobuf_helpers.py b/tests/unit/test_protobuf_helpers.py new file mode 100644 index 0000000..3df45df --- /dev/null +++ b/tests/unit/test_protobuf_helpers.py @@ -0,0 +1,518 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import pytest + +from google.api import http_pb2 +from google.api_core import protobuf_helpers +from google.longrunning import operations_pb2 +from google.protobuf import any_pb2 +from google.protobuf import message +from google.protobuf import source_context_pb2 +from google.protobuf import struct_pb2 +from google.protobuf import timestamp_pb2 +from google.protobuf import type_pb2 +from google.protobuf import wrappers_pb2 +from google.type import color_pb2 +from google.type import date_pb2 +from google.type import timeofday_pb2 + + +def test_from_any_pb_success(): + in_message = date_pb2.Date(year=1990) + in_message_any = any_pb2.Any() + in_message_any.Pack(in_message) + out_message = protobuf_helpers.from_any_pb(date_pb2.Date, in_message_any) + + assert in_message == out_message + + +def test_from_any_pb_wrapped_success(): + # Declare a message class conforming to wrapped messages. + class WrappedDate(object): + def __init__(self, **kwargs): + self._pb = date_pb2.Date(**kwargs) + + def __eq__(self, other): + return self._pb == other + + @classmethod + def pb(cls, msg): + return msg._pb + + # Run the same test as `test_from_any_pb_success`, but using the + # wrapped class. + in_message = date_pb2.Date(year=1990) + in_message_any = any_pb2.Any() + in_message_any.Pack(in_message) + out_message = protobuf_helpers.from_any_pb(WrappedDate, in_message_any) + + assert out_message == in_message + + +def test_from_any_pb_failure(): + in_message = any_pb2.Any() + in_message.Pack(date_pb2.Date(year=1990)) + + with pytest.raises(TypeError): + protobuf_helpers.from_any_pb(timeofday_pb2.TimeOfDay, in_message) + + +def test_check_protobuf_helpers_ok(): + assert protobuf_helpers.check_oneof() is None + assert protobuf_helpers.check_oneof(foo="bar") is None + assert protobuf_helpers.check_oneof(foo="bar", baz=None) is None + assert protobuf_helpers.check_oneof(foo=None, baz="bacon") is None + assert protobuf_helpers.check_oneof(foo="bar", spam=None, eggs=None) is None + + +def test_check_protobuf_helpers_failures(): + with pytest.raises(ValueError): + protobuf_helpers.check_oneof(foo="bar", spam="eggs") + with pytest.raises(ValueError): + protobuf_helpers.check_oneof(foo="bar", baz="bacon", spam="eggs") + with pytest.raises(ValueError): + protobuf_helpers.check_oneof(foo="bar", spam=0, eggs=None) + + +def test_get_messages(): + answer = protobuf_helpers.get_messages(date_pb2) + + # Ensure that Date was exported properly. + assert answer["Date"] is date_pb2.Date + + # Ensure that no non-Message objects were exported. + for value in answer.values(): + assert issubclass(value, message.Message) + + +def test_get_dict_absent(): + with pytest.raises(KeyError): + assert protobuf_helpers.get({}, "foo") + + +def test_get_dict_present(): + assert protobuf_helpers.get({"foo": "bar"}, "foo") == "bar" + + +def test_get_dict_default(): + assert protobuf_helpers.get({}, "foo", default="bar") == "bar" + + +def test_get_dict_nested(): + assert protobuf_helpers.get({"foo": {"bar": "baz"}}, "foo.bar") == "baz" + + +def test_get_dict_nested_default(): + assert protobuf_helpers.get({}, "foo.baz", default="bacon") == "bacon" + assert protobuf_helpers.get({"foo": {}}, "foo.baz", default="bacon") == "bacon" + + +def test_get_msg_sentinel(): + msg = timestamp_pb2.Timestamp() + with pytest.raises(KeyError): + assert protobuf_helpers.get(msg, "foo") + + +def test_get_msg_present(): + msg = timestamp_pb2.Timestamp(seconds=42) + assert protobuf_helpers.get(msg, "seconds") == 42 + + +def test_get_msg_default(): + msg = timestamp_pb2.Timestamp() + assert protobuf_helpers.get(msg, "foo", default="bar") == "bar" + + +def test_invalid_object(): + with pytest.raises(TypeError): + protobuf_helpers.get(object(), "foo", "bar") + + +def test_set_dict(): + mapping = {} + protobuf_helpers.set(mapping, "foo", "bar") + assert mapping == {"foo": "bar"} + + +def test_set_msg(): + msg = timestamp_pb2.Timestamp() + protobuf_helpers.set(msg, "seconds", 42) + assert msg.seconds == 42 + + +def test_set_dict_nested(): + mapping = {} + protobuf_helpers.set(mapping, "foo.bar", "baz") + assert mapping == {"foo": {"bar": "baz"}} + + +def test_set_invalid_object(): + with pytest.raises(TypeError): + protobuf_helpers.set(object(), "foo", "bar") + + +def test_set_list(): + list_ops_response = operations_pb2.ListOperationsResponse() + + protobuf_helpers.set( + list_ops_response, + "operations", + [{"name": "foo"}, operations_pb2.Operation(name="bar")], + ) + + assert len(list_ops_response.operations) == 2 + + for operation in list_ops_response.operations: + assert isinstance(operation, operations_pb2.Operation) + + assert list_ops_response.operations[0].name == "foo" + assert list_ops_response.operations[1].name == "bar" + + +def test_set_list_clear_existing(): + list_ops_response = operations_pb2.ListOperationsResponse( + operations=[{"name": "baz"}] + ) + + protobuf_helpers.set( + list_ops_response, + "operations", + [{"name": "foo"}, operations_pb2.Operation(name="bar")], + ) + + assert len(list_ops_response.operations) == 2 + for operation in list_ops_response.operations: + assert isinstance(operation, operations_pb2.Operation) + assert list_ops_response.operations[0].name == "foo" + assert list_ops_response.operations[1].name == "bar" + + +def test_set_msg_with_msg_field(): + rule = http_pb2.HttpRule() + pattern = http_pb2.CustomHttpPattern(kind="foo", path="bar") + + protobuf_helpers.set(rule, "custom", pattern) + + assert rule.custom.kind == "foo" + assert rule.custom.path == "bar" + + +def test_set_msg_with_dict_field(): + rule = http_pb2.HttpRule() + pattern = {"kind": "foo", "path": "bar"} + + protobuf_helpers.set(rule, "custom", pattern) + + assert rule.custom.kind == "foo" + assert rule.custom.path == "bar" + + +def test_set_msg_nested_key(): + rule = http_pb2.HttpRule(custom=http_pb2.CustomHttpPattern(kind="foo", path="bar")) + + protobuf_helpers.set(rule, "custom.kind", "baz") + + assert rule.custom.kind == "baz" + assert rule.custom.path == "bar" + + +def test_setdefault_dict_unset(): + mapping = {} + protobuf_helpers.setdefault(mapping, "foo", "bar") + assert mapping == {"foo": "bar"} + + +def test_setdefault_dict_falsy(): + mapping = {"foo": None} + protobuf_helpers.setdefault(mapping, "foo", "bar") + assert mapping == {"foo": "bar"} + + +def test_setdefault_dict_truthy(): + mapping = {"foo": "bar"} + protobuf_helpers.setdefault(mapping, "foo", "baz") + assert mapping == {"foo": "bar"} + + +def test_setdefault_pb2_falsy(): + operation = operations_pb2.Operation() + protobuf_helpers.setdefault(operation, "name", "foo") + assert operation.name == "foo" + + +def test_setdefault_pb2_truthy(): + operation = operations_pb2.Operation(name="bar") + protobuf_helpers.setdefault(operation, "name", "foo") + assert operation.name == "bar" + + +def test_field_mask_invalid_args(): + with pytest.raises(ValueError): + protobuf_helpers.field_mask("foo", any_pb2.Any()) + with pytest.raises(ValueError): + protobuf_helpers.field_mask(any_pb2.Any(), "bar") + with pytest.raises(ValueError): + protobuf_helpers.field_mask(any_pb2.Any(), operations_pb2.Operation()) + + +def test_field_mask_equal_values(): + assert protobuf_helpers.field_mask(None, None).paths == [] + + original = struct_pb2.Value(number_value=1.0) + modified = struct_pb2.Value(number_value=1.0) + assert protobuf_helpers.field_mask(original, modified).paths == [] + + original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0)) + modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0)) + assert protobuf_helpers.field_mask(original, modified).paths == [] + + original = struct_pb2.ListValue(values=[struct_pb2.Value(number_value=1.0)]) + modified = struct_pb2.ListValue(values=[struct_pb2.Value(number_value=1.0)]) + assert protobuf_helpers.field_mask(original, modified).paths == [] + + original = struct_pb2.Struct(fields={"bar": struct_pb2.Value(number_value=1.0)}) + modified = struct_pb2.Struct(fields={"bar": struct_pb2.Value(number_value=1.0)}) + assert protobuf_helpers.field_mask(original, modified).paths == [] + + +def test_field_mask_zero_values(): + # Singular Values + original = color_pb2.Color(red=0.0) + modified = None + assert protobuf_helpers.field_mask(original, modified).paths == [] + + original = None + modified = color_pb2.Color(red=0.0) + assert protobuf_helpers.field_mask(original, modified).paths == [] + + # Repeated Values + original = struct_pb2.ListValue(values=[]) + modified = None + assert protobuf_helpers.field_mask(original, modified).paths == [] + + original = None + modified = struct_pb2.ListValue(values=[]) + assert protobuf_helpers.field_mask(original, modified).paths == [] + + # Maps + original = struct_pb2.Struct(fields={}) + modified = None + assert protobuf_helpers.field_mask(original, modified).paths == [] + + original = None + modified = struct_pb2.Struct(fields={}) + assert protobuf_helpers.field_mask(original, modified).paths == [] + + # Oneofs + original = struct_pb2.Value(number_value=0.0) + modified = None + assert protobuf_helpers.field_mask(original, modified).paths == [] + + original = None + modified = struct_pb2.Value(number_value=0.0) + assert protobuf_helpers.field_mask(original, modified).paths == [] + + +def test_field_mask_singular_field_diffs(): + original = type_pb2.Type(name="name") + modified = type_pb2.Type() + assert protobuf_helpers.field_mask(original, modified).paths == ["name"] + + original = type_pb2.Type(name="name") + modified = type_pb2.Type() + assert protobuf_helpers.field_mask(original, modified).paths == ["name"] + + original = None + modified = type_pb2.Type(name="name") + assert protobuf_helpers.field_mask(original, modified).paths == ["name"] + + original = type_pb2.Type(name="name") + modified = None + assert protobuf_helpers.field_mask(original, modified).paths == ["name"] + + +def test_field_mask_message_diffs(): + original = type_pb2.Type() + modified = type_pb2.Type( + source_context=source_context_pb2.SourceContext(file_name="name") + ) + assert protobuf_helpers.field_mask(original, modified).paths == [ + "source_context.file_name" + ] + + original = type_pb2.Type( + source_context=source_context_pb2.SourceContext(file_name="name") + ) + modified = type_pb2.Type() + assert protobuf_helpers.field_mask(original, modified).paths == ["source_context"] + + original = type_pb2.Type( + source_context=source_context_pb2.SourceContext(file_name="name") + ) + modified = type_pb2.Type( + source_context=source_context_pb2.SourceContext(file_name="other_name") + ) + assert protobuf_helpers.field_mask(original, modified).paths == [ + "source_context.file_name" + ] + + original = None + modified = type_pb2.Type( + source_context=source_context_pb2.SourceContext(file_name="name") + ) + assert protobuf_helpers.field_mask(original, modified).paths == [ + "source_context.file_name" + ] + + original = type_pb2.Type( + source_context=source_context_pb2.SourceContext(file_name="name") + ) + modified = None + assert protobuf_helpers.field_mask(original, modified).paths == ["source_context"] + + +def test_field_mask_wrapper_type_diffs(): + original = color_pb2.Color() + modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0)) + assert protobuf_helpers.field_mask(original, modified).paths == ["alpha"] + + original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0)) + modified = color_pb2.Color() + assert protobuf_helpers.field_mask(original, modified).paths == ["alpha"] + + original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0)) + modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=2.0)) + assert protobuf_helpers.field_mask(original, modified).paths == ["alpha"] + + original = None + modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=2.0)) + assert protobuf_helpers.field_mask(original, modified).paths == ["alpha"] + + original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0)) + modified = None + assert protobuf_helpers.field_mask(original, modified).paths == ["alpha"] + + +def test_field_mask_repeated_diffs(): + original = struct_pb2.ListValue() + modified = struct_pb2.ListValue( + values=[struct_pb2.Value(number_value=1.0), struct_pb2.Value(number_value=2.0)] + ) + assert protobuf_helpers.field_mask(original, modified).paths == ["values"] + + original = struct_pb2.ListValue( + values=[struct_pb2.Value(number_value=1.0), struct_pb2.Value(number_value=2.0)] + ) + modified = struct_pb2.ListValue() + assert protobuf_helpers.field_mask(original, modified).paths == ["values"] + + original = None + modified = struct_pb2.ListValue( + values=[struct_pb2.Value(number_value=1.0), struct_pb2.Value(number_value=2.0)] + ) + assert protobuf_helpers.field_mask(original, modified).paths == ["values"] + + original = struct_pb2.ListValue( + values=[struct_pb2.Value(number_value=1.0), struct_pb2.Value(number_value=2.0)] + ) + modified = None + assert protobuf_helpers.field_mask(original, modified).paths == ["values"] + + original = struct_pb2.ListValue( + values=[struct_pb2.Value(number_value=1.0), struct_pb2.Value(number_value=2.0)] + ) + modified = struct_pb2.ListValue( + values=[struct_pb2.Value(number_value=2.0), struct_pb2.Value(number_value=1.0)] + ) + assert protobuf_helpers.field_mask(original, modified).paths == ["values"] + + +def test_field_mask_map_diffs(): + original = struct_pb2.Struct() + modified = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)}) + assert protobuf_helpers.field_mask(original, modified).paths == ["fields"] + + original = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)}) + modified = struct_pb2.Struct() + assert protobuf_helpers.field_mask(original, modified).paths == ["fields"] + + original = None + modified = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)}) + assert protobuf_helpers.field_mask(original, modified).paths == ["fields"] + + original = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)}) + modified = None + assert protobuf_helpers.field_mask(original, modified).paths == ["fields"] + + original = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)}) + modified = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=2.0)}) + assert protobuf_helpers.field_mask(original, modified).paths == ["fields"] + + original = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)}) + modified = struct_pb2.Struct(fields={"bar": struct_pb2.Value(number_value=1.0)}) + assert protobuf_helpers.field_mask(original, modified).paths == ["fields"] + + +def test_field_mask_different_level_diffs(): + original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0)) + modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=2.0), red=1.0) + assert sorted(protobuf_helpers.field_mask(original, modified).paths) == [ + "alpha", + "red", + ] + + +@pytest.mark.skipif( + sys.version_info.major == 2, + reason="Field names with trailing underscores can only be created" + "through proto-plus, which is Python 3 only.", +) +def test_field_mask_ignore_trailing_underscore(): + import proto + + class Foo(proto.Message): + type_ = proto.Field(proto.STRING, number=1) + input_config = proto.Field(proto.STRING, number=2) + + modified = Foo(type_="bar", input_config="baz") + + assert sorted(protobuf_helpers.field_mask(None, Foo.pb(modified)).paths) == [ + "input_config", + "type", + ] + + +@pytest.mark.skipif( + sys.version_info.major == 2, + reason="Field names with trailing underscores can only be created" + "through proto-plus, which is Python 3 only.", +) +def test_field_mask_ignore_trailing_underscore_with_nesting(): + import proto + + class Bar(proto.Message): + class Baz(proto.Message): + input_config = proto.Field(proto.STRING, number=1) + + type_ = proto.Field(Baz, number=1) + + modified = Bar() + modified.type_.input_config = "foo" + + assert sorted(protobuf_helpers.field_mask(None, Bar.pb(modified)).paths) == [ + "type.input_config", + ] diff --git a/tests/unit/test_rest_helpers.py b/tests/unit/test_rest_helpers.py new file mode 100644 index 0000000..5932fa5 --- /dev/null +++ b/tests/unit/test_rest_helpers.py @@ -0,0 +1,77 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from google.api_core import rest_helpers + + +def test_flatten_simple_value(): + with pytest.raises(TypeError): + rest_helpers.flatten_query_params("abc") + + +def test_flatten_list(): + with pytest.raises(TypeError): + rest_helpers.flatten_query_params(["abc", "def"]) + + +def test_flatten_none(): + assert rest_helpers.flatten_query_params(None) == [] + + +def test_flatten_empty_dict(): + assert rest_helpers.flatten_query_params({}) == [] + + +def test_flatten_simple_dict(): + assert rest_helpers.flatten_query_params({"a": "abc", "b": "def"}) == [ + ("a", "abc"), + ("b", "def"), + ] + + +def test_flatten_repeated_field(): + assert rest_helpers.flatten_query_params({"a": ["x", "y", "z", None]}) == [ + ("a", "x"), + ("a", "y"), + ("a", "z"), + ] + + +def test_flatten_nested_dict(): + obj = {"a": {"b": {"c": ["x", "y", "z"]}}, "d": {"e": "uvw"}} + expected_result = [("a.b.c", "x"), ("a.b.c", "y"), ("a.b.c", "z"), ("d.e", "uvw")] + + assert rest_helpers.flatten_query_params(obj) == expected_result + + +def test_flatten_repeated_dict(): + obj = { + "a": {"b": {"c": [{"v": 1}, {"v": 2}]}}, + "d": "uvw", + } + + with pytest.raises(ValueError): + rest_helpers.flatten_query_params(obj) + + +def test_flatten_repeated_list(): + obj = { + "a": {"b": {"c": [["e", "f"], ["g", "h"]]}}, + "d": "uvw", + } + + with pytest.raises(ValueError): + rest_helpers.flatten_query_params(obj) diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py new file mode 100644 index 0000000..199ca55 --- /dev/null +++ b/tests/unit/test_retry.py @@ -0,0 +1,458 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import itertools +import re + +import mock +import pytest +import requests.exceptions + +from google.api_core import exceptions +from google.api_core import retry +from google.auth import exceptions as auth_exceptions + + +def test_if_exception_type(): + predicate = retry.if_exception_type(ValueError) + + assert predicate(ValueError()) + assert not predicate(TypeError()) + + +def test_if_exception_type_multiple(): + predicate = retry.if_exception_type(ValueError, TypeError) + + assert predicate(ValueError()) + assert predicate(TypeError()) + assert not predicate(RuntimeError()) + + +def test_if_transient_error(): + assert retry.if_transient_error(exceptions.InternalServerError("")) + assert retry.if_transient_error(exceptions.TooManyRequests("")) + assert retry.if_transient_error(exceptions.ServiceUnavailable("")) + assert retry.if_transient_error(requests.exceptions.ConnectionError("")) + assert retry.if_transient_error(requests.exceptions.ChunkedEncodingError("")) + assert retry.if_transient_error(auth_exceptions.TransportError("")) + assert not retry.if_transient_error(exceptions.InvalidArgument("")) + + +# Make uniform return half of its maximum, which will be the calculated +# sleep time. +@mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n / 2.0) +def test_exponential_sleep_generator_base_2(uniform): + gen = retry.exponential_sleep_generator(1, 60, multiplier=2) + + result = list(itertools.islice(gen, 8)) + assert result == [1, 2, 4, 8, 16, 32, 60, 60] + + +@mock.patch("time.sleep", autospec=True) +@mock.patch( + "google.api_core.datetime_helpers.utcnow", + return_value=datetime.datetime.min, + autospec=True, +) +def test_retry_target_success(utcnow, sleep): + predicate = retry.if_exception_type(ValueError) + call_count = [0] + + def target(): + call_count[0] += 1 + if call_count[0] < 3: + raise ValueError() + return 42 + + result = retry.retry_target(target, predicate, range(10), None) + + assert result == 42 + assert call_count[0] == 3 + sleep.assert_has_calls([mock.call(0), mock.call(1)]) + + +@mock.patch("time.sleep", autospec=True) +@mock.patch( + "google.api_core.datetime_helpers.utcnow", + return_value=datetime.datetime.min, + autospec=True, +) +def test_retry_target_w_on_error(utcnow, sleep): + predicate = retry.if_exception_type(ValueError) + call_count = {"target": 0} + to_raise = ValueError() + + def target(): + call_count["target"] += 1 + if call_count["target"] < 3: + raise to_raise + return 42 + + on_error = mock.Mock() + + result = retry.retry_target(target, predicate, range(10), None, on_error=on_error) + + assert result == 42 + assert call_count["target"] == 3 + + on_error.assert_has_calls([mock.call(to_raise), mock.call(to_raise)]) + sleep.assert_has_calls([mock.call(0), mock.call(1)]) + + +@mock.patch("time.sleep", autospec=True) +@mock.patch( + "google.api_core.datetime_helpers.utcnow", + return_value=datetime.datetime.min, + autospec=True, +) +def test_retry_target_non_retryable_error(utcnow, sleep): + predicate = retry.if_exception_type(ValueError) + exception = TypeError() + target = mock.Mock(side_effect=exception) + + with pytest.raises(TypeError) as exc_info: + retry.retry_target(target, predicate, range(10), None) + + assert exc_info.value == exception + sleep.assert_not_called() + + +@mock.patch("time.sleep", autospec=True) +@mock.patch("google.api_core.datetime_helpers.utcnow", autospec=True) +def test_retry_target_deadline_exceeded(utcnow, sleep): + predicate = retry.if_exception_type(ValueError) + exception = ValueError("meep") + target = mock.Mock(side_effect=exception) + # Setup the timeline so that the first call takes 5 seconds but the second + # call takes 6, which puts the retry over the deadline. + utcnow.side_effect = [ + # The first call to utcnow establishes the start of the timeline. + datetime.datetime.min, + datetime.datetime.min + datetime.timedelta(seconds=5), + datetime.datetime.min + datetime.timedelta(seconds=11), + ] + + with pytest.raises(exceptions.RetryError) as exc_info: + retry.retry_target(target, predicate, range(10), deadline=10) + + assert exc_info.value.cause == exception + assert exc_info.match("Deadline of 10.0s exceeded") + assert exc_info.match("last exception: meep") + assert target.call_count == 2 + + +def test_retry_target_bad_sleep_generator(): + with pytest.raises(ValueError, match="Sleep generator"): + retry.retry_target(mock.sentinel.target, mock.sentinel.predicate, [], None) + + +class TestRetry(object): + def test_constructor_defaults(self): + retry_ = retry.Retry() + assert retry_._predicate == retry.if_transient_error + assert retry_._initial == 1 + assert retry_._maximum == 60 + assert retry_._multiplier == 2 + assert retry_._deadline == 120 + assert retry_._on_error is None + assert retry_.deadline == 120 + + def test_constructor_options(self): + _some_function = mock.Mock() + + retry_ = retry.Retry( + predicate=mock.sentinel.predicate, + initial=1, + maximum=2, + multiplier=3, + deadline=4, + on_error=_some_function, + ) + assert retry_._predicate == mock.sentinel.predicate + assert retry_._initial == 1 + assert retry_._maximum == 2 + assert retry_._multiplier == 3 + assert retry_._deadline == 4 + assert retry_._on_error is _some_function + + def test_with_deadline(self): + retry_ = retry.Retry( + predicate=mock.sentinel.predicate, + initial=1, + maximum=2, + multiplier=3, + deadline=4, + on_error=mock.sentinel.on_error, + ) + new_retry = retry_.with_deadline(42) + assert retry_ is not new_retry + assert new_retry._deadline == 42 + + # the rest of the attributes should remain the same + assert new_retry._predicate is retry_._predicate + assert new_retry._initial == retry_._initial + assert new_retry._maximum == retry_._maximum + assert new_retry._multiplier == retry_._multiplier + assert new_retry._on_error is retry_._on_error + + def test_with_predicate(self): + retry_ = retry.Retry( + predicate=mock.sentinel.predicate, + initial=1, + maximum=2, + multiplier=3, + deadline=4, + on_error=mock.sentinel.on_error, + ) + new_retry = retry_.with_predicate(mock.sentinel.predicate) + assert retry_ is not new_retry + assert new_retry._predicate == mock.sentinel.predicate + + # the rest of the attributes should remain the same + assert new_retry._deadline == retry_._deadline + assert new_retry._initial == retry_._initial + assert new_retry._maximum == retry_._maximum + assert new_retry._multiplier == retry_._multiplier + assert new_retry._on_error is retry_._on_error + + def test_with_delay_noop(self): + retry_ = retry.Retry( + predicate=mock.sentinel.predicate, + initial=1, + maximum=2, + multiplier=3, + deadline=4, + on_error=mock.sentinel.on_error, + ) + new_retry = retry_.with_delay() + assert retry_ is not new_retry + assert new_retry._initial == retry_._initial + assert new_retry._maximum == retry_._maximum + assert new_retry._multiplier == retry_._multiplier + + def test_with_delay(self): + retry_ = retry.Retry( + predicate=mock.sentinel.predicate, + initial=1, + maximum=2, + multiplier=3, + deadline=4, + on_error=mock.sentinel.on_error, + ) + new_retry = retry_.with_delay(initial=5, maximum=6, multiplier=7) + assert retry_ is not new_retry + assert new_retry._initial == 5 + assert new_retry._maximum == 6 + assert new_retry._multiplier == 7 + + # the rest of the attributes should remain the same + assert new_retry._deadline == retry_._deadline + assert new_retry._predicate is retry_._predicate + assert new_retry._on_error is retry_._on_error + + def test_with_delay_partial_options(self): + retry_ = retry.Retry( + predicate=mock.sentinel.predicate, + initial=1, + maximum=2, + multiplier=3, + deadline=4, + on_error=mock.sentinel.on_error, + ) + new_retry = retry_.with_delay(initial=4) + assert retry_ is not new_retry + assert new_retry._initial == 4 + assert new_retry._maximum == 2 + assert new_retry._multiplier == 3 + + new_retry = retry_.with_delay(maximum=4) + assert retry_ is not new_retry + assert new_retry._initial == 1 + assert new_retry._maximum == 4 + assert new_retry._multiplier == 3 + + new_retry = retry_.with_delay(multiplier=4) + assert retry_ is not new_retry + assert new_retry._initial == 1 + assert new_retry._maximum == 2 + assert new_retry._multiplier == 4 + + # the rest of the attributes should remain the same + assert new_retry._deadline == retry_._deadline + assert new_retry._predicate is retry_._predicate + assert new_retry._on_error is retry_._on_error + + def test___str__(self): + def if_exception_type(exc): + return bool(exc) # pragma: NO COVER + + # Explicitly set all attributes as changed Retry defaults should not + # cause this test to start failing. + retry_ = retry.Retry( + predicate=if_exception_type, + initial=1.0, + maximum=60.0, + multiplier=2.0, + deadline=120.0, + on_error=None, + ) + assert re.match( + ( + r"<Retry predicate=<function.*?if_exception_type.*?>, " + r"initial=1.0, maximum=60.0, multiplier=2.0, deadline=120.0, " + r"on_error=None>" + ), + str(retry_), + ) + + @mock.patch("time.sleep", autospec=True) + def test___call___and_execute_success(self, sleep): + retry_ = retry.Retry() + target = mock.Mock(spec=["__call__"], return_value=42) + # __name__ is needed by functools.partial. + target.__name__ = "target" + + decorated = retry_(target) + target.assert_not_called() + + result = decorated("meep") + + assert result == 42 + target.assert_called_once_with("meep") + sleep.assert_not_called() + + # Make uniform return half of its maximum, which is the calculated sleep time. + @mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n / 2.0) + @mock.patch("time.sleep", autospec=True) + def test___call___and_execute_retry(self, sleep, uniform): + + on_error = mock.Mock(spec=["__call__"], side_effect=[None]) + retry_ = retry.Retry(predicate=retry.if_exception_type(ValueError)) + + target = mock.Mock(spec=["__call__"], side_effect=[ValueError(), 42]) + # __name__ is needed by functools.partial. + target.__name__ = "target" + + decorated = retry_(target, on_error=on_error) + target.assert_not_called() + + result = decorated("meep") + + assert result == 42 + assert target.call_count == 2 + target.assert_has_calls([mock.call("meep"), mock.call("meep")]) + sleep.assert_called_once_with(retry_._initial) + assert on_error.call_count == 1 + + # Make uniform return half of its maximum, which is the calculated sleep time. + @mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n / 2.0) + @mock.patch("time.sleep", autospec=True) + def test___call___and_execute_retry_hitting_deadline(self, sleep, uniform): + + on_error = mock.Mock(spec=["__call__"], side_effect=[None] * 10) + retry_ = retry.Retry( + predicate=retry.if_exception_type(ValueError), + initial=1.0, + maximum=1024.0, + multiplier=2.0, + deadline=9.9, + ) + + utcnow = datetime.datetime.utcnow() + utcnow_patcher = mock.patch( + "google.api_core.datetime_helpers.utcnow", return_value=utcnow + ) + + target = mock.Mock(spec=["__call__"], side_effect=[ValueError()] * 10) + # __name__ is needed by functools.partial. + target.__name__ = "target" + + decorated = retry_(target, on_error=on_error) + target.assert_not_called() + + with utcnow_patcher as patched_utcnow: + # Make sure that calls to fake time.sleep() also advance the mocked + # time clock. + def increase_time(sleep_delay): + patched_utcnow.return_value += datetime.timedelta(seconds=sleep_delay) + + sleep.side_effect = increase_time + + with pytest.raises(exceptions.RetryError): + decorated("meep") + + assert target.call_count == 5 + target.assert_has_calls([mock.call("meep")] * 5) + assert on_error.call_count == 5 + + # check the delays + assert sleep.call_count == 4 # once between each successive target calls + last_wait = sleep.call_args.args[0] + total_wait = sum(call_args.args[0] for call_args in sleep.call_args_list) + + assert last_wait == 2.9 # and not 8.0, because the last delay was shortened + assert total_wait == 9.9 # the same as the deadline + + @mock.patch("time.sleep", autospec=True) + def test___init___without_retry_executed(self, sleep): + _some_function = mock.Mock() + + retry_ = retry.Retry( + predicate=retry.if_exception_type(ValueError), on_error=_some_function + ) + # check the proper creation of the class + assert retry_._on_error is _some_function + + target = mock.Mock(spec=["__call__"], side_effect=[42]) + # __name__ is needed by functools.partial. + target.__name__ = "target" + + wrapped = retry_(target) + + result = wrapped("meep") + + assert result == 42 + target.assert_called_once_with("meep") + sleep.assert_not_called() + _some_function.assert_not_called() + + # Make uniform return half of its maximum, which is the calculated sleep time. + @mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n / 2.0) + @mock.patch("time.sleep", autospec=True) + def test___init___when_retry_is_executed(self, sleep, uniform): + _some_function = mock.Mock() + + retry_ = retry.Retry( + predicate=retry.if_exception_type(ValueError), on_error=_some_function + ) + # check the proper creation of the class + assert retry_._on_error is _some_function + + target = mock.Mock( + spec=["__call__"], side_effect=[ValueError(), ValueError(), 42] + ) + # __name__ is needed by functools.partial. + target.__name__ = "target" + + wrapped = retry_(target) + target.assert_not_called() + + result = wrapped("meep") + + assert result == 42 + assert target.call_count == 3 + assert _some_function.call_count == 2 + target.assert_has_calls([mock.call("meep"), mock.call("meep")]) + sleep.assert_any_call(retry_._initial) diff --git a/tests/unit/test_timeout.py b/tests/unit/test_timeout.py new file mode 100644 index 0000000..30d624e --- /dev/null +++ b/tests/unit/test_timeout.py @@ -0,0 +1,129 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import itertools + +import mock + +from google.api_core import timeout + + +def test__exponential_timeout_generator_base_2(): + gen = timeout._exponential_timeout_generator(1.0, 60.0, 2.0, deadline=None) + + result = list(itertools.islice(gen, 8)) + assert result == [1, 2, 4, 8, 16, 32, 60, 60] + + +@mock.patch("google.api_core.datetime_helpers.utcnow", autospec=True) +def test__exponential_timeout_generator_base_deadline(utcnow): + # Make each successive call to utcnow() advance one second. + utcnow.side_effect = [ + datetime.datetime.min + datetime.timedelta(seconds=n) for n in range(15) + ] + + gen = timeout._exponential_timeout_generator(1.0, 60.0, 2.0, deadline=30.0) + + result = list(itertools.islice(gen, 14)) + # Should grow until the cumulative time is > 30s, then start decreasing as + # the cumulative time approaches 60s. + assert result == [1, 2, 4, 8, 16, 24, 23, 22, 21, 20, 19, 18, 17, 16] + + +class TestConstantTimeout(object): + def test_constructor(self): + timeout_ = timeout.ConstantTimeout() + assert timeout_._timeout is None + + def test_constructor_args(self): + timeout_ = timeout.ConstantTimeout(42.0) + assert timeout_._timeout == 42.0 + + def test___str__(self): + timeout_ = timeout.ConstantTimeout(1) + assert str(timeout_) == "<ConstantTimeout timeout=1.0>" + + def test_apply(self): + target = mock.Mock(spec=["__call__", "__name__"], __name__="target") + timeout_ = timeout.ConstantTimeout(42.0) + wrapped = timeout_(target) + + wrapped() + + target.assert_called_once_with(timeout=42.0) + + def test_apply_passthrough(self): + target = mock.Mock(spec=["__call__", "__name__"], __name__="target") + timeout_ = timeout.ConstantTimeout(42.0) + wrapped = timeout_(target) + + wrapped(1, 2, meep="moop") + + target.assert_called_once_with(1, 2, meep="moop", timeout=42.0) + + +class TestExponentialTimeout(object): + def test_constructor(self): + timeout_ = timeout.ExponentialTimeout() + assert timeout_._initial == timeout._DEFAULT_INITIAL_TIMEOUT + assert timeout_._maximum == timeout._DEFAULT_MAXIMUM_TIMEOUT + assert timeout_._multiplier == timeout._DEFAULT_TIMEOUT_MULTIPLIER + assert timeout_._deadline == timeout._DEFAULT_DEADLINE + + def test_constructor_args(self): + timeout_ = timeout.ExponentialTimeout(1, 2, 3, 4) + assert timeout_._initial == 1 + assert timeout_._maximum == 2 + assert timeout_._multiplier == 3 + assert timeout_._deadline == 4 + + def test_with_timeout(self): + original_timeout = timeout.ExponentialTimeout() + timeout_ = original_timeout.with_deadline(42) + assert original_timeout is not timeout_ + assert timeout_._initial == timeout._DEFAULT_INITIAL_TIMEOUT + assert timeout_._maximum == timeout._DEFAULT_MAXIMUM_TIMEOUT + assert timeout_._multiplier == timeout._DEFAULT_TIMEOUT_MULTIPLIER + assert timeout_._deadline == 42 + + def test___str__(self): + timeout_ = timeout.ExponentialTimeout(1, 2, 3, 4) + assert str(timeout_) == ( + "<ExponentialTimeout initial=1.0, maximum=2.0, multiplier=3.0, " + "deadline=4.0>" + ) + + def test_apply(self): + target = mock.Mock(spec=["__call__", "__name__"], __name__="target") + timeout_ = timeout.ExponentialTimeout(1, 10, 2) + wrapped = timeout_(target) + + wrapped() + target.assert_called_with(timeout=1) + + wrapped() + target.assert_called_with(timeout=2) + + wrapped() + target.assert_called_with(timeout=4) + + def test_apply_passthrough(self): + target = mock.Mock(spec=["__call__", "__name__"], __name__="target") + timeout_ = timeout.ExponentialTimeout(42.0, 100, 2) + wrapped = timeout_(target) + + wrapped(1, 2, meep="moop") + + target.assert_called_once_with(1, 2, meep="moop", timeout=42.0) |