diff options
author | bojeil-google <bojeil-google@users.noreply.github.com> | 2021-07-20 10:43:13 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-07-20 10:43:13 -0700 |
commit | dfad66128c6ee7513e5565d39bc7b002055dd0d5 (patch) | |
tree | 5de987d030d79bb69ca9d964ab0d11c3835a860e /tests/test_downscoped.py | |
parent | df9f2f9e9ad0797c4f708c9c8c6da382af7910f3 (diff) | |
download | google-auth-library-python-dfad66128c6ee7513e5565d39bc7b002055dd0d5.tar.gz |
fix: fallback to source creds expiration in downscoped tokens (#805)
For downscoping CAB flow, the STS endpoint may not return the expiration
field for certain source credentials. The generated downscoped token
should always have the same expiration time as the source credentials.
When no `expires_in` field is returned in the response, we can just get
the expiration time from the source credentials.
Co-authored-by: arithmetic1728 <58957152+arithmetic1728@users.noreply.github.com>
Diffstat (limited to 'tests/test_downscoped.py')
-rw-r--r-- | tests/test_downscoped.py | 46 |
1 files changed, 44 insertions, 2 deletions
diff --git a/tests/test_downscoped.py b/tests/test_downscoped.py index ac60e5b..795ec29 100644 --- a/tests/test_downscoped.py +++ b/tests/test_downscoped.py @@ -80,10 +80,11 @@ CREDENTIAL_ACCESS_BOUNDARY_JSON = { class SourceCredentials(credentials.Credentials): - def __init__(self, raise_error=False): + def __init__(self, raise_error=False, expires_in=3600): super(SourceCredentials, self).__init__() self._counter = 0 self._raise_error = raise_error + self._expires_in = expires_in def refresh(self, request): if self._raise_error: @@ -93,7 +94,7 @@ class SourceCredentials(credentials.Credentials): now = _helpers.utcnow() self._counter += 1 self.token = "ACCESS_TOKEN_{}".format(self._counter) - self.expiry = now + datetime.timedelta(seconds=3600) + self.expiry = now + datetime.timedelta(seconds=self._expires_in) def make_availability_condition(expression, title=None, description=None): @@ -539,6 +540,47 @@ class TestCredentials(object): # Confirm source credentials called with the same request instance. wrapped_souce_cred_refresh.assert_called_with(request) + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + def test_refresh_without_response_expires_in(self, unused_utcnow): + response = SUCCESS_RESPONSE.copy() + # Simulate the response is missing the expires_in field. + # The downscoped token expiration should match the source credentials + # expiration. + del response["expires_in"] + expected_expires_in = 1800 + # Simulate the source credentials generates a token with 1800 second + # expiration time. The generated downscoped token should have the same + # expiration time. + source_credentials = SourceCredentials(expires_in=expected_expires_in) + expected_expiry = datetime.datetime.min + datetime.timedelta( + seconds=expected_expires_in + ) + headers = {"Content-Type": "application/x-www-form-urlencoded"} + request_data = { + "grant_type": GRANT_TYPE, + "subject_token": "ACCESS_TOKEN_1", + "subject_token_type": SUBJECT_TOKEN_TYPE, + "requested_token_type": REQUESTED_TOKEN_TYPE, + "options": urllib.parse.quote(json.dumps(CREDENTIAL_ACCESS_BOUNDARY_JSON)), + } + request = self.make_mock_request(status=http_client.OK, data=response) + credentials = self.make_credentials(source_credentials=source_credentials) + + # Spy on calls to source credentials refresh to confirm the expected request + # instance is used. + with mock.patch.object( + source_credentials, "refresh", wraps=source_credentials.refresh + ) as wrapped_souce_cred_refresh: + credentials.refresh(request) + + self.assert_request_kwargs(request.call_args[1], headers, request_data) + assert credentials.valid + assert credentials.expiry == expected_expiry + assert not credentials.expired + assert credentials.token == response["access_token"] + # Confirm source credentials called with the same request instance. + wrapped_souce_cred_refresh.assert_called_with(request) + def test_refresh_token_exchange_error(self): request = self.make_mock_request( status=http_client.BAD_REQUEST, data=ERROR_RESPONSE |