diff --git a/packages/google-auth/google/auth/_credentials_async.py b/packages/google-auth/google/auth/_credentials_async.py index 760758d851b0..937f6e8fb6df 100644 --- a/packages/google-auth/google/auth/_credentials_async.py +++ b/packages/google-auth/google/auth/_credentials_async.py @@ -18,6 +18,7 @@ import abc import inspect +from google.auth import _regional_access_boundary_utils from google.auth import credentials @@ -64,8 +65,28 @@ async def before_request(self, request, method, url, headers): await self.refresh(request) else: self.refresh(request) + + if inspect.iscoroutinefunction(self._after_refresh): + await self._after_refresh(request, method, url, headers) + else: + self._after_refresh(request, method, url, headers) + self.apply(headers) + def _after_refresh(self, request, method, url, headers): + """Hook for subclasses to perform actions after refresh but before + applying credentials to headers. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + method (str): The request's HTTP method or the RPC method being + invoked. + url (str): The request's URI or the RPC service's URI. + headers (Mapping[str, str]): The request's headers. + """ + pass + class CredentialsWithQuotaProject(credentials.CredentialsWithQuotaProject): """Abstract base for credentials supporting ``with_quota_project`` factory""" @@ -169,3 +190,74 @@ def with_scopes_if_required(credentials, scopes): class Signing(credentials.Signing, metaclass=abc.ABCMeta): """Interface for credentials that can cryptographically sign messages.""" + + +class CredentialsWithRegionalAccessBoundary( + Credentials, credentials.CredentialsWithRegionalAccessBoundary +): + """Async base for credentials supporting regional access boundary configuration.""" + + def __init__(self): + super().__init__() + self._rab_manager.refresh_manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + + def __setstate__(self, state): + super().__setstate__(state) + self._rab_manager.refresh_manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + + async def _after_refresh(self, request, method, url, headers): + """Triggers the Regional Access Boundary lookup asynchronously if necessary.""" + await self._maybe_start_regional_access_boundary_refresh_async(request, url) + + async def _maybe_start_regional_access_boundary_refresh_async(self, request, url): + """Starts a background refresh or performs a blocking refresh asynchronously. + + Args: + request (google.auth.aio.transport.Request): The object used to make + HTTP requests. + url (str): The URL of the request. + """ + # Do not perform a lookup if the request is for a regional endpoint. + if self._is_regional_endpoint(url): + return + + # A refresh is only needed if the feature is enabled. + if not self._is_regional_access_boundary_lookup_required(): + return + + # Trigger background or blocking refresh if needed. + await self._rab_manager.maybe_start_refresh_async(self, request) + + async def _lookup_regional_access_boundary(self, request, fail_fast=False): + """Calls the Regional Access Boundary lookup API asynchronously. + + Args: + request (google.auth.aio.transport.Request): The object used to make + HTTP requests. + fail_fast (bool): Whether the lookup should fail fast (short timeout, no retries). + + Returns: + Optional[Dict[str, str]]: The Regional Access Boundary information + returned by the lookup API, or None if the lookup failed. + """ + url_builder = self._build_regional_access_boundary_lookup_url + if inspect.iscoroutinefunction(url_builder): + url = await url_builder(request=request) + else: + url = url_builder(request=request) + + if not url: + return None + + headers = {} + self._apply(headers) + + from google.oauth2 import _client_async + + return await _client_async._lookup_regional_access_boundary( + request, url, headers=headers, fail_fast=fail_fast + ) diff --git a/packages/google-auth/google/auth/_jwt_async.py b/packages/google-auth/google/auth/_jwt_async.py index 3a1abc5b85c9..a956fc05186a 100644 --- a/packages/google-auth/google/auth/_jwt_async.py +++ b/packages/google-auth/google/auth/_jwt_async.py @@ -91,7 +91,9 @@ def decode(token, certs=None, verify=True, audience=None): class Credentials( - jwt.Credentials, _credentials_async.Signing, _credentials_async.Credentials + jwt.Credentials, + _credentials_async.Signing, + _credentials_async.CredentialsWithRegionalAccessBoundary, ): """Credentials that use a JWT as the bearer token. @@ -142,6 +144,15 @@ class Credentials( new_credentials = credentials.with_claims(audience=new_audience) """ + def __setstate__(self, state): + """Restores the credential state and ensures the async refresh manager is attached.""" + super().__setstate__(state) + from google.auth import _regional_access_boundary_utils + + self._rab_manager.refresh_manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + class OnDemandCredentials( jwt.OnDemandCredentials, _credentials_async.Signing, _credentials_async.Credentials diff --git a/packages/google-auth/google/auth/_regional_access_boundary_utils.py b/packages/google-auth/google/auth/_regional_access_boundary_utils.py index 81011911df3d..3024d3719b8d 100644 --- a/packages/google-auth/google/auth/_regional_access_boundary_utils.py +++ b/packages/google-auth/google/auth/_regional_access_boundary_utils.py @@ -14,6 +14,7 @@ """Utilities for Regional Access Boundary management.""" +import asyncio import copy import datetime import functools @@ -170,12 +171,11 @@ def apply_headers(self, headers): else: headers.pop(_REGIONAL_ACCESS_BOUNDARY_HEADER, None) - def maybe_start_refresh(self, credentials, request): - """Starts a background thread to refresh the Regional Access Boundary if needed. + def _should_refresh(self): + """Checks if the Regional Access Boundary data needs a refresh and is not in cooldown. - Args: - credentials (google.auth.credentials.Credentials): The credentials to refresh. - request (google.auth.transport.Request): The object used to make HTTP requests. + Returns: + bool: True if a refresh is required, False otherwise. """ rab_data = self._data @@ -186,10 +186,22 @@ def maybe_start_refresh(self, credentials, request): and _helpers.utcnow() < (rab_data.expiry - REGIONAL_ACCESS_BOUNDARY_REFRESH_THRESHOLD) ): - return + return False # Don't start a new refresh if the cooldown is still in effect. if rab_data.cooldown_expiry and _helpers.utcnow() < rab_data.cooldown_expiry: + return False + + return True + + def maybe_start_refresh(self, credentials, request): + """Starts a background thread to refresh the Regional Access Boundary if needed. + + Args: + credentials (google.auth.credentials.Credentials): The credentials to refresh. + request (google.auth.transport.Request): The object used to make HTTP requests. + """ + if not self._should_refresh(): return # If all checks pass, start the background refresh. @@ -198,6 +210,22 @@ def maybe_start_refresh(self, credentials, request): else: self.refresh_manager.start_refresh(credentials, request, self) + async def maybe_start_refresh_async(self, credentials, request): + """Starts a background refresh or performs a blocking refresh asynchronously. + + Args: + credentials (google.auth.credentials.Credentials): The credentials to refresh. + request (google.auth.aio.transport.Request): The object used to make HTTP requests. + """ + if not self._should_refresh(): + return + + # If all checks pass, start the refresh. + if self._use_blocking_regional_access_boundary_lookup: + await self.start_blocking_refresh_async(credentials, request) + else: + self.refresh_manager.start_refresh(credentials, request, self) + def start_blocking_refresh(self, credentials, request): """Initiates a blocking lookup of the Regional Access Boundary. @@ -227,6 +255,37 @@ def start_blocking_refresh(self, credentials, request): self.process_regional_access_boundary_info(regional_access_boundary_info) + async def start_blocking_refresh_async(self, credentials, request): + """Initiates a blocking lookup of the Regional Access Boundary asynchronously. + + If the lookup raises an exception, it is caught and logged as a warning, + and the lookup is treated as a failure (entering cooldown). Exceptions + are not propagated to the caller. + + Args: + credentials (google.auth.credentials.Credentials): The credentials to refresh. + request (google.auth.aio.transport.Request): The object used to make HTTP requests. + """ + try: + # The fail_fast parameter is set to True to ensure we don't block the calling + # thread for too long. This will do two things: 1) set a timeout to 3s + # instead of the default 120s and 2) ensure we do not retry at all + regional_access_boundary_info = ( + await credentials._lookup_regional_access_boundary( + request, fail_fast=True + ) + ) + except Exception as e: + if _helpers.is_logging_enabled(_LOGGER): + _LOGGER.warning( + "Regional Access Boundary lookup raised an exception: %s", + e, + exc_info=True, + ) + regional_access_boundary_info = None + + self.process_regional_access_boundary_info(regional_access_boundary_info) + def process_regional_access_boundary_info(self, regional_access_boundary_info): """Processes the regional access boundary info and updates the state. @@ -384,3 +443,61 @@ def start_refresh(self, credentials, request, rab_manager): credentials, copied_request, rab_manager ) self._worker.start() + + +class _AsyncRegionalAccessBoundaryRefreshManager(object): + """Manages a task for background refreshing of the Regional Access Boundary in async flows.""" + + def __init__(self): + self._lock = threading.Lock() + self._worker_task = None + + def __getstate__(self): + """Pickle helper that excludes the un-picklable _lock and _worker_task attributes from serialization.""" + state = self.__dict__.copy() + state["_lock"] = None + state["_worker_task"] = None + return state + + def __setstate__(self, state): + """Pickle helper that restores state and re-initializes the _lock and _worker_task attributes.""" + self.__dict__.update(state) + self._lock = threading.Lock() + self._worker_task = None + + def start_refresh(self, credentials, request, rab_manager): + """ + Starts a background task to refresh the Regional Access Boundary if one is not already running. + + Args: + credentials (CredentialsWithRegionalAccessBoundary): The credentials + to refresh. + request (google.auth.aio.transport.Request): The object used to make + HTTP requests. + rab_manager (_RegionalAccessBoundaryManager): The manager container to update. + """ + with self._lock: + if self._worker_task and not self._worker_task.done(): + # A refresh is already in progress. + return + + async def _worker(): + try: + # credentials._lookup_regional_access_boundary should be async in the async creds class + regional_access_boundary_info = ( + await credentials._lookup_regional_access_boundary(request) + ) + except Exception as e: + if _helpers.is_logging_enabled(_LOGGER): + _LOGGER.warning( + "Asynchronous Regional Access Boundary lookup raised an exception: %s", + e, + exc_info=True, + ) + regional_access_boundary_info = None + + rab_manager.process_regional_access_boundary_info( + regional_access_boundary_info + ) + + self._worker_task = asyncio.create_task(_worker()) diff --git a/packages/google-auth/google/auth/credentials.py b/packages/google-auth/google/auth/credentials.py index 4a686cb01907..19d4eb1a822d 100644 --- a/packages/google-auth/google/auth/credentials.py +++ b/packages/google-auth/google/auth/credentials.py @@ -239,9 +239,25 @@ def before_request(self, request, method, url, headers): else: self._blocking_refresh(request) + self._after_refresh(request, method, url, headers) + metrics.add_metric_header(headers, self._metric_header_for_usage()) self.apply(headers) + def _after_refresh(self, request, method, url, headers): + """Hook for subclasses to perform actions after refresh but before + applying credentials to headers. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + method (str): The request's HTTP method or the RPC method being + invoked. + url (str): The request's URI or the RPC service's URI. + headers (Mapping): The request's headers. + """ + pass + def with_non_blocking_refresh(self): self._use_non_blocking_refresh = True @@ -309,6 +325,22 @@ def __init__(self): _regional_access_boundary_utils._RegionalAccessBoundaryManager() ) + def __setstate__(self, state): + """Pickle helper that restores state, safely reconstructing RAB fields if missing.""" + self.__dict__.update(state) + if "_rab_manager" not in self.__dict__: + from google.auth import _regional_access_boundary_utils + + self._rab_manager = ( + _regional_access_boundary_utils._RegionalAccessBoundaryManager() + ) + if "_use_non_blocking_refresh" not in self.__dict__: + self._use_non_blocking_refresh = False + if "_refresh_worker" not in self.__dict__: + from google.auth._refresh_worker import RefreshThreadManager + + self._refresh_worker = RefreshThreadManager() + @property def regional_access_boundary(self): """Optional[str]: The encoded Regional Access Boundary locations.""" @@ -364,12 +396,11 @@ def with_trust_boundary(self, trust_boundary): ) def _copy_regional_access_boundary_manager(self, target): - """Copies the regional access boundary manager to another instance.""" - # Create a new manager for the clone to isolate background refresh locks and threads, - # but share the immutable data reference to avoid unnecessary initial lookups. - new_manager = _regional_access_boundary_utils._RegionalAccessBoundaryManager() - new_manager._data = self._rab_manager._data - target._rab_manager = new_manager + """Copies the regional access boundary manager state to another instance.""" + target._rab_manager._data = self._rab_manager._data + target._rab_manager._use_blocking_regional_access_boundary_lookup = ( + self._rab_manager._use_blocking_regional_access_boundary_lookup + ) def _set_regional_access_boundary(self, seed): """Applies the regional_access_boundary provided via the seed on these @@ -403,18 +434,14 @@ def _set_blocking_regional_access_boundary_lookup(self): self._rab_manager.enable_blocking_lookup() return self - def _maybe_start_regional_access_boundary_refresh(self, request, url): - """ - Starts a background thread to refresh the Regional Access Boundary if needed. - - This method checks if a refresh is necessary and if one is not already - in progress or in a cooldown period. If so, it starts a background - thread to perform the lookup. + def _is_regional_endpoint(self, url): + """Checks if the request URL is for a regional endpoint. Args: - request (google.auth.transport.Request): The object used to make - HTTP requests. url (str): The URL of the request. + + Returns: + bool: True if the URL is a regional endpoint, False otherwise. """ try: # Do not perform a lookup if the request is for a regional endpoint. @@ -423,16 +450,35 @@ def _maybe_start_regional_access_boundary_refresh(self, request, url): hostname.endswith(".rep.googleapis.com") or hostname.endswith(".rep.sandbox.googleapis.com") ): - return - except (ValueError, TypeError): + return True + except (ValueError, TypeError, AttributeError): # If the URL is malformed, proceed with the default lookup behavior. pass + return False + + def _maybe_start_regional_access_boundary_refresh(self, request, url): + """ + Starts a background thread to refresh the Regional Access Boundary if needed. + + This method checks if a refresh is necessary and if one is not already + in progress or in a cooldown period. If so, it starts a background + thread to perform the lookup. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + url (str): The URL of the request. + """ + # Do not perform a lookup if the request is for a regional endpoint. + if self._is_regional_endpoint(url): + return + # A refresh is only needed if the feature is enabled. if not self._is_regional_access_boundary_lookup_required(): return - # Start the background refresh if needed. + # Trigger background or blocking refresh if needed self._rab_manager.maybe_start_refresh(self, request) def _is_regional_access_boundary_lookup_required(self): @@ -444,11 +490,11 @@ def _is_regional_access_boundary_lookup_required(self): Returns: bool: True if a Regional Access Boundary lookup is required, False otherwise. """ - # 1. Check if the feature is enabled. + # Check if the feature is enabled. if not _regional_access_boundary_utils.is_regional_access_boundary_enabled(): return False - # 2. Skip for non-default universe domains. + # Skip for non-default universe domains. if self.universe_domain != DEFAULT_UNIVERSE_DOMAIN: return False @@ -459,20 +505,10 @@ def apply(self, headers, token=None): super().apply(headers, token) self._rab_manager.apply_headers(headers) - def before_request(self, request, method, url, headers): - """Refreshes the access token and triggers the Regional Access Boundary - lookup if necessary. - """ - if self._use_non_blocking_refresh: - self._non_blocking_refresh(request) - else: - self._blocking_refresh(request) - + def _after_refresh(self, request, method, url, headers): + """Triggers the Regional Access Boundary lookup if necessary.""" self._maybe_start_regional_access_boundary_refresh(request, url) - metrics.add_metric_header(headers, self._metric_header_for_usage()) - self.apply(headers) - def refresh(self, request): """Refreshes the access token. @@ -505,7 +541,6 @@ def _lookup_regional_access_boundary( headers: Dict[str, str] = {} self._apply(headers) - self._rab_manager.apply_headers(headers) return _client._lookup_regional_access_boundary( request, url, headers=headers, fail_fast=fail_fast ) diff --git a/packages/google-auth/google/auth/jwt.py b/packages/google-auth/google/auth/jwt.py index b6fe60736fa1..38a84bfd97aa 100644 --- a/packages/google-auth/google/auth/jwt.py +++ b/packages/google-auth/google/auth/jwt.py @@ -55,6 +55,7 @@ from google.auth import _service_account_info from google.auth import crypt from google.auth import exceptions +from google.auth import iam import google.auth.credentials try: @@ -317,7 +318,9 @@ def decode(token, certs=None, verify=True, audience=None, clock_skew_in_seconds= class Credentials( - google.auth.credentials.Signing, google.auth.credentials.CredentialsWithQuotaProject + google.auth.credentials.Signing, + google.auth.credentials.CredentialsWithQuotaProject, + google.auth.credentials.CredentialsWithRegionalAccessBoundary, ): """Credentials that use a JWT as the bearer token. @@ -490,7 +493,15 @@ def from_signing_credentials(cls, credentials, audience, **kwargs): """ kwargs.setdefault("issuer", credentials.signer_email) kwargs.setdefault("subject", credentials.signer_email) - return cls(credentials.signer, audience=audience, **kwargs) + jwt_creds = cls(credentials.signer, audience=audience, **kwargs) + + if isinstance( + credentials, + google.auth.credentials.CredentialsWithRegionalAccessBoundary, + ): + credentials._copy_regional_access_boundary_manager(jwt_creds) + + return jwt_creds def with_claims( self, issuer=None, subject=None, audience=None, additional_claims=None @@ -514,7 +525,7 @@ def with_claims( new_additional_claims = copy.deepcopy(self._additional_claims) new_additional_claims.update(additional_claims or {}) - return self.__class__( + cred = self.__class__( self._signer, issuer=issuer if issuer is not None else self._issuer, subject=subject if subject is not None else self._subject, @@ -522,10 +533,12 @@ def with_claims( additional_claims=new_additional_claims, quota_project_id=self._quota_project_id, ) + self._copy_regional_access_boundary_manager(cred) + return cred @_helpers.copy_docstring(google.auth.credentials.CredentialsWithQuotaProject) def with_quota_project(self, quota_project_id): - return self.__class__( + cred = self.__class__( self._signer, issuer=self._issuer, subject=self._subject, @@ -533,6 +546,8 @@ def with_quota_project(self, quota_project_id): additional_claims=self._additional_claims, quota_project_id=quota_project_id, ) + self._copy_regional_access_boundary_manager(cred) + return cred def _make_jwt(self): """Make a signed JWT. @@ -559,7 +574,7 @@ def _make_jwt(self): return jwt, expiry - def refresh(self, request): + def _perform_refresh_token(self, request): """Refreshes the access token. Args: @@ -569,6 +584,15 @@ def refresh(self, request): # (pylint doesn't correctly recognize overridden methods.) self.token, self.expiry = self._make_jwt() + def _build_regional_access_boundary_lookup_url(self, request=None): + """Builds the lookup URL using the service account's email address.""" + if not self.signer_email: + return None + + return iam._SERVICE_ACCOUNT_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT.format( + service_account_email=self.signer_email + ) + @_helpers.copy_docstring(google.auth.credentials.Signing) def sign_bytes(self, message): return self._signer.sign(message) diff --git a/packages/google-auth/google/oauth2/_client.py b/packages/google-auth/google/oauth2/_client.py index 1c7ba46b72e1..88083d022986 100644 --- a/packages/google-auth/google/oauth2/_client.py +++ b/packages/google-auth/google/oauth2/_client.py @@ -549,7 +549,7 @@ def _lookup_regional_access_boundary(request, url, headers=None, fail_fast=False # Error was already logged by _lookup_regional_access_boundary_request return None - if "encodedLocations" not in response_data: + if not isinstance(response_data, dict) or "encodedLocations" not in response_data: _LOGGER.error( "Regional Access Boundary response malformed: missing 'encodedLocations' key in %s", response_data, diff --git a/packages/google-auth/google/oauth2/_client_async.py b/packages/google-auth/google/oauth2/_client_async.py index a6201fbdcb94..a80531d96758 100644 --- a/packages/google-auth/google/oauth2/_client_async.py +++ b/packages/google-auth/google/oauth2/_client_async.py @@ -23,6 +23,7 @@ .. _Section 3.1 of rfc6749: https://tools.ietf.org/html/rfc6749#section-3.2 """ +import asyncio import http.client as http_client import json import urllib @@ -288,3 +289,159 @@ async def refresh_grant( request, token_uri, body, can_retry=can_retry ) return client._handle_refresh_grant_response(response_data, refresh_token) + + +async def _lookup_regional_access_boundary(request, url, headers=None, fail_fast=False): + """Implements the global lookup of a credential Regional Access Boundary. + For the lookup, we send a request to the global lookup endpoint and then + parse the response. Service account credentials, workload identity + pools and workforce pools implementation may have Regional Access Boundaries configured. + Args: + request (google.auth.aio.transport.Request): A callable used to make + HTTP requests. The returned response must support `await response.read()` + (standard async transport) or `await response.content()` (legacy/custom transport). + url (str): The Regional Access Boundary lookup url. + headers (Optional[Mapping[str, str]]): The headers for the request. + fail_fast (bool): Whether the lookup should fail fast (uses a short timeout and no retries). + Returns: + Optional[Mapping[str,list|str]]: A dictionary containing + "locations" as a list of allowed locations as strings and + "encodedLocations" as a hex string. + e.g: + { + "locations": [ + "us-central1", "us-east1", "europe-west1", "asia-east1" + ], + "encodedLocations": "0xA30" + } + """ + response_data = await _lookup_regional_access_boundary_request( + request, url, headers=headers, fail_fast=fail_fast + ) + if response_data is None: + # Error was already logged by _lookup_regional_access_boundary_request + return None + + if not isinstance(response_data, dict) or "encodedLocations" not in response_data: + client._LOGGER.error( + "Regional Access Boundary response malformed: missing 'encodedLocations' key in %s", + response_data, + ) + return None + return response_data + + +async def _lookup_regional_access_boundary_request( + request, url, can_retry=True, headers=None, fail_fast=False +): + """Makes a request to the Regional Access Boundary lookup endpoint. + + Args: + request (google.auth.aio.transport.Request): A callable used to make + HTTP requests. The returned response must support `await response.read()` + (standard async transport) or `await response.content()` (legacy/custom transport). + url (str): The Regional Access Boundary lookup url. + can_retry (bool): Enable or disable request retry behavior. Defaults to true. + headers (Optional[Mapping[str, str]]): The headers for the request. + fail_fast (bool): Whether the lookup should fail fast (uses a short timeout and no retries). + + Returns: + Optional[Mapping[str, str]]: The JSON-decoded response data on success, or None on failure. + """ + ( + response_status_ok, + response_data, + retryable_error, + ) = await _lookup_regional_access_boundary_request_no_throw( + request, url, can_retry=can_retry, headers=headers, fail_fast=fail_fast + ) + if not response_status_ok: + client._LOGGER.warning( + "Regional Access Boundary HTTP request failed after retries: response_data=%s, retryable_error=%s", + response_data, + retryable_error, + ) + return None + return response_data + + +async def _lookup_regional_access_boundary_request_no_throw( + request, url, can_retry=True, headers=None, fail_fast=False +): + """Makes a request to the Regional Access Boundary lookup endpoint. This + function doesn't throw on response errors. + + Args: + request (google.auth.aio.transport.Request): A callable used to make + HTTP requests. The returned response must support `await response.read()` + (standard async transport) or `await response.content()` (legacy/custom transport). + url (str): The Regional Access Boundary lookup url. + can_retry (bool): Enable or disable request retry behavior. Defaults to true. + headers (Optional[Mapping[str, str]]): The headers for the request. + fail_fast (bool): Whether the lookup should fail fast (uses a short timeout and no retries). + + Returns: + Tuple(bool, Mapping[str, str], Optional[bool]): A boolean indicating + if the request is successful, a mapping for the JSON-decoded response + data and in the case of an error a boolean indicating if the error + is retryable. + """ + + response_data = {} + retryable_error = False + + timeout = ( + client._BLOCKING_REGIONAL_ACCESS_BOUNDARY_LOOKUP_TIMEOUT if fail_fast else None + ) + total_attempts = 1 if fail_fast else 6 + retries = _exponential_backoff.AsyncExponentialBackoff( + total_attempts=total_attempts + ) + + async for _ in retries: + try: + if timeout: + response = await asyncio.wait_for( + request(method="GET", url=url, headers=headers, timeout=timeout), + timeout=timeout, + ) + else: + response = await request(method="GET", url=url, headers=headers) + + # Supports both modern google.auth.aio (exposing read()) and legacy transports (exposing content()) + if hasattr(response, "read"): + response_bytes = await response.read() + else: + response_bytes = await response.content() + except (asyncio.TimeoutError, exceptions.TransportError): + return False, {}, False + + try: + response_body = ( + response_bytes.decode("utf-8") + if hasattr(response_bytes, "decode") + else response_bytes + ) + response_data = json.loads(response_body) + except (UnicodeDecodeError, ValueError): + return False, {}, False + + status_code = ( + response.status_code + if hasattr(response, "status_code") + else response.status + ) + + if status_code == http_client.OK: + return True, response_data, None + + retryable_error = client._can_retry( + status_code=status_code, response_data=response_data + ) + if status_code == http_client.BAD_GATEWAY: + retryable_error = True + + if not can_retry or not retryable_error: + return False, response_data, retryable_error + + return False, response_data, retryable_error diff --git a/packages/google-auth/google/oauth2/_service_account_async.py b/packages/google-auth/google/oauth2/_service_account_async.py index fa6cfb7b7d7a..39af8cfea2d5 100644 --- a/packages/google-auth/google/oauth2/_service_account_async.py +++ b/packages/google-auth/google/oauth2/_service_account_async.py @@ -29,7 +29,9 @@ class Credentials( - service_account.Credentials, credentials_async.Scoped, credentials_async.Credentials + service_account.Credentials, + credentials_async.Scoped, + credentials_async.CredentialsWithRegionalAccessBoundary, ): """Service account credentials @@ -66,6 +68,15 @@ class Credentials( credentials = credentials.with_quota_project('myproject-123') """ + def __setstate__(self, state): + """Restores the credential state and ensures the async refresh manager is attached.""" + super().__setstate__(state) + from google.auth import _regional_access_boundary_utils + + self._rab_manager.refresh_manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + @_helpers.copy_docstring(credentials_async.Credentials) async def refresh(self, request): assertion = self._make_authorization_grant_assertion() @@ -75,13 +86,6 @@ async def refresh(self, request): self.token = access_token self.expiry = expiry - @_helpers.copy_docstring(credentials_async.Credentials) - async def before_request(self, request, method, url, headers): - # Explicit override to bypass synchronous CredentialsWithRegionalAccessBoundary. - await credentials_async.Credentials.before_request( - self, request, method, url, headers - ) - class IDTokenCredentials( service_account.IDTokenCredentials, @@ -137,11 +141,3 @@ async def refresh(self, request): ) self.token = access_token self.expiry = expiry - - @_helpers.copy_docstring(credentials_async.Credentials) - async def before_request(self, request, method, url, headers): - # Explicit override to bypass synchronous CredentialsWithRegionalAccessBoundary - # and disable Regional Access Boundary refresh for async credentials. - await credentials_async.Credentials.before_request( - self, request, method, url, headers - ) diff --git a/packages/google-auth/tests/compute_engine/test_credentials.py b/packages/google-auth/tests/compute_engine/test_credentials.py index 5a60ffd44145..864ddf6436df 100644 --- a/packages/google-auth/tests/compute_engine/test_credentials.py +++ b/packages/google-auth/tests/compute_engine/test_credentials.py @@ -306,8 +306,9 @@ def test_build_regional_access_boundary_lookup_url_default_email( url = creds._build_regional_access_boundary_lookup_url(request=mock_request) mock_get_service_account_info.assert_called_once_with(mock_request, "default") - expected_url = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/resolved-email@example.com/allowedLocations" - assert url == expected_url + expected_url_standard = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/resolved-email@example.com/allowedLocations" + expected_url_mtls = "https://iamcredentials.mtls.googleapis.com/v1/projects/-/serviceAccounts/resolved-email@example.com/allowedLocations" + assert url in (expected_url_standard, expected_url_mtls) @mock.patch("google.auth.compute_engine._metadata.get", autospec=True) def test_build_regional_access_boundary_lookup_url_http_client_request( @@ -323,8 +324,9 @@ def test_build_regional_access_boundary_lookup_url_http_client_request( url = creds._build_regional_access_boundary_lookup_url(request=req) - expected_url = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/resolved-email@example.com/allowedLocations" - assert url == expected_url + expected_url_standard = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/resolved-email@example.com/allowedLocations" + expected_url_mtls = "https://iamcredentials.mtls.googleapis.com/v1/projects/-/serviceAccounts/resolved-email@example.com/allowedLocations" + assert url in (expected_url_standard, expected_url_mtls) @mock.patch( "google.auth.compute_engine._metadata.get_service_account_info", autospec=True @@ -343,9 +345,9 @@ def test_build_regional_access_boundary_lookup_url_explicit_email( url = creds._build_regional_access_boundary_lookup_url() mock_get_service_account_info.assert_not_called() - assert url == ( - "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/foo@bar.com/allowedLocations" - ) + expected_url_standard = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/foo@bar.com/allowedLocations" + expected_url_mtls = "https://iamcredentials.mtls.googleapis.com/v1/projects/-/serviceAccounts/foo@bar.com/allowedLocations" + assert url in (expected_url_standard, expected_url_mtls) @mock.patch( "google.auth.compute_engine._metadata.get_universe_domain", autospec=True diff --git a/packages/google-auth/tests/oauth2/test_service_account.py b/packages/google-auth/tests/oauth2/test_service_account.py index f0d8f0759e50..4da25a65407a 100644 --- a/packages/google-auth/tests/oauth2/test_service_account.py +++ b/packages/google-auth/tests/oauth2/test_service_account.py @@ -230,13 +230,16 @@ def test_with_quota_project(self): def test_build_regional_access_boundary_lookup_url(self): credentials = self.make_credentials() - expected_url = ( - "https://iamcredentials.googleapis.com/v1/projects/-/" - "serviceAccounts/{}/allowedLocations".format( - credentials.service_account_email - ) + url = credentials._build_regional_access_boundary_lookup_url() + + expected_url_standard = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( + credentials.service_account_email ) - assert credentials._build_regional_access_boundary_lookup_url() == expected_url + expected_url_mtls = "https://iamcredentials.mtls.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( + credentials.service_account_email + ) + + assert url in (expected_url_standard, expected_url_mtls) def test_with_token_uri(self): credentials = self.make_credentials() diff --git a/packages/google-auth/tests/test__regional_access_boundary_utils.py b/packages/google-auth/tests/test__regional_access_boundary_utils.py index ab6ec75fd9b8..c345894deb60 100644 --- a/packages/google-auth/tests/test__regional_access_boundary_utils.py +++ b/packages/google-auth/tests/test__regional_access_boundary_utils.py @@ -301,6 +301,24 @@ def test_serialization(self): assert unpickled.refresh_manager._lock is not None assert unpickled.refresh_manager._worker is None + def test_unpickle_old_credentials_without_rab(self): + creds = CredentialsImpl() + old_state = creds.__dict__.copy() + if "_rab_manager" in old_state: + del old_state["_rab_manager"] + if "_use_non_blocking_refresh" in old_state: + del old_state["_use_non_blocking_refresh"] + if "_refresh_worker" in old_state: + del old_state["_refresh_worker"] + + new_instance = CredentialsImpl.__new__(CredentialsImpl) + new_instance.__setstate__(old_state) + + assert hasattr(new_instance, "_rab_manager") + assert new_instance._rab_manager is not None + assert new_instance._use_non_blocking_refresh is False + assert new_instance._refresh_worker is not None + @mock.patch( "google.auth._regional_access_boundary_utils._RegionalAccessBoundaryRefreshManager.start_refresh" ) @@ -552,3 +570,28 @@ def test_regional_access_boundary_refresh_manager_start_refresh_safety_lock(self mock_thread_class.assert_not_called() assert manager._worker == mock_worker + + @pytest.mark.asyncio + async def test_async_refresh_manager_session_closed_ignored(self): + credentials = mock.AsyncMock() + # Simulate a closed session RuntimeError when invoking the boundary lookup + credentials._lookup_regional_access_boundary.side_effect = RuntimeError( + "Session is closed" + ) + + request = mock.Mock() + rab_manager = mock.Mock() + + manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + + # Trigger refresh, which starts a background task that should swallow the error + manager.start_refresh(credentials, request, rab_manager) + + # Wait for the background worker task to terminate + await manager._worker_task + + # Verify that the lookup was still triggered but failed open cleanly + credentials._lookup_regional_access_boundary.assert_called_once_with(request) + rab_manager.process_regional_access_boundary_info.assert_called_once_with(None) diff --git a/packages/google-auth/tests/test_credentials.py b/packages/google-auth/tests/test_credentials.py index e1528a3ce365..33c936e7f150 100644 --- a/packages/google-auth/tests/test_credentials.py +++ b/packages/google-auth/tests/test_credentials.py @@ -154,6 +154,21 @@ def test_before_request_with_regional_access_boundary(): assert headers["x-allowed-locations"] == DUMMY_BOUNDARY +def test_copy_regional_access_boundary_manager_state_and_config(): + creds = CredentialsImpl() + creds._rab_manager._data = mock.sentinel.rab_data + creds._rab_manager._use_blocking_regional_access_boundary_lookup = True + + new_creds = creds._make_copy() + + # Verify references to immutable boundary data are shared + assert new_creds._rab_manager._data == mock.sentinel.rab_data + # Verify blocking config flag is preserved + assert new_creds._rab_manager._use_blocking_regional_access_boundary_lookup is True + # Verify target manager object is isolated (kept from constructor, not replaced) + assert new_creds._rab_manager is not creds._rab_manager + + def test_before_request_metrics(): credentials = CredentialsImplWithMetrics() request = "token" diff --git a/packages/google-auth/tests/test_external_account.py b/packages/google-auth/tests/test_external_account.py index dc296f7a52ae..8469a2912fef 100644 --- a/packages/google-auth/tests/test_external_account.py +++ b/packages/google-auth/tests/test_external_account.py @@ -1729,13 +1729,21 @@ def test_before_request_expired(self, utcnow): def test_build_regional_access_boundary_lookup_url_workload(self): credentials = self.make_credentials() - expected_url = "https://iamcredentials.googleapis.com/v1/projects/123456/locations/global/workloadIdentityPools/POOL_ID/allowedLocations" - assert credentials._build_regional_access_boundary_lookup_url() == expected_url + url = credentials._build_regional_access_boundary_lookup_url() + + expected_url_standard = "https://iamcredentials.googleapis.com/v1/projects/123456/locations/global/workloadIdentityPools/POOL_ID/allowedLocations" + expected_url_mtls = "https://iamcredentials.mtls.googleapis.com/v1/projects/123456/locations/global/workloadIdentityPools/POOL_ID/allowedLocations" + + assert url in (expected_url_standard, expected_url_mtls) def test_build_regional_access_boundary_lookup_url_workforce(self): credentials = self.make_workforce_pool_credentials() - expected_url = "https://iamcredentials.googleapis.com/v1/locations/global/workforcePools/POOL_ID/allowedLocations" - assert credentials._build_regional_access_boundary_lookup_url() == expected_url + url = credentials._build_regional_access_boundary_lookup_url() + + expected_url_standard = "https://iamcredentials.googleapis.com/v1/locations/global/workforcePools/POOL_ID/allowedLocations" + expected_url_mtls = "https://iamcredentials.mtls.googleapis.com/v1/locations/global/workforcePools/POOL_ID/allowedLocations" + + assert url in (expected_url_standard, expected_url_mtls) @pytest.mark.parametrize( "audience", diff --git a/packages/google-auth/tests/test_external_account_authorized_user.py b/packages/google-auth/tests/test_external_account_authorized_user.py index 648966d924bf..83176d5bbf23 100644 --- a/packages/google-auth/tests/test_external_account_authorized_user.py +++ b/packages/google-auth/tests/test_external_account_authorized_user.py @@ -603,8 +603,12 @@ def test_from_file_full_options(self, tmpdir): def test_build_regional_access_boundary_lookup_url(self): credentials = self.make_credentials() - expected_url = "https://iamcredentials.googleapis.com/v1/locations/global/workforcePools/POOL_ID/allowedLocations" - assert credentials._build_regional_access_boundary_lookup_url() == expected_url + url = credentials._build_regional_access_boundary_lookup_url() + + expected_url_standard = "https://iamcredentials.googleapis.com/v1/locations/global/workforcePools/POOL_ID/allowedLocations" + expected_url_mtls = "https://iamcredentials.mtls.googleapis.com/v1/locations/global/workforcePools/POOL_ID/allowedLocations" + + assert url in (expected_url_standard, expected_url_mtls) @pytest.mark.parametrize( "audience", diff --git a/packages/google-auth/tests/test_iam.py b/packages/google-auth/tests/test_iam.py index 26a4c825a7b3..29949b926e34 100644 --- a/packages/google-auth/tests/test_iam.py +++ b/packages/google-auth/tests/test_iam.py @@ -15,6 +15,7 @@ import base64 import datetime import http.client as http_client +import importlib import json from unittest import mock @@ -113,3 +114,37 @@ def test_sign_bytes_retryable_failure(self, mock_time): with pytest.raises(exceptions.TransportError): signer.sign("123") request.call_count == 3 + + +def test_endpoint_constants_mtls(monkeypatch): + from google.auth.transport import _mtls_helper + + # Mock check_use_client_cert to return True (simulating mTLS environment) + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: True) + + # Force a reload of the iam module to trigger the top-level domain computation + importlib.reload(iam) + + try: + # Verify it constructed the mTLS domain for ALL endpoints + assert ( + "iamcredentials.mtls.googleapis.com" + in iam._SERVICE_ACCOUNT_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT + ) + assert ( + "iamcredentials.mtls.googleapis.com" + in iam._WORKFORCE_POOL_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT + ) + assert ( + "iamcredentials.mtls.googleapis.com" + in iam._WORKLOAD_IDENTITY_POOL_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT + ) + assert "iamcredentials.mtls.googleapis.com" in iam._IAM_ENDPOINT + assert "iamcredentials.mtls.googleapis.com" in iam._IAM_SIGN_ENDPOINT + assert "iamcredentials.mtls.googleapis.com" in iam._IAM_SIGNJWT_ENDPOINT + assert "iamcredentials.mtls.googleapis.com" in iam._IAM_IDTOKEN_ENDPOINT + + finally: + # Restore the original state for other tests by undoing the patch and reloading again + monkeypatch.undo() + importlib.reload(iam) diff --git a/packages/google-auth/tests/test_impersonated_credentials.py b/packages/google-auth/tests/test_impersonated_credentials.py index 500209f663d7..572d961cc3a9 100644 --- a/packages/google-auth/tests/test_impersonated_credentials.py +++ b/packages/google-auth/tests/test_impersonated_credentials.py @@ -719,11 +719,16 @@ def test_build_regional_access_boundary_lookup_url_no_email(self): def test_build_regional_access_boundary_lookup_url_success(self): credentials = self.make_credentials() - # Ensure service_account_email is properly set by default mock - expected_url = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( + url = credentials._build_regional_access_boundary_lookup_url() + + expected_url_standard = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( + credentials.service_account_email + ) + expected_url_mtls = "https://iamcredentials.mtls.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( credentials.service_account_email ) - assert credentials._build_regional_access_boundary_lookup_url() == expected_url + + assert url in (expected_url_standard, expected_url_mtls) def test_with_scopes_provide_default_scopes(self): credentials = self.make_credentials() diff --git a/packages/google-auth/tests/test_jwt.py b/packages/google-auth/tests/test_jwt.py index 4c5988469494..9ed90cdf12b8 100644 --- a/packages/google-auth/tests/test_jwt.py +++ b/packages/google-auth/tests/test_jwt.py @@ -553,6 +553,44 @@ def test_before_request_refreshes(self): self.credentials.before_request(None, "GET", "http://example.com?a=1#3", {}) assert self.credentials.valid + def test_build_regional_access_boundary_lookup_url(self): + url = self.credentials._build_regional_access_boundary_lookup_url() + expected_url_standard = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( + self.SERVICE_ACCOUNT_EMAIL + ) + expected_url_mtls = "https://iamcredentials.mtls.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( + self.SERVICE_ACCOUNT_EMAIL + ) + + assert url in (expected_url_standard, expected_url_mtls) + + def test_cloning_retains_rab_manager_data(self): + self.credentials._rab_manager._data = mock.sentinel.rab_data + + cloned_claims = self.credentials.with_claims(audience="new-audience") + cloned_quota = self.credentials.with_quota_project("new-quota") + + # Verify references to immutable boundary data are shared + assert cloned_claims._rab_manager._data == mock.sentinel.rab_data + assert cloned_quota._rab_manager._data == mock.sentinel.rab_data + + # Verify manager objects and lock properties are isolated to prevent race conditions + assert cloned_claims._rab_manager is not self.credentials._rab_manager + assert cloned_quota._rab_manager is not self.credentials._rab_manager + + def test_from_signing_credentials_copies_rab_state(self): + from google.oauth2 import service_account + + sa_creds = service_account.Credentials.from_service_account_info( + SERVICE_ACCOUNT_INFO + ) + sa_creds._rab_manager._data = mock.sentinel.rab_data + + jwt_creds = jwt.Credentials.from_signing_credentials(sa_creds, audience="aud") + + assert jwt_creds._rab_manager._data == mock.sentinel.rab_data + assert jwt_creds._rab_manager is not sa_creds._rab_manager + class TestOnDemandCredentials(object): SERVICE_ACCOUNT_EMAIL = "service-account@example.com" diff --git a/packages/google-auth/tests_async/oauth2/test__client_async.py b/packages/google-auth/tests_async/oauth2/test__client_async.py index 5ad9596cf85c..e5a1bdbced7f 100644 --- a/packages/google-auth/tests_async/oauth2/test__client_async.py +++ b/packages/google-auth/tests_async/oauth2/test__client_async.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import datetime import http.client as http_client import json @@ -25,6 +26,7 @@ from google.auth import exceptions from google.oauth2 import _client as sync_client from google.oauth2 import _client_async as _client +from google.auth.aio import transport as aio_transport from tests.oauth2 import test__client as test_client @@ -40,6 +42,17 @@ def make_request(response_data, status=http_client.OK, text=False): return request +def make_aio_request(response_data, status_code=http_client.OK, text=False): + """Creates a mock request/response conforming to the google.auth.aio.transport interface (exposing .status_code and .read()).""" + response = mock.AsyncMock(spec=aio_transport.Response) + response.status_code = status_code + data = response_data if text else json.dumps(response_data).encode("utf-8") + response.read = mock.AsyncMock(return_value=data) + request = mock.AsyncMock(spec=aio_transport.Request) + request.return_value = response + return request + + @pytest.mark.asyncio async def test__token_endpoint_request(): request = make_request({"test": "response"}) @@ -492,3 +505,125 @@ async def test__token_endpoint_request_no_throw_with_retry(can_retry): assert mock_request.call_count == 3 else: assert mock_request.call_count == 1 + + +@pytest.mark.asyncio +async def test__lookup_regional_access_boundary_success(): + request = make_aio_request({"encodedLocations": "0xA30", "locations": ["us-central1"]}) + result = await _client._lookup_regional_access_boundary( + request, "http://example.com" + ) + assert result == {"encodedLocations": "0xA30", "locations": ["us-central1"]} + + +@pytest.mark.asyncio +async def test__lookup_regional_access_boundary_legacy_transport(): + # Create a legacy mock response that has .status and .content() + response = mock.AsyncMock(spec=["transport.Response"]) + response.status = http_client.OK + + data = json.dumps({"encodedLocations": "0xA30", "locations": ["us-central1"]}).encode("utf-8") + response.content = mock.AsyncMock(return_value=data) + + request = mock.AsyncMock(spec=["transport.Request"]) + request.return_value = response + + result = await _client._lookup_regional_access_boundary( + request, "http://example.com" + ) + assert result == {"encodedLocations": "0xA30", "locations": ["us-central1"]} + + +@pytest.mark.asyncio +async def test__lookup_regional_access_boundary_malformed(): + request = make_aio_request({"locations": ["us-central1"]}) + result = await _client._lookup_regional_access_boundary( + request, "http://example.com" + ) + assert result is None + + +@pytest.mark.asyncio +async def test__lookup_regional_access_boundary_invalid_json(): + request = make_aio_request("Service Unavailable", text=True) + result = await _client._lookup_regional_access_boundary( + request, "http://example.com" + ) + assert result is None + + +@pytest.mark.asyncio +async def test__lookup_regional_access_boundary_non_dict_response(): + request = make_aio_request(123) + result = await _client._lookup_regional_access_boundary( + request, "http://example.com" + ) + assert result is None + + +@pytest.mark.asyncio +@mock.patch("asyncio.wait_for", side_effect=asyncio.TimeoutError) +async def test__lookup_regional_access_boundary_request_no_throw_timeout(mock_wait_for): + request = mock.AsyncMock(spec=["transport.Request"]) + + ( + success, + data, + retryable, + ) = await _client._lookup_regional_access_boundary_request_no_throw( + request, "http://example.com", fail_fast=True + ) + + assert success is False + assert data == {} + assert retryable is False + + +@pytest.mark.asyncio +@mock.patch("asyncio.sleep", new_callable=mock.AsyncMock) +async def test__lookup_regional_access_boundary_request_no_throw_bad_gateway_retry( + mock_sleep, +): + bad_gateway_response = mock.AsyncMock(spec=["transport.Response"]) + bad_gateway_response.status = http_client.BAD_GATEWAY + bad_gateway_response.content = mock.AsyncMock(return_value=b"{}") + + ok_response = mock.AsyncMock(spec=["transport.Response"]) + ok_response.status = http_client.OK + ok_response.content = mock.AsyncMock(return_value=b'{"encodedLocations": "0xA30"}') + + request = mock.AsyncMock(spec=["transport.Request"]) + request.side_effect = [bad_gateway_response, ok_response] + + ( + success, + data, + retryable, + ) = await _client._lookup_regional_access_boundary_request_no_throw( + request, "http://example.com" + ) + + assert success is True + assert data == {"encodedLocations": "0xA30"} + assert request.call_count == 2 + + +@pytest.mark.asyncio +async def test__lookup_regional_access_boundary_request_no_throw_transport_error(): + request = mock.AsyncMock(spec=["transport.Request"]) + request.side_effect = exceptions.TransportError("Socket connection failed") + + ( + success, + data, + retryable, + ) = await _client._lookup_regional_access_boundary_request_no_throw( + request, "http://example.com" + ) + + assert success is False + assert data == {} + assert retryable is False + + + diff --git a/packages/google-auth/tests_async/oauth2/test_service_account_async.py b/packages/google-auth/tests_async/oauth2/test_service_account_async.py index 5a9a89fcaac2..ba4efa91cf24 100644 --- a/packages/google-auth/tests_async/oauth2/test_service_account_async.py +++ b/packages/google-auth/tests_async/oauth2/test_service_account_async.py @@ -229,6 +229,96 @@ async def test_before_request_refreshes(self, jwt_grant): # Credentials should now be valid. assert credentials.valid + @mock.patch( + "google.oauth2._client_async._lookup_regional_access_boundary", autospec=True + ) + @pytest.mark.asyncio + async def test_before_request_triggers_rab_refresh(self, mock_lookup): + credentials = self.make_credentials() + credentials.token = "tok" + + mock_lookup.return_value = { + "locations": ["us-central1", "europe-west1"], + "encodedLocations": "0xA30", + } + + request = mock.AsyncMock(spec=["transport.Request"]) + headers1 = {} + + with mock.patch.object( + credentials, + "_is_regional_access_boundary_lookup_required", + return_value=True, + ): + # First request triggers background refresh, but proceeds without the header + await credentials.before_request( + request, "GET", "https://storage.googleapis.com/bucket", headers1 + ) + assert "x-allowed-locations" not in headers1 + + # Wait for the background task to finish and update the cache + await credentials._rab_manager.refresh_manager._worker_task + assert mock_lookup.called + + # Second request should now find the data in the cache and attach the header + headers2 = {} + await credentials.before_request( + request, "GET", "https://storage.googleapis.com/bucket", headers2 + ) + assert headers2["x-allowed-locations"] == "0xA30" + + @mock.patch( + "google.oauth2._client_async._lookup_regional_access_boundary", autospec=True + ) + @pytest.mark.asyncio + async def test_before_request_rab_refresh_failure_ignored(self, mock_lookup): + credentials = self.make_credentials() + credentials.token = "tok" + + mock_lookup.side_effect = Exception("Transport failed") + + request = mock.AsyncMock(spec=["transport.Request"]) + headers = {} + + with mock.patch.object( + credentials, + "_is_regional_access_boundary_lookup_required", + return_value=True, + ): + # The exception must be caught gracefully and not bubble up + await credentials.before_request( + request, "GET", "https://storage.googleapis.com/bucket", headers + ) + + # Wait for the background task to finish + await credentials._rab_manager.refresh_manager._worker_task + + assert mock_lookup.called + assert "x-allowed-locations" not in headers + + def test_unpickle_old_credentials_without_rab(self): + import pickle + from google.auth import _regional_access_boundary_utils + + credentials = self.make_credentials() + old_state = credentials.__dict__.copy() + if "_rab_manager" in old_state: + del old_state["_rab_manager"] + if "_use_non_blocking_refresh" in old_state: + del old_state["_use_non_blocking_refresh"] + if "_refresh_worker" in old_state: + del old_state["_refresh_worker"] + + new_instance = type(credentials).__new__(type(credentials)) + new_instance.__setstate__(old_state) + + # Verify the manager was correctly restored with the async refresh manager! + assert hasattr(new_instance, "_rab_manager") + assert isinstance( + new_instance._rab_manager.refresh_manager, + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager, + ) + class TestIDTokenCredentials(object): SERVICE_ACCOUNT_EMAIL = "service-account@example.com" diff --git a/packages/google-auth/tests_async/test__regional_access_boundary_utils.py b/packages/google-auth/tests_async/test__regional_access_boundary_utils.py new file mode 100644 index 000000000000..268ee37261c8 --- /dev/null +++ b/packages/google-auth/tests_async/test__regional_access_boundary_utils.py @@ -0,0 +1,84 @@ +# Copyright 2026 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from unittest import mock + +import pytest # type: ignore + +from google.auth import _regional_access_boundary_utils + + +@pytest.mark.asyncio +async def test_async_refresh_manager_start_refresh(): + credentials = mock.AsyncMock() + credentials._lookup_regional_access_boundary.return_value = { + "encodedLocations": "0xA30" + } + + request = mock.Mock() + rab_manager = mock.Mock() + + manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + + manager.start_refresh(credentials, request, rab_manager) + + # Wait for the background task to finish + await manager._worker_task + + credentials._lookup_regional_access_boundary.assert_called_once_with(request) + rab_manager.process_regional_access_boundary_info.assert_called_once_with( + {"encodedLocations": "0xA30"} + ) + + +@pytest.mark.asyncio +async def test_async_refresh_manager_duplicate_refresh_prevented(): + credentials = mock.AsyncMock() + + # Use events to control the concurrency timing + lookup_started = asyncio.Event() + lookup_finish = asyncio.Event() + + async def controlled_lookup(*args, **kwargs): + lookup_started.set() # Signal that the background lookup has started. + await lookup_finish.wait() # Block until the test allows the lookup to complete. + return {"encodedLocations": "0xA30"} + + credentials._lookup_regional_access_boundary.side_effect = controlled_lookup + + request = mock.Mock() + rab_manager = mock.Mock() + + manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + + # Start the initial refresh task in the background. + manager.start_refresh(credentials, request, rab_manager) + + # Wait until the background task has begun executing the lookup. + await lookup_started.wait() + + # Attempt a second refresh while the initial task is still in progress. + manager.start_refresh(credentials, request, rab_manager) + + # Unblock the initial task and wait for it to complete. + lookup_finish.set() + await manager._worker_task + + # Verify that the second refresh request was ignored and only one lookup occurred. + assert credentials._lookup_regional_access_boundary.call_count == 1