diff options
Diffstat (limited to 'google/auth/external_account.py')
-rw-r--r-- | google/auth/external_account.py | 67 |
1 files changed, 57 insertions, 10 deletions
diff --git a/google/auth/external_account.py b/google/auth/external_account.py index 24b93b4..f588981 100644 --- a/google/auth/external_account.py +++ b/google/auth/external_account.py @@ -73,6 +73,7 @@ class Credentials( quota_project_id=None, scopes=None, default_scopes=None, + workforce_pool_user_project=None, ): """Instantiates an external account credentials object. @@ -90,6 +91,11 @@ class Credentials( authorization grant. default_scopes (Optional[Sequence[str]]): Default scopes passed by a Google client library. Use 'scopes' for user-defined scopes. + workforce_pool_user_project (Optona[str]): The optional workforce pool user + project number when the credential corresponds to a workforce pool and not + a workload identity pool. The underlying principal must still have + serviceusage.services.use IAM permission to use the project for + billing/quota. Raises: google.auth.exceptions.RefreshError: If the generateAccessToken endpoint returned an error. @@ -105,6 +111,7 @@ class Credentials( self._quota_project_id = quota_project_id self._scopes = scopes self._default_scopes = default_scopes + self._workforce_pool_user_project = workforce_pool_user_project if self._client_id: self._client_auth = utils.ClientAuthentication( @@ -120,6 +127,13 @@ class Credentials( self._impersonated_credentials = None self._project_id = None + if not self.is_workforce_pool and self._workforce_pool_user_project: + # Workload identity pools do not support workforce pool user projects. + raise ValueError( + "workforce_pool_user_project should not be set for non-workforce pool " + "credentials" + ) + @property def info(self): """Generates the dictionary representation of the current credentials. @@ -140,6 +154,7 @@ class Credentials( "quota_project_id": self._quota_project_id, "client_id": self._client_id, "client_secret": self._client_secret, + "workforce_pool_user_project": self._workforce_pool_user_project, } return {key: value for key, value in config_info.items() if value is not None} @@ -178,12 +193,23 @@ class Credentials( # service account. if self._service_account_impersonation_url: return False + return self.is_workforce_pool + + @property + def is_workforce_pool(self): + """Returns whether the credentials represent a workforce pool (True) or + workload (False) based on the credentials' audience. + + This will also return True for impersonated workforce pool credentials. + + Returns: + bool: True if the credentials represent a workforce pool. False if they + represent a workload. + """ # Workforce pools representing users have the following audience format: # //iam.googleapis.com/locations/$location/workforcePools/$poolId/providers/$providerId p = re.compile(r"//iam\.googleapis\.com/locations/[^/]+/workforcePools/") - if p.match(self._audience): - return True - return False + return p.match(self._audience or "") is not None @property def requires_scopes(self): @@ -210,7 +236,7 @@ class Credentials( @_helpers.copy_docstring(credentials.Scoped) def with_scopes(self, scopes, default_scopes=None): - return self.__class__( + d = dict( audience=self._audience, subject_token_type=self._subject_token_type, token_url=self._token_url, @@ -221,7 +247,11 @@ class Credentials( quota_project_id=self._quota_project_id, scopes=scopes, default_scopes=default_scopes, + workforce_pool_user_project=self._workforce_pool_user_project, ) + if not self.is_workforce_pool: + d.pop("workforce_pool_user_project") + return self.__class__(**d) @abc.abstractmethod def retrieve_subject_token(self, request): @@ -238,7 +268,9 @@ class Credentials( raise NotImplementedError("retrieve_subject_token must be implemented") def get_project_id(self, request): - """Retrieves the project ID corresponding to the workload identity pool. + """Retrieves the project ID corresponding to the workload identity or workforce pool. + For workforce pool credentials, it returns the project ID corresponding to + the workforce_pool_user_project. When not determinable, None is returned. @@ -255,16 +287,17 @@ class Credentials( HTTP requests. Returns: Optional[str]: The project ID corresponding to the workload identity pool - if determinable. + or workforce pool if determinable. """ if self._project_id: # If already retrieved, return the cached project ID value. return self._project_id scopes = self._scopes if self._scopes is not None else self._default_scopes # Scopes are required in order to retrieve a valid access token. - if self.project_number and scopes: + project_number = self.project_number or self._workforce_pool_user_project + if project_number and scopes: headers = {} - url = _CLOUD_RESOURCE_MANAGER + self.project_number + url = _CLOUD_RESOURCE_MANAGER + project_number self.before_request(request, "GET", url, headers) response = request(url=url, method="GET", headers=headers) @@ -291,6 +324,11 @@ class Credentials( self.expiry = self._impersonated_credentials.expiry else: now = _helpers.utcnow() + additional_options = None + # Do not pass workforce_pool_user_project when client authentication + # is used. The client ID is sufficient for determining the user project. + if self._workforce_pool_user_project and not self._client_id: + additional_options = {"userProject": self._workforce_pool_user_project} response_data = self._sts_client.exchange_token( request=request, grant_type=_STS_GRANT_TYPE, @@ -299,6 +337,7 @@ class Credentials( audience=self._audience, scopes=scopes, requested_token_type=_STS_REQUESTED_TOKEN_TYPE, + additional_options=additional_options, ) self.token = response_data.get("access_token") lifetime = datetime.timedelta(seconds=response_data.get("expires_in")) @@ -307,7 +346,7 @@ class Credentials( @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) def with_quota_project(self, quota_project_id): # Return copy of instance with the provided quota project ID. - return self.__class__( + d = dict( audience=self._audience, subject_token_type=self._subject_token_type, token_url=self._token_url, @@ -318,7 +357,11 @@ class Credentials( quota_project_id=quota_project_id, scopes=self._scopes, default_scopes=self._default_scopes, + workforce_pool_user_project=self._workforce_pool_user_project, ) + if not self.is_workforce_pool: + d.pop("workforce_pool_user_project") + return self.__class__(**d) def _initialize_impersonated_credentials(self): """Generates an impersonated credentials. @@ -336,7 +379,7 @@ class Credentials( endpoint returned an error. """ # Return copy of instance with no service account impersonation. - source_credentials = self.__class__( + d = dict( audience=self._audience, subject_token_type=self._subject_token_type, token_url=self._token_url, @@ -347,7 +390,11 @@ class Credentials( quota_project_id=self._quota_project_id, scopes=self._scopes, default_scopes=self._default_scopes, + workforce_pool_user_project=self._workforce_pool_user_project, ) + if not self.is_workforce_pool: + d.pop("workforce_pool_user_project") + source_credentials = self.__class__(**d) # Determine target_principal. target_principal = self.service_account_email |