aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorbojeil-google <bojeil-google@users.noreply.github.com>2021-07-22 10:01:31 -0700
committerGitHub <noreply@github.com>2021-07-22 13:01:31 -0400
commitec2fb18e7f0f452fb20e43fd0bfbb788bcf7f46b (patch)
treeb093f3dd59c20c834da46e694b7bbca64ad7d9cb
parent63ac08afe44c3da51a6fb3bfc2fcdb6588e040ef (diff)
downloadgoogle-auth-library-python-ec2fb18e7f0f452fb20e43fd0bfbb788bcf7f46b.tar.gz
feat: support refresh callable on google.oauth2.credentials.Credentials (#812)
This is an optional parameter that can be set via the constructor. It is used to provide the credentials with new tokens and their expiration time on `refresh()` call. ``` def refresh_handler(request, scopes): # Generate a new token for the requested scopes by calling # an external process. return ( "ACCESS_TOKEN", _helpers.utcnow() + datetime.timedelta(seconds=3600)) creds = google.oauth2.credentials.Credentials( scopes=scopes, refresh_handler=refresh_handler) creds.refresh(request) ``` It is useful in the following cases: - Useful in general when tokens are obtained by calling some external process on demand. - Useful in particular for retrieving downscoped tokens from a token broker. This should have no impact on existing behavior. Refresh tokens will still have higher priority over refresh handlers. A getter and setter is exposed to make it easy to set the callable on unpickled credentials as the callable may not be easily serialized. ``` unpickled = pickle.loads(pickle.dumps(oauth_creds)) unpickled.refresh_handler = refresh_handler ```
-rw-r--r--google/oauth2/credentials.py71
-rw-r--r--tests/oauth2/test_credentials.py285
2 files changed, 353 insertions, 3 deletions
diff --git a/google/oauth2/credentials.py b/google/oauth2/credentials.py
index dcfa5f9..158249e 100644
--- a/google/oauth2/credentials.py
+++ b/google/oauth2/credentials.py
@@ -74,6 +74,7 @@ class Credentials(credentials.ReadOnlyScoped, credentials.CredentialsWithQuotaPr
quota_project_id=None,
expiry=None,
rapt_token=None,
+ refresh_handler=None,
):
"""
Args:
@@ -103,6 +104,13 @@ class Credentials(credentials.ReadOnlyScoped, credentials.CredentialsWithQuotaPr
This project may be different from the project used to
create the credentials.
rapt_token (Optional[str]): The reauth Proof Token.
+ refresh_handler (Optional[Callable[[google.auth.transport.Request, Sequence[str]], [str, datetime]]]):
+ A callable which takes in the HTTP request callable and the list of
+ OAuth scopes and when called returns an access token string for the
+ requested scopes and its expiry datetime. This is useful when no
+ refresh tokens are provided and tokens are obtained by calling
+ some external process on demand. It is particularly useful for
+ retrieving downscoped tokens from a token broker.
"""
super(Credentials, self).__init__()
self.token = token
@@ -116,13 +124,20 @@ class Credentials(credentials.ReadOnlyScoped, credentials.CredentialsWithQuotaPr
self._client_secret = client_secret
self._quota_project_id = quota_project_id
self._rapt_token = rapt_token
+ self.refresh_handler = refresh_handler
def __getstate__(self):
"""A __getstate__ method must exist for the __setstate__ to be called
This is identical to the default implementation.
See https://docs.python.org/3.7/library/pickle.html#object.__setstate__
"""
- return self.__dict__
+ state_dict = self.__dict__.copy()
+ # Remove _refresh_handler function as there are limitations pickling and
+ # unpickling certain callables (lambda, functools.partial instances)
+ # because they need to be importable.
+ # Instead, the refresh_handler setter should be used to repopulate this.
+ del state_dict["_refresh_handler"]
+ return state_dict
def __setstate__(self, d):
"""Credentials pickled with older versions of the class do not have
@@ -138,6 +153,8 @@ class Credentials(credentials.ReadOnlyScoped, credentials.CredentialsWithQuotaPr
self._client_secret = d.get("_client_secret")
self._quota_project_id = d.get("_quota_project_id")
self._rapt_token = d.get("_rapt_token")
+ # The refresh_handler setter should be used to repopulate this.
+ self._refresh_handler = None
@property
def refresh_token(self):
@@ -187,6 +204,31 @@ class Credentials(credentials.ReadOnlyScoped, credentials.CredentialsWithQuotaPr
"""Optional[str]: The reauth Proof Token."""
return self._rapt_token
+ @property
+ def refresh_handler(self):
+ """Returns the refresh handler if available.
+
+ Returns:
+ Optional[Callable[[google.auth.transport.Request, Sequence[str]], [str, datetime]]]:
+ The current refresh handler.
+ """
+ return self._refresh_handler
+
+ @refresh_handler.setter
+ def refresh_handler(self, value):
+ """Updates the current refresh handler.
+
+ Args:
+ value (Optional[Callable[[google.auth.transport.Request, Sequence[str]], [str, datetime]]]):
+ The updated value of the refresh handler.
+
+ Raises:
+ TypeError: If the value is not a callable or None.
+ """
+ if not callable(value) and value is not None:
+ raise TypeError("The provided refresh_handler is not a callable or None.")
+ self._refresh_handler = value
+
@_helpers.copy_docstring(credentials.CredentialsWithQuotaProject)
def with_quota_project(self, quota_project_id):
@@ -205,6 +247,31 @@ class Credentials(credentials.ReadOnlyScoped, credentials.CredentialsWithQuotaPr
@_helpers.copy_docstring(credentials.Credentials)
def refresh(self, request):
+ scopes = self._scopes if self._scopes is not None else self._default_scopes
+ # Use refresh handler if available and no refresh token is
+ # available. This is useful in general when tokens are obtained by calling
+ # some external process on demand. It is particularly useful for retrieving
+ # downscoped tokens from a token broker.
+ if self._refresh_token is None and self.refresh_handler:
+ token, expiry = self.refresh_handler(request, scopes=scopes)
+ # Validate returned data.
+ if not isinstance(token, str):
+ raise exceptions.RefreshError(
+ "The refresh_handler returned token is not a string."
+ )
+ if not isinstance(expiry, datetime):
+ raise exceptions.RefreshError(
+ "The refresh_handler returned expiry is not a datetime object."
+ )
+ if _helpers.utcnow() >= expiry - _helpers.CLOCK_SKEW:
+ raise exceptions.RefreshError(
+ "The credentials returned by the refresh_handler are "
+ "already expired."
+ )
+ self.token = token
+ self.expiry = expiry
+ return
+
if (
self._refresh_token is None
or self._token_uri is None
@@ -217,8 +284,6 @@ class Credentials(credentials.ReadOnlyScoped, credentials.CredentialsWithQuotaPr
"token_uri, client_id, and client_secret."
)
- scopes = self._scopes if self._scopes is not None else self._default_scopes
-
(
access_token,
refresh_token,
diff --git a/tests/oauth2/test_credentials.py b/tests/oauth2/test_credentials.py
index 4a387a5..4a7f66e 100644
--- a/tests/oauth2/test_credentials.py
+++ b/tests/oauth2/test_credentials.py
@@ -66,6 +66,50 @@ class TestCredentials(object):
assert credentials.client_id == self.CLIENT_ID
assert credentials.client_secret == self.CLIENT_SECRET
assert credentials.rapt_token == self.RAPT_TOKEN
+ assert credentials.refresh_handler is None
+
+ def test_refresh_handler_setter_and_getter(self):
+ scopes = ["email", "profile"]
+ original_refresh_handler = mock.Mock(return_value=("ACCESS_TOKEN_1", None))
+ updated_refresh_handler = mock.Mock(return_value=("ACCESS_TOKEN_2", None))
+ creds = credentials.Credentials(
+ token=None,
+ refresh_token=None,
+ token_uri=None,
+ client_id=None,
+ client_secret=None,
+ rapt_token=None,
+ scopes=scopes,
+ default_scopes=None,
+ refresh_handler=original_refresh_handler,
+ )
+
+ assert creds.refresh_handler is original_refresh_handler
+
+ creds.refresh_handler = updated_refresh_handler
+
+ assert creds.refresh_handler is updated_refresh_handler
+
+ creds.refresh_handler = None
+
+ assert creds.refresh_handler is None
+
+ def test_invalid_refresh_handler(self):
+ scopes = ["email", "profile"]
+ with pytest.raises(TypeError) as excinfo:
+ credentials.Credentials(
+ token=None,
+ refresh_token=None,
+ token_uri=None,
+ client_id=None,
+ client_secret=None,
+ rapt_token=None,
+ scopes=scopes,
+ default_scopes=None,
+ refresh_handler=object(),
+ )
+
+ assert excinfo.match("The provided refresh_handler is not a callable or None.")
@mock.patch("google.oauth2.reauth.refresh_grant", autospec=True)
@mock.patch(
@@ -131,6 +175,221 @@ class TestCredentials(object):
"google.auth._helpers.utcnow",
return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
)
+ def test_refresh_with_refresh_token_and_refresh_handler(
+ self, unused_utcnow, refresh_grant
+ ):
+ token = "token"
+ new_rapt_token = "new_rapt_token"
+ expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
+ grant_response = {"id_token": mock.sentinel.id_token}
+ refresh_grant.return_value = (
+ # Access token
+ token,
+ # New refresh token
+ None,
+ # Expiry,
+ expiry,
+ # Extra data
+ grant_response,
+ # rapt_token
+ new_rapt_token,
+ )
+
+ refresh_handler = mock.Mock()
+ request = mock.create_autospec(transport.Request)
+ creds = credentials.Credentials(
+ token=None,
+ refresh_token=self.REFRESH_TOKEN,
+ token_uri=self.TOKEN_URI,
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ rapt_token=self.RAPT_TOKEN,
+ refresh_handler=refresh_handler,
+ )
+
+ # Refresh credentials
+ creds.refresh(request)
+
+ # Check jwt grant call.
+ refresh_grant.assert_called_with(
+ request,
+ self.TOKEN_URI,
+ self.REFRESH_TOKEN,
+ self.CLIENT_ID,
+ self.CLIENT_SECRET,
+ None,
+ self.RAPT_TOKEN,
+ )
+
+ # Check that the credentials have the token and expiry
+ assert creds.token == token
+ assert creds.expiry == expiry
+ assert creds.id_token == mock.sentinel.id_token
+ assert creds.rapt_token == new_rapt_token
+
+ # Check that the credentials are valid (have a token and are not
+ # expired)
+ assert creds.valid
+
+ # Assert refresh handler not called as the refresh token has
+ # higher priority.
+ refresh_handler.assert_not_called()
+
+ @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
+ def test_refresh_with_refresh_handler_success_scopes(self, unused_utcnow):
+ expected_expiry = datetime.datetime.min + datetime.timedelta(seconds=2800)
+ refresh_handler = mock.Mock(return_value=("ACCESS_TOKEN", expected_expiry))
+ scopes = ["email", "profile"]
+ default_scopes = ["https://www.googleapis.com/auth/cloud-platform"]
+ request = mock.create_autospec(transport.Request)
+ creds = credentials.Credentials(
+ token=None,
+ refresh_token=None,
+ token_uri=None,
+ client_id=None,
+ client_secret=None,
+ rapt_token=None,
+ scopes=scopes,
+ default_scopes=default_scopes,
+ refresh_handler=refresh_handler,
+ )
+
+ creds.refresh(request)
+
+ assert creds.token == "ACCESS_TOKEN"
+ assert creds.expiry == expected_expiry
+ assert creds.valid
+ assert not creds.expired
+ # Confirm refresh handler called with the expected arguments.
+ refresh_handler.assert_called_with(request, scopes=scopes)
+
+ @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
+ def test_refresh_with_refresh_handler_success_default_scopes(self, unused_utcnow):
+ expected_expiry = datetime.datetime.min + datetime.timedelta(seconds=2800)
+ original_refresh_handler = mock.Mock(
+ return_value=("UNUSED_TOKEN", expected_expiry)
+ )
+ refresh_handler = mock.Mock(return_value=("ACCESS_TOKEN", expected_expiry))
+ default_scopes = ["https://www.googleapis.com/auth/cloud-platform"]
+ request = mock.create_autospec(transport.Request)
+ creds = credentials.Credentials(
+ token=None,
+ refresh_token=None,
+ token_uri=None,
+ client_id=None,
+ client_secret=None,
+ rapt_token=None,
+ scopes=None,
+ default_scopes=default_scopes,
+ refresh_handler=original_refresh_handler,
+ )
+
+ # Test newly set refresh_handler is used instead of the original one.
+ creds.refresh_handler = refresh_handler
+ creds.refresh(request)
+
+ assert creds.token == "ACCESS_TOKEN"
+ assert creds.expiry == expected_expiry
+ assert creds.valid
+ assert not creds.expired
+ # default_scopes should be used since no developer provided scopes
+ # are provided.
+ refresh_handler.assert_called_with(request, scopes=default_scopes)
+
+ @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
+ def test_refresh_with_refresh_handler_invalid_token(self, unused_utcnow):
+ expected_expiry = datetime.datetime.min + datetime.timedelta(seconds=2800)
+ # Simulate refresh handler does not return a valid token.
+ refresh_handler = mock.Mock(return_value=(None, expected_expiry))
+ scopes = ["email", "profile"]
+ default_scopes = ["https://www.googleapis.com/auth/cloud-platform"]
+ request = mock.create_autospec(transport.Request)
+ creds = credentials.Credentials(
+ token=None,
+ refresh_token=None,
+ token_uri=None,
+ client_id=None,
+ client_secret=None,
+ rapt_token=None,
+ scopes=scopes,
+ default_scopes=default_scopes,
+ refresh_handler=refresh_handler,
+ )
+
+ with pytest.raises(
+ exceptions.RefreshError, match="returned token is not a string"
+ ):
+ creds.refresh(request)
+
+ assert creds.token is None
+ assert creds.expiry is None
+ assert not creds.valid
+ # Confirm refresh handler called with the expected arguments.
+ refresh_handler.assert_called_with(request, scopes=scopes)
+
+ def test_refresh_with_refresh_handler_invalid_expiry(self):
+ # Simulate refresh handler returns expiration time in an invalid unit.
+ refresh_handler = mock.Mock(return_value=("TOKEN", 2800))
+ scopes = ["email", "profile"]
+ default_scopes = ["https://www.googleapis.com/auth/cloud-platform"]
+ request = mock.create_autospec(transport.Request)
+ creds = credentials.Credentials(
+ token=None,
+ refresh_token=None,
+ token_uri=None,
+ client_id=None,
+ client_secret=None,
+ rapt_token=None,
+ scopes=scopes,
+ default_scopes=default_scopes,
+ refresh_handler=refresh_handler,
+ )
+
+ with pytest.raises(
+ exceptions.RefreshError, match="returned expiry is not a datetime object"
+ ):
+ creds.refresh(request)
+
+ assert creds.token is None
+ assert creds.expiry is None
+ assert not creds.valid
+ # Confirm refresh handler called with the expected arguments.
+ refresh_handler.assert_called_with(request, scopes=scopes)
+
+ @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
+ def test_refresh_with_refresh_handler_expired_token(self, unused_utcnow):
+ expected_expiry = datetime.datetime.min + _helpers.CLOCK_SKEW
+ # Simulate refresh handler returns an expired token.
+ refresh_handler = mock.Mock(return_value=("TOKEN", expected_expiry))
+ scopes = ["email", "profile"]
+ default_scopes = ["https://www.googleapis.com/auth/cloud-platform"]
+ request = mock.create_autospec(transport.Request)
+ creds = credentials.Credentials(
+ token=None,
+ refresh_token=None,
+ token_uri=None,
+ client_id=None,
+ client_secret=None,
+ rapt_token=None,
+ scopes=scopes,
+ default_scopes=default_scopes,
+ refresh_handler=refresh_handler,
+ )
+
+ with pytest.raises(exceptions.RefreshError, match="already expired"):
+ creds.refresh(request)
+
+ assert creds.token is None
+ assert creds.expiry is None
+ assert not creds.valid
+ # Confirm refresh handler called with the expected arguments.
+ refresh_handler.assert_called_with(request, scopes=scopes)
+
+ @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True)
+ @mock.patch(
+ "google.auth._helpers.utcnow",
+ return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
+ )
def test_credentials_with_scopes_requested_refresh_success(
self, unused_utcnow, refresh_grant
):
@@ -527,6 +786,32 @@ class TestCredentials(object):
for attr in list(creds.__dict__):
assert getattr(creds, attr) == getattr(unpickled, attr)
+ def test_pickle_and_unpickle_with_refresh_handler(self):
+ expected_expiry = _helpers.utcnow() + datetime.timedelta(seconds=2800)
+ refresh_handler = mock.Mock(return_value=("TOKEN", expected_expiry))
+
+ creds = credentials.Credentials(
+ token=None,
+ refresh_token=None,
+ token_uri=None,
+ client_id=None,
+ client_secret=None,
+ rapt_token=None,
+ refresh_handler=refresh_handler,
+ )
+ unpickled = pickle.loads(pickle.dumps(creds))
+
+ # make sure attributes aren't lost during pickling
+ assert list(creds.__dict__).sort() == list(unpickled.__dict__).sort()
+
+ for attr in list(creds.__dict__):
+ # For the _refresh_handler property, the unpickled creds should be
+ # set to None.
+ if attr == "_refresh_handler":
+ assert getattr(unpickled, attr) is None
+ else:
+ assert getattr(creds, attr) == getattr(unpickled, attr)
+
def test_pickle_with_missing_attribute(self):
creds = self.make_credentials()