aboutsummaryrefslogtreecommitdiff
path: root/tests/transport/test_grpc.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/transport/test_grpc.py')
-rw-r--r--tests/transport/test_grpc.py502
1 files changed, 502 insertions, 0 deletions
diff --git a/tests/transport/test_grpc.py b/tests/transport/test_grpc.py
new file mode 100644
index 0000000..3437658
--- /dev/null
+++ b/tests/transport/test_grpc.py
@@ -0,0 +1,502 @@
+# Copyright 2016 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import os
+import time
+
+import mock
+import pytest
+
+from google.auth import _helpers
+from google.auth import credentials
+from google.auth import environment_vars
+from google.auth import exceptions
+from google.auth import transport
+from google.oauth2 import service_account
+
+try:
+ # pylint: disable=ungrouped-imports
+ import grpc
+ import google.auth.transport.grpc
+
+ HAS_GRPC = True
+except ImportError: # pragma: NO COVER
+ HAS_GRPC = False
+
+DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
+METADATA_PATH = os.path.join(DATA_DIR, "context_aware_metadata.json")
+with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh:
+ PRIVATE_KEY_BYTES = fh.read()
+with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh:
+ PUBLIC_CERT_BYTES = fh.read()
+
+pytestmark = pytest.mark.skipif(not HAS_GRPC, reason="gRPC is unavailable.")
+
+
+class CredentialsStub(credentials.Credentials):
+ def __init__(self, token="token"):
+ super(CredentialsStub, self).__init__()
+ self.token = token
+ self.expiry = None
+
+ def refresh(self, request):
+ self.token += "1"
+
+ def with_quota_project(self, quota_project_id):
+ raise NotImplementedError()
+
+
+class TestAuthMetadataPlugin(object):
+ def test_call_no_refresh(self):
+ credentials = CredentialsStub()
+ request = mock.create_autospec(transport.Request)
+
+ plugin = google.auth.transport.grpc.AuthMetadataPlugin(credentials, request)
+
+ context = mock.create_autospec(grpc.AuthMetadataContext, instance=True)
+ context.method_name = mock.sentinel.method_name
+ context.service_url = mock.sentinel.service_url
+ callback = mock.create_autospec(grpc.AuthMetadataPluginCallback)
+
+ plugin(context, callback)
+
+ time.sleep(2)
+
+ callback.assert_called_once_with(
+ [("authorization", "Bearer {}".format(credentials.token))], None
+ )
+
+ def test_call_refresh(self):
+ credentials = CredentialsStub()
+ credentials.expiry = datetime.datetime.min + _helpers.REFRESH_THRESHOLD
+ request = mock.create_autospec(transport.Request)
+
+ plugin = google.auth.transport.grpc.AuthMetadataPlugin(credentials, request)
+
+ context = mock.create_autospec(grpc.AuthMetadataContext, instance=True)
+ context.method_name = mock.sentinel.method_name
+ context.service_url = mock.sentinel.service_url
+ callback = mock.create_autospec(grpc.AuthMetadataPluginCallback)
+
+ plugin(context, callback)
+
+ time.sleep(2)
+
+ assert credentials.token == "token1"
+ callback.assert_called_once_with(
+ [("authorization", "Bearer {}".format(credentials.token))], None
+ )
+
+ def test__get_authorization_headers_with_service_account(self):
+ credentials = mock.create_autospec(service_account.Credentials)
+ request = mock.create_autospec(transport.Request)
+
+ plugin = google.auth.transport.grpc.AuthMetadataPlugin(credentials, request)
+
+ context = mock.create_autospec(grpc.AuthMetadataContext, instance=True)
+ context.method_name = "methodName"
+ context.service_url = "https://pubsub.googleapis.com/methodName"
+
+ plugin._get_authorization_headers(context)
+
+ credentials._create_self_signed_jwt.assert_called_once_with(None)
+
+ def test__get_authorization_headers_with_service_account_and_default_host(self):
+ credentials = mock.create_autospec(service_account.Credentials)
+ request = mock.create_autospec(transport.Request)
+
+ default_host = "pubsub.googleapis.com"
+ plugin = google.auth.transport.grpc.AuthMetadataPlugin(
+ credentials, request, default_host=default_host
+ )
+
+ context = mock.create_autospec(grpc.AuthMetadataContext, instance=True)
+ context.method_name = "methodName"
+ context.service_url = "https://pubsub.googleapis.com/methodName"
+
+ plugin._get_authorization_headers(context)
+
+ credentials._create_self_signed_jwt.assert_called_once_with(
+ "https://{}/".format(default_host)
+ )
+
+
+@mock.patch(
+ "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True
+)
+@mock.patch("grpc.composite_channel_credentials", autospec=True)
+@mock.patch("grpc.metadata_call_credentials", autospec=True)
+@mock.patch("grpc.ssl_channel_credentials", autospec=True)
+@mock.patch("grpc.secure_channel", autospec=True)
+class TestSecureAuthorizedChannel(object):
+ @mock.patch(
+ "google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True
+ )
+ @mock.patch(
+ "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
+ )
+ def test_secure_authorized_channel_adc(
+ self,
+ check_dca_metadata_path,
+ read_dca_metadata_file,
+ secure_channel,
+ ssl_channel_credentials,
+ metadata_call_credentials,
+ composite_channel_credentials,
+ get_client_ssl_credentials,
+ ):
+ credentials = CredentialsStub()
+ request = mock.create_autospec(transport.Request)
+ target = "example.com:80"
+
+ # Mock the context aware metadata and client cert/key so mTLS SSL channel
+ # will be used.
+ check_dca_metadata_path.return_value = METADATA_PATH
+ read_dca_metadata_file.return_value = {
+ "cert_provider_command": ["some command"]
+ }
+ get_client_ssl_credentials.return_value = (
+ True,
+ PUBLIC_CERT_BYTES,
+ PRIVATE_KEY_BYTES,
+ None,
+ )
+
+ channel = None
+ with mock.patch.dict(
+ os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
+ ):
+ channel = google.auth.transport.grpc.secure_authorized_channel(
+ credentials, request, target, options=mock.sentinel.options
+ )
+
+ # Check the auth plugin construction.
+ auth_plugin = metadata_call_credentials.call_args[0][0]
+ assert isinstance(auth_plugin, google.auth.transport.grpc.AuthMetadataPlugin)
+ assert auth_plugin._credentials == credentials
+ assert auth_plugin._request == request
+
+ # Check the ssl channel call.
+ ssl_channel_credentials.assert_called_once_with(
+ certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES
+ )
+
+ # Check the composite credentials call.
+ composite_channel_credentials.assert_called_once_with(
+ ssl_channel_credentials.return_value, metadata_call_credentials.return_value
+ )
+
+ # Check the channel call.
+ secure_channel.assert_called_once_with(
+ target,
+ composite_channel_credentials.return_value,
+ options=mock.sentinel.options,
+ )
+ assert channel == secure_channel.return_value
+
+ @mock.patch("google.auth.transport.grpc.SslCredentials", autospec=True)
+ def test_secure_authorized_channel_adc_without_client_cert_env(
+ self,
+ ssl_credentials_adc_method,
+ secure_channel,
+ ssl_channel_credentials,
+ metadata_call_credentials,
+ composite_channel_credentials,
+ get_client_ssl_credentials,
+ ):
+ # Test client cert won't be used if GOOGLE_API_USE_CLIENT_CERTIFICATE
+ # environment variable is not set.
+ credentials = CredentialsStub()
+ request = mock.create_autospec(transport.Request)
+ target = "example.com:80"
+
+ channel = google.auth.transport.grpc.secure_authorized_channel(
+ credentials, request, target, options=mock.sentinel.options
+ )
+
+ # Check the auth plugin construction.
+ auth_plugin = metadata_call_credentials.call_args[0][0]
+ assert isinstance(auth_plugin, google.auth.transport.grpc.AuthMetadataPlugin)
+ assert auth_plugin._credentials == credentials
+ assert auth_plugin._request == request
+
+ # Check the ssl channel call.
+ ssl_channel_credentials.assert_called_once()
+ ssl_credentials_adc_method.assert_not_called()
+
+ # Check the composite credentials call.
+ composite_channel_credentials.assert_called_once_with(
+ ssl_channel_credentials.return_value, metadata_call_credentials.return_value
+ )
+
+ # Check the channel call.
+ secure_channel.assert_called_once_with(
+ target,
+ composite_channel_credentials.return_value,
+ options=mock.sentinel.options,
+ )
+ assert channel == secure_channel.return_value
+
+ def test_secure_authorized_channel_explicit_ssl(
+ self,
+ secure_channel,
+ ssl_channel_credentials,
+ metadata_call_credentials,
+ composite_channel_credentials,
+ get_client_ssl_credentials,
+ ):
+ credentials = mock.Mock()
+ request = mock.Mock()
+ target = "example.com:80"
+ ssl_credentials = mock.Mock()
+
+ google.auth.transport.grpc.secure_authorized_channel(
+ credentials, request, target, ssl_credentials=ssl_credentials
+ )
+
+ # Since explicit SSL credentials are provided, get_client_ssl_credentials
+ # shouldn't be called.
+ assert not get_client_ssl_credentials.called
+
+ # Check the ssl channel call.
+ assert not ssl_channel_credentials.called
+
+ # Check the composite credentials call.
+ composite_channel_credentials.assert_called_once_with(
+ ssl_credentials, metadata_call_credentials.return_value
+ )
+
+ def test_secure_authorized_channel_mutual_exclusive(
+ self,
+ secure_channel,
+ ssl_channel_credentials,
+ metadata_call_credentials,
+ composite_channel_credentials,
+ get_client_ssl_credentials,
+ ):
+ credentials = mock.Mock()
+ request = mock.Mock()
+ target = "example.com:80"
+ ssl_credentials = mock.Mock()
+ client_cert_callback = mock.Mock()
+
+ with pytest.raises(ValueError):
+ google.auth.transport.grpc.secure_authorized_channel(
+ credentials,
+ request,
+ target,
+ ssl_credentials=ssl_credentials,
+ client_cert_callback=client_cert_callback,
+ )
+
+ def test_secure_authorized_channel_with_client_cert_callback_success(
+ self,
+ secure_channel,
+ ssl_channel_credentials,
+ metadata_call_credentials,
+ composite_channel_credentials,
+ get_client_ssl_credentials,
+ ):
+ credentials = mock.Mock()
+ request = mock.Mock()
+ target = "example.com:80"
+ client_cert_callback = mock.Mock()
+ client_cert_callback.return_value = (PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES)
+
+ with mock.patch.dict(
+ os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
+ ):
+ google.auth.transport.grpc.secure_authorized_channel(
+ credentials, request, target, client_cert_callback=client_cert_callback
+ )
+
+ client_cert_callback.assert_called_once()
+
+ # Check we are using the cert and key provided by client_cert_callback.
+ ssl_channel_credentials.assert_called_once_with(
+ certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES
+ )
+
+ # Check the composite credentials call.
+ composite_channel_credentials.assert_called_once_with(
+ ssl_channel_credentials.return_value, metadata_call_credentials.return_value
+ )
+
+ @mock.patch(
+ "google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True
+ )
+ @mock.patch(
+ "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
+ )
+ def test_secure_authorized_channel_with_client_cert_callback_failure(
+ self,
+ check_dca_metadata_path,
+ read_dca_metadata_file,
+ secure_channel,
+ ssl_channel_credentials,
+ metadata_call_credentials,
+ composite_channel_credentials,
+ get_client_ssl_credentials,
+ ):
+ credentials = mock.Mock()
+ request = mock.Mock()
+ target = "example.com:80"
+
+ client_cert_callback = mock.Mock()
+ client_cert_callback.side_effect = Exception("callback exception")
+
+ with pytest.raises(Exception) as excinfo:
+ with mock.patch.dict(
+ os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
+ ):
+ google.auth.transport.grpc.secure_authorized_channel(
+ credentials,
+ request,
+ target,
+ client_cert_callback=client_cert_callback,
+ )
+
+ assert str(excinfo.value) == "callback exception"
+
+ def test_secure_authorized_channel_cert_callback_without_client_cert_env(
+ self,
+ secure_channel,
+ ssl_channel_credentials,
+ metadata_call_credentials,
+ composite_channel_credentials,
+ get_client_ssl_credentials,
+ ):
+ # Test client cert won't be used if GOOGLE_API_USE_CLIENT_CERTIFICATE
+ # environment variable is not set.
+ credentials = mock.Mock()
+ request = mock.Mock()
+ target = "example.com:80"
+ client_cert_callback = mock.Mock()
+
+ google.auth.transport.grpc.secure_authorized_channel(
+ credentials, request, target, client_cert_callback=client_cert_callback
+ )
+
+ # Check client_cert_callback is not called because GOOGLE_API_USE_CLIENT_CERTIFICATE
+ # is not set.
+ client_cert_callback.assert_not_called()
+
+ ssl_channel_credentials.assert_called_once()
+
+ # Check the composite credentials call.
+ composite_channel_credentials.assert_called_once_with(
+ ssl_channel_credentials.return_value, metadata_call_credentials.return_value
+ )
+
+
+@mock.patch("grpc.ssl_channel_credentials", autospec=True)
+@mock.patch(
+ "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True
+)
+@mock.patch("google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True)
+@mock.patch(
+ "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
+)
+class TestSslCredentials(object):
+ def test_no_context_aware_metadata(
+ self,
+ mock_check_dca_metadata_path,
+ mock_read_dca_metadata_file,
+ mock_get_client_ssl_credentials,
+ mock_ssl_channel_credentials,
+ ):
+ # Mock that the metadata file doesn't exist.
+ mock_check_dca_metadata_path.return_value = None
+
+ with mock.patch.dict(
+ os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
+ ):
+ ssl_credentials = google.auth.transport.grpc.SslCredentials()
+
+ # Since no context aware metadata is found, we wouldn't call
+ # get_client_ssl_credentials, and the SSL channel credentials created is
+ # non mTLS.
+ assert ssl_credentials.ssl_credentials is not None
+ assert not ssl_credentials.is_mtls
+ mock_get_client_ssl_credentials.assert_not_called()
+ mock_ssl_channel_credentials.assert_called_once_with()
+
+ def test_get_client_ssl_credentials_failure(
+ self,
+ mock_check_dca_metadata_path,
+ mock_read_dca_metadata_file,
+ mock_get_client_ssl_credentials,
+ mock_ssl_channel_credentials,
+ ):
+ mock_check_dca_metadata_path.return_value = METADATA_PATH
+ mock_read_dca_metadata_file.return_value = {
+ "cert_provider_command": ["some command"]
+ }
+
+ # Mock that client cert and key are not loaded and exception is raised.
+ mock_get_client_ssl_credentials.side_effect = exceptions.ClientCertError()
+
+ with pytest.raises(exceptions.MutualTLSChannelError):
+ with mock.patch.dict(
+ os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
+ ):
+ assert google.auth.transport.grpc.SslCredentials().ssl_credentials
+
+ def test_get_client_ssl_credentials_success(
+ self,
+ mock_check_dca_metadata_path,
+ mock_read_dca_metadata_file,
+ mock_get_client_ssl_credentials,
+ mock_ssl_channel_credentials,
+ ):
+ mock_check_dca_metadata_path.return_value = METADATA_PATH
+ mock_read_dca_metadata_file.return_value = {
+ "cert_provider_command": ["some command"]
+ }
+ mock_get_client_ssl_credentials.return_value = (
+ True,
+ PUBLIC_CERT_BYTES,
+ PRIVATE_KEY_BYTES,
+ None,
+ )
+
+ with mock.patch.dict(
+ os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
+ ):
+ ssl_credentials = google.auth.transport.grpc.SslCredentials()
+
+ assert ssl_credentials.ssl_credentials is not None
+ assert ssl_credentials.is_mtls
+ mock_get_client_ssl_credentials.assert_called_once()
+ mock_ssl_channel_credentials.assert_called_once_with(
+ certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES
+ )
+
+ def test_get_client_ssl_credentials_without_client_cert_env(
+ self,
+ mock_check_dca_metadata_path,
+ mock_read_dca_metadata_file,
+ mock_get_client_ssl_credentials,
+ mock_ssl_channel_credentials,
+ ):
+ # Test client cert won't be used if GOOGLE_API_USE_CLIENT_CERTIFICATE is not set.
+ ssl_credentials = google.auth.transport.grpc.SslCredentials()
+
+ assert ssl_credentials.ssl_credentials is not None
+ assert not ssl_credentials.is_mtls
+ mock_check_dca_metadata_path.assert_not_called()
+ mock_read_dca_metadata_file.assert_not_called()
+ mock_get_client_ssl_credentials.assert_not_called()
+ mock_ssl_channel_credentials.assert_called_once()