aboutsummaryrefslogtreecommitdiff
path: root/tests/test_impersonated_credentials.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_impersonated_credentials.py')
-rw-r--r--tests/test_impersonated_credentials.py553
1 files changed, 553 insertions, 0 deletions
diff --git a/tests/test_impersonated_credentials.py b/tests/test_impersonated_credentials.py
new file mode 100644
index 0000000..bc404e3
--- /dev/null
+++ b/tests/test_impersonated_credentials.py
@@ -0,0 +1,553 @@
+# Copyright 2018 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import json
+import os
+
+import mock
+import pytest
+from six.moves import http_client
+
+from google.auth import _helpers
+from google.auth import crypt
+from google.auth import exceptions
+from google.auth import impersonated_credentials
+from google.auth import transport
+from google.auth.impersonated_credentials import Credentials
+from google.oauth2 import credentials
+from google.oauth2 import service_account
+
+DATA_DIR = os.path.join(os.path.dirname(__file__), "", "data")
+
+with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh:
+ PRIVATE_KEY_BYTES = fh.read()
+
+SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "service_account.json")
+
+ID_TOKEN_DATA = (
+ "eyJhbGciOiJSUzI1NiIsImtpZCI6ImRmMzc1ODkwOGI3OTIyOTNhZDk3N2Ew"
+ "Yjk5MWQ5OGE3N2Y0ZWVlY2QiLCJ0eXAiOiJKV1QifQ.eyJhdWQiOiJodHRwc"
+ "zovL2Zvby5iYXIiLCJhenAiOiIxMDIxMDE1NTA4MzQyMDA3MDg1NjgiLCJle"
+ "HAiOjE1NjQ0NzUwNTEsImlhdCI6MTU2NDQ3MTQ1MSwiaXNzIjoiaHR0cHM6L"
+ "y9hY2NvdW50cy5nb29nbGUuY29tIiwic3ViIjoiMTAyMTAxNTUwODM0MjAwN"
+ "zA4NTY4In0.redacted"
+)
+ID_TOKEN_EXPIRY = 1564475051
+
+with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh:
+ SERVICE_ACCOUNT_INFO = json.load(fh)
+
+SIGNER = crypt.RSASigner.from_string(PRIVATE_KEY_BYTES, "1")
+TOKEN_URI = "https://example.com/oauth2/token"
+
+
+@pytest.fixture
+def mock_donor_credentials():
+ with mock.patch("google.oauth2._client.jwt_grant", autospec=True) as grant:
+ grant.return_value = (
+ "source token",
+ _helpers.utcnow() + datetime.timedelta(seconds=500),
+ {},
+ )
+ yield grant
+
+
+class MockResponse:
+ def __init__(self, json_data, status_code):
+ self.json_data = json_data
+ self.status_code = status_code
+
+ def json(self):
+ return self.json_data
+
+
+@pytest.fixture
+def mock_authorizedsession_sign():
+ with mock.patch(
+ "google.auth.transport.requests.AuthorizedSession.request", autospec=True
+ ) as auth_session:
+ data = {"keyId": "1", "signedBlob": "c2lnbmF0dXJl"}
+ auth_session.return_value = MockResponse(data, http_client.OK)
+ yield auth_session
+
+
+@pytest.fixture
+def mock_authorizedsession_idtoken():
+ with mock.patch(
+ "google.auth.transport.requests.AuthorizedSession.request", autospec=True
+ ) as auth_session:
+ data = {"token": ID_TOKEN_DATA}
+ auth_session.return_value = MockResponse(data, http_client.OK)
+ yield auth_session
+
+
+class TestImpersonatedCredentials(object):
+
+ SERVICE_ACCOUNT_EMAIL = "service-account@example.com"
+ TARGET_PRINCIPAL = "impersonated@project.iam.gserviceaccount.com"
+ TARGET_SCOPES = ["https://www.googleapis.com/auth/devstorage.read_only"]
+ DELEGATES = []
+ LIFETIME = 3600
+ SOURCE_CREDENTIALS = service_account.Credentials(
+ SIGNER, SERVICE_ACCOUNT_EMAIL, TOKEN_URI
+ )
+ USER_SOURCE_CREDENTIALS = credentials.Credentials(token="ABCDE")
+ IAM_ENDPOINT_OVERRIDE = (
+ "https://us-east1-iamcredentials.googleapis.com/v1/projects/-"
+ + "/serviceAccounts/{}:generateAccessToken".format(SERVICE_ACCOUNT_EMAIL)
+ )
+
+ def make_credentials(
+ self,
+ source_credentials=SOURCE_CREDENTIALS,
+ lifetime=LIFETIME,
+ target_principal=TARGET_PRINCIPAL,
+ iam_endpoint_override=None,
+ ):
+
+ return Credentials(
+ source_credentials=source_credentials,
+ target_principal=target_principal,
+ target_scopes=self.TARGET_SCOPES,
+ delegates=self.DELEGATES,
+ lifetime=lifetime,
+ iam_endpoint_override=iam_endpoint_override,
+ )
+
+ def test_make_from_user_credentials(self):
+ credentials = self.make_credentials(
+ source_credentials=self.USER_SOURCE_CREDENTIALS
+ )
+ assert not credentials.valid
+ assert credentials.expired
+
+ def test_default_state(self):
+ credentials = self.make_credentials()
+ assert not credentials.valid
+ assert credentials.expired
+
+ def make_request(
+ self,
+ data,
+ status=http_client.OK,
+ headers=None,
+ side_effect=None,
+ use_data_bytes=True,
+ ):
+ response = mock.create_autospec(transport.Response, instance=False)
+ response.status = status
+ response.data = _helpers.to_bytes(data) if use_data_bytes else data
+ response.headers = headers or {}
+
+ request = mock.create_autospec(transport.Request, instance=False)
+ request.side_effect = side_effect
+ request.return_value = response
+
+ return request
+
+ @pytest.mark.parametrize("use_data_bytes", [True, False])
+ def test_refresh_success(self, use_data_bytes, mock_donor_credentials):
+ credentials = self.make_credentials(lifetime=None)
+ token = "token"
+
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
+ ).isoformat("T") + "Z"
+ response_body = {"accessToken": token, "expireTime": expire_time}
+
+ request = self.make_request(
+ data=json.dumps(response_body),
+ status=http_client.OK,
+ use_data_bytes=use_data_bytes,
+ )
+
+ credentials.refresh(request)
+
+ assert credentials.valid
+ assert not credentials.expired
+
+ @pytest.mark.parametrize("use_data_bytes", [True, False])
+ def test_refresh_success_iam_endpoint_override(
+ self, use_data_bytes, mock_donor_credentials
+ ):
+ credentials = self.make_credentials(
+ lifetime=None, iam_endpoint_override=self.IAM_ENDPOINT_OVERRIDE
+ )
+ token = "token"
+
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
+ ).isoformat("T") + "Z"
+ response_body = {"accessToken": token, "expireTime": expire_time}
+
+ request = self.make_request(
+ data=json.dumps(response_body),
+ status=http_client.OK,
+ use_data_bytes=use_data_bytes,
+ )
+
+ credentials.refresh(request)
+
+ assert credentials.valid
+ assert not credentials.expired
+ # Confirm override endpoint used.
+ request_kwargs = request.call_args[1]
+ assert request_kwargs["url"] == self.IAM_ENDPOINT_OVERRIDE
+
+ @pytest.mark.parametrize("time_skew", [100, -100])
+ def test_refresh_source_credentials(self, time_skew):
+ credentials = self.make_credentials(lifetime=None)
+
+ # Source credentials is refreshed only if it is expired within
+ # _helpers.REFRESH_THRESHOLD from now. We add a time_skew to the expiry, so
+ # source credentials is refreshed only if time_skew <= 0.
+ credentials._source_credentials.expiry = (
+ _helpers.utcnow()
+ + _helpers.REFRESH_THRESHOLD
+ + datetime.timedelta(seconds=time_skew)
+ )
+ credentials._source_credentials.token = "Token"
+
+ with mock.patch(
+ "google.oauth2.service_account.Credentials.refresh", autospec=True
+ ) as source_cred_refresh:
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0)
+ + datetime.timedelta(seconds=500)
+ ).isoformat("T") + "Z"
+ response_body = {"accessToken": "token", "expireTime": expire_time}
+ request = self.make_request(
+ data=json.dumps(response_body), status=http_client.OK
+ )
+
+ credentials.refresh(request)
+
+ assert credentials.valid
+ assert not credentials.expired
+
+ # Source credentials is refreshed only if it is expired within
+ # _helpers.REFRESH_THRESHOLD
+ if time_skew > 0:
+ source_cred_refresh.assert_not_called()
+ else:
+ source_cred_refresh.assert_called_once()
+
+ def test_refresh_failure_malformed_expire_time(self, mock_donor_credentials):
+ credentials = self.make_credentials(lifetime=None)
+ token = "token"
+
+ expire_time = (_helpers.utcnow() + datetime.timedelta(seconds=500)).isoformat(
+ "T"
+ )
+ response_body = {"accessToken": token, "expireTime": expire_time}
+
+ request = self.make_request(
+ data=json.dumps(response_body), status=http_client.OK
+ )
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ credentials.refresh(request)
+
+ assert excinfo.match(impersonated_credentials._REFRESH_ERROR)
+
+ assert not credentials.valid
+ assert credentials.expired
+
+ def test_refresh_failure_unauthorzed(self, mock_donor_credentials):
+ credentials = self.make_credentials(lifetime=None)
+
+ response_body = {
+ "error": {
+ "code": 403,
+ "message": "The caller does not have permission",
+ "status": "PERMISSION_DENIED",
+ }
+ }
+
+ request = self.make_request(
+ data=json.dumps(response_body), status=http_client.UNAUTHORIZED
+ )
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ credentials.refresh(request)
+
+ assert excinfo.match(impersonated_credentials._REFRESH_ERROR)
+
+ assert not credentials.valid
+ assert credentials.expired
+
+ def test_refresh_failure_http_error(self, mock_donor_credentials):
+ credentials = self.make_credentials(lifetime=None)
+
+ response_body = {}
+
+ request = self.make_request(
+ data=json.dumps(response_body), status=http_client.HTTPException
+ )
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ credentials.refresh(request)
+
+ assert excinfo.match(impersonated_credentials._REFRESH_ERROR)
+
+ assert not credentials.valid
+ assert credentials.expired
+
+ def test_expired(self):
+ credentials = self.make_credentials(lifetime=None)
+ assert credentials.expired
+
+ def test_signer(self):
+ credentials = self.make_credentials()
+ assert isinstance(credentials.signer, impersonated_credentials.Credentials)
+
+ def test_signer_email(self):
+ credentials = self.make_credentials(target_principal=self.TARGET_PRINCIPAL)
+ assert credentials.signer_email == self.TARGET_PRINCIPAL
+
+ def test_service_account_email(self):
+ credentials = self.make_credentials(target_principal=self.TARGET_PRINCIPAL)
+ assert credentials.service_account_email == self.TARGET_PRINCIPAL
+
+ def test_sign_bytes(self, mock_donor_credentials, mock_authorizedsession_sign):
+ credentials = self.make_credentials(lifetime=None)
+ token = "token"
+
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
+ ).isoformat("T") + "Z"
+ token_response_body = {"accessToken": token, "expireTime": expire_time}
+
+ response = mock.create_autospec(transport.Response, instance=False)
+ response.status = http_client.OK
+ response.data = _helpers.to_bytes(json.dumps(token_response_body))
+
+ request = mock.create_autospec(transport.Request, instance=False)
+ request.return_value = response
+
+ credentials.refresh(request)
+
+ assert credentials.valid
+ assert not credentials.expired
+
+ signature = credentials.sign_bytes(b"signed bytes")
+ assert signature == b"signature"
+
+ def test_sign_bytes_failure(self):
+ credentials = self.make_credentials(lifetime=None)
+
+ with mock.patch(
+ "google.auth.transport.requests.AuthorizedSession.request", autospec=True
+ ) as auth_session:
+ data = {"error": {"code": 403, "message": "unauthorized"}}
+ auth_session.return_value = MockResponse(data, http_client.FORBIDDEN)
+
+ with pytest.raises(exceptions.TransportError) as excinfo:
+ credentials.sign_bytes(b"foo")
+ assert excinfo.match("'code': 403")
+
+ def test_with_quota_project(self):
+ credentials = self.make_credentials()
+
+ quota_project_creds = credentials.with_quota_project("project-foo")
+ assert quota_project_creds._quota_project_id == "project-foo"
+
+ @pytest.mark.parametrize("use_data_bytes", [True, False])
+ def test_with_quota_project_iam_endpoint_override(
+ self, use_data_bytes, mock_donor_credentials
+ ):
+ credentials = self.make_credentials(
+ lifetime=None, iam_endpoint_override=self.IAM_ENDPOINT_OVERRIDE
+ )
+ token = "token"
+ # iam_endpoint_override should be copied to created credentials.
+ quota_project_creds = credentials.with_quota_project("project-foo")
+
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
+ ).isoformat("T") + "Z"
+ response_body = {"accessToken": token, "expireTime": expire_time}
+
+ request = self.make_request(
+ data=json.dumps(response_body),
+ status=http_client.OK,
+ use_data_bytes=use_data_bytes,
+ )
+
+ quota_project_creds.refresh(request)
+
+ assert quota_project_creds.valid
+ assert not quota_project_creds.expired
+ # Confirm override endpoint used.
+ request_kwargs = request.call_args[1]
+ assert request_kwargs["url"] == self.IAM_ENDPOINT_OVERRIDE
+
+ def test_id_token_success(
+ self, mock_donor_credentials, mock_authorizedsession_idtoken
+ ):
+ credentials = self.make_credentials(lifetime=None)
+ token = "token"
+ target_audience = "https://foo.bar"
+
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
+ ).isoformat("T") + "Z"
+ response_body = {"accessToken": token, "expireTime": expire_time}
+
+ request = self.make_request(
+ data=json.dumps(response_body), status=http_client.OK
+ )
+
+ credentials.refresh(request)
+
+ assert credentials.valid
+ assert not credentials.expired
+
+ id_creds = impersonated_credentials.IDTokenCredentials(
+ credentials, target_audience=target_audience
+ )
+ id_creds.refresh(request)
+
+ assert id_creds.token == ID_TOKEN_DATA
+ assert id_creds.expiry == datetime.datetime.fromtimestamp(ID_TOKEN_EXPIRY)
+
+ def test_id_token_from_credential(
+ self, mock_donor_credentials, mock_authorizedsession_idtoken
+ ):
+ credentials = self.make_credentials(lifetime=None)
+ token = "token"
+ target_audience = "https://foo.bar"
+
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
+ ).isoformat("T") + "Z"
+ response_body = {"accessToken": token, "expireTime": expire_time}
+
+ request = self.make_request(
+ data=json.dumps(response_body), status=http_client.OK
+ )
+
+ credentials.refresh(request)
+
+ assert credentials.valid
+ assert not credentials.expired
+
+ id_creds = impersonated_credentials.IDTokenCredentials(
+ credentials, target_audience=target_audience, include_email=True
+ )
+ id_creds = id_creds.from_credentials(target_credentials=credentials)
+ id_creds.refresh(request)
+
+ assert id_creds.token == ID_TOKEN_DATA
+ assert id_creds._include_email is True
+
+ def test_id_token_with_target_audience(
+ self, mock_donor_credentials, mock_authorizedsession_idtoken
+ ):
+ credentials = self.make_credentials(lifetime=None)
+ token = "token"
+ target_audience = "https://foo.bar"
+
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
+ ).isoformat("T") + "Z"
+ response_body = {"accessToken": token, "expireTime": expire_time}
+
+ request = self.make_request(
+ data=json.dumps(response_body), status=http_client.OK
+ )
+
+ credentials.refresh(request)
+
+ assert credentials.valid
+ assert not credentials.expired
+
+ id_creds = impersonated_credentials.IDTokenCredentials(
+ credentials, include_email=True
+ )
+ id_creds = id_creds.with_target_audience(target_audience=target_audience)
+ id_creds.refresh(request)
+
+ assert id_creds.token == ID_TOKEN_DATA
+ assert id_creds.expiry == datetime.datetime.fromtimestamp(ID_TOKEN_EXPIRY)
+ assert id_creds._include_email is True
+
+ def test_id_token_invalid_cred(
+ self, mock_donor_credentials, mock_authorizedsession_idtoken
+ ):
+ credentials = None
+
+ with pytest.raises(exceptions.GoogleAuthError) as excinfo:
+ impersonated_credentials.IDTokenCredentials(credentials)
+
+ assert excinfo.match("Provided Credential must be" " impersonated_credentials")
+
+ def test_id_token_with_include_email(
+ self, mock_donor_credentials, mock_authorizedsession_idtoken
+ ):
+ credentials = self.make_credentials(lifetime=None)
+ token = "token"
+ target_audience = "https://foo.bar"
+
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
+ ).isoformat("T") + "Z"
+ response_body = {"accessToken": token, "expireTime": expire_time}
+
+ request = self.make_request(
+ data=json.dumps(response_body), status=http_client.OK
+ )
+
+ credentials.refresh(request)
+
+ assert credentials.valid
+ assert not credentials.expired
+
+ id_creds = impersonated_credentials.IDTokenCredentials(
+ credentials, target_audience=target_audience
+ )
+ id_creds = id_creds.with_include_email(True)
+ id_creds.refresh(request)
+
+ assert id_creds.token == ID_TOKEN_DATA
+
+ def test_id_token_with_quota_project(
+ self, mock_donor_credentials, mock_authorizedsession_idtoken
+ ):
+ credentials = self.make_credentials(lifetime=None)
+ token = "token"
+ target_audience = "https://foo.bar"
+
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
+ ).isoformat("T") + "Z"
+ response_body = {"accessToken": token, "expireTime": expire_time}
+
+ request = self.make_request(
+ data=json.dumps(response_body), status=http_client.OK
+ )
+
+ credentials.refresh(request)
+
+ assert credentials.valid
+ assert not credentials.expired
+
+ id_creds = impersonated_credentials.IDTokenCredentials(
+ credentials, target_audience=target_audience
+ )
+ id_creds = id_creds.with_quota_project("project-foo")
+ id_creds.refresh(request)
+
+ assert id_creds.quota_project_id == "project-foo"