diff options
author | robot-piglet <robot-piglet@yandex-team.com> | 2024-04-04 07:45:46 +0300 |
---|---|---|
committer | robot-piglet <robot-piglet@yandex-team.com> | 2024-04-04 07:53:28 +0300 |
commit | 51958dfd22674e02052c8a292ab70fc2d52a07fc (patch) | |
tree | 42d2b859c555e9045203791ed3f60fa16e1b267e /contrib/python/google-auth | |
parent | 7e62114667a5059c6b6a617d4cd9076928818478 (diff) | |
download | ydb-51958dfd22674e02052c8a292ab70fc2d52a07fc.tar.gz |
Intermediate changes
Diffstat (limited to 'contrib/python/google-auth')
9 files changed, 1174 insertions, 471 deletions
diff --git a/contrib/python/google-auth/py3/.dist-info/METADATA b/contrib/python/google-auth/py3/.dist-info/METADATA index b4012b87e3..e0785656a4 100644 --- a/contrib/python/google-auth/py3/.dist-info/METADATA +++ b/contrib/python/google-auth/py3/.dist-info/METADATA @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: google-auth -Version: 2.28.2 +Version: 2.29.0 Summary: Google Authentication Library Home-page: https://github.com/googleapis/google-auth-library-python Author: Google Cloud Platform diff --git a/contrib/python/google-auth/py3/google/auth/aws.py b/contrib/python/google-auth/py3/google/auth/aws.py index 6e0e4e864f..28c065d3c7 100644 --- a/contrib/python/google-auth/py3/google/auth/aws.py +++ b/contrib/python/google-auth/py3/google/auth/aws.py @@ -21,10 +21,11 @@ of long-live service account private keys. AWS Credentials are initialized using external_account arguments which are typically loaded from the external credentials JSON file. -Unlike other Credentials that can be initialized with a list of explicit -arguments, secrets or credentials, external account clients use the -environment and hints/guidelines provided by the external_account JSON -file to retrieve credentials and exchange them for Google access tokens. + +This module also provides a definition for an abstract AWS security credentials supplier. +This supplier can be implemented to return valid AWS security credentials and an AWS region +and used to create AWS credentials. The credentials will then call the +supplier instead of using pre-defined methods such as calling the EC2 metadata endpoints. This module also provides a basic implementation of the `AWS Signature Version 4`_ request signing algorithm. @@ -37,6 +38,8 @@ via the GCP STS endpoint. .. _AWS STS GetCallerIdentity: https://docs.aws.amazon.com/STS/latest/APIReference/API_GetCallerIdentity.html """ +import abc +from dataclasses import dataclass import hashlib import hmac import http.client as http_client @@ -44,6 +47,7 @@ import json import os import posixpath import re +from typing import Optional import urllib from urllib.parse import urljoin @@ -61,6 +65,12 @@ _AWS_REQUEST_TYPE = "aws4_request" _AWS_SECURITY_TOKEN_HEADER = "x-amz-security-token" # The AWS authorization header name for the auto-generated date. _AWS_DATE_HEADER = "x-amz-date" +# The default AWS regional credential verification URL. +_DEFAULT_AWS_REGIONAL_CREDENTIAL_VERIFICATION_URL = ( + "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15" +) +# IMDSV2 session token lifetime. This is set to a low value because the session token is used immediately. +_IMDSV2_SESSION_TOKEN_TTL_SECONDS = "300" class RequestSigner(object): @@ -92,8 +102,7 @@ class RequestSigner(object): https://docs.aws.amazon.com/general/latest/gr/sigv4_signing.html Args: - aws_security_credentials (Mapping[str, str]): A dictionary containing - the AWS security credentials. + aws_security_credentials (AWSSecurityCredentials): The AWS security credentials. url (str): The AWS service URL containing the canonical URI and query string. method (str): The HTTP method used to call this API. @@ -105,10 +114,6 @@ class RequestSigner(object): Returns: Mapping[str, str]: The AWS signed request dictionary object. """ - # Get AWS credentials. - access_key = aws_security_credentials.get("access_key_id") - secret_key = aws_security_credentials.get("secret_access_key") - security_token = aws_security_credentials.get("security_token") additional_headers = additional_headers or {} @@ -129,9 +134,7 @@ class RequestSigner(object): canonical_querystring=_get_canonical_querystring(uri.query), method=method, region=self._region_name, - access_key=access_key, - secret_key=secret_key, - security_token=security_token, + aws_security_credentials=aws_security_credentials, request_payload=request_payload, additional_headers=additional_headers, ) @@ -147,8 +150,8 @@ class RequestSigner(object): headers[key] = additional_headers[key] # Add session token if available. - if security_token is not None: - headers[_AWS_SECURITY_TOKEN_HEADER] = security_token + if aws_security_credentials.session_token is not None: + headers[_AWS_SECURITY_TOKEN_HEADER] = aws_security_credentials.session_token signed_request = {"url": url, "method": method, "headers": headers} if request_payload: @@ -233,9 +236,7 @@ def _generate_authentication_header_map( canonical_querystring, method, region, - access_key, - secret_key, - security_token, + aws_security_credentials, request_payload="", additional_headers={}, ): @@ -248,10 +249,7 @@ def _generate_authentication_header_map( canonical_querystring (str): The AWS service URL query string. method (str): The HTTP method used to call this API. region (str): The AWS region. - access_key (str): The AWS access key ID. - secret_key (str): The AWS secret access key. - security_token (Optional[str]): The AWS security session token. This is - available for temporary sessions. + aws_security_credentials (AWSSecurityCredentials): The AWS security credentials. request_payload (Optional[str]): The optional request payload if available. additional_headers (Optional[Mapping[str, str]]): The optional @@ -274,8 +272,10 @@ def _generate_authentication_header_map( for key in additional_headers: full_headers[key.lower()] = additional_headers[key] # Add AWS session token if available. - if security_token is not None: - full_headers[_AWS_SECURITY_TOKEN_HEADER] = security_token + if aws_security_credentials.session_token is not None: + full_headers[ + _AWS_SECURITY_TOKEN_HEADER + ] = aws_security_credentials.session_token # Required headers full_headers["host"] = host @@ -321,14 +321,20 @@ def _generate_authentication_header_map( ) # https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html - signing_key = _get_signing_key(secret_key, date_stamp, region, service_name) + signing_key = _get_signing_key( + aws_security_credentials.secret_access_key, date_stamp, region, service_name + ) signature = hmac.new( signing_key, string_to_sign.encode("utf-8"), hashlib.sha256 ).hexdigest() # https://docs.aws.amazon.com/general/latest/gr/sigv4-add-signature-to-request.html authorization_header = "{} Credential={}/{}, SignedHeaders={}, Signature={}".format( - _AWS_ALGORITHM, access_key, credential_scope, signed_headers, signature + _AWS_ALGORITHM, + aws_security_credentials.access_key_id, + credential_scope, + signed_headers, + signature, ) authentication_header = {"authorization_header": authorization_header} @@ -338,211 +344,112 @@ def _generate_authentication_header_map( return authentication_header -class Credentials(external_account.Credentials): - """AWS external account credentials. - This is used to exchange serialized AWS signature v4 signed requests to - AWS STS GetCallerIdentity service for Google access tokens. - """ - - def __init__( - self, - audience, - subject_token_type, - token_url, - credential_source=None, - *args, - **kwargs - ): - """Instantiates an AWS workload external account credentials object. - - Args: - audience (str): The STS audience field. - subject_token_type (str): The subject token type. - token_url (str): The STS endpoint URL. - credential_source (Mapping): The credential source dictionary used - to provide instructions on how to retrieve external credential - to be exchanged for Google access tokens. - args (List): Optional positional arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. - kwargs (Mapping): Optional keyword arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. - - Raises: - google.auth.exceptions.RefreshError: If an error is encountered during - access token retrieval logic. - ValueError: For invalid parameters. - - .. note:: Typically one of the helper constructors - :meth:`from_file` or - :meth:`from_info` are used instead of calling the constructor directly. - """ - super(Credentials, self).__init__( - audience=audience, - subject_token_type=subject_token_type, - token_url=token_url, - credential_source=credential_source, - *args, - **kwargs - ) - credential_source = credential_source or {} - self._environment_id = credential_source.get("environment_id") or "" - self._region_url = credential_source.get("region_url") - self._security_credentials_url = credential_source.get("url") - self._cred_verification_url = credential_source.get( - "regional_cred_verification_url" - ) - self._imdsv2_session_token_url = credential_source.get( - "imdsv2_session_token_url" - ) - self._region = None - self._request_signer = None - self._target_resource = audience +@dataclass +class AwsSecurityCredentials: + """A class that models AWS security credentials with an optional session token. - # Get the environment ID. Currently, only one version supported (v1). - matches = re.match(r"^(aws)([\d]+)$", self._environment_id) - if matches: - env_id, env_version = matches.groups() - else: - env_id, env_version = (None, None) + Attributes: + access_key_id (str): The AWS security credentials access key id. + secret_access_key (str): The AWS security credentials secret access key. + session_token (Optional[str]): The optional AWS security credentials session token. This should be set when using temporary credentials. + """ - if env_id != "aws" or self._cred_verification_url is None: - raise exceptions.InvalidResource( - "No valid AWS 'credential_source' provided" - ) - elif int(env_version or "") != 1: - raise exceptions.InvalidValue( - "aws version '{}' is not supported in the current build.".format( - env_version - ) - ) + access_key_id: str + secret_access_key: str + session_token: Optional[str] = None - def retrieve_subject_token(self, request): - """Retrieves the subject token using the credential_source object. - The subject token is a serialized `AWS GetCallerIdentity signed request`_. - The logic is summarized as: +class AwsSecurityCredentialsSupplier(metaclass=abc.ABCMeta): + """Base class for AWS security credential suppliers. This can be implemented with custom logic to retrieve + AWS security credentials to exchange for a Google Cloud access token. The AWS external account credential does + not cache the AWS security credentials, so caching logic should be added in the implementation. + """ - Retrieve the AWS region from the AWS_REGION or AWS_DEFAULT_REGION - environment variable or from the AWS metadata server availability-zone - if not found in the environment variable. + @abc.abstractmethod + def get_aws_security_credentials(self, context, request): + """Returns the AWS security credentials for the requested context. - Check AWS credentials in environment variables. If not found, retrieve - from the AWS metadata server security-credentials endpoint. + .. warning: This is not cached by the calling Google credential, so caching logic should be implemented in the supplier. - When retrieving AWS credentials from the metadata server - security-credentials endpoint, the AWS role needs to be determined by - calling the security-credentials endpoint without any argument. Then the - credentials can be retrieved via: security-credentials/role_name + Args: + context (google.auth.externalaccount.SupplierContext): The context object + containing information about the requested audience and subject token type. + request (google.auth.transport.Request): The object used to make + HTTP requests. - Generate the signed request to AWS STS GetCallerIdentity action. + Raises: + google.auth.exceptions.RefreshError: If an error is encountered during + security credential retrieval logic. - Inject x-goog-cloud-target-resource into header and serialize the - signed request. This will be the subject-token to pass to GCP STS. + Returns: + AwsSecurityCredentials: The requested AWS security credentials. + """ + raise NotImplementedError("") - .. _AWS GetCallerIdentity signed request: - https://cloud.google.com/iam/docs/access-resources-aws#exchange-token + @abc.abstractmethod + def get_aws_region(self, context, request): + """Returns the AWS region for the requested context. Args: - request (google.auth.transport.Request): A callable used to make + context (google.auth.externalaccount.SupplierContext): The context object + containing information about the requested audience and subject token type. + request (google.auth.transport.Request): The object used to make HTTP requests. - Returns: - str: The retrieved subject token. - """ - # Fetch the session token required to make meta data endpoint calls to aws. - if ( - request is not None - and self._imdsv2_session_token_url is not None - and self._should_use_metadata_server() - ): - headers = {"X-aws-ec2-metadata-token-ttl-seconds": "300"} - imdsv2_session_token_response = request( - url=self._imdsv2_session_token_url, method="PUT", headers=headers - ) + Raises: + google.auth.exceptions.RefreshError: If an error is encountered during + region retrieval logic. - if imdsv2_session_token_response.status != 200: - raise exceptions.RefreshError( - "Unable to retrieve AWS Session Token", - imdsv2_session_token_response.data, - ) + Returns: + str: The AWS region. + """ + raise NotImplementedError("") - imdsv2_session_token = imdsv2_session_token_response.data - else: - imdsv2_session_token = None - # Initialize the request signer if not yet initialized after determining - # the current AWS region. - if self._request_signer is None: - self._region = self._get_region( - request, self._region_url, imdsv2_session_token - ) - self._request_signer = RequestSigner(self._region) +class _DefaultAwsSecurityCredentialsSupplier(AwsSecurityCredentialsSupplier): + """Default implementation of AWS security credentials supplier. Supports retrieving + credentials and region via EC2 metadata endpoints and environment variables. + """ - # Retrieve the AWS security credentials needed to generate the signed - # request. - aws_security_credentials = self._get_security_credentials( - request, imdsv2_session_token - ) - # Generate the signed request to AWS STS GetCallerIdentity API. - # Use the required regional endpoint. Otherwise, the request will fail. - request_options = self._request_signer.get_request_options( - aws_security_credentials, - self._cred_verification_url.replace("{region}", self._region), - "POST", + def __init__(self, credential_source): + self._region_url = credential_source.get("region_url") + self._security_credentials_url = credential_source.get("url") + self._imdsv2_session_token_url = credential_source.get( + "imdsv2_session_token_url" ) - # The GCP STS endpoint expects the headers to be formatted as: - # [ - # {key: 'x-amz-date', value: '...'}, - # {key: 'Authorization', value: '...'}, - # ... - # ] - # And then serialized as: - # quote(json.dumps({ - # url: '...', - # method: 'POST', - # headers: [{key: 'x-amz-date', value: '...'}, ...] - # })) - request_headers = request_options.get("headers") - # The full, canonical resource name of the workload identity pool - # provider, with or without the HTTPS prefix. - # Including this header as part of the signature is recommended to - # ensure data integrity. - request_headers["x-goog-cloud-target-resource"] = self._target_resource - # Serialize AWS signed request. - # Keeping inner keys in sorted order makes testing easier for Python - # versions <=3.5 as the stringified JSON string would have a predictable - # key order. - aws_signed_req = {} - aws_signed_req["url"] = request_options.get("url") - aws_signed_req["method"] = request_options.get("method") - aws_signed_req["headers"] = [] - # Reformat header to GCP STS expected format. - for key in sorted(request_headers.keys()): - aws_signed_req["headers"].append( - {"key": key, "value": request_headers[key]} - ) + @_helpers.copy_docstring(AwsSecurityCredentialsSupplier) + def get_aws_security_credentials(self, context, request): - return urllib.parse.quote( - json.dumps(aws_signed_req, separators=(",", ":"), sort_keys=True) + # Check environment variables for permanent credentials first. + # https://docs.aws.amazon.com/general/latest/gr/aws-sec-cred-types.html + env_aws_access_key_id = os.environ.get(environment_vars.AWS_ACCESS_KEY_ID) + env_aws_secret_access_key = os.environ.get( + environment_vars.AWS_SECRET_ACCESS_KEY ) + # This is normally not available for permanent credentials. + env_aws_session_token = os.environ.get(environment_vars.AWS_SESSION_TOKEN) + if env_aws_access_key_id and env_aws_secret_access_key: + return AwsSecurityCredentials( + env_aws_access_key_id, env_aws_secret_access_key, env_aws_session_token + ) - def _get_region(self, request, url, imdsv2_session_token): - """Retrieves the current AWS region from either the AWS_REGION or - AWS_DEFAULT_REGION environment variable or from the AWS metadata server. + imdsv2_session_token = self._get_imdsv2_session_token(request) + role_name = self._get_metadata_role_name(request, imdsv2_session_token) - Args: - request (google.auth.transport.Request): A callable used to make - HTTP requests. - url (str): The AWS metadata server region URL. - imdsv2_session_token (str): The AWS IMDSv2 session token to be added as a - header in the requests to AWS metadata endpoint. + # Get security credentials. + credentials = self._get_metadata_security_credentials( + request, role_name, imdsv2_session_token + ) - Returns: - str: The current AWS region. + return AwsSecurityCredentials( + credentials.get("AccessKeyId"), + credentials.get("SecretAccessKey"), + credentials.get("Token"), + ) - Raises: - google.auth.exceptions.RefreshError: If an error occurs while - retrieving the AWS region. - """ + @_helpers.copy_docstring(AwsSecurityCredentialsSupplier) + def get_aws_region(self, context, request): # The AWS metadata server is not available in some AWS environments # such as AWS lambda. Instead, it is available via environment # variable. @@ -558,6 +465,7 @@ class Credentials(external_account.Credentials): raise exceptions.RefreshError("Unable to determine AWS region") headers = None + imdsv2_session_token = self._get_imdsv2_session_token(request) if imdsv2_session_token is not None: headers = {"X-aws-ec2-metadata-token": imdsv2_session_token} @@ -570,62 +478,35 @@ class Credentials(external_account.Credentials): else response.data ) - if response.status != 200: + if response.status != http_client.OK: raise exceptions.RefreshError( - "Unable to retrieve AWS region", response_body + "Unable to retrieve AWS region: {}".format(response_body) ) # This endpoint will return the region in format: us-east-2b. # Only the us-east-2 part should be used. return response_body[:-1] - def _get_security_credentials(self, request, imdsv2_session_token): - """Retrieves the AWS security credentials required for signing AWS - requests from either the AWS security credentials environment variables - or from the AWS metadata server. - - Args: - request (google.auth.transport.Request): A callable used to make - HTTP requests. - imdsv2_session_token (str): The AWS IMDSv2 session token to be added as a - header in the requests to AWS metadata endpoint. - - Returns: - Mapping[str, str]: The AWS security credentials dictionary object. - - Raises: - google.auth.exceptions.RefreshError: If an error occurs while - retrieving the AWS security credentials. - """ - - # Check environment variables for permanent credentials first. - # https://docs.aws.amazon.com/general/latest/gr/aws-sec-cred-types.html - env_aws_access_key_id = os.environ.get(environment_vars.AWS_ACCESS_KEY_ID) - env_aws_secret_access_key = os.environ.get( - environment_vars.AWS_SECRET_ACCESS_KEY - ) - # This is normally not available for permanent credentials. - env_aws_session_token = os.environ.get(environment_vars.AWS_SESSION_TOKEN) - if env_aws_access_key_id and env_aws_secret_access_key: - return { - "access_key_id": env_aws_access_key_id, - "secret_access_key": env_aws_secret_access_key, - "security_token": env_aws_session_token, + def _get_imdsv2_session_token(self, request): + if request is not None and self._imdsv2_session_token_url is not None: + headers = { + "X-aws-ec2-metadata-token-ttl-seconds": _IMDSV2_SESSION_TOKEN_TTL_SECONDS } - # Get role name. - role_name = self._get_metadata_role_name(request, imdsv2_session_token) + imdsv2_session_token_response = request( + url=self._imdsv2_session_token_url, method="PUT", headers=headers + ) - # Get security credentials. - credentials = self._get_metadata_security_credentials( - request, role_name, imdsv2_session_token - ) + if imdsv2_session_token_response.status != http_client.OK: + raise exceptions.RefreshError( + "Unable to retrieve AWS Session Token: {}".format( + imdsv2_session_token_response.data + ) + ) - return { - "access_key_id": credentials.get("AccessKeyId"), - "secret_access_key": credentials.get("SecretAccessKey"), - "security_token": credentials.get("Token"), - } + return imdsv2_session_token_response.data + else: + return None def _get_metadata_security_credentials( self, request, role_name, imdsv2_session_token @@ -669,7 +550,7 @@ class Credentials(external_account.Credentials): if response.status != http_client.OK: raise exceptions.RefreshError( - "Unable to retrieve AWS security credentials", response_body + "Unable to retrieve AWS security credentials: {}".format(response_body) ) credentials_response = json.loads(response_body) @@ -717,35 +598,232 @@ class Credentials(external_account.Credentials): if response.status != http_client.OK: raise exceptions.RefreshError( - "Unable to retrieve AWS role name", response_body + "Unable to retrieve AWS role name {}".format(response_body) ) return response_body - def _should_use_metadata_server(self): - # The AWS region can be provided through AWS_REGION or AWS_DEFAULT_REGION. - # The metadata server should be used if it cannot be retrieved from one of - # these environment variables. - if not os.environ.get(environment_vars.AWS_REGION) and not os.environ.get( - environment_vars.AWS_DEFAULT_REGION - ): - return True - # AWS security credentials can be retrieved from the AWS_ACCESS_KEY_ID - # and AWS_SECRET_ACCESS_KEY environment variables. The metadata server - # should be used if either of these are not available. - if not os.environ.get(environment_vars.AWS_ACCESS_KEY_ID) or not os.environ.get( - environment_vars.AWS_SECRET_ACCESS_KEY +class Credentials(external_account.Credentials): + """AWS external account credentials. + This is used to exchange serialized AWS signature v4 signed requests to + AWS STS GetCallerIdentity service for Google access tokens. + """ + + def __init__( + self, + audience, + subject_token_type, + token_url=external_account._DEFAULT_TOKEN_URL, + credential_source=None, + aws_security_credentials_supplier=None, + *args, + **kwargs + ): + """Instantiates an AWS workload external account credentials object. + + Args: + audience (str): The STS audience field. + subject_token_type (str): The subject token type based on the Oauth2.0 token exchange spec. + Expected values include:: + + “urn:ietf:params:aws:token-type:aws4_request” + + token_url (Optional [str]): The STS endpoint URL. If not provided, will default to "https://sts.googleapis.com/v1/token". + credential_source (Optional [Mapping]): The credential source dictionary used + to provide instructions on how to retrieve external credential to be exchanged for Google access tokens. + Either a credential source or an AWS security credentials supplier must be provided. + + Example credential_source for AWS credential:: + + { + "environment_id": "aws1", + "regional_cred_verification_url": "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "region_url": "http://169.254.169.254/latest/meta-data/placement/availability-zone", + "url": "http://169.254.169.254/latest/meta-data/iam/security-credentials", + imdsv2_session_token_url": "http://169.254.169.254/latest/api/token" + } + + aws_security_credentials_supplier (Optional [AwsSecurityCredentialsSupplier]): Optional AWS security credentials supplier. + This will be called to supply valid AWS security credentails which will then + be exchanged for Google access tokens. Either an AWS security credentials supplier + or a credential source must be provided. + args (List): Optional positional arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. + kwargs (Mapping): Optional keyword arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. + + Raises: + google.auth.exceptions.RefreshError: If an error is encountered during + access token retrieval logic. + ValueError: For invalid parameters. + + .. note:: Typically one of the helper constructors + :meth:`from_file` or + :meth:`from_info` are used instead of calling the constructor directly. + """ + super(Credentials, self).__init__( + audience=audience, + subject_token_type=subject_token_type, + token_url=token_url, + credential_source=credential_source, + *args, + **kwargs + ) + if credential_source is None and aws_security_credentials_supplier is None: + raise exceptions.InvalidValue( + "A valid credential source or AWS security credentials supplier must be provided." + ) + if ( + credential_source is not None + and aws_security_credentials_supplier is not None ): - return True + raise exceptions.InvalidValue( + "AWS credential cannot have both a credential source and an AWS security credentials supplier." + ) + + if aws_security_credentials_supplier: + self._aws_security_credentials_supplier = aws_security_credentials_supplier + # The regional cred verification URL would normally be provided through the credential source. So set it to the default one here. + self._cred_verification_url = ( + _DEFAULT_AWS_REGIONAL_CREDENTIAL_VERIFICATION_URL + ) + else: + environment_id = credential_source.get("environment_id") or "" + self._aws_security_credentials_supplier = _DefaultAwsSecurityCredentialsSupplier( + credential_source + ) + self._cred_verification_url = credential_source.get( + "regional_cred_verification_url" + ) + + # Get the environment ID, i.e. "aws1". Currently, only one version supported (1). + matches = re.match(r"^(aws)([\d]+)$", environment_id) + if matches: + env_id, env_version = matches.groups() + else: + env_id, env_version = (None, None) + + if env_id != "aws" or self._cred_verification_url is None: + raise exceptions.InvalidResource( + "No valid AWS 'credential_source' provided" + ) + elif env_version is None or int(env_version) != 1: + raise exceptions.InvalidValue( + "aws version '{}' is not supported in the current build.".format( + env_version + ) + ) + + self._target_resource = audience + self._request_signer = None + + def retrieve_subject_token(self, request): + """Retrieves the subject token using the credential_source object. + The subject token is a serialized `AWS GetCallerIdentity signed request`_. + + The logic is summarized as: + + Retrieve the AWS region from the AWS_REGION or AWS_DEFAULT_REGION + environment variable or from the AWS metadata server availability-zone + if not found in the environment variable. + + Check AWS credentials in environment variables. If not found, retrieve + from the AWS metadata server security-credentials endpoint. + + When retrieving AWS credentials from the metadata server + security-credentials endpoint, the AWS role needs to be determined by + calling the security-credentials endpoint without any argument. Then the + credentials can be retrieved via: security-credentials/role_name + + Generate the signed request to AWS STS GetCallerIdentity action. - return False + Inject x-goog-cloud-target-resource into header and serialize the + signed request. This will be the subject-token to pass to GCP STS. + + .. _AWS GetCallerIdentity signed request: + https://cloud.google.com/iam/docs/access-resources-aws#exchange-token + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + Returns: + str: The retrieved subject token. + """ + + # Initialize the request signer if not yet initialized after determining + # the current AWS region. + if self._request_signer is None: + self._region = self._aws_security_credentials_supplier.get_aws_region( + self._supplier_context, request + ) + self._request_signer = RequestSigner(self._region) + + # Retrieve the AWS security credentials needed to generate the signed + # request. + aws_security_credentials = self._aws_security_credentials_supplier.get_aws_security_credentials( + self._supplier_context, request + ) + # Generate the signed request to AWS STS GetCallerIdentity API. + # Use the required regional endpoint. Otherwise, the request will fail. + request_options = self._request_signer.get_request_options( + aws_security_credentials, + self._cred_verification_url.replace("{region}", self._region), + "POST", + ) + # The GCP STS endpoint expects the headers to be formatted as: + # [ + # {key: 'x-amz-date', value: '...'}, + # {key: 'Authorization', value: '...'}, + # ... + # ] + # And then serialized as: + # quote(json.dumps({ + # url: '...', + # method: 'POST', + # headers: [{key: 'x-amz-date', value: '...'}, ...] + # })) + request_headers = request_options.get("headers") + # The full, canonical resource name of the workload identity pool + # provider, with or without the HTTPS prefix. + # Including this header as part of the signature is recommended to + # ensure data integrity. + request_headers["x-goog-cloud-target-resource"] = self._target_resource + + # Serialize AWS signed request. + aws_signed_req = {} + aws_signed_req["url"] = request_options.get("url") + aws_signed_req["method"] = request_options.get("method") + aws_signed_req["headers"] = [] + # Reformat header to GCP STS expected format. + for key in request_headers.keys(): + aws_signed_req["headers"].append( + {"key": key, "value": request_headers[key]} + ) + + return urllib.parse.quote( + json.dumps(aws_signed_req, separators=(",", ":"), sort_keys=True) + ) def _create_default_metrics_options(self): metrics_options = super(Credentials, self)._create_default_metrics_options() metrics_options["source"] = "aws" + if self._has_custom_supplier(): + metrics_options["source"] = "programmatic" return metrics_options + def _has_custom_supplier(self): + return self._credential_source is None + + def _constructor_args(self): + args = super(Credentials, self)._constructor_args() + # If a custom supplier was used, append it to the args dict. + if self._has_custom_supplier(): + args.update( + { + "aws_security_credentials_supplier": self._aws_security_credentials_supplier + } + ) + return args + @classmethod def from_info(cls, info, **kwargs): """Creates an AWS Credentials instance from parsed external account info. @@ -761,6 +839,12 @@ class Credentials(external_account.Credentials): Raises: ValueError: For invalid parameters. """ + aws_security_credentials_supplier = info.get( + "aws_security_credentials_supplier" + ) + kwargs.update( + {"aws_security_credentials_supplier": aws_security_credentials_supplier} + ) return super(Credentials, cls).from_info(info, **kwargs) @classmethod diff --git a/contrib/python/google-auth/py3/google/auth/external_account.py b/contrib/python/google-auth/py3/google/auth/external_account.py index 0420883f86..c14001bc2b 100644 --- a/contrib/python/google-auth/py3/google/auth/external_account.py +++ b/contrib/python/google-auth/py3/google/auth/external_account.py @@ -29,6 +29,7 @@ token exchange endpoint following the `OAuth 2.0 Token Exchange`_ spec. import abc import copy +from dataclasses import dataclass import datetime import io import json @@ -50,6 +51,29 @@ _STS_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange" _STS_REQUESTED_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token" # Cloud resource manager URL used to retrieve project information. _CLOUD_RESOURCE_MANAGER = "https://cloudresourcemanager.googleapis.com/v1/projects/" +# Default Google sts token url. +_DEFAULT_TOKEN_URL = "https://sts.googleapis.com/v1/token" + + +@dataclass +class SupplierContext: + """A context class that contains information about the requested third party credential that is passed + to AWS security credential and subject token suppliers. + + Attributes: + subject_token_type (str): The requested subject token type based on the Oauth2.0 token exchange spec. + Expected values include:: + + “urn:ietf:params:oauth:token-type:jwt” + “urn:ietf:params:oauth:token-type:id-token” + “urn:ietf:params:oauth:token-type:saml2” + “urn:ietf:params:aws:token-type:aws4_request” + + audience (str): The requested audience for the subject token. + """ + + subject_token_type: str + audience: str class Credentials( @@ -88,7 +112,14 @@ class Credentials( Args: audience (str): The STS audience field. - subject_token_type (str): The subject token type. + subject_token_type (str): The subject token type based on the Oauth2.0 token exchange spec. + Expected values include:: + + “urn:ietf:params:oauth:token-type:jwt” + “urn:ietf:params:oauth:token-type:id-token” + “urn:ietf:params:oauth:token-type:saml2” + “urn:ietf:params:aws:token-type:aws4_request” + token_url (str): The STS endpoint URL. credential_source (Mapping): The credential source dictionary. service_account_impersonation_url (Optional[str]): The optional service account @@ -145,11 +176,11 @@ class Credentials( self._metrics_options = self._create_default_metrics_options() - if self._service_account_impersonation_url: - self._impersonated_credentials = self._initialize_impersonated_credentials() - else: - self._impersonated_credentials = None + self._impersonated_credentials = None self._project_id = None + self._supplier_context = SupplierContext( + self._subject_token_type, self._audience + ) if not self.is_workforce_pool and self._workforce_pool_user_project: # Workload identity pools do not support workforce pool user projects. @@ -358,6 +389,10 @@ class Credentials( @_helpers.copy_docstring(credentials.Credentials) def refresh(self, request): scopes = self._scopes if self._scopes is not None else self._default_scopes + + if self._should_initialize_impersonated_credentials(): + self._impersonated_credentials = self._initialize_impersonated_credentials() + if self._impersonated_credentials: self._impersonated_credentials.refresh(request) self.token = self._impersonated_credentials.token @@ -421,6 +456,12 @@ class Credentials( new_cred._metrics_options = self._metrics_options return new_cred + def _should_initialize_impersonated_credentials(self): + return ( + self._service_account_impersonation_url is not None + and self._impersonated_credentials is None + ) + def _initialize_impersonated_credentials(self): """Generates an impersonated credentials. diff --git a/contrib/python/google-auth/py3/google/auth/identity_pool.py b/contrib/python/google-auth/py3/google/auth/identity_pool.py index a515353c37..a9ec577334 100644 --- a/contrib/python/google-auth/py3/google/auth/identity_pool.py +++ b/contrib/python/google-auth/py3/google/auth/identity_pool.py @@ -26,11 +26,13 @@ long-live service account private keys. Identity Pool Credentials are initialized using external_account arguments which are typically loaded from an external credentials file or -an external credentials URL. Unlike other Credentials that can be initialized -with a list of explicit arguments, secrets or credentials, external account -clients use the environment and hints/guidelines provided by the -external_account JSON file to retrieve credentials and exchange them for Google -access tokens. +an external credentials URL. + +This module also provides a definition for an abstract subject token supplier. +This supplier can be implemented to return a valid OIDC or SAML2.0 subject token +and used to create Identity Pool credentials. The credentials will then call the +supplier instead of using pre-defined methods such as reading a local file or +calling a URL. """ try: @@ -38,15 +40,129 @@ try: # Python 2.7 compatibility except ImportError: # pragma: NO COVER from collections import Mapping -import io +import abc import json import os +from typing import NamedTuple from google.auth import _helpers from google.auth import exceptions from google.auth import external_account +class SubjectTokenSupplier(metaclass=abc.ABCMeta): + """Base class for subject token suppliers. This can be implemented with custom logic to retrieve + a subject token to exchange for a Google Cloud access token when using Workload or + Workforce Identity Federation. The identity pool credential does not cache the subject token, + so caching logic should be added in the implementation. + """ + + @abc.abstractmethod + def get_subject_token(self, context, request): + """Returns the requested subject token. The subject token must be valid. + + .. warning: This is not cached by the calling Google credential, so caching logic should be implemented in the supplier. + + Args: + context (google.auth.externalaccount.SupplierContext): The context object + containing information about the requested audience and subject token type. + request (google.auth.transport.Request): The object used to make + HTTP requests. + + Raises: + google.auth.exceptions.RefreshError: If an error is encountered during + subject token retrieval logic. + + Returns: + str: The requested subject token string. + """ + raise NotImplementedError("") + + +class _TokenContent(NamedTuple): + """Models the token content response from file and url internal suppliers. + Attributes: + content (str): The string content of the file or URL response. + location (str): The location the content was retrieved from. This will either be a file location or a URL. + """ + + content: str + location: str + + +class _FileSupplier(SubjectTokenSupplier): + """ Internal implementation of subject token supplier which supports reading a subject token from a file.""" + + def __init__(self, path, format_type, subject_token_field_name): + self._path = path + self._format_type = format_type + self._subject_token_field_name = subject_token_field_name + + @_helpers.copy_docstring(SubjectTokenSupplier) + def get_subject_token(self, context, request): + if not os.path.exists(self._path): + raise exceptions.RefreshError("File '{}' was not found.".format(self._path)) + + with open(self._path, "r", encoding="utf-8") as file_obj: + token_content = _TokenContent(file_obj.read(), self._path) + + return _parse_token_data( + token_content, self._format_type, self._subject_token_field_name + ) + + +class _UrlSupplier(SubjectTokenSupplier): + """ Internal implementation of subject token supplier which supports retrieving a subject token by calling a URL endpoint.""" + + def __init__(self, url, format_type, subject_token_field_name, headers): + self._url = url + self._format_type = format_type + self._subject_token_field_name = subject_token_field_name + self._headers = headers + + @_helpers.copy_docstring(SubjectTokenSupplier) + def get_subject_token(self, context, request): + response = request(url=self._url, method="GET", headers=self._headers) + + # support both string and bytes type response.data + response_body = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + + if response.status != 200: + raise exceptions.RefreshError( + "Unable to retrieve Identity Pool subject token", response_body + ) + token_content = _TokenContent(response_body, self._url) + return _parse_token_data( + token_content, self._format_type, self._subject_token_field_name + ) + + +def _parse_token_data(token_content, format_type="text", subject_token_field_name=None): + if format_type == "text": + token = token_content.content + else: + try: + # Parse file content as JSON. + response_data = json.loads(token_content.content) + # Get the subject_token. + token = response_data[subject_token_field_name] + except (KeyError, ValueError): + raise exceptions.RefreshError( + "Unable to parse subject_token from JSON file '{}' using key '{}'".format( + token_content.location, subject_token_field_name + ) + ) + if not token: + raise exceptions.RefreshError( + "Missing subject_token in the credential_source file" + ) + return token + + class Credentials(external_account.Credentials): """External account credentials sourced from files and URLs.""" @@ -54,8 +170,9 @@ class Credentials(external_account.Credentials): self, audience, subject_token_type, - token_url, - credential_source, + token_url=external_account._DEFAULT_TOKEN_URL, + credential_source=None, + subject_token_supplier=None, *args, **kwargs ): @@ -63,11 +180,18 @@ class Credentials(external_account.Credentials): Args: audience (str): The STS audience field. - subject_token_type (str): The subject token type. - token_url (str): The STS endpoint URL. - credential_source (Mapping): The credential source dictionary used to + subject_token_type (str): The subject token type based on the Oauth2.0 token exchange spec. + Expected values include:: + + “urn:ietf:params:oauth:token-type:jwt” + “urn:ietf:params:oauth:token-type:id-token” + “urn:ietf:params:oauth:token-type:saml2” + + token_url (Optional [str]): The STS endpoint URL. If not provided, will default to "https://sts.googleapis.com/v1/token". + credential_source (Optional [Mapping]): The credential source dictionary used to provide instructions on how to retrieve external credential to be - exchanged for Google access tokens. + exchanged for Google access tokens. Either a credential source or + a subject token supplier must be provided. Example credential_source for url-sourced credential:: @@ -85,6 +209,10 @@ class Credentials(external_account.Credentials): { "file": "/path/to/token/file.txt" } + subject_token_supplier (Optional [SubjectTokenSupplier]): Optional subject token supplier. + This will be called to supply a valid subject token which will then + be exchanged for Google access tokens. Either a subject token supplier + or a credential source must be provided. args (List): Optional positional arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. kwargs (Mapping): Optional keyword arguments passed into the underlying :meth:`~external_account.Credentials.__init__` method. @@ -106,10 +234,25 @@ class Credentials(external_account.Credentials): *args, **kwargs ) - if not isinstance(credential_source, Mapping): + if credential_source is None and subject_token_supplier is None: + raise exceptions.InvalidValue( + "A valid credential source or a subject token supplier must be provided." + ) + if credential_source is not None and subject_token_supplier is not None: + raise exceptions.InvalidValue( + "Identity pool credential cannot have both a credential source and a subject token supplier." + ) + + if subject_token_supplier is not None: + self._subject_token_supplier = subject_token_supplier self._credential_source_file = None self._credential_source_url = None else: + if not isinstance(credential_source, Mapping): + self._credential_source_executable = None + raise exceptions.MalformedError( + "Invalid credential_source. The credential_source is not a dict." + ) self._credential_source_file = credential_source.get("file") self._credential_source_url = credential_source.get("url") self._credential_source_headers = credential_source.get("headers") @@ -143,79 +286,35 @@ class Credentials(external_account.Credentials): else: self._credential_source_field_name = None - if self._credential_source_file and self._credential_source_url: - raise exceptions.MalformedError( - "Ambiguous credential_source. 'file' is mutually exclusive with 'url'." - ) - if not self._credential_source_file and not self._credential_source_url: - raise exceptions.MalformedError( - "Missing credential_source. A 'file' or 'url' must be provided." - ) + if self._credential_source_file and self._credential_source_url: + raise exceptions.MalformedError( + "Ambiguous credential_source. 'file' is mutually exclusive with 'url'." + ) + if not self._credential_source_file and not self._credential_source_url: + raise exceptions.MalformedError( + "Missing credential_source. A 'file' or 'url' must be provided." + ) + + if self._credential_source_file: + self._subject_token_supplier = _FileSupplier( + self._credential_source_file, + self._credential_source_format_type, + self._credential_source_field_name, + ) + else: + self._subject_token_supplier = _UrlSupplier( + self._credential_source_url, + self._credential_source_format_type, + self._credential_source_field_name, + self._credential_source_headers, + ) @_helpers.copy_docstring(external_account.Credentials) def retrieve_subject_token(self, request): - return self._parse_token_data( - self._get_token_data(request), - self._credential_source_format_type, - self._credential_source_field_name, - ) - - def _get_token_data(self, request): - if self._credential_source_file: - return self._get_file_data(self._credential_source_file) - else: - return self._get_url_data( - request, self._credential_source_url, self._credential_source_headers - ) - - def _get_file_data(self, filename): - if not os.path.exists(filename): - raise exceptions.RefreshError("File '{}' was not found.".format(filename)) - - with io.open(filename, "r", encoding="utf-8") as file_obj: - return file_obj.read(), filename - - def _get_url_data(self, request, url, headers): - response = request(url=url, method="GET", headers=headers) - - # support both string and bytes type response.data - response_body = ( - response.data.decode("utf-8") - if hasattr(response.data, "decode") - else response.data + return self._subject_token_supplier.get_subject_token( + self._supplier_context, request ) - if response.status != 200: - raise exceptions.RefreshError( - "Unable to retrieve Identity Pool subject token", response_body - ) - - return response_body, url - - def _parse_token_data( - self, token_content, format_type="text", subject_token_field_name=None - ): - content, filename = token_content - if format_type == "text": - token = content - else: - try: - # Parse file content as JSON. - response_data = json.loads(content) - # Get the subject_token. - token = response_data[subject_token_field_name] - except (KeyError, ValueError): - raise exceptions.RefreshError( - "Unable to parse subject_token from JSON file '{}' using key '{}'".format( - filename, subject_token_field_name - ) - ) - if not token: - raise exceptions.RefreshError( - "Missing subject_token in the credential_source file" - ) - return token - def _create_default_metrics_options(self): metrics_options = super(Credentials, self)._create_default_metrics_options() # Check that credential source is a dict before checking for file vs url. This check needs to be done @@ -226,8 +325,20 @@ class Credentials(external_account.Credentials): metrics_options["source"] = "file" else: metrics_options["source"] = "url" + else: + metrics_options["source"] = "programmatic" return metrics_options + def _has_custom_supplier(self): + return self._credential_source is None + + def _constructor_args(self): + args = super(Credentials, self)._constructor_args() + # If a custom supplier was used, append it to the args dict. + if self._has_custom_supplier(): + args.update({"subject_token_supplier": self._subject_token_supplier}) + return args + @classmethod def from_info(cls, info, **kwargs): """Creates an Identity Pool Credentials instance from parsed external account info. @@ -244,6 +355,8 @@ class Credentials(external_account.Credentials): Raises: ValueError: For invalid parameters. """ + subject_token_supplier = info.get("subject_token_supplier") + kwargs.update({"subject_token_supplier": subject_token_supplier}) return super(Credentials, cls).from_info(info, **kwargs) @classmethod diff --git a/contrib/python/google-auth/py3/google/auth/version.py b/contrib/python/google-auth/py3/google/auth/version.py index 0959c75419..f0dd919dca 100644 --- a/contrib/python/google-auth/py3/google/auth/version.py +++ b/contrib/python/google-auth/py3/google/auth/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.28.2" +__version__ = "2.29.0" diff --git a/contrib/python/google-auth/py3/tests/test_aws.py b/contrib/python/google-auth/py3/tests/test_aws.py index 3f358d52b0..5614820312 100644 --- a/contrib/python/google-auth/py3/tests/test_aws.py +++ b/contrib/python/google-auth/py3/tests/test_aws.py @@ -21,7 +21,7 @@ import urllib.parse import mock import pytest # type: ignore -from google.auth import _helpers +from google.auth import _helpers, external_account from google.auth import aws from google.auth import environment_vars from google.auth import exceptions @@ -616,8 +616,13 @@ class TestRequestSigner(object): ): utcnow.return_value = datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%SZ") request_signer = aws.RequestSigner(region) + credentials_object = aws.AwsSecurityCredentials( + credentials.get("access_key_id"), + credentials.get("secret_access_key"), + credentials.get("security_token"), + ) actual_signed_request = request_signer.get_request_options( - credentials, + credentials_object, original_request.get("url"), original_request.get("method"), original_request.get("data"), @@ -631,10 +636,7 @@ class TestRequestSigner(object): with pytest.raises(ValueError) as excinfo: request_signer.get_request_options( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - }, + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY), "invalid", "POST", ) @@ -646,10 +648,7 @@ class TestRequestSigner(object): with pytest.raises(ValueError) as excinfo: request_signer.get_request_options( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - }, + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY), "http://invalid", "POST", ) @@ -661,10 +660,7 @@ class TestRequestSigner(object): with pytest.raises(ValueError) as excinfo: request_signer.get_request_options( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - }, + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY), "https://", "POST", ) @@ -672,6 +668,36 @@ class TestRequestSigner(object): assert excinfo.match(r"Invalid AWS service URL") +class TestAwsSecurityCredentialsSupplier(aws.AwsSecurityCredentialsSupplier): + def __init__( + self, + security_credentials=None, + region=None, + credentials_exception=None, + region_exception=None, + expected_context=None, + ): + self._security_credentials = security_credentials + self._region = region + self._credentials_exception = credentials_exception + self._region_exception = region_exception + self._expected_context = expected_context + + def get_aws_security_credentials(self, context, request): + if self._expected_context is not None: + assert self._expected_context == context + if self._credentials_exception is not None: + raise self._credentials_exception + return self._security_credentials + + def get_aws_region(self, context, request): + if self._expected_context is not None: + assert self._expected_context == context + if self._region_exception is not None: + raise self._region_exception + return self._region + + class TestCredentials(object): AWS_REGION = "us-east-2" AWS_ROLE = "gcp-aws-role" @@ -734,7 +760,7 @@ class TestCredentials(object): ], } # Include security token if available. - if "security_token" in aws_security_credentials: + if aws_security_credentials.session_token is not None: reformatted_signed_request.get("headers").append( { "key": "x-amz-security-token", @@ -773,16 +799,17 @@ class TestCredentials(object): in an AWS environment. """ responses = [] - if imdsv2_session_token_status: - # AWS session token request - imdsv2_session_response = mock.create_autospec( - transport.Response, instance=True - ) - imdsv2_session_response.status = imdsv2_session_token_status - imdsv2_session_response.data = imdsv2_session_token_data - responses.append(imdsv2_session_response) if region_status: + if imdsv2_session_token_status: + # AWS session token request + imdsv2_session_response = mock.create_autospec( + transport.Response, instance=True + ) + imdsv2_session_response.status = imdsv2_session_token_status + imdsv2_session_response.data = imdsv2_session_token_data + responses.append(imdsv2_session_response) + # AWS region request. region_response = mock.create_autospec(transport.Response, instance=True) region_response.status = region_status @@ -790,6 +817,15 @@ class TestCredentials(object): region_response.data = "{}b".format(region_name).encode("utf-8") responses.append(region_response) + if imdsv2_session_token_status: + # AWS session token request + imdsv2_session_response = mock.create_autospec( + transport.Response, instance=True + ) + imdsv2_session_response.status = imdsv2_session_token_status + imdsv2_session_response.data = imdsv2_session_token_data + responses.append(imdsv2_session_response) + if role_status: # AWS role name request. role_response = mock.create_autospec(transport.Response, instance=True) @@ -834,7 +870,8 @@ class TestCredentials(object): @classmethod def make_credentials( cls, - credential_source, + credential_source=None, + aws_security_credentials_supplier=None, token_url=TOKEN_URL, token_info_url=TOKEN_INFO_URL, client_id=None, @@ -851,6 +888,7 @@ class TestCredentials(object): token_info_url=token_info_url, service_account_impersonation_url=service_account_impersonation_url, credential_source=credential_source, + aws_security_credentials_supplier=aws_security_credentials_supplier, client_id=client_id, client_secret=client_secret, quota_project_id=quota_project_id, @@ -929,6 +967,7 @@ class TestCredentials(object): client_id=CLIENT_ID, client_secret=CLIENT_SECRET, credential_source=self.CREDENTIAL_SOURCE, + aws_security_credentials_supplier=None, quota_project_id=QUOTA_PROJECT_ID, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -957,6 +996,38 @@ class TestCredentials(object): client_id=None, client_secret=None, credential_source=self.CREDENTIAL_SOURCE, + aws_security_credentials_supplier=None, + quota_project_id=None, + workforce_pool_user_project=None, + universe_domain=DEFAULT_UNIVERSE_DOMAIN, + ) + + @mock.patch.object(aws.Credentials, "__init__", return_value=None) + def test_from_info_supplier(self, mock_init): + supplier = TestAwsSecurityCredentialsSupplier() + + credentials = aws.Credentials.from_info( + { + "audience": AUDIENCE, + "subject_token_type": SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "aws_security_credentials_supplier": supplier, + } + ) + + # Confirm aws.Credentials instance initialized with the expected parameters. + assert isinstance(credentials, aws.Credentials) + mock_init.assert_called_once_with( + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + token_info_url=None, + service_account_impersonation_url=None, + service_account_impersonation_options={}, + client_id=None, + client_secret=None, + credential_source=None, + aws_security_credentials_supplier=supplier, quota_project_id=None, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -993,6 +1064,7 @@ class TestCredentials(object): client_id=CLIENT_ID, client_secret=CLIENT_SECRET, credential_source=self.CREDENTIAL_SOURCE, + aws_security_credentials_supplier=None, quota_project_id=QUOTA_PROJECT_ID, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -1022,6 +1094,7 @@ class TestCredentials(object): client_id=None, client_secret=None, credential_source=self.CREDENTIAL_SOURCE, + aws_security_credentials_supplier=None, quota_project_id=None, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -1036,6 +1109,27 @@ class TestCredentials(object): assert excinfo.match(r"No valid AWS 'credential_source' provided") + def test_constructor_invalid_credential_source_and_supplier(self): + # Provide both a credential source and supplier. + with pytest.raises(ValueError) as excinfo: + self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE, + aws_security_credentials_supplier="test", + ) + + assert excinfo.match( + r"AWS credential cannot have both a credential source and an AWS security credentials supplier." + ) + + def test_constructor_invalid_no_credential_source_or_supplier(self): + # Provide no credential source or supplier. + with pytest.raises(ValueError) as excinfo: + self.make_credentials() + + assert excinfo.match( + r"A valid credential source or AWS security credentials supplier must be provided." + ) + def test_constructor_invalid_environment_id(self): # Provide invalid environment_id. credential_source = self.CREDENTIAL_SOURCE.copy() @@ -1158,11 +1252,7 @@ class TestCredentials(object): subject_token = credentials.retrieve_subject_token(request) assert subject_token == self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) # Assert region request. self.assert_aws_metadata_request_kwargs( @@ -1231,11 +1321,7 @@ class TestCredentials(object): subject_token = credentials.retrieve_subject_token(request) assert subject_token == self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) # Assert session token request self.assert_aws_metadata_request_kwargs( @@ -1250,15 +1336,22 @@ class TestCredentials(object): REGION_URL, {"X-aws-ec2-metadata-token": self.AWS_IMDSV2_SESSION_TOKEN}, ) - # Assert role request. + # Assert session token request self.assert_aws_metadata_request_kwargs( request.call_args_list[2][1], + IMDSV2_SESSION_TOKEN_URL, + {"X-aws-ec2-metadata-token-ttl-seconds": "300"}, + "PUT", + ) + # Assert role request. + self.assert_aws_metadata_request_kwargs( + request.call_args_list[3][1], SECURITY_CREDS_URL, {"X-aws-ec2-metadata-token": self.AWS_IMDSV2_SESSION_TOKEN}, ) # Assert security credentials request. self.assert_aws_metadata_request_kwargs( - request.call_args_list[3][1], + request.call_args_list[4][1], "{}/{}".format(SECURITY_CREDS_URL, self.AWS_ROLE), { "Content-Type": "application/json", @@ -1335,11 +1428,7 @@ class TestCredentials(object): subject_token = credentials.retrieve_subject_token(request) assert subject_token == self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) # Assert session token request. self.assert_aws_metadata_request_kwargs( @@ -1396,11 +1485,7 @@ class TestCredentials(object): subject_token = credentials.retrieve_subject_token(request) assert subject_token == self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) # Assert session token request. self.assert_aws_metadata_request_kwargs( @@ -1451,11 +1536,7 @@ class TestCredentials(object): subject_token = credentials.retrieve_subject_token(request) assert subject_token == self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) # Assert session token request. self.assert_aws_metadata_request_kwargs( @@ -1530,11 +1611,7 @@ class TestCredentials(object): subject_token = credentials.retrieve_subject_token(request) assert subject_token == self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) # Assert session token request. self.assert_aws_metadata_request_kwargs( @@ -1549,15 +1626,22 @@ class TestCredentials(object): REGION_URL_IPV6, {"X-aws-ec2-metadata-token": self.AWS_IMDSV2_SESSION_TOKEN}, ) - # Assert role request. + # Assert session token request. self.assert_aws_metadata_request_kwargs( request.call_args_list[2][1], + IMDSV2_SESSION_TOKEN_URL_IPV6, + {"X-aws-ec2-metadata-token-ttl-seconds": "300"}, + "PUT", + ) + # Assert role request. + self.assert_aws_metadata_request_kwargs( + request.call_args_list[3][1], SECURITY_CREDS_URL_IPV6, {"X-aws-ec2-metadata-token": self.AWS_IMDSV2_SESSION_TOKEN}, ) # Assert security credentials request. self.assert_aws_metadata_request_kwargs( - request.call_args_list[3][1], + request.call_args_list[4][1], "{}/{}".format(SECURITY_CREDS_URL_IPV6, self.AWS_ROLE), { "Content-Type": "application/json", @@ -1619,7 +1703,7 @@ class TestCredentials(object): subject_token = credentials.retrieve_subject_token(request) assert subject_token == self.make_serialized_aws_signed_request( - {"access_key_id": ACCESS_KEY_ID, "secret_access_key": SECRET_ACCESS_KEY} + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY) ) @mock.patch("google.auth._helpers.utcnow") @@ -1636,11 +1720,7 @@ class TestCredentials(object): subject_token = credentials.retrieve_subject_token(None) assert subject_token == self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) @mock.patch("google.auth._helpers.utcnow") @@ -1659,11 +1739,7 @@ class TestCredentials(object): subject_token = credentials.retrieve_subject_token(None) assert subject_token == self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) @mock.patch("google.auth._helpers.utcnow") @@ -1686,11 +1762,7 @@ class TestCredentials(object): subject_token = credentials.retrieve_subject_token(None) assert subject_token == self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) @mock.patch("google.auth._helpers.utcnow") @@ -1708,7 +1780,7 @@ class TestCredentials(object): subject_token = credentials.retrieve_subject_token(None) assert subject_token == self.make_serialized_aws_signed_request( - {"access_key_id": ACCESS_KEY_ID, "secret_access_key": SECRET_ACCESS_KEY} + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY) ) @mock.patch("google.auth._helpers.utcnow") @@ -1730,11 +1802,7 @@ class TestCredentials(object): subject_token = credentials.retrieve_subject_token(request) assert subject_token == self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) def test_retrieve_subject_token_error_determining_aws_region(self): @@ -1806,11 +1874,7 @@ class TestCredentials(object): self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" ) expected_subject_token = self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) token_headers = { "Content-Type": "application/x-www-form-urlencoded", @@ -1869,11 +1933,7 @@ class TestCredentials(object): self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" ) expected_subject_token = self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) token_headers = { "Content-Type": "application/x-www-form-urlencoded", @@ -1939,11 +1999,7 @@ class TestCredentials(object): _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) ).isoformat("T") + "Z" expected_subject_token = self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) token_headers = { "Content-Type": "application/x-www-form-urlencoded", @@ -2036,11 +2092,7 @@ class TestCredentials(object): _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) ).isoformat("T") + "Z" expected_subject_token = self.make_serialized_aws_signed_request( - { - "access_key_id": ACCESS_KEY_ID, - "secret_access_key": SECRET_ACCESS_KEY, - "security_token": TOKEN, - } + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) ) token_headers = { "Content-Type": "application/x-www-form-urlencoded", @@ -2122,3 +2174,249 @@ class TestCredentials(object): credentials.refresh(request) assert excinfo.match(r"Unable to retrieve AWS region") + + @mock.patch("google.auth._helpers.utcnow") + def test_retrieve_subject_token_success_with_supplier(self, utcnow): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + request = self.make_mock_request() + + security_credentials = aws.AwsSecurityCredentials( + ACCESS_KEY_ID, SECRET_ACCESS_KEY + ) + supplier = TestAwsSecurityCredentialsSupplier( + security_credentials=security_credentials, region=self.AWS_REGION + ) + + credentials = self.make_credentials(aws_security_credentials_supplier=supplier) + + subject_token = credentials.retrieve_subject_token(request) + assert subject_token == self.make_serialized_aws_signed_request( + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY) + ) + + @mock.patch("google.auth._helpers.utcnow") + def test_retrieve_subject_token_success_with_supplier_session_token(self, utcnow): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + request = self.make_mock_request() + + security_credentials = aws.AwsSecurityCredentials( + ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN + ) + supplier = TestAwsSecurityCredentialsSupplier( + security_credentials=security_credentials, region=self.AWS_REGION + ) + + credentials = self.make_credentials(aws_security_credentials_supplier=supplier) + + subject_token = credentials.retrieve_subject_token(request) + assert subject_token == self.make_serialized_aws_signed_request( + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) + ) + + @mock.patch("google.auth._helpers.utcnow") + def test_retrieve_subject_token_success_with_supplier_correct_context(self, utcnow): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + request = self.make_mock_request() + expected_context = external_account.SupplierContext( + SUBJECT_TOKEN_TYPE, AUDIENCE + ) + + security_credentials = aws.AwsSecurityCredentials( + ACCESS_KEY_ID, SECRET_ACCESS_KEY + ) + supplier = TestAwsSecurityCredentialsSupplier( + security_credentials=security_credentials, + region=self.AWS_REGION, + expected_context=expected_context, + ) + + credentials = self.make_credentials(aws_security_credentials_supplier=supplier) + + credentials.retrieve_subject_token(request) + + def test_retrieve_subject_token_error_with_supplier(self): + request = self.make_mock_request() + expected_exception = exceptions.RefreshError("Test error") + supplier = TestAwsSecurityCredentialsSupplier( + region=self.AWS_REGION, credentials_exception=expected_exception + ) + + credentials = self.make_credentials(aws_security_credentials_supplier=supplier) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.refresh(request) + + assert excinfo.match(r"Test error") + + def test_retrieve_subject_token_error_with_supplier_region(self): + request = self.make_mock_request() + expected_exception = exceptions.RefreshError("Test error") + security_credentials = aws.AwsSecurityCredentials( + ACCESS_KEY_ID, SECRET_ACCESS_KEY + ) + supplier = TestAwsSecurityCredentialsSupplier( + security_credentials=security_credentials, + region_exception=expected_exception, + ) + + credentials = self.make_credentials(aws_security_credentials_supplier=supplier) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.refresh(request) + + assert excinfo.match(r"Test error") + + @mock.patch( + "google.auth.metrics.python_and_auth_lib_version", + return_value=LANG_LIBRARY_METRICS_HEADER_VALUE, + ) + @mock.patch("google.auth._helpers.utcnow") + def test_refresh_success_with_supplier_with_impersonation( + self, utcnow, mock_auth_lib_value + ): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) + ).isoformat("T") + "Z" + expected_subject_token = self.make_serialized_aws_signed_request( + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) + ) + token_headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic " + BASIC_AUTH_ENCODING, + "x-goog-api-client": "gl-python/3.7 auth/1.1 google-byoid-sdk sa-impersonation/true config-lifetime/false source/programmatic", + } + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "scope": "https://www.googleapis.com/auth/iam", + "subject_token": expected_subject_token, + "subject_token_type": SUBJECT_TOKEN_TYPE, + } + # Service account impersonation request/response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + impersonation_headers = { + "Content-Type": "application/json", + "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]), + "x-goog-user-project": QUOTA_PROJECT_ID, + "x-goog-api-client": IMPERSONATE_ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, + "x-allowed-locations": "0x0", + } + impersonation_request_data = { + "delegates": None, + "scope": SCOPES, + "lifetime": "3600s", + } + request = self.make_mock_request( + token_status=http_client.OK, + token_data=self.SUCCESS_RESPONSE, + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + ) + + supplier = TestAwsSecurityCredentialsSupplier( + security_credentials=aws.AwsSecurityCredentials( + ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN + ), + region=self.AWS_REGION, + ) + + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + aws_security_credentials_supplier=supplier, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + quota_project_id=QUOTA_PROJECT_ID, + scopes=SCOPES, + # Default scopes should be ignored. + default_scopes=["ignored"], + ) + + credentials.refresh(request) + + assert len(request.call_args_list) == 2 + # First request should be sent to GCP STS endpoint. + self.assert_token_request_kwargs( + request.call_args_list[0][1], token_headers, token_request_data + ) + # Second request should be sent to iamcredentials endpoint for service + # account impersonation. + self.assert_impersonation_request_kwargs( + request.call_args_list[1][1], + impersonation_headers, + impersonation_request_data, + ) + assert credentials.token == impersonation_response["accessToken"] + assert credentials.quota_project_id == QUOTA_PROJECT_ID + assert credentials.scopes == SCOPES + assert credentials.default_scopes == ["ignored"] + + @mock.patch( + "google.auth.metrics.python_and_auth_lib_version", + return_value=LANG_LIBRARY_METRICS_HEADER_VALUE, + ) + @mock.patch("google.auth._helpers.utcnow") + def test_refresh_success_with_supplier(self, utcnow, mock_auth_lib_value): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + expected_subject_token = self.make_serialized_aws_signed_request( + aws.AwsSecurityCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN) + ) + token_headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic " + BASIC_AUTH_ENCODING, + "x-goog-api-client": "gl-python/3.7 auth/1.1 google-byoid-sdk sa-impersonation/false config-lifetime/false source/programmatic", + } + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "scope": " ".join(SCOPES), + "subject_token": expected_subject_token, + "subject_token_type": SUBJECT_TOKEN_TYPE, + } + request = self.make_mock_request( + token_status=http_client.OK, token_data=self.SUCCESS_RESPONSE + ) + + supplier = TestAwsSecurityCredentialsSupplier( + security_credentials=aws.AwsSecurityCredentials( + ACCESS_KEY_ID, SECRET_ACCESS_KEY, TOKEN + ), + region=self.AWS_REGION, + ) + + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + aws_security_credentials_supplier=supplier, + quota_project_id=QUOTA_PROJECT_ID, + scopes=SCOPES, + # Default scopes should be ignored. + default_scopes=["ignored"], + ) + + credentials.refresh(request) + + assert len(request.call_args_list) == 1 + # First request should be sent to GCP STS endpoint. + self.assert_token_request_kwargs( + request.call_args_list[0][1], token_headers, token_request_data + ) + assert credentials.token == self.SUCCESS_RESPONSE["access_token"] + assert credentials.quota_project_id == QUOTA_PROJECT_ID + assert credentials.scopes == SCOPES + assert credentials.default_scopes == ["ignored"] diff --git a/contrib/python/google-auth/py3/tests/test_external_account.py b/contrib/python/google-auth/py3/tests/test_external_account.py index 03a5014ce5..c458b21b64 100644 --- a/contrib/python/google-auth/py3/tests/test_external_account.py +++ b/contrib/python/google-auth/py3/tests/test_external_account.py @@ -477,16 +477,6 @@ class TestCredentials(object): universe_domain=DEFAULT_UNIVERSE_DOMAIN, ) - def test_with_invalid_impersonation_target_principal(self): - invalid_url = "https://iamcredentials.googleapis.com/v1/invalid" - - with pytest.raises(exceptions.RefreshError) as excinfo: - self.make_credentials(service_account_impersonation_url=invalid_url) - - assert excinfo.match( - r"Unable to determine target principal from service account impersonation URL." - ) - def test_info(self): credentials = self.make_credentials(universe_domain="dummy_universe.com") @@ -1069,6 +1059,21 @@ class TestCredentials(object): assert not credentials.expired assert credentials.token is None + def test_refresh_impersonation_invalid_impersonated_url_error(self): + credentials = self.make_credentials( + service_account_impersonation_url="https://iamcredentials.googleapis.com/v1/invalid", + scopes=self.SCOPES, + ) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.refresh(None) + + assert excinfo.match( + r"Unable to determine target principal from service account impersonation URL." + ) + assert not credentials.expired + assert credentials.token is None + @mock.patch( "google.auth.metrics.python_and_auth_lib_version", return_value=LANG_LIBRARY_METRICS_HEADER_VALUE, @@ -1913,3 +1918,10 @@ class TestCredentials(object): assert project_id is None # Only 2 requests to STS and cloud resource manager should be sent. assert len(request.call_args_list) == 2 + + +def test_supplier_context(): + context = external_account.SupplierContext("TestTokenType", "TestAudience") + + assert context.subject_token_type == "TestTokenType" + assert context.audience == "TestAudience" diff --git a/contrib/python/google-auth/py3/tests/test_identity_pool.py b/contrib/python/google-auth/py3/tests/test_identity_pool.py index 96be1d61c2..0de711832f 100644 --- a/contrib/python/google-auth/py3/tests/test_identity_pool.py +++ b/contrib/python/google-auth/py3/tests/test_identity_pool.py @@ -21,7 +21,7 @@ import urllib import mock import pytest # type: ignore -from google.auth import _helpers +from google.auth import _helpers, external_account from google.auth import exceptions from google.auth import identity_pool from google.auth import metrics @@ -152,6 +152,22 @@ INVALID_SERVICE_ACCOUNT_IMPERSONATION_URLS = [ ] +class TestSubjectTokenSupplier(identity_pool.SubjectTokenSupplier): + def __init__( + self, subject_token=None, subject_token_exception=None, expected_context=None + ): + self._subject_token = subject_token + self._subject_token_exception = subject_token_exception + self._expected_context = expected_context + + def get_subject_token(self, context, request): + if self._expected_context is not None: + assert self._expected_context == context + if self._subject_token_exception is not None: + raise self._subject_token_exception + return self._subject_token + + class TestCredentials(object): CREDENTIAL_SOURCE_TEXT = {"file": SUBJECT_TOKEN_TEXT_FILE} CREDENTIAL_SOURCE_JSON = { @@ -274,10 +290,13 @@ class TestCredentials(object): else: metrics_options["sa-impersonation"] = "false" metrics_options["config-lifetime"] = "false" - if credentials._credential_source_file: - metrics_options["source"] = "file" + if credentials._credential_source: + if credentials._credential_source_file: + metrics_options["source"] = "file" + else: + metrics_options["source"] = "url" else: - metrics_options["source"] = "url" + metrics_options["source"] = "programmatic" token_headers["x-goog-api-client"] = metrics.byoid_metrics_header( metrics_options @@ -387,6 +406,7 @@ class TestCredentials(object): default_scopes=None, service_account_impersonation_url=None, credential_source=None, + subject_token_supplier=None, workforce_pool_user_project=None, ): return identity_pool.Credentials( @@ -396,6 +416,7 @@ class TestCredentials(object): token_info_url=token_info_url, service_account_impersonation_url=service_account_impersonation_url, credential_source=credential_source, + subject_token_supplier=subject_token_supplier, client_id=client_id, client_secret=client_secret, quota_project_id=quota_project_id, @@ -433,6 +454,7 @@ class TestCredentials(object): client_id=CLIENT_ID, client_secret=CLIENT_SECRET, credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=None, quota_project_id=QUOTA_PROJECT_ID, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -461,6 +483,38 @@ class TestCredentials(object): client_id=None, client_secret=None, credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=None, + quota_project_id=None, + workforce_pool_user_project=None, + universe_domain=DEFAULT_UNIVERSE_DOMAIN, + ) + + @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None) + def test_from_info_supplier(self, mock_init): + supplier = TestSubjectTokenSupplier() + + credentials = identity_pool.Credentials.from_info( + { + "audience": AUDIENCE, + "subject_token_type": SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "subject_token_supplier": supplier, + } + ) + + # Confirm identity_pool.Credentials instantiated with expected attributes. + assert isinstance(credentials, identity_pool.Credentials) + mock_init.assert_called_once_with( + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + token_info_url=None, + service_account_impersonation_url=None, + service_account_impersonation_options={}, + client_id=None, + client_secret=None, + credential_source=None, + subject_token_supplier=supplier, quota_project_id=None, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -490,6 +544,7 @@ class TestCredentials(object): client_id=None, client_secret=None, credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=None, quota_project_id=None, workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -525,6 +580,7 @@ class TestCredentials(object): client_id=CLIENT_ID, client_secret=CLIENT_SECRET, credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=None, quota_project_id=QUOTA_PROJECT_ID, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -554,6 +610,7 @@ class TestCredentials(object): client_id=None, client_secret=None, credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=None, quota_project_id=None, workforce_pool_user_project=None, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -584,6 +641,7 @@ class TestCredentials(object): client_id=None, client_secret=None, credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=None, quota_project_id=None, workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, universe_domain=DEFAULT_UNIVERSE_DOMAIN, @@ -634,7 +692,29 @@ class TestCredentials(object): with pytest.raises(ValueError) as excinfo: self.make_credentials(credential_source="non-dict") - assert excinfo.match(r"Missing credential_source") + assert excinfo.match( + r"Invalid credential_source. The credential_source is not a dict." + ) + + def test_constructor_invalid_no_credential_source_or_supplier(self): + with pytest.raises(ValueError) as excinfo: + self.make_credentials() + + assert excinfo.match( + r"A valid credential source or a subject token supplier must be provided." + ) + + def test_constructor_invalid_both_credential_source_and_supplier(self): + supplier = TestSubjectTokenSupplier() + with pytest.raises(ValueError) as excinfo: + self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_TEXT, + subject_token_supplier=supplier, + ) + + assert excinfo.match( + r"Identity pool credential cannot have both a credential source and a subject token supplier." + ) def test_constructor_invalid_credential_source_format_type(self): credential_source = {"format": {"type": "xml"}} @@ -1298,3 +1378,78 @@ class TestCredentials(object): self.CREDENTIAL_URL, "not_found" ) ) + + def test_retrieve_subject_token_supplier(self): + supplier = TestSubjectTokenSupplier(subject_token=JSON_FILE_SUBJECT_TOKEN) + + credentials = self.make_credentials(subject_token_supplier=supplier) + + subject_token = credentials.retrieve_subject_token(None) + + assert subject_token == JSON_FILE_SUBJECT_TOKEN + + def test_retrieve_subject_token_supplier_correct_context(self): + supplier = TestSubjectTokenSupplier( + subject_token=JSON_FILE_SUBJECT_TOKEN, + expected_context=external_account.SupplierContext( + SUBJECT_TOKEN_TYPE, AUDIENCE + ), + ) + + credentials = self.make_credentials(subject_token_supplier=supplier) + + credentials.retrieve_subject_token(None) + + def test_retrieve_subject_token_supplier_error(self): + expected_exception = exceptions.RefreshError("test error") + supplier = TestSubjectTokenSupplier(subject_token_exception=expected_exception) + + credentials = self.make_credentials(subject_token_supplier=supplier) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.refresh(self.make_mock_request(token_data=JSON_FILE_CONTENT)) + + assert excinfo.match("test error") + + def test_refresh_success_supplier_with_impersonation_url(self): + # Initialize credentials with service account impersonation and a supplier. + supplier = TestSubjectTokenSupplier(subject_token=JSON_FILE_SUBJECT_TOKEN) + credentials = self.make_credentials( + subject_token_supplier=supplier, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=SCOPES, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + basic_auth_encoding=None, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + default_scopes=None, + ) + + def test_refresh_success_supplier_without_impersonation_url(self): + # Initialize supplier credentials without service account impersonation. + supplier = TestSubjectTokenSupplier(subject_token=JSON_FILE_SUBJECT_TOKEN) + credentials = self.make_credentials( + subject_token_supplier=supplier, scopes=SCOPES + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + basic_auth_encoding=None, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + default_scopes=None, + ) diff --git a/contrib/python/google-auth/py3/ya.make b/contrib/python/google-auth/py3/ya.make index 6b3decfc4e..952d1ebdd3 100644 --- a/contrib/python/google-auth/py3/ya.make +++ b/contrib/python/google-auth/py3/ya.make @@ -2,7 +2,7 @@ PY3_LIBRARY() -VERSION(2.28.2) +VERSION(2.29.0) LICENSE(Apache-2.0) |