aboutsummaryrefslogtreecommitdiff
path: root/google/auth/external_account.py
diff options
context:
space:
mode:
Diffstat (limited to 'google/auth/external_account.py')
-rw-r--r--google/auth/external_account.py67
1 files changed, 57 insertions, 10 deletions
diff --git a/google/auth/external_account.py b/google/auth/external_account.py
index 24b93b4..f588981 100644
--- a/google/auth/external_account.py
+++ b/google/auth/external_account.py
@@ -73,6 +73,7 @@ class Credentials(
quota_project_id=None,
scopes=None,
default_scopes=None,
+ workforce_pool_user_project=None,
):
"""Instantiates an external account credentials object.
@@ -90,6 +91,11 @@ class Credentials(
authorization grant.
default_scopes (Optional[Sequence[str]]): Default scopes passed by a
Google client library. Use 'scopes' for user-defined scopes.
+ workforce_pool_user_project (Optona[str]): The optional workforce pool user
+ project number when the credential corresponds to a workforce pool and not
+ a workload identity pool. The underlying principal must still have
+ serviceusage.services.use IAM permission to use the project for
+ billing/quota.
Raises:
google.auth.exceptions.RefreshError: If the generateAccessToken
endpoint returned an error.
@@ -105,6 +111,7 @@ class Credentials(
self._quota_project_id = quota_project_id
self._scopes = scopes
self._default_scopes = default_scopes
+ self._workforce_pool_user_project = workforce_pool_user_project
if self._client_id:
self._client_auth = utils.ClientAuthentication(
@@ -120,6 +127,13 @@ class Credentials(
self._impersonated_credentials = None
self._project_id = None
+ if not self.is_workforce_pool and self._workforce_pool_user_project:
+ # Workload identity pools do not support workforce pool user projects.
+ raise ValueError(
+ "workforce_pool_user_project should not be set for non-workforce pool "
+ "credentials"
+ )
+
@property
def info(self):
"""Generates the dictionary representation of the current credentials.
@@ -140,6 +154,7 @@ class Credentials(
"quota_project_id": self._quota_project_id,
"client_id": self._client_id,
"client_secret": self._client_secret,
+ "workforce_pool_user_project": self._workforce_pool_user_project,
}
return {key: value for key, value in config_info.items() if value is not None}
@@ -178,12 +193,23 @@ class Credentials(
# service account.
if self._service_account_impersonation_url:
return False
+ return self.is_workforce_pool
+
+ @property
+ def is_workforce_pool(self):
+ """Returns whether the credentials represent a workforce pool (True) or
+ workload (False) based on the credentials' audience.
+
+ This will also return True for impersonated workforce pool credentials.
+
+ Returns:
+ bool: True if the credentials represent a workforce pool. False if they
+ represent a workload.
+ """
# Workforce pools representing users have the following audience format:
# //iam.googleapis.com/locations/$location/workforcePools/$poolId/providers/$providerId
p = re.compile(r"//iam\.googleapis\.com/locations/[^/]+/workforcePools/")
- if p.match(self._audience):
- return True
- return False
+ return p.match(self._audience or "") is not None
@property
def requires_scopes(self):
@@ -210,7 +236,7 @@ class Credentials(
@_helpers.copy_docstring(credentials.Scoped)
def with_scopes(self, scopes, default_scopes=None):
- return self.__class__(
+ d = dict(
audience=self._audience,
subject_token_type=self._subject_token_type,
token_url=self._token_url,
@@ -221,7 +247,11 @@ class Credentials(
quota_project_id=self._quota_project_id,
scopes=scopes,
default_scopes=default_scopes,
+ workforce_pool_user_project=self._workforce_pool_user_project,
)
+ if not self.is_workforce_pool:
+ d.pop("workforce_pool_user_project")
+ return self.__class__(**d)
@abc.abstractmethod
def retrieve_subject_token(self, request):
@@ -238,7 +268,9 @@ class Credentials(
raise NotImplementedError("retrieve_subject_token must be implemented")
def get_project_id(self, request):
- """Retrieves the project ID corresponding to the workload identity pool.
+ """Retrieves the project ID corresponding to the workload identity or workforce pool.
+ For workforce pool credentials, it returns the project ID corresponding to
+ the workforce_pool_user_project.
When not determinable, None is returned.
@@ -255,16 +287,17 @@ class Credentials(
HTTP requests.
Returns:
Optional[str]: The project ID corresponding to the workload identity pool
- if determinable.
+ or workforce pool if determinable.
"""
if self._project_id:
# If already retrieved, return the cached project ID value.
return self._project_id
scopes = self._scopes if self._scopes is not None else self._default_scopes
# Scopes are required in order to retrieve a valid access token.
- if self.project_number and scopes:
+ project_number = self.project_number or self._workforce_pool_user_project
+ if project_number and scopes:
headers = {}
- url = _CLOUD_RESOURCE_MANAGER + self.project_number
+ url = _CLOUD_RESOURCE_MANAGER + project_number
self.before_request(request, "GET", url, headers)
response = request(url=url, method="GET", headers=headers)
@@ -291,6 +324,11 @@ class Credentials(
self.expiry = self._impersonated_credentials.expiry
else:
now = _helpers.utcnow()
+ additional_options = None
+ # Do not pass workforce_pool_user_project when client authentication
+ # is used. The client ID is sufficient for determining the user project.
+ if self._workforce_pool_user_project and not self._client_id:
+ additional_options = {"userProject": self._workforce_pool_user_project}
response_data = self._sts_client.exchange_token(
request=request,
grant_type=_STS_GRANT_TYPE,
@@ -299,6 +337,7 @@ class Credentials(
audience=self._audience,
scopes=scopes,
requested_token_type=_STS_REQUESTED_TOKEN_TYPE,
+ additional_options=additional_options,
)
self.token = response_data.get("access_token")
lifetime = datetime.timedelta(seconds=response_data.get("expires_in"))
@@ -307,7 +346,7 @@ class Credentials(
@_helpers.copy_docstring(credentials.CredentialsWithQuotaProject)
def with_quota_project(self, quota_project_id):
# Return copy of instance with the provided quota project ID.
- return self.__class__(
+ d = dict(
audience=self._audience,
subject_token_type=self._subject_token_type,
token_url=self._token_url,
@@ -318,7 +357,11 @@ class Credentials(
quota_project_id=quota_project_id,
scopes=self._scopes,
default_scopes=self._default_scopes,
+ workforce_pool_user_project=self._workforce_pool_user_project,
)
+ if not self.is_workforce_pool:
+ d.pop("workforce_pool_user_project")
+ return self.__class__(**d)
def _initialize_impersonated_credentials(self):
"""Generates an impersonated credentials.
@@ -336,7 +379,7 @@ class Credentials(
endpoint returned an error.
"""
# Return copy of instance with no service account impersonation.
- source_credentials = self.__class__(
+ d = dict(
audience=self._audience,
subject_token_type=self._subject_token_type,
token_url=self._token_url,
@@ -347,7 +390,11 @@ class Credentials(
quota_project_id=self._quota_project_id,
scopes=self._scopes,
default_scopes=self._default_scopes,
+ workforce_pool_user_project=self._workforce_pool_user_project,
)
+ if not self.is_workforce_pool:
+ d.pop("workforce_pool_user_project")
+ source_credentials = self.__class__(**d)
# Determine target_principal.
target_principal = self.service_account_email