From ac24ef71e6976d25eefcad3d80c15bd477ba65a3 Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Mon, 2 Feb 2026 23:35:32 +0530 Subject: [PATCH 01/16] feat: add Multiple Custom Domains (MCD) support and fix JWT verification --- .gitignore | 4 +- examples/MCD.md | 139 +++ .../auth_server/server_client.py | 599 ++++++++-- .../auth_types/__init__.py | 23 + src/auth0_server_python/error/__init__.py | 27 + .../tests/test_server_client.py | 1028 ++++++++++++++++- src/auth0_server_python/utils/helpers.py | 68 ++ 7 files changed, 1778 insertions(+), 110 deletions(-) create mode 100644 examples/MCD.md diff --git a/.gitignore b/.gitignore index fe90143..3d5c66a 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,6 @@ test.py test-script.py .coverage coverage.xml - +examples/mcd-poc +IMPLEMENTATION_NOTES.md +examples/MCD_DEVELOPER_GUIDE.md \ No newline at end of file diff --git a/examples/MCD.md b/examples/MCD.md new file mode 100644 index 0000000..85d188d --- /dev/null +++ b/examples/MCD.md @@ -0,0 +1,139 @@ +# Multiple Custom Domains (MCD) Guide + +This guide explains how to implement Multiple Custom Domain (MCD) support using the Auth0 Python SDKs. + +## What is MCD? + +Multiple Custom Domains (MCD) allows your application to serve different organizations or tenants from different hostnames, each mapping to a different Auth0 tenant/domain. + +**Example:** +- `https://acme.yourapp.com` → Auth0 tenant: `acme.auth0.com` +- `https://globex.yourapp.com` → Auth0 tenant: `globex.auth0.com` + +Each tenant gets its own branded login experience while using a single application codebase. + +## Configuration Methods + +### Method 1: Static Domain (Single Tenant) + +For applications with a single Auth0 domain: + +```python +from auth0_server_python import ServerClient + +client = ServerClient( + domain="your-tenant.auth0.com", # Static string + client_id="your_client_id", + client_secret="your_client_secret", + secret="your_encryption_secret" +) +``` + +### Method 2: Dynamic Domain Resolver (MCD) + +For MCD support, provide a domain resolver function that receives a `DomainResolverContext`: + +```python +from auth0_server_python import ServerClient +from auth0_server_python.auth_types import DomainResolverContext + +# Map your app hostnames to Auth0 domains +DOMAIN_MAP = { + "acme.yourapp.com": "acme.auth0.com", + "globex.yourapp.com": "globex.auth0.com", +} +DEFAULT_DOMAIN = "default.auth0.com" + +async def domain_resolver(context: DomainResolverContext) -> str: + """ + Resolve Auth0 domain based on request hostname. + + Args: + context: Contains request_url and request_headers + + Returns: + Auth0 domain string (e.g., "acme.auth0.com") + """ + # Extract hostname from request headers + if not context.request_headers: + return DEFAULT_DOMAIN + + host = context.request_headers.get('host', DEFAULT_DOMAIN) + host_without_port = host.split(':')[0] + + # Look up Auth0 domain + return DOMAIN_MAP.get(host_without_port, DEFAULT_DOMAIN) + +client = ServerClient( + domain=domain_resolver, # Callable function + client_id="your_client_id", + client_secret="your_client_secret", + secret="your_encryption_secret" +) +``` + +## DomainResolverContext + +The `DomainResolverContext` object provides request information to your resolver: + +| Property | Type | Description | +|----------|------|-------------| +| `request_url` | `Optional[str]` | Full request URL (e.g., "https://acme.yourapp.com/auth/login") | +| `request_headers` | `Optional[dict[str, str]]` | Request headers dictionary | + +**Common headers:** +- `host`: Request hostname (e.g., "acme.yourapp.com") +- `x-forwarded-host`: Original host when behind proxy/load balancer + +**Example usage:** + +```python +async def domain_resolver(context: DomainResolverContext) -> str: + # Check if we have request headers + if not context.request_headers: + return DEFAULT_DOMAIN + + # Use x-forwarded-host if behind proxy, otherwise use host + host = (context.request_headers.get('x-forwarded-host') or + context.request_headers.get('host', '')) + + # Remove port number if present + hostname = host.split(':')[0].lower() + + # Look up in mapping + return DOMAIN_MAP.get(hostname, DEFAULT_DOMAIN) +``` + +## Error Handling + +### DomainResolverError + +The domain resolver should return a valid Auth0 domain string. Invalid returns will raise `DomainResolverError`: + +```python +from auth0_server_python.error import DomainResolverError + +async def domain_resolver(context: DomainResolverContext) -> str: + try: + domain = lookup_domain_from_db(context) + + if not domain: + # Return default instead of None + return DEFAULT_DOMAIN + + return domain # Must be a non-empty string + + except Exception as e: + # Log error and return default + logger.error(f"Domain resolution failed: {e}") + return DEFAULT_DOMAIN +``` + +**Invalid return values that raise `DomainResolverError`:** +- `None` +- Empty string `""` +- Non-string types (int, list, dict, etc.) + +**Exceptions raised by your resolver:** +- Automatically wrapped in `DomainResolverError` +- Original exception accessible via `.original_error` \ No newline at end of file diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index c968120..bee5541 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -6,7 +6,7 @@ import asyncio import json import time -from typing import Any, Generic, Optional, TypeVar +from typing import Any, Callable, Generic, Optional, TypeVar, Union from urllib.parse import parse_qs, urlencode, urlparse, urlunparse import httpx @@ -32,19 +32,25 @@ AccessTokenForConnectionErrorCode, ApiError, BackchannelLogoutError, + ConfigurationError, + DomainResolverError, MissingRequiredArgumentError, MissingTransactionError, PollingApiError, StartLinkUserError, ) from auth0_server_python.utils import PKCE, URL, State +from auth0_server_python.utils.helpers import ( + build_domain_resolver_context, + validate_resolved_domain_value, +) from authlib.integrations.base_client.errors import OAuthError from authlib.integrations.httpx_client import AsyncOAuth2Client from pydantic import ValidationError # Generic type for store options TStoreOptions = TypeVar('TStoreOptions') -INTERNAL_AUTHORIZE_PARAMS = ["client_id", "redirect_uri", "response_type", +INTERNAL_AUTHORIZE_PARAMS = ["client_id", "response_type", "code_challenge", "code_challenge_method", "state", "nonce", "scope"] @@ -55,11 +61,15 @@ class ServerClient(Generic[TStoreOptions]): """ DEFAULT_AUDIENCE_STATE_KEY = "default" + # ========================================== + # Initialization + # ========================================== + def __init__( self, - domain: str, - client_id: str, - client_secret: str, + domain: Union[str, Callable[[Optional[dict[str, Any]]], str]] = None, + client_id: str = None, + client_secret: str = None, redirect_uri: Optional[str] = None, secret: str = None, transaction_store=None, @@ -67,13 +77,13 @@ def __init__( transaction_identifier: str = "_a0_tx", state_identifier: str = "_a0_session", authorization_params: Optional[dict[str, Any]] = None, - pushed_authorization_requests: bool = False + pushed_authorization_requests: bool = False, ): """ Initialize the Auth0 server client. Args: - domain: Auth0 domain (e.g., 'your-tenant.auth0.com') + domain: Auth0 domain - either a static string (e.g., 'tenant.auth0.com') or a callable that resolves domain dynamically. client_id: Auth0 client ID client_secret: Auth0 client secret redirect_uri: Default redirect URI for authentication @@ -83,12 +93,34 @@ def __init__( transaction_identifier: Identifier for transaction data state_identifier: Identifier for state data authorization_params: Default parameters for authorization requests + pushed_authorization_requests: Whether to use Pushed Authorization Requests """ if not secret: raise MissingRequiredArgumentError("secret") - # Store configuration - self._domain = domain + if domain is None: + raise ConfigurationError( + "Domain is required" + ) + + # Validate domain type + if not isinstance(domain, str) and not callable(domain): + raise ConfigurationError( + f"Domain must be either a string or a callable function. " + f"Got {type(domain).__name__} instead." + ) + + # Determine if domain is static string or dynamic callable + if callable(domain): + self._domain = None + self._domain_resolver = domain + else: + # Validate static domain string + domain_str = str(domain) + if not domain_str or domain_str.strip() == "": + raise ConfigurationError("Domain cannot be empty.") + self._domain = domain_str + self._domain_resolver = None self._client_id = client_id self._client_secret = client_secret self._redirect_uri = redirect_uri @@ -109,12 +141,15 @@ def __init__( self._my_account_client = MyAccountClient(domain=domain) - async def _fetch_oidc_metadata(self, domain: str) -> dict: - metadata_url = f"https://{domain}/.well-known/openid-configuration" - async with httpx.AsyncClient() as client: - response = await client.get(metadata_url) - response.raise_for_status() - return response.json() + # Cache for OIDC metadata and JWKS (Requirement 3: MCD Support) + self._metadata_cache = {} # {domain: {"data": {...}, "expires_at": timestamp}} + self._jwks_cache = {} # {domain: {"data": {...}, "expires_at": timestamp}} + self._cache_ttl = 3600 # 1 hour TTL + self._cache_max_size = 100 # Max 100 domains to prevent memory bloat + + # ========================================== + # Interactive Login Flow + # ========================================== async def start_interactive_login( self, @@ -126,12 +161,38 @@ async def start_interactive_login( Args: options: Configuration options for the login process + store_options: Store options containing request/response Returns: Authorization URL to redirect the user to """ options = options or StartInteractiveLoginOptions() + # Resolve domain (static or dynamic) + if self._domain_resolver: + # Build context and call developer's resolver + context = build_domain_resolver_context(store_options) + try: + resolved = await self._domain_resolver(context) + origin_domain = validate_resolved_domain_value(resolved) + except DomainResolverError: + raise + except Exception as e: + raise DomainResolverError( + f"Domain resolver function raised an exception: {str(e)}", + original_error=e + ) + else: + origin_domain = self._domain + + # Fetch OIDC metadata from resolved domain + try: + metadata = await self._get_oidc_metadata_cached(origin_domain) + origin_issuer = metadata.get('issuer') + except Exception as e: + raise ApiError("metadata_error", + "Failed to fetch OIDC metadata", e) + # Get effective authorization params (merge defaults with provided ones) auth_params = dict(self._default_authorization_params) if options.authorization_params: @@ -160,17 +221,20 @@ async def start_interactive_login( state = PKCE.generate_random_string(32) auth_params["state"] = state - #merge any requested scope with defaults + # Merge any requested scope with defaults requested_scope = options.authorization_params.get("scope", None) if options.authorization_params else None audience = auth_params.get("audience", None) merged_scope = self._merge_scope_with_defaults(requested_scope, audience) auth_params["scope"] = merged_scope - # Build the transaction data to store + # Build the transaction data to store with origin domain and issuer transaction_data = TransactionData( code_verifier=code_verifier, app_state=options.app_state, audience=audience, + origin_domain=origin_domain, + origin_issuer=origin_issuer, + redirect_uri=auth_params.get("redirect_uri"), ) # Store the transaction data @@ -179,11 +243,9 @@ async def start_interactive_login( transaction_data, options=store_options ) - try: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) - except Exception as e: - raise ApiError("metadata_error", - "Failed to fetch OIDC metadata", e) + + # Set metadata for OAuth client + self._oauth.metadata = metadata # If PAR is enabled, use the PAR endpoint if self._pushed_authorization_requests: par_endpoint = self._oauth.metadata.get( @@ -274,34 +336,101 @@ async def complete_interactive_login( if not code: raise MissingRequiredArgumentError("code") - if not self._oauth.metadata or "token_endpoint" not in self._oauth.metadata: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) + # Get origin domain and issuer from transaction + origin_domain = transaction_data.origin_domain + origin_issuer = transaction_data.origin_issuer + + # Fetch metadata from the origin domain + metadata = await self._get_oidc_metadata_cached(origin_domain) + self._oauth.metadata = metadata # Exchange the code for tokens + # Use redirect_uri from transaction if available, otherwise fall back to default + token_redirect_uri = transaction_data.redirect_uri or self._redirect_uri try: token_endpoint = self._oauth.metadata["token_endpoint"] token_response = await self._oauth.fetch_token( token_endpoint, code=code, code_verifier=transaction_data.code_verifier, - redirect_uri=self._redirect_uri, + redirect_uri=token_redirect_uri, ) except OAuthError as e: # Raise a custom error (or handle it as appropriate) raise ApiError( "token_error", f"Token exchange failed: {str(e)}", e) + print(f"Token Response : {token_response}") - # Use the userinfo field from the token_response for user claims + # Use the userinfo field from the token_response for user claims user_info = token_response.get("userinfo") user_claims = None + id_token = token_response.get("id_token") + if user_info: user_claims = UserClaims.parse_obj(user_info) - else: - id_token = token_response.get("id_token") - if id_token: - claims = jwt.decode(id_token, options={ - "verify_signature": False}) + elif id_token: + # Fetch JWKS for signature verification (Requirement 3) + jwks = await self._get_jwks_cached(origin_domain, metadata) + + # Decode and verify ID token with signature verification enabled + try: + # Get the signing key from JWKS + unverified_header = jwt.get_unverified_header(id_token) + kid = unverified_header.get("kid") + + # Find the key with matching kid + signing_key = None + for key in jwks.get("keys", []): + if key.get("kid") == kid: + signing_key = jwt.PyJWK.from_dict(key) + break + + if not signing_key: + raise ApiError( + "jwks_key_not_found", + f"No matching key found in JWKS for kid: {kid}" + ) + + claims = jwt.decode( + id_token, + signing_key.key, + algorithms=["RS256"], + audience=self._client_id, + issuer=origin_issuer, + options={"verify_signature": True} + ) user_claims = UserClaims.parse_obj(claims) + except jwt.InvalidSignatureError as e: + raise ApiError( + "invalid_signature", + f"ID token signature verification failed. The token may have been tampered with or is from an untrusted source: {str(e)}", + e + ) + except jwt.InvalidAudienceError as e: + raise ApiError( + "invalid_audience", + f"ID token audience mismatch. Expected: {self._client_id}. Ensure your client_id is configured correctly: {str(e)}", + e + ) + except jwt.InvalidIssuerError as e: + raise ApiError( + "invalid_issuer", + f"ID token issuer mismatch. Expected: {origin_issuer}. Ensure your Auth0 domain is configured correctly: {str(e)}", + e + ) + except jwt.ExpiredSignatureError as e: + raise ApiError( + "token_expired", + f"ID token has expired: {str(e)}", + e + ) + except jwt.InvalidTokenError as e: + raise ApiError( + "invalid_token", + f"ID token verification failed: {str(e)}", + e + ) + # Build a token set using the token response data token_set = TokenSet( @@ -323,6 +452,7 @@ async def complete_interactive_login( # might be None if not provided refresh_token=token_response.get("refresh_token"), token_sets=[token_set], + domain=origin_domain, internal={ "sid": sid, "created_at": int(time.time()) @@ -346,6 +476,10 @@ async def complete_interactive_login( return result + # ========================================== + # User Account Linking + # ========================================== + async def start_link_user( self, options, @@ -493,6 +627,10 @@ async def complete_unlink_user( "app_state": result.get("app_state") } + # ========================================== + # Backchannel Authentication (CIBA) + # ========================================== + async def login_backchannel( self, options: dict[str, Any], @@ -539,6 +677,10 @@ async def login_backchannel( } return result + # ========================================== + # Session & Token Management + # ========================================== + async def get_user(self, store_options: Optional[dict[str, Any]] = None) -> Optional[dict[str, Any]]: """ Retrieves the user from the store, or None if no user found. @@ -552,6 +694,25 @@ async def get_user(self, store_options: Optional[dict[str, Any]] = None) -> Opti state_data = await self._state_store.get(self._state_identifier, store_options) if state_data: + # Validate session domain matches current request domain + if self._domain_resolver: + context = build_domain_resolver_context(store_options) + try: + resolved = await self._domain_resolver(context) + current_domain = validate_resolved_domain_value(resolved) + except DomainResolverError: + raise + except Exception as e: + raise DomainResolverError( + f"Domain resolver function raised an exception: {str(e)}", + original_error=e + ) + session_domain = getattr(state_data, 'domain', None) + + if session_domain and session_domain != current_domain: + # Session created with different domain - reject for security + return None + if hasattr(state_data, "dict") and callable(state_data.dict): state_data = state_data.dict() return state_data.get("user") @@ -570,6 +731,25 @@ async def get_session(self, store_options: Optional[dict[str, Any]] = None) -> O state_data = await self._state_store.get(self._state_identifier, store_options) if state_data: + # Validate session domain matches current request domain + if self._domain_resolver: + context = build_domain_resolver_context(store_options) + try: + resolved = await self._domain_resolver(context) + current_domain = validate_resolved_domain_value(resolved) + except DomainResolverError: + raise + except Exception as e: + raise DomainResolverError( + f"Domain resolver function raised an exception: {str(e)}", + original_error=e + ) + session_domain = getattr(state_data, 'domain', None) + + if session_domain and session_domain != current_domain: + # Session created with different domain - reject for security + return None + if hasattr(state_data, "dict") and callable(state_data.dict): state_data = state_data.dict() session_data = {k: v for k, v in state_data.items() @@ -599,6 +779,28 @@ async def get_access_token( """ state_data = await self._state_store.get(self._state_identifier, store_options) + # Validate session domain matches current request domain + if state_data and self._domain_resolver: + context = build_domain_resolver_context(store_options) + try: + resolved = await self._domain_resolver(context) + current_domain = validate_resolved_domain_value(resolved) + except DomainResolverError: + raise + except Exception as e: + raise DomainResolverError( + f"Domain resolver function raised an exception: {str(e)}", + original_error=e + ) + session_domain = getattr(state_data, 'domain', None) + + if session_domain and session_domain != current_domain: + # Session created with different domain - reject for security + raise AccessTokenError( + AccessTokenErrorCode.MISSING_REFRESH_TOKEN, + "Session domain mismatch. User needs to re-authenticate with the current domain." + ) + auth_params = self._default_authorization_params or {} # Get audience passed in on options or use defaults @@ -630,7 +832,12 @@ async def get_access_token( # Get new token with refresh token try: - get_refresh_token_options = {"refresh_token": state_data_dict["refresh_token"]} + # Use session's domain for token refresh + session_domain = state_data_dict.get("domain") or self._domain + get_refresh_token_options = { + "refresh_token": state_data_dict["refresh_token"], + "domain": session_domain + } if audience: get_refresh_token_options["audience"] = audience @@ -656,50 +863,7 @@ async def get_access_token( f"Failed to get token with refresh token: {str(e)}" ) - def _merge_scope_with_defaults( - self, - request_scope: Optional[str], - audience: Optional[str] - ) -> Optional[str]: - audience = audience or self.DEFAULT_AUDIENCE_STATE_KEY - default_scopes = "" - if self._default_authorization_params and "scope" in self._default_authorization_params: - auth_param_scope = self._default_authorization_params.get("scope") - # For backwards compatibility, allow scope to be a single string - # or dictionary by audience for MRRT - if isinstance(auth_param_scope, dict) and audience in auth_param_scope: - default_scopes = auth_param_scope[audience] - elif isinstance(auth_param_scope, str): - default_scopes = auth_param_scope - - default_scopes_list = default_scopes.split() - request_scopes_list = (request_scope or "").split() - - merged_scopes = list(dict.fromkeys(default_scopes_list + request_scopes_list)) - return " ".join(merged_scopes) if merged_scopes else None - - - def _find_matching_token_set( - self, - token_sets: list[dict[str, Any]], - audience: Optional[str], - scope: Optional[str] - ) -> Optional[dict[str, Any]]: - audience = audience or self.DEFAULT_AUDIENCE_STATE_KEY - requested_scopes = set(scope.split()) if scope else set() - matches: list[tuple[int, dict]] = [] - for token_set in token_sets: - token_set_audience = token_set.get("audience") - token_set_scopes = set(token_set.get("scope", "").split()) - if token_set_audience == audience and token_set_scopes == requested_scopes: - # short-circuit if exact match - return token_set - if token_set_audience == audience and token_set_scopes.issuperset(requested_scopes): - # consider stored tokens with more scopes than requested by number of scopes - matches.append((len(token_set_scopes), token_set)) - - # Return the token set with the smallest superset of scopes that matches the requested audience and scopes - return min(matches, key=lambda t: t[0])[1] if matches else None + async def get_access_token_for_connection( self, @@ -751,10 +915,13 @@ async def get_access_token_for_connection( "A refresh token was not found but is required to be able to retrieve an access token for a connection." ) # Get new token for connection + # Use session's domain for token exchange + session_domain = state_data_dict.get("domain") or self._domain token_endpoint_response = await self.get_token_for_connection({ "connection": options.get("connection"), "login_hint": options.get("login_hint"), - "refresh_token": state_data_dict["refresh_token"] + "refresh_token": state_data_dict["refresh_token"], + "domain": session_domain }) # Update state data with new token @@ -766,6 +933,10 @@ async def get_access_token_for_connection( return token_endpoint_response["access_token"] + # ========================================== + # Logout + # ========================================== + async def logout( self, options: Optional[LogoutOptions] = None, @@ -776,9 +947,25 @@ async def logout( # Delete the session from the state store await self._state_store.delete(self._state_identifier, store_options) + # Resolve domain dynamically for MCD support + if self._domain_resolver: + context = build_domain_resolver_context(store_options) + try: + resolved = await self._domain_resolver(context) + domain = validate_resolved_domain_value(resolved) + except DomainResolverError: + raise + except Exception as e: + raise DomainResolverError( + f"Domain resolver function raised an exception: {str(e)}", + original_error=e + ) + else: + domain = self._domain + # Use the URL helper to create the logout URL. logout_url = URL.create_logout_url( - self._domain, self._client_id, options.return_to) + domain, self._client_id, options.return_to) return logout_url @@ -798,9 +985,41 @@ async def handle_backchannel_logout( raise BackchannelLogoutError("Missing logout token") try: - # Decode the token without verification - claims = jwt.decode(logout_token, options={ - "verify_signature": False}) + # Fetch JWKS for signature verification (Requirement 3) + jwks = await self._get_jwks_cached(self._domain) + + # Decode and verify logout token with signature verification enabled + try: + # Get the signing key from JWKS + unverified_header = jwt.get_unverified_header(logout_token) + kid = unverified_header.get("kid") + + # Find the key with matching kid + signing_key = None + for key in jwks.get("keys", []): + if key.get("kid") == kid: + signing_key = jwt.PyJWK.from_dict(key) + break + + if not signing_key: + raise BackchannelLogoutError( + f"No matching key found in JWKS for kid: {kid}" + ) + + claims = jwt.decode( + logout_token, + signing_key.key, + algorithms=["RS256"], + options={"verify_signature": True} + ) + except jwt.InvalidSignatureError as e: + raise BackchannelLogoutError( + f"Logout token signature verification failed: {str(e)}" + ) + except jwt.InvalidTokenError as e: + raise BackchannelLogoutError( + f"Logout token verification failed: {str(e)}" + ) # Validate the token is a logout token events = claims.get("events", {}) @@ -816,11 +1035,195 @@ async def handle_backchannel_logout( await self._state_store.delete_by_logout_token(logout_claims.dict(), store_options) - except (jwt.JoseError, ValidationError) as e: + except (jwt.PyJWTError, ValidationError) as e: raise BackchannelLogoutError( f"Error processing logout token: {str(e)}") - # Authlib Helpers + # ========================================== + # Internal Helpers + # ========================================== + + # ------------------------------------------ + # OIDC Discovery & Metadata + # ------------------------------------------ + + def _normalize_domain(self, domain: str) -> str: + """ + Normalize domain for comparison and URL construction. + Handles cases with/without https:// scheme. + """ + if domain.startswith('https://'): + return domain + elif domain.startswith('http://'): + return domain.replace('http://', 'https://') + else: + return f'https://{domain}' + + async def _fetch_oidc_metadata(self, domain: str) -> dict: + """Fetch OIDC metadata from domain.""" + normalized_domain = self._normalize_domain(domain) + metadata_url = f"{normalized_domain}/.well-known/openid-configuration" + async with httpx.AsyncClient() as client: + response = await client.get(metadata_url) + response.raise_for_status() + return response.json() + + async def _get_oidc_metadata_cached(self, domain: str) -> dict: + """ + Get OIDC metadata with caching. + + Args: + domain: Auth0 domain + + Returns: + OIDC metadata document + """ + now = time.time() + + # Check cache + if domain in self._metadata_cache: + cached = self._metadata_cache[domain] + if cached["expires_at"] > now: + return cached["data"] + + # Cache miss/expired - fetch fresh + metadata = await self._fetch_oidc_metadata(domain) + + # Enforce cache size limit (FIFO eviction) + if len(self._metadata_cache) >= self._cache_max_size: + oldest_key = next(iter(self._metadata_cache)) + del self._metadata_cache[oldest_key] + + # Store in cache + self._metadata_cache[domain] = { + "data": metadata, + "expires_at": now + self._cache_ttl + } + + return metadata + + async def _fetch_jwks(self, jwks_uri: str) -> dict: + """ + Fetch JWKS (JSON Web Key Set) from jwks_uri. + + Args: + jwks_uri: The JWKS endpoint URL + + Returns: + JWKS document containing public keys + + Raises: + ApiError: If JWKS fetch fails + """ + try: + async with httpx.AsyncClient() as client: + response = await client.get(jwks_uri) + response.raise_for_status() + return response.json() + except Exception as e: + raise ApiError("jwks_fetch_error", f"Failed to fetch JWKS from {jwks_uri}", e) + + async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: + """ + Get JWKS with caching usingOIDC discovery. + + Args: + domain: Auth0 domain + metadata: Optional OIDC metadata (if already fetched) + + Returns: + JWKS document + + Raises: + ApiError: If JWKS fetch fails or jwks_uri missing from metadata + """ + now = time.time() + + # Check cache + if domain in self._jwks_cache: + cached = self._jwks_cache[domain] + if cached["expires_at"] > now: + return cached["data"] + + # Get jwks_uri from OIDC metadata + if not metadata: + metadata = await self._get_oidc_metadata_cached(domain) + + jwks_uri = metadata.get('jwks_uri') + if not jwks_uri: + raise ApiError( + "missing_jwks_uri", + f"OIDC metadata for {domain} does not contain jwks_uri. Provider may be non-RFC-compliant." + ) + + # Fetch JWKS + jwks = await self._fetch_jwks(jwks_uri) + + # Enforce cache size limit (FIFO eviction) + if len(self._jwks_cache) >= self._cache_max_size: + oldest_key = next(iter(self._jwks_cache)) + del self._jwks_cache[oldest_key] + + # Store in cache + self._jwks_cache[domain] = { + "data": jwks, + "expires_at": now + self._cache_ttl + } + + return jwks + + # ------------------------------------------ + # Token & Scope Management - MRRT + # ------------------------------------------ + + def _merge_scope_with_defaults( + self, + request_scope: Optional[str], + audience: Optional[str] + ) -> Optional[str]: + audience = audience or self.DEFAULT_AUDIENCE_STATE_KEY + default_scopes = "" + if self._default_authorization_params and "scope" in self._default_authorization_params: + auth_param_scope = self._default_authorization_params.get("scope") + # For backwards compatibility, allow scope to be a single string + # or dictionary by audience for MRRT + if isinstance(auth_param_scope, dict) and audience in auth_param_scope: + default_scopes = auth_param_scope[audience] + elif isinstance(auth_param_scope, str): + default_scopes = auth_param_scope + + default_scopes_list = default_scopes.split() + request_scopes_list = (request_scope or "").split() + + merged_scopes = list(dict.fromkeys(default_scopes_list + request_scopes_list)) + return " ".join(merged_scopes) if merged_scopes else None + + + def _find_matching_token_set( + self, + token_sets: list[dict[str, Any]], + audience: Optional[str], + scope: Optional[str] + ) -> Optional[dict[str, Any]]: + audience = audience or self.DEFAULT_AUDIENCE_STATE_KEY + requested_scopes = set(scope.split()) if scope else set() + matches: list[tuple[int, dict]] = [] + for token_set in token_sets: + token_set_audience = token_set.get("audience") + token_set_scopes = set(token_set.get("scope", "").split()) + if token_set_audience == audience and token_set_scopes == requested_scopes: + # short-circuit if exact match + return token_set + if token_set_audience == audience and token_set_scopes.issuperset(requested_scopes): + # consider stored tokens with more scopes than requested by number of scopes + matches.append((len(token_set_scopes), token_set)) + + # Return the token set with the smallest superset of scopes that matches the requested audience and scopes + return min(matches, key=lambda t: t[0])[1] if matches else None + + # ------------------------------------------ + # URL Builders + # ------------------------------------------ async def _build_link_user_url( self, @@ -837,7 +1240,7 @@ async def _build_link_user_url( # Get metadata if not already fetched if not hasattr(self, '_oauth_metadata'): - self._oauth_metadata = await self._fetch_oidc_metadata(self._domain) + self._oauth_metadata = await self._get_oidc_metadata_cached(self._domain) # Get authorization endpoint auth_endpoint = self._oauth_metadata.get("authorization_endpoint", @@ -880,7 +1283,7 @@ async def _build_unlink_user_url( # Get metadata if not already fetched if not hasattr(self, '_oauth_metadata'): - self._oauth_metadata = await self._fetch_oidc_metadata(self._domain) + self._oauth_metadata = await self._get_oidc_metadata_cached(self._domain) # Get authorization endpoint auth_endpoint = self._oauth_metadata.get("authorization_endpoint", @@ -1025,7 +1428,7 @@ async def initiate_backchannel_authentication( try: # Fetch OpenID Connect metadata if not already fetched if not hasattr(self, '_oauth_metadata'): - self._oauth_metadata = await self._fetch_oidc_metadata(self._domain) + self._oauth_metadata = await self._get_oidc_metadata_cached(self._domain) # Get the issuer from metadata issuer = self._oauth_metadata.get( @@ -1120,7 +1523,7 @@ async def backchannel_authentication_grant(self, auth_req_id: str) -> dict[str, try: # Ensure we have the OIDC metadata if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) + self._oauth.metadata = await self._get_oidc_metadata_cached(self._domain) token_endpoint = self._oauth.metadata.get("token_endpoint") if not token_endpoint: @@ -1178,6 +1581,10 @@ async def backchannel_authentication_grant(self, auth_req_id: str) -> dict[str, e ) + # ========================================== + # Token Exchange Operations + # ========================================== + async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, Any]: """ Retrieves a token by exchanging a refresh token. @@ -1197,9 +1604,12 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, raise MissingRequiredArgumentError("refresh_token") try: - # Ensure we have the OIDC metadata + # Use session domain if provided, otherwise fallback to static domain + domain = options.get("domain") or self._domain + + # Ensure we have the OIDC metadata from the correct domain if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) + self._oauth.metadata = await self._get_oidc_metadata_cached(domain) token_endpoint = self._oauth.metadata.get("token_endpoint") if not token_endpoint: @@ -1280,9 +1690,12 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A REQUESTED_TOKEN_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = "http://auth0.com/oauth/token-type/federated-connection-access-token" GRANT_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = "urn:auth0:params:oauth:grant-type:token-exchange:federated-connection-access-token" try: - # Ensure we have OIDC metadata + # Use session domain if provided, otherwise fallback to static domain + domain = options.get("domain") or self._domain + + # Ensure we have OIDC metadata from the correct domain if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) + self._oauth.metadata = await self._get_oidc_metadata_cached(domain) token_endpoint = self._oauth.metadata.get("token_endpoint") if not token_endpoint: @@ -1340,6 +1753,10 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A e ) + # ========================================== + # Account Connection + # ========================================== + async def start_connect_account( self, options: ConnectAccountOptions, diff --git a/src/auth0_server_python/auth_types/__init__.py b/src/auth0_server_python/auth_types/__init__.py index 677a7da..7c27d36 100644 --- a/src/auth0_server_python/auth_types/__init__.py +++ b/src/auth0_server_python/auth_types/__init__.py @@ -66,6 +66,7 @@ class SessionData(BaseModel): refresh_token: Optional[str] = None token_sets: list[TokenSet] = Field(default_factory=list) connection_token_sets: list[ConnectionTokenSet] = Field(default_factory=list) + domain: Optional[str] = None class Config: extra = "allow" # Allow additional fields not defined in the model @@ -89,6 +90,8 @@ class TransactionData(BaseModel): app_state: Optional[Any] = None auth_session: Optional[str] = None redirect_uri: Optional[str] = None + origin_domain: Optional[str] = None + origin_issuer: Optional[str] = None class Config: extra = "allow" # Allow additional fields not defined in the model @@ -252,3 +255,23 @@ class CompleteConnectAccountResponse(BaseModel): created_at: str expires_at: Optional[str] = None app_state: Optional[Any] = None + + +class DomainResolverContext(BaseModel): + """ + Context passed to domain resolver function for MCD support. + + Contains request information needed to determine the correct Auth0 domain + based on the incoming request's hostname or headers. + + Attributes: + request_url: The full request URL (e.g., "https://a.my-app.com/auth/login") + request_headers: Dictionary of request headers (e.g., {"host": "a.my-app.com", "x-forwarded-host": "..."}) + + Example: + async def domain_resolver(context: DomainResolverContext) -> str: + host = context.request_headers.get('host', '').split(':')[0] + return DOMAIN_MAP.get(host, DEFAULT_DOMAIN) + """ + request_url: Optional[str] = None + request_headers: Optional[dict[str, str]] = None diff --git a/src/auth0_server_python/error/__init__.py b/src/auth0_server_python/error/__init__.py index ef181ce..93fcba2 100644 --- a/src/auth0_server_python/error/__init__.py +++ b/src/auth0_server_python/error/__init__.py @@ -101,6 +101,18 @@ def __init__(self, argument: str): self.argument = argument +class ConfigurationError(Auth0Error): + """ + Error raised when SDK configuration is invalid. + This includes invalid combinations of parameters or incorrect configuration values. + """ + code = "configuration_error" + + def __init__(self, message: str): + super().__init__(message) + self.name = "ConfigurationError" + + class BackchannelLogoutError(Auth0Error): """ Error raised during backchannel logout processing. @@ -113,6 +125,21 @@ def __init__(self, message: str): self.name = "BackchannelLogoutError" +class DomainResolverError(Auth0Error): + """ + Error raised when domain resolver function fails or returns invalid value. + + This error indicates an issue with the custom domain resolver function + provided for MCD (Multiple Custom Domains) support. + """ + code = "domain_resolver_error" + + def __init__(self, message: str, original_error: Exception = None): + super().__init__(message) + self.name = "DomainResolverError" + self.original_error = original_error + + class AccessTokenForConnectionError(Auth0Error): """Error when retrieving access tokens for a specific connection fails.""" diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index 9f4f2cd..97fcc77 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -1,8 +1,9 @@ import json import time -from unittest.mock import ANY, AsyncMock, MagicMock +from unittest.mock import ANY, AsyncMock, MagicMock, patch from urllib.parse import parse_qs, urlparse +import jwt import pytest from auth0_server_python.auth_server.my_account_client import MyAccountClient from auth0_server_python.auth_server.server_client import ServerClient @@ -12,13 +13,17 @@ ConnectAccountRequest, ConnectAccountResponse, ConnectParams, + DomainResolverContext, LogoutOptions, + StateData, TransactionData, ) from auth0_server_python.error import ( AccessTokenForConnectionError, ApiError, BackchannelLogoutError, + ConfigurationError, + DomainResolverError, MissingRequiredArgumentError, MissingTransactionError, PollingApiError, @@ -42,7 +47,7 @@ async def test_init_no_secret_raises(): @pytest.mark.asyncio -async def test_start_interactive_login_no_redirect_uri(): +async def test_start_interactive_login_no_redirect_uri(mocker): client = ServerClient( domain="auth0.local", client_id="", @@ -51,6 +56,14 @@ async def test_start_interactive_login_no_redirect_uri(): transaction_store=AsyncMock(), secret="some-secret" ) + + # Mock OIDC metadata fetch + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://auth0.local/", "authorization_endpoint": "https://auth0.local/authorize"} + ) + with pytest.raises(MissingRequiredArgumentError) as exc: await client.start_interactive_login() # Check the error message @@ -74,7 +87,7 @@ async def test_start_interactive_login_builds_auth_url(mocker): # Mock out HTTP calls or the internal methods that create the auth URL mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={"authorization_endpoint": "https://auth0.local/authorize"} ) mock_oauth = mocker.patch.object( @@ -115,8 +128,13 @@ async def test_complete_interactive_login_no_transaction(): @pytest.mark.asyncio async def test_complete_interactive_login_returns_app_state(mocker): mock_tx_store = AsyncMock() - # The stored transaction includes an appState - mock_tx_store.get.return_value = TransactionData(code_verifier="123", app_state={"foo": "bar"}) + # The stored transaction includes an appState with origin_domain and origin_issuer + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + app_state={"foo": "bar"}, + origin_domain="auth0.local", + origin_issuer="https://auth0.local/" + ) mock_state_store = AsyncMock() @@ -129,6 +147,13 @@ async def test_complete_interactive_login_returns_app_state(mocker): secret="some-secret", ) + # Mock OIDC metadata fetch + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://auth0.local/", "token_endpoint": "https://auth0.local/token"} + ) + # Patch token exchange mocker.patch.object(client._oauth, "metadata", {"token_endpoint": "https://auth0.local/token"}) @@ -204,7 +229,7 @@ async def test_complete_link_user_returns_app_state(mocker): ) # Patch token exchange - mocker.patch.object(client, "_fetch_oidc_metadata", return_value={"token_endpoint": "https://auth0.local/token"}) + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={"token_endpoint": "https://auth0.local/token"}) async_fetch_token = AsyncMock() async_fetch_token.return_value = { "access_token": "token123", @@ -400,7 +425,8 @@ async def test_get_access_token_refresh_expired(mocker): assert token == "new_token" mock_state_store.set.assert_awaited_once() get_refresh_token_mock.assert_awaited_with({ - "refresh_token": "refresh_xyz" + "refresh_token": "refresh_xyz", + "domain": "auth0.local" }) @pytest.mark.asyncio @@ -441,6 +467,7 @@ async def test_get_access_token_refresh_merging_default_scope(mocker): mock_state_store.set.assert_awaited_once() get_refresh_token_mock.assert_awaited_with({ "refresh_token": "refresh_xyz", + "domain": "auth0.local", "audience": "default", "scope": "openid profile email foo:bar" }) @@ -482,6 +509,7 @@ async def test_get_access_token_refresh_with_auth_params_scope(mocker): mock_state_store.set.assert_awaited_once() get_refresh_token_mock.assert_awaited_with({ "refresh_token": "refresh_xyz", + "domain": "auth0.local", "scope": "openid profile email" }) @@ -522,6 +550,7 @@ async def test_get_access_token_refresh_with_auth_params_audience(mocker): mock_state_store.set.assert_awaited_once() get_refresh_token_mock.assert_awaited_with({ "refresh_token": "refresh_xyz", + "domain": "auth0.local", "audience": "my_audience" }) @@ -568,6 +597,7 @@ async def test_get_access_token_mrrt(mocker): assert len(stored_state["token_sets"]) == 2 get_refresh_token_mock.assert_awaited_with({ "refresh_token": "refresh_xyz", + "domain": "auth0.local", "audience": "some_audience", "scope": "foo:bar", }) @@ -621,6 +651,7 @@ async def test_get_access_token_mrrt_with_auth_params_scope(mocker): assert len(stored_state["token_sets"]) == 2 get_refresh_token_mock.assert_awaited_with({ "refresh_token": "refresh_xyz", + "domain": "auth0.local", "audience": "some_audience", "scope": "foo:bar", }) @@ -848,6 +879,18 @@ async def test_handle_backchannel_logout_ok(mocker): secret="some-secret" ) + # Mock JWKS fetch to prevent network call + mocker.patch.object( + client, + "_get_jwks_cached", + return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]} + ) + + # Mock JWT verification + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) mocker.patch("jwt.decode", return_value={ "events": {"http://schemas.openid.net/event/backchannel-logout": {}}, "sub": "user_sub", @@ -874,7 +917,7 @@ async def test_build_link_user_url_success(mocker): # Patch _fetch_oidc_metadata to return an authorization_endpoint mock_fetch = mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={"authorization_endpoint": "https://auth0.local/authorize"} ) @@ -932,7 +975,7 @@ async def test_build_link_user_url_fallback_authorize(mocker): # Patch _fetch_oidc_metadata to NOT have an authorization_endpoint mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={} # empty dict, triggers fallback ) @@ -969,7 +1012,7 @@ async def test_build_unlink_user_url_success(mocker): # Patch out metadata mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={"authorization_endpoint": "https://auth0.local/authorize"} ) @@ -1002,7 +1045,7 @@ async def test_build_unlink_user_url_fallback_authorize(mocker): ) # No 'authorization_endpoint' - mocker.patch.object(client, "_fetch_oidc_metadata", return_value={}) + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={}) result_url = await client._build_unlink_user_url( connection="", @@ -1033,7 +1076,7 @@ async def test_build_unlink_user_url_with_metadata(mocker): # Patch the metadata fetch to include a valid authorization endpoint mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={"authorization_endpoint": "https://auth0.local/authorize"} ) @@ -1086,7 +1129,7 @@ async def test_build_unlink_user_url_no_authorization_endpoint(mocker): # Patch _fetch_oidc_metadata to return no authorization_endpoint mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={} ) result_url = await client._build_unlink_user_url( @@ -1117,7 +1160,7 @@ async def test_backchannel_auth_with_audience_and_binding_message(mocker): mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={ "issuer": "https://auth0.local/", "backchannel_authentication_endpoint": "https://auth0.local/custom-authorize", @@ -1166,7 +1209,7 @@ async def test_backchannel_auth_rar(mocker): mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={ "issuer": "https://auth0.local/", "backchannel_authentication_endpoint": "https://auth0.local/custom-authorize", @@ -1217,7 +1260,7 @@ async def test_backchannel_auth_token_exchange_failed(mocker): mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={ "issuer": "https://auth0.local/", "backchannel_authentication_endpoint": "https://auth0.local/custom-authorize", @@ -1267,7 +1310,7 @@ async def test_initiate_backchannel_authentication_success(mocker): # Mock OIDC metadata mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={ "issuer": "https://auth0.local/", "backchannel_authentication_endpoint": "https://auth0.local/backchannel" @@ -1315,7 +1358,7 @@ async def test_initiate_backchannel_authentication_error_response(mocker): ) mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={ "issuer": "https://auth0.local/", "backchannel_authentication_endpoint": "https://auth0.local/backchannel" @@ -1932,3 +1975,952 @@ async def test_complete_connect_account_no_transactions(mocker): # Assert assert "transaction" in str(exc.value) mock_my_account_client.complete_connect_account.assert_not_awaited() + + +# ============================================================================= +# Requirement 1: Multiple Issuer Configuration Methods Tests +# ============================================================================= + +@pytest.mark.asyncio +async def test_domain_as_static_string(): + """Test Method 1: Static domain string configuration.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client_id", + client_secret="test_client_secret", + secret="test_secret_key_32_chars_long!!" + ) + + assert client._domain == "tenant.auth0.com" + assert client._domain_resolver is None + + +@pytest.mark.asyncio +async def test_domain_as_callable_function(): + """Test Method 2: Domain resolver function configuration.""" + async def domain_resolver(store_options): + return "tenant.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client_id", + client_secret="test_client_secret", + secret="test_secret_key_32_chars_long!!" + ) + + assert client._domain is None + assert client._domain_resolver == domain_resolver + + +@pytest.mark.asyncio +async def test_missing_domain_raises_configuration_error(): + """Test that missing domain parameter raises ConfigurationError.""" + with pytest.raises(ConfigurationError, match="Domain is required"): + ServerClient( + domain=None, + client_id="test_client_id", + client_secret="test_client_secret", + secret="test_secret_key_32_chars_long!!" + ) + + +@pytest.mark.asyncio +async def test_invalid_domain_type_list(): + """Test that list domain raises ConfigurationError.""" + with pytest.raises(ConfigurationError, match="must be either a string or a callable"): + ServerClient( + domain=["tenant.auth0.com"], + client_id="test_client_id", + client_secret="test_client_secret", + secret="test_secret_key_32_chars_long!!" + ) + + +@pytest.mark.asyncio +async def test_empty_domain_string(): + """Test that empty domain string raises ConfigurationError.""" + with pytest.raises(ConfigurationError, match="Domain cannot be empty"): + ServerClient( + domain="", + client_id="test_client_id", + client_secret="test_client_secret", + secret="test_secret_key_32_chars_long!!" + ) + + +# ============================================================================= +# Requirement 2: Domain Resolver Context Tests +# ============================================================================= + +@pytest.mark.asyncio +async def test_domain_resolver_receives_context(mocker): + """Test that domain resolver receives DomainResolverContext with request data.""" + received_context = None + + async def domain_resolver(context): + nonlocal received_context + received_context = context + return "tenant.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Mock request with headers + mock_request = MagicMock() + mock_request.url = "https://a.my-app.com/auth/login" + mock_request.headers = {"host": "a.my-app.com", "x-forwarded-host": "a.my-app.com"} + + # Mock OIDC metadata fetch + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "authorization_endpoint": "https://tenant.auth0.com/authorize"} + ) + + try: + await client.start_interactive_login(store_options={"request": mock_request}) + except: + pass # We only care about context being passed + + assert received_context is not None + assert isinstance(received_context, DomainResolverContext) + assert received_context.request_url == "https://a.my-app.com/auth/login" + assert received_context.request_headers.get("host") == "a.my-app.com" + + +@pytest.mark.asyncio +async def test_domain_resolver_error_on_none(): + """Test that domain resolver returning None raises DomainResolverError.""" + async def bad_resolver(context): + return None + + client = ServerClient( + domain=bad_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + with pytest.raises(DomainResolverError, match="returned None"): + await client.start_interactive_login(store_options={"request": MagicMock()}) + + +@pytest.mark.asyncio +async def test_domain_resolver_error_on_empty_string(): + """Test that domain resolver returning empty string raises DomainResolverError.""" + async def bad_resolver(context): + return "" + + client = ServerClient( + domain=bad_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + with pytest.raises(DomainResolverError, match="empty string"): + await client.start_interactive_login(store_options={"request": MagicMock()}) + + +@pytest.mark.asyncio +async def test_domain_resolver_error_on_exception(): + """Test that domain resolver exceptions are wrapped in DomainResolverError.""" + async def bad_resolver(context): + raise ValueError("Something went wrong") + + client = ServerClient( + domain=bad_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + with pytest.raises(DomainResolverError, match="raised an exception"): + await client.start_interactive_login(store_options={"request": MagicMock()}) + + +@pytest.mark.asyncio +async def test_domain_resolver_with_no_request(mocker): + """Test that domain resolver works with empty context when no request.""" + received_context = None + + async def domain_resolver(context): + nonlocal received_context + received_context = context + return "tenant.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "authorization_endpoint": "https://tenant.auth0.com/authorize"} + ) + + try: + await client.start_interactive_login(store_options=None) + except: + pass + + assert received_context is not None + assert received_context.request_url is None + assert received_context.request_headers is None + + +@pytest.mark.asyncio +async def test_domain_resolver_error_on_non_string_type(): + """Test that domain resolver returning non-string raises DomainResolverError.""" + async def bad_resolver(context): + return 12345 + + client = ServerClient( + domain=bad_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + with pytest.raises(DomainResolverError, match="must return a string"): + await client.start_interactive_login(store_options={"request": MagicMock()}) + + +# ============================================================================= +# Requirement 3: OIDC Metadata and JWKS Fetching Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_fetch_jwks_success(): + """Test successful JWKS fetch from URI.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + mock_jwks = { + "keys": [ + { + "kty": "RSA", + "use": "sig", + "kid": "test-key-id", + "n": "test-modulus", + "e": "AQAB" + } + ] + } + + # Mock httpx client + mock_response = MagicMock() + mock_response.json.return_value = mock_jwks + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.get.return_value = mock_response + + with patch('httpx.AsyncClient', return_value=mock_client): + jwks = await client._fetch_jwks("https://tenant.auth0.com/.well-known/jwks.json") + + assert jwks == mock_jwks + assert "keys" in jwks + mock_client.get.assert_awaited_once_with("https://tenant.auth0.com/.well-known/jwks.json") + + +@pytest.mark.asyncio +async def test_fetch_jwks_failure(): + """Test JWKS fetch failure raises ApiError.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Mock httpx client to raise exception + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.get.side_effect = Exception("Network error") + + with patch('httpx.AsyncClient', return_value=mock_client): + with pytest.raises(ApiError, match="Failed to fetch JWKS"): + await client._fetch_jwks("https://tenant.auth0.com/.well-known/jwks.json") + + +@pytest.mark.asyncio +async def test_oidc_metadata_caching(): + """Test OIDC metadata is cached and reused.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + mock_metadata = { + "issuer": "https://tenant.auth0.com/", + "authorization_endpoint": "https://tenant.auth0.com/authorize", + "token_endpoint": "https://tenant.auth0.com/oauth/token", + "jwks_uri": "https://tenant.auth0.com/.well-known/jwks.json" + } + + # Mock _fetch_oidc_metadata to track calls + fetch_count = 0 + async def mock_fetch(domain): + nonlocal fetch_count + fetch_count += 1 + return mock_metadata + + client._fetch_oidc_metadata = mock_fetch + + # First call - should fetch + result1 = await client._get_oidc_metadata_cached("tenant.auth0.com") + assert result1 == mock_metadata + assert fetch_count == 1 + + # Second call - should use cache + result2 = await client._get_oidc_metadata_cached("tenant.auth0.com") + assert result2 == mock_metadata + assert fetch_count == 1 # Should NOT increment + + # Verify cache contains data + assert "tenant.auth0.com" in client._metadata_cache + assert client._metadata_cache["tenant.auth0.com"]["data"] == mock_metadata + + +@pytest.mark.asyncio +async def test_oidc_metadata_cache_expiration(): + """Test OIDC metadata cache expires after TTL.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Set short TTL for testing + client._cache_ttl = 1 # 1 second + + mock_metadata = { + "issuer": "https://tenant.auth0.com/", + "jwks_uri": "https://tenant.auth0.com/.well-known/jwks.json" + } + + fetch_count = 0 + async def mock_fetch(domain): + nonlocal fetch_count + fetch_count += 1 + return mock_metadata + + client._fetch_oidc_metadata = mock_fetch + + # First call + await client._get_oidc_metadata_cached("tenant.auth0.com") + assert fetch_count == 1 + + # Wait for cache to expire + time.sleep(1.1) + + # Second call after expiration - should fetch again + await client._get_oidc_metadata_cached("tenant.auth0.com") + assert fetch_count == 2 + + +@pytest.mark.asyncio +async def test_jwks_caching(): + """Test JWKS is cached and reused.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + mock_metadata = { + "issuer": "https://tenant.auth0.com/", + "jwks_uri": "https://tenant.auth0.com/.well-known/jwks.json" + } + + mock_jwks = { + "keys": [{"kty": "RSA", "kid": "key1"}] + } + + # Mock the fetch methods + client._get_oidc_metadata_cached = AsyncMock(return_value=mock_metadata) + + fetch_count = 0 + async def mock_fetch_jwks(uri): + nonlocal fetch_count + fetch_count += 1 + return mock_jwks + + client._fetch_jwks = mock_fetch_jwks + + # First call - should fetch + result1 = await client._get_jwks_cached("tenant.auth0.com", mock_metadata) + assert result1 == mock_jwks + assert fetch_count == 1 + + # Second call - should use cache + result2 = await client._get_jwks_cached("tenant.auth0.com", mock_metadata) + assert result2 == mock_jwks + assert fetch_count == 1 # Should NOT increment + + +@pytest.mark.asyncio +async def test_jwks_cache_size_limit(): + """Test JWKS cache enforces max size limit with FIFO eviction.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Set small cache size for testing + client._cache_max_size = 3 + + mock_jwks = {"keys": [{"kty": "RSA"}]} + + # Mock methods + async def mock_fetch_metadata(domain): + return {"jwks_uri": f"https://{domain}/.well-known/jwks.json"} + + async def mock_fetch_jwks(uri): + return mock_jwks + + client._fetch_oidc_metadata = mock_fetch_metadata + client._fetch_jwks = mock_fetch_jwks + + # Fill cache to limit + await client._get_jwks_cached("domain1.auth0.com") + await client._get_jwks_cached("domain2.auth0.com") + await client._get_jwks_cached("domain3.auth0.com") + + assert len(client._jwks_cache) == 3 + assert "domain1.auth0.com" in client._jwks_cache + + # Add one more - should evict oldest (domain1) + await client._get_jwks_cached("domain4.auth0.com") + + assert len(client._jwks_cache) == 3 + assert "domain1.auth0.com" not in client._jwks_cache # Evicted + assert "domain4.auth0.com" in client._jwks_cache + + +@pytest.mark.asyncio +async def test_jwks_missing_uri_raises_error(): + """Test that missing jwks_uri in metadata raises ApiError.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Metadata WITHOUT jwks_uri + mock_metadata_no_jwks_uri = { + "issuer": "https://tenant.auth0.com/", + "authorization_endpoint": "https://tenant.auth0.com/authorize" + # No jwks_uri + } + + client._get_oidc_metadata_cached = AsyncMock(return_value=mock_metadata_no_jwks_uri) + + # Should raise ApiError when jwks_uri is missing + with pytest.raises(ApiError) as exc_info: + await client._get_jwks_cached("tenant.auth0.com") + + assert exc_info.value.code == "missing_jwks_uri" + assert "non-RFC-compliant" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_metadata_cache_size_limit(): + """Test OIDC metadata cache enforces max size limit.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + client._cache_max_size = 2 + + async def mock_fetch(domain): + return {"issuer": f"https://{domain}/"} + + client._fetch_oidc_metadata = mock_fetch + + # Fill cache + await client._get_oidc_metadata_cached("domain1.auth0.com") + await client._get_oidc_metadata_cached("domain2.auth0.com") + + assert len(client._metadata_cache) == 2 + + # Add third - should evict first + await client._get_oidc_metadata_cached("domain3.auth0.com") + + assert len(client._metadata_cache) == 2 + assert "domain1.auth0.com" not in client._metadata_cache + assert "domain3.auth0.com" in client._metadata_cache + + +# ============================================================================= +# Requirement 4: Issuer Validation Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_complete_login_issuer_validation_success(mocker): + """Test complete login with valid issuer in ID token.""" + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + origin_domain="tenant.auth0.com", + origin_issuer="https://tenant.auth0.com/" + ) + + mock_state_store = AsyncMock() + + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Mock OIDC metadata + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "token_endpoint": "https://tenant.auth0.com/token"} + ) + + # Mock JWKS fetch + mocker.patch.object( + client, + "_get_jwks_cached", + return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]} + ) + + # Mock OAuth fetch_token + async_fetch_token = AsyncMock() + async_fetch_token.return_value = { + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid profile" + } + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + + # Mock jwt.get_unverified_header + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + + # Mock PyJWK.from_dict + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + + # Mock jwt.decode with valid issuer + mocker.patch("jwt.decode", return_value={ + "sub": "user123", + "iss": "https://tenant.auth0.com/", # Matches origin_issuer + "aud": "test_client" + }) + + # Should succeed without raising error + result = await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + assert result is not None + assert "state_data" in result + + +@pytest.mark.asyncio +async def test_complete_login_issuer_mismatch_raises_error(mocker): + """Test that issuer mismatch in ID token raises ApiError.""" + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + origin_domain="tenant.auth0.com", + origin_issuer="https://tenant.auth0.com/" + ) + + mock_state_store = AsyncMock() + + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Mock OIDC metadata + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "token_endpoint": "https://tenant.auth0.com/token"} + ) + + # Mock JWKS fetch + mocker.patch.object( + client, + "_get_jwks_cached", + return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]} + ) + + # Mock OAuth fetch_token + async_fetch_token = AsyncMock() + async_fetch_token.return_value = { + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid profile" + } + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + + # Mock jwt.get_unverified_header + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + + # Mock PyJWK.from_dict + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + + # Mock jwt.decode to raise InvalidIssuerError + mocker.patch("jwt.decode", side_effect=jwt.InvalidIssuerError("Invalid issuer")) + + # Should raise ApiError with invalid_issuer code + with pytest.raises(ApiError) as exc_info: + await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + assert exc_info.value.code == "invalid_issuer" + assert "issuer mismatch" in str(exc_info.value).lower() + + +@pytest.mark.asyncio +async def test_normalize_domain_handles_different_schemes(): + """Test that _normalize_domain handles various URL schemes correctly.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Test domain without scheme + assert client._normalize_domain("auth0.com") == "https://auth0.com" + + # Test domain with https scheme (should remain unchanged) + assert client._normalize_domain("https://auth0.com") == "https://auth0.com" + + # Test domain with http scheme (should convert to https) + assert client._normalize_domain("http://auth0.com") == "https://auth0.com" + + # Test domain with trailing slash + assert client._normalize_domain("https://auth0.com/") == "https://auth0.com/" + + +# ============================================================================= +# Requirements 5-8: Domain-specific Session Management Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_session_stores_origin_domain(mocker): + """Test that session stores origin domain from login (Requirement 5).""" + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + origin_domain="tenant1.auth0.com", + origin_issuer="https://tenant1.auth0.com/" + ) + + captured_state = None + async def capture_state(identifier, state_data, options=None): + nonlocal captured_state + captured_state = state_data + + mock_state_store = AsyncMock() + mock_state_store.set = AsyncMock(side_effect=capture_state) + + client = ServerClient( + domain="tenant1.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={ + "issuer": "https://tenant1.auth0.com/", + "token_endpoint": "https://tenant1.auth0.com/token" + }) + mocker.patch.object(client, "_get_jwks_cached", return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]}) + + async_fetch_token = AsyncMock(return_value={ + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid" + }) + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + + # Mock JWT verification + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + mocker.patch("jwt.decode", return_value={"sub": "user123", "iss": "https://tenant1.auth0.com/"}) + + await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + # Verify session has domain field set + assert captured_state.domain == "tenant1.auth0.com" + + +@pytest.mark.asyncio +async def test_cross_domain_session_rejected(): + """Test that session from domain1 cannot be used with domain2 (Requirement 5).""" + # Create session with domain1 + session_data = StateData( + user={"sub": "user123"}, + domain="tenant1.auth0.com", + token_sets=[], + internal={"sid": "123", "created_at": int(time.time())} + ) + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + # Domain resolver returns domain2 (different from session) + async def domain_resolver(context): + return "tenant2.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # get_user should return None (session rejected) + user = await client.get_user(store_options={"request": {}}) + assert user is None + + +@pytest.mark.asyncio +async def test_logout_uses_current_domain(mocker): + """Test that logout uses current resolved domain (Requirement 7).""" + current_domain = "tenant2.auth0.com" + + async def domain_resolver(context): + return current_domain + + mock_state_store = AsyncMock() + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + logout_url = await client.logout(store_options={"request": {}}) + + # Verify logout URL uses current domain + assert current_domain in logout_url + assert logout_url.startswith(f"https://{current_domain}") + + +@pytest.mark.asyncio +async def test_logout_clears_session_for_current_domain(): + """Test that logout clears session (Requirement 7).""" + mock_state_store = AsyncMock() + + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + await client.logout() + + # Verify session was deleted + mock_state_store.delete.assert_called_once() + + +@pytest.mark.asyncio +async def test_domain_migration_old_sessions_remain_valid(): + """Test that old sessions remain valid with old domain requests (Requirement 8).""" + old_domain = "old-tenant.auth0.com" + + # Session from old domain + session_data = StateData( + user={"sub": "user123"}, + domain=old_domain, + token_sets=[], + internal={"sid": "123", "created_at": int(time.time())} + ) + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + # Domain resolver returns old domain + async def domain_resolver(context): + return old_domain + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Should successfully retrieve user + user = await client.get_user(store_options={"request": {}}) + assert user is not None + assert user["sub"] == "user123" + + +@pytest.mark.asyncio +async def test_domain_migration_new_sessions_use_new_domain(mocker): + """Test that new logins create sessions with new domain (Requirement 8).""" + new_domain = "new-tenant.auth0.com" + + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + origin_domain=new_domain, + origin_issuer=f"https://{new_domain}/" + ) + + captured_state = None + async def capture_state(identifier, state_data, options=None): + nonlocal captured_state + captured_state = state_data + + mock_state_store = AsyncMock() + mock_state_store.set = AsyncMock(side_effect=capture_state) + + client = ServerClient( + domain=new_domain, + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={ + "issuer": f"https://{new_domain}/", + "token_endpoint": f"https://{new_domain}/token" + }) + mocker.patch.object(client, "_get_jwks_cached", return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]}) + + async_fetch_token = AsyncMock(return_value={ + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid" + }) + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + + # Mock JWT verification + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + mocker.patch("jwt.decode", return_value={"sub": "user123", "iss": f"https://{new_domain}/"}) + + await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + # Verify new session has new domain + assert captured_state.domain == new_domain + + +@pytest.mark.asyncio +async def test_domain_migration_sessions_isolated(): + """Test that old domain sessions cannot be used with new domain (Requirement 8).""" + old_domain = "old-tenant.auth0.com" + new_domain = "new-tenant.auth0.com" + + # Session from old domain + session_data = StateData( + user={"sub": "user123"}, + domain=old_domain, + token_sets=[], + internal={"sid": "123", "created_at": int(time.time())} + ) + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + # Domain resolver returns NEW domain (migration happened) + async def domain_resolver(context): + return new_domain + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Should reject old session + user = await client.get_user(store_options={"request": {}}) + assert user is None \ No newline at end of file diff --git a/src/auth0_server_python/utils/helpers.py b/src/auth0_server_python/utils/helpers.py index c57ab18..1a49835 100644 --- a/src/auth0_server_python/utils/helpers.py +++ b/src/auth0_server_python/utils/helpers.py @@ -5,6 +5,8 @@ import time from typing import Any, Optional from urllib.parse import parse_qs, urlencode, urlparse +from auth0_server_python.auth_types import DomainResolverContext +from auth0_server_python.error import DomainResolverError class PKCE: @@ -224,3 +226,69 @@ def create_logout_url(domain: str, client_id: str, return_to: Optional[str] = No if return_to: params["returnTo"] = return_to return URL.build_url(base_url, params) + + +# ============================================================================= +# Domain Resolver Utilities +# ============================================================================= + +def build_domain_resolver_context(store_options: Optional[dict[str, Any]]) -> 'DomainResolverContext': + """ + Build DomainResolverContext from store_options. + + Extracts request information in a framework-agnostic way using duck typing. + + Args: + store_options: Dictionary containing 'request' and 'response' objects + + Returns: + DomainResolverContext with extracted request data + """ + + if not store_options: + return DomainResolverContext() + + request = store_options.get('request') + if not request: + return DomainResolverContext() + + # Framework-agnostic extraction using duck typing + request_url = str(request.url) if hasattr(request, 'url') else None + request_headers = dict(request.headers) if hasattr(request, 'headers') else None + + return DomainResolverContext( + request_url=request_url, + request_headers=request_headers + ) + + +def validate_resolved_domain_value(domain_value: Any) -> str: + """ + Validate the value returned by domain resolver. + + Args: + domain_value: The value returned by the domain resolver + + Returns: + The validated domain string + + Raises: + DomainResolverError: If the returned value is invalid + """ + + if domain_value is None: + raise DomainResolverError( + "Domain resolver returned None. Must return a valid domain string." + ) + + if not isinstance(domain_value, str): + raise DomainResolverError( + f"Domain resolver must return a string. Got {type(domain_value).__name__} instead." + ) + + if not domain_value.strip(): + raise DomainResolverError( + "Domain resolver returned an empty string. Must return a valid domain." + ) + + return domain_value From 116484868b468235dbe219f996e76c625a0538d1 Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Mon, 2 Feb 2026 23:45:32 +0530 Subject: [PATCH 02/16] Bump poetry version from latest to 2.2.1 in test workflow --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d6e025f..9c87ea2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,7 +36,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: - version: latest + version: 2.2.1 virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true From 26466821e66c76b9a09f3648b532e848807f0b54 Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Mon, 2 Feb 2026 23:51:56 +0530 Subject: [PATCH 03/16] Fix linting errors --- .../auth_server/server_client.py | 86 +++---- .../auth_types/__init__.py | 6 +- src/auth0_server_python/error/__init__.py | 2 +- .../tests/test_server_client.py | 222 +++++++++--------- src/auth0_server_python/utils/helpers.py | 29 +-- 5 files changed, 173 insertions(+), 172 deletions(-) diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index bee5541..56cd636 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -184,7 +184,7 @@ async def start_interactive_login( ) else: origin_domain = self._domain - + # Fetch OIDC metadata from resolved domain try: metadata = await self._get_oidc_metadata_cached(origin_domain) @@ -243,7 +243,7 @@ async def start_interactive_login( transaction_data, options=store_options ) - + # Set metadata for OAuth client self._oauth.metadata = metadata # If PAR is enabled, use the PAR endpoint @@ -339,7 +339,7 @@ async def complete_interactive_login( # Get origin domain and issuer from transaction origin_domain = transaction_data.origin_domain origin_issuer = transaction_data.origin_issuer - + # Fetch metadata from the origin domain metadata = await self._get_oidc_metadata_cached(origin_domain) self._oauth.metadata = metadata @@ -365,32 +365,32 @@ async def complete_interactive_login( user_info = token_response.get("userinfo") user_claims = None id_token = token_response.get("id_token") - + if user_info: user_claims = UserClaims.parse_obj(user_info) elif id_token: # Fetch JWKS for signature verification (Requirement 3) jwks = await self._get_jwks_cached(origin_domain, metadata) - + # Decode and verify ID token with signature verification enabled try: # Get the signing key from JWKS unverified_header = jwt.get_unverified_header(id_token) kid = unverified_header.get("kid") - + # Find the key with matching kid signing_key = None for key in jwks.get("keys", []): if key.get("kid") == kid: signing_key = jwt.PyJWK.from_dict(key) break - + if not signing_key: raise ApiError( "jwks_key_not_found", f"No matching key found in JWKS for kid: {kid}" ) - + claims = jwt.decode( id_token, signing_key.key, @@ -430,7 +430,7 @@ async def complete_interactive_login( f"ID token verification failed: {str(e)}", e ) - + # Build a token set using the token response data token_set = TokenSet( @@ -708,11 +708,11 @@ async def get_user(self, store_options: Optional[dict[str, Any]] = None) -> Opti original_error=e ) session_domain = getattr(state_data, 'domain', None) - + if session_domain and session_domain != current_domain: # Session created with different domain - reject for security return None - + if hasattr(state_data, "dict") and callable(state_data.dict): state_data = state_data.dict() return state_data.get("user") @@ -745,11 +745,11 @@ async def get_session(self, store_options: Optional[dict[str, Any]] = None) -> O original_error=e ) session_domain = getattr(state_data, 'domain', None) - + if session_domain and session_domain != current_domain: # Session created with different domain - reject for security return None - + if hasattr(state_data, "dict") and callable(state_data.dict): state_data = state_data.dict() session_data = {k: v for k, v in state_data.items() @@ -793,7 +793,7 @@ async def get_access_token( original_error=e ) session_domain = getattr(state_data, 'domain', None) - + if session_domain and session_domain != current_domain: # Session created with different domain - reject for security raise AccessTokenError( @@ -832,7 +832,7 @@ async def get_access_token( # Get new token with refresh token try: - # Use session's domain for token refresh + # Use session's domain for token refresh session_domain = state_data_dict.get("domain") or self._domain get_refresh_token_options = { "refresh_token": state_data_dict["refresh_token"], @@ -863,7 +863,7 @@ async def get_access_token( f"Failed to get token with refresh token: {str(e)}" ) - + async def get_access_token_for_connection( self, @@ -987,25 +987,25 @@ async def handle_backchannel_logout( try: # Fetch JWKS for signature verification (Requirement 3) jwks = await self._get_jwks_cached(self._domain) - + # Decode and verify logout token with signature verification enabled try: # Get the signing key from JWKS unverified_header = jwt.get_unverified_header(logout_token) kid = unverified_header.get("kid") - + # Find the key with matching kid signing_key = None for key in jwks.get("keys", []): if key.get("kid") == kid: signing_key = jwt.PyJWK.from_dict(key) break - + if not signing_key: raise BackchannelLogoutError( f"No matching key found in JWKS for kid: {kid}" ) - + claims = jwt.decode( logout_token, signing_key.key, @@ -1071,47 +1071,47 @@ async def _fetch_oidc_metadata(self, domain: str) -> dict: async def _get_oidc_metadata_cached(self, domain: str) -> dict: """ Get OIDC metadata with caching. - + Args: domain: Auth0 domain - + Returns: OIDC metadata document """ now = time.time() - + # Check cache if domain in self._metadata_cache: cached = self._metadata_cache[domain] if cached["expires_at"] > now: return cached["data"] - + # Cache miss/expired - fetch fresh metadata = await self._fetch_oidc_metadata(domain) - + # Enforce cache size limit (FIFO eviction) if len(self._metadata_cache) >= self._cache_max_size: oldest_key = next(iter(self._metadata_cache)) del self._metadata_cache[oldest_key] - + # Store in cache self._metadata_cache[domain] = { "data": metadata, "expires_at": now + self._cache_ttl } - + return metadata async def _fetch_jwks(self, jwks_uri: str) -> dict: """ Fetch JWKS (JSON Web Key Set) from jwks_uri. - + Args: jwks_uri: The JWKS endpoint URL - + Returns: JWKS document containing public keys - + Raises: ApiError: If JWKS fetch fails """ @@ -1126,54 +1126,54 @@ async def _fetch_jwks(self, jwks_uri: str) -> dict: async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: """ Get JWKS with caching usingOIDC discovery. - + Args: domain: Auth0 domain metadata: Optional OIDC metadata (if already fetched) - + Returns: JWKS document - + Raises: ApiError: If JWKS fetch fails or jwks_uri missing from metadata """ now = time.time() - + # Check cache if domain in self._jwks_cache: cached = self._jwks_cache[domain] if cached["expires_at"] > now: return cached["data"] - + # Get jwks_uri from OIDC metadata if not metadata: metadata = await self._get_oidc_metadata_cached(domain) - + jwks_uri = metadata.get('jwks_uri') if not jwks_uri: raise ApiError( "missing_jwks_uri", f"OIDC metadata for {domain} does not contain jwks_uri. Provider may be non-RFC-compliant." ) - + # Fetch JWKS jwks = await self._fetch_jwks(jwks_uri) - + # Enforce cache size limit (FIFO eviction) if len(self._jwks_cache) >= self._cache_max_size: oldest_key = next(iter(self._jwks_cache)) del self._jwks_cache[oldest_key] - + # Store in cache self._jwks_cache[domain] = { "data": jwks, "expires_at": now + self._cache_ttl } - + return jwks # ------------------------------------------ - # Token & Scope Management - MRRT + # Token & Scope Management - MRRT # ------------------------------------------ def _merge_scope_with_defaults( @@ -1606,7 +1606,7 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, try: # Use session domain if provided, otherwise fallback to static domain domain = options.get("domain") or self._domain - + # Ensure we have the OIDC metadata from the correct domain if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: self._oauth.metadata = await self._get_oidc_metadata_cached(domain) @@ -1692,7 +1692,7 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A try: # Use session domain if provided, otherwise fallback to static domain domain = options.get("domain") or self._domain - + # Ensure we have OIDC metadata from the correct domain if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: self._oauth.metadata = await self._get_oidc_metadata_cached(domain) diff --git a/src/auth0_server_python/auth_types/__init__.py b/src/auth0_server_python/auth_types/__init__.py index 7c27d36..6686a3f 100644 --- a/src/auth0_server_python/auth_types/__init__.py +++ b/src/auth0_server_python/auth_types/__init__.py @@ -260,14 +260,14 @@ class CompleteConnectAccountResponse(BaseModel): class DomainResolverContext(BaseModel): """ Context passed to domain resolver function for MCD support. - + Contains request information needed to determine the correct Auth0 domain based on the incoming request's hostname or headers. - + Attributes: request_url: The full request URL (e.g., "https://a.my-app.com/auth/login") request_headers: Dictionary of request headers (e.g., {"host": "a.my-app.com", "x-forwarded-host": "..."}) - + Example: async def domain_resolver(context: DomainResolverContext) -> str: host = context.request_headers.get('host', '').split(':')[0] diff --git a/src/auth0_server_python/error/__init__.py b/src/auth0_server_python/error/__init__.py index 93fcba2..6b863e1 100644 --- a/src/auth0_server_python/error/__init__.py +++ b/src/auth0_server_python/error/__init__.py @@ -128,7 +128,7 @@ def __init__(self, message: str): class DomainResolverError(Auth0Error): """ Error raised when domain resolver function fails or returns invalid value. - + This error indicates an issue with the custom domain resolver function provided for MCD (Multiple Custom Domains) support. """ diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index 97fcc77..7a9665c 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -56,14 +56,14 @@ async def test_start_interactive_login_no_redirect_uri(mocker): transaction_store=AsyncMock(), secret="some-secret" ) - + # Mock OIDC metadata fetch mocker.patch.object( client, "_get_oidc_metadata_cached", return_value={"issuer": "https://auth0.local/", "authorization_endpoint": "https://auth0.local/authorize"} ) - + with pytest.raises(MissingRequiredArgumentError) as exc: await client.start_interactive_login() # Check the error message @@ -130,7 +130,7 @@ async def test_complete_interactive_login_returns_app_state(mocker): mock_tx_store = AsyncMock() # The stored transaction includes an appState with origin_domain and origin_issuer mock_tx_store.get.return_value = TransactionData( - code_verifier="123", + code_verifier="123", app_state={"foo": "bar"}, origin_domain="auth0.local", origin_issuer="https://auth0.local/" @@ -1990,7 +1990,7 @@ async def test_domain_as_static_string(): client_secret="test_client_secret", secret="test_secret_key_32_chars_long!!" ) - + assert client._domain == "tenant.auth0.com" assert client._domain_resolver is None @@ -2000,14 +2000,14 @@ async def test_domain_as_callable_function(): """Test Method 2: Domain resolver function configuration.""" async def domain_resolver(store_options): return "tenant.auth0.com" - + client = ServerClient( domain=domain_resolver, client_id="test_client_id", client_secret="test_client_secret", secret="test_secret_key_32_chars_long!!" ) - + assert client._domain is None assert client._domain_resolver == domain_resolver @@ -2056,12 +2056,12 @@ async def test_empty_domain_string(): async def test_domain_resolver_receives_context(mocker): """Test that domain resolver receives DomainResolverContext with request data.""" received_context = None - + async def domain_resolver(context): nonlocal received_context received_context = context return "tenant.auth0.com" - + client = ServerClient( domain=domain_resolver, client_id="test_client", @@ -2070,24 +2070,24 @@ async def domain_resolver(context): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + # Mock request with headers mock_request = MagicMock() mock_request.url = "https://a.my-app.com/auth/login" mock_request.headers = {"host": "a.my-app.com", "x-forwarded-host": "a.my-app.com"} - + # Mock OIDC metadata fetch mocker.patch.object( client, "_get_oidc_metadata_cached", return_value={"issuer": "https://tenant.auth0.com/", "authorization_endpoint": "https://tenant.auth0.com/authorize"} ) - + try: await client.start_interactive_login(store_options={"request": mock_request}) - except: + except Exception: # noqa: S110 pass # We only care about context being passed - + assert received_context is not None assert isinstance(received_context, DomainResolverContext) assert received_context.request_url == "https://a.my-app.com/auth/login" @@ -2099,7 +2099,7 @@ async def test_domain_resolver_error_on_none(): """Test that domain resolver returning None raises DomainResolverError.""" async def bad_resolver(context): return None - + client = ServerClient( domain=bad_resolver, client_id="test_client", @@ -2108,7 +2108,7 @@ async def bad_resolver(context): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + with pytest.raises(DomainResolverError, match="returned None"): await client.start_interactive_login(store_options={"request": MagicMock()}) @@ -2118,7 +2118,7 @@ async def test_domain_resolver_error_on_empty_string(): """Test that domain resolver returning empty string raises DomainResolverError.""" async def bad_resolver(context): return "" - + client = ServerClient( domain=bad_resolver, client_id="test_client", @@ -2127,7 +2127,7 @@ async def bad_resolver(context): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + with pytest.raises(DomainResolverError, match="empty string"): await client.start_interactive_login(store_options={"request": MagicMock()}) @@ -2137,7 +2137,7 @@ async def test_domain_resolver_error_on_exception(): """Test that domain resolver exceptions are wrapped in DomainResolverError.""" async def bad_resolver(context): raise ValueError("Something went wrong") - + client = ServerClient( domain=bad_resolver, client_id="test_client", @@ -2146,7 +2146,7 @@ async def bad_resolver(context): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + with pytest.raises(DomainResolverError, match="raised an exception"): await client.start_interactive_login(store_options={"request": MagicMock()}) @@ -2155,12 +2155,12 @@ async def bad_resolver(context): async def test_domain_resolver_with_no_request(mocker): """Test that domain resolver works with empty context when no request.""" received_context = None - + async def domain_resolver(context): nonlocal received_context received_context = context return "tenant.auth0.com" - + client = ServerClient( domain=domain_resolver, client_id="test_client", @@ -2169,18 +2169,18 @@ async def domain_resolver(context): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + mocker.patch.object( client, "_get_oidc_metadata_cached", return_value={"issuer": "https://tenant.auth0.com/", "authorization_endpoint": "https://tenant.auth0.com/authorize"} ) - + try: await client.start_interactive_login(store_options=None) - except: - pass - + except Exception: # noqa: S110 + pass # Intentionally ignore - testing context only + assert received_context is not None assert received_context.request_url is None assert received_context.request_headers is None @@ -2191,7 +2191,7 @@ async def test_domain_resolver_error_on_non_string_type(): """Test that domain resolver returning non-string raises DomainResolverError.""" async def bad_resolver(context): return 12345 - + client = ServerClient( domain=bad_resolver, client_id="test_client", @@ -2200,7 +2200,7 @@ async def bad_resolver(context): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + with pytest.raises(DomainResolverError, match="must return a string"): await client.start_interactive_login(store_options={"request": MagicMock()}) @@ -2221,7 +2221,7 @@ async def test_fetch_jwks_success(): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + mock_jwks = { "keys": [ { @@ -2233,20 +2233,20 @@ async def test_fetch_jwks_success(): } ] } - + # Mock httpx client mock_response = MagicMock() mock_response.json.return_value = mock_jwks mock_response.raise_for_status = MagicMock() - + mock_client = AsyncMock() mock_client.__aenter__.return_value = mock_client mock_client.__aexit__.return_value = None mock_client.get.return_value = mock_response - + with patch('httpx.AsyncClient', return_value=mock_client): jwks = await client._fetch_jwks("https://tenant.auth0.com/.well-known/jwks.json") - + assert jwks == mock_jwks assert "keys" in jwks mock_client.get.assert_awaited_once_with("https://tenant.auth0.com/.well-known/jwks.json") @@ -2263,13 +2263,13 @@ async def test_fetch_jwks_failure(): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + # Mock httpx client to raise exception mock_client = AsyncMock() mock_client.__aenter__.return_value = mock_client mock_client.__aexit__.return_value = None mock_client.get.side_effect = Exception("Network error") - + with patch('httpx.AsyncClient', return_value=mock_client): with pytest.raises(ApiError, match="Failed to fetch JWKS"): await client._fetch_jwks("https://tenant.auth0.com/.well-known/jwks.json") @@ -2286,33 +2286,33 @@ async def test_oidc_metadata_caching(): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + mock_metadata = { "issuer": "https://tenant.auth0.com/", "authorization_endpoint": "https://tenant.auth0.com/authorize", "token_endpoint": "https://tenant.auth0.com/oauth/token", "jwks_uri": "https://tenant.auth0.com/.well-known/jwks.json" } - + # Mock _fetch_oidc_metadata to track calls fetch_count = 0 async def mock_fetch(domain): nonlocal fetch_count fetch_count += 1 return mock_metadata - + client._fetch_oidc_metadata = mock_fetch - + # First call - should fetch result1 = await client._get_oidc_metadata_cached("tenant.auth0.com") assert result1 == mock_metadata assert fetch_count == 1 - + # Second call - should use cache result2 = await client._get_oidc_metadata_cached("tenant.auth0.com") assert result2 == mock_metadata assert fetch_count == 1 # Should NOT increment - + # Verify cache contains data assert "tenant.auth0.com" in client._metadata_cache assert client._metadata_cache["tenant.auth0.com"]["data"] == mock_metadata @@ -2329,30 +2329,30 @@ async def test_oidc_metadata_cache_expiration(): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + # Set short TTL for testing client._cache_ttl = 1 # 1 second - + mock_metadata = { "issuer": "https://tenant.auth0.com/", "jwks_uri": "https://tenant.auth0.com/.well-known/jwks.json" } - + fetch_count = 0 async def mock_fetch(domain): nonlocal fetch_count fetch_count += 1 return mock_metadata - + client._fetch_oidc_metadata = mock_fetch - + # First call await client._get_oidc_metadata_cached("tenant.auth0.com") assert fetch_count == 1 - + # Wait for cache to expire time.sleep(1.1) - + # Second call after expiration - should fetch again await client._get_oidc_metadata_cached("tenant.auth0.com") assert fetch_count == 2 @@ -2369,32 +2369,32 @@ async def test_jwks_caching(): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + mock_metadata = { "issuer": "https://tenant.auth0.com/", "jwks_uri": "https://tenant.auth0.com/.well-known/jwks.json" } - + mock_jwks = { "keys": [{"kty": "RSA", "kid": "key1"}] } - + # Mock the fetch methods client._get_oidc_metadata_cached = AsyncMock(return_value=mock_metadata) - + fetch_count = 0 async def mock_fetch_jwks(uri): nonlocal fetch_count fetch_count += 1 return mock_jwks - + client._fetch_jwks = mock_fetch_jwks - + # First call - should fetch result1 = await client._get_jwks_cached("tenant.auth0.com", mock_metadata) assert result1 == mock_jwks assert fetch_count == 1 - + # Second call - should use cache result2 = await client._get_jwks_cached("tenant.auth0.com", mock_metadata) assert result2 == mock_jwks @@ -2412,33 +2412,33 @@ async def test_jwks_cache_size_limit(): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + # Set small cache size for testing client._cache_max_size = 3 - + mock_jwks = {"keys": [{"kty": "RSA"}]} - + # Mock methods async def mock_fetch_metadata(domain): return {"jwks_uri": f"https://{domain}/.well-known/jwks.json"} - + async def mock_fetch_jwks(uri): return mock_jwks - + client._fetch_oidc_metadata = mock_fetch_metadata client._fetch_jwks = mock_fetch_jwks - + # Fill cache to limit await client._get_jwks_cached("domain1.auth0.com") await client._get_jwks_cached("domain2.auth0.com") await client._get_jwks_cached("domain3.auth0.com") - + assert len(client._jwks_cache) == 3 assert "domain1.auth0.com" in client._jwks_cache - + # Add one more - should evict oldest (domain1) await client._get_jwks_cached("domain4.auth0.com") - + assert len(client._jwks_cache) == 3 assert "domain1.auth0.com" not in client._jwks_cache # Evicted assert "domain4.auth0.com" in client._jwks_cache @@ -2455,20 +2455,20 @@ async def test_jwks_missing_uri_raises_error(): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + # Metadata WITHOUT jwks_uri mock_metadata_no_jwks_uri = { "issuer": "https://tenant.auth0.com/", "authorization_endpoint": "https://tenant.auth0.com/authorize" # No jwks_uri } - + client._get_oidc_metadata_cached = AsyncMock(return_value=mock_metadata_no_jwks_uri) - + # Should raise ApiError when jwks_uri is missing with pytest.raises(ApiError) as exc_info: await client._get_jwks_cached("tenant.auth0.com") - + assert exc_info.value.code == "missing_jwks_uri" assert "non-RFC-compliant" in str(exc_info.value) @@ -2484,23 +2484,23 @@ async def test_metadata_cache_size_limit(): transaction_store=AsyncMock(), state_store=AsyncMock() ) - + client._cache_max_size = 2 - + async def mock_fetch(domain): return {"issuer": f"https://{domain}/"} - + client._fetch_oidc_metadata = mock_fetch - + # Fill cache await client._get_oidc_metadata_cached("domain1.auth0.com") await client._get_oidc_metadata_cached("domain2.auth0.com") - + assert len(client._metadata_cache) == 2 - + # Add third - should evict first await client._get_oidc_metadata_cached("domain3.auth0.com") - + assert len(client._metadata_cache) == 2 assert "domain1.auth0.com" not in client._metadata_cache assert "domain3.auth0.com" in client._metadata_cache @@ -2557,12 +2557,12 @@ async def test_complete_login_issuer_validation_success(mocker): # Mock jwt.get_unverified_header mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) - + # Mock PyJWK.from_dict mock_signing_key = mocker.MagicMock() mock_signing_key.key = "mock_pem_key" mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) - + # Mock jwt.decode with valid issuer mocker.patch("jwt.decode", return_value={ "sub": "user123", @@ -2572,7 +2572,7 @@ async def test_complete_login_issuer_validation_success(mocker): # Should succeed without raising error result = await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") - + assert result is not None assert "state_data" in result @@ -2623,19 +2623,19 @@ async def test_complete_login_issuer_mismatch_raises_error(mocker): # Mock jwt.get_unverified_header mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) - + # Mock PyJWK.from_dict mock_signing_key = mocker.MagicMock() mock_signing_key.key = "mock_pem_key" mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) - + # Mock jwt.decode to raise InvalidIssuerError mocker.patch("jwt.decode", side_effect=jwt.InvalidIssuerError("Invalid issuer")) # Should raise ApiError with invalid_issuer code with pytest.raises(ApiError) as exc_info: await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") - + assert exc_info.value.code == "invalid_issuer" assert "issuer mismatch" in str(exc_info.value).lower() @@ -2654,13 +2654,13 @@ async def test_normalize_domain_handles_different_schemes(): # Test domain without scheme assert client._normalize_domain("auth0.com") == "https://auth0.com" - + # Test domain with https scheme (should remain unchanged) assert client._normalize_domain("https://auth0.com") == "https://auth0.com" - + # Test domain with http scheme (should convert to https) assert client._normalize_domain("http://auth0.com") == "https://auth0.com" - + # Test domain with trailing slash assert client._normalize_domain("https://auth0.com/") == "https://auth0.com/" @@ -2684,7 +2684,7 @@ async def test_session_stores_origin_domain(mocker): async def capture_state(identifier, state_data, options=None): nonlocal captured_state captured_state = state_data - + mock_state_store = AsyncMock() mock_state_store.set = AsyncMock(side_effect=capture_state) @@ -2702,14 +2702,14 @@ async def capture_state(identifier, state_data, options=None): "token_endpoint": "https://tenant1.auth0.com/token" }) mocker.patch.object(client, "_get_jwks_cached", return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]}) - + async_fetch_token = AsyncMock(return_value={ "access_token": "token123", "id_token": "id_token_jwt", "scope": "openid" }) mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) - + # Mock JWT verification mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) mock_signing_key = mocker.MagicMock() @@ -2718,7 +2718,7 @@ async def capture_state(identifier, state_data, options=None): mocker.patch("jwt.decode", return_value={"sub": "user123", "iss": "https://tenant1.auth0.com/"}) await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") - + # Verify session has domain field set assert captured_state.domain == "tenant1.auth0.com" @@ -2733,14 +2733,14 @@ async def test_cross_domain_session_rejected(): token_sets=[], internal={"sid": "123", "created_at": int(time.time())} ) - + mock_state_store = AsyncMock() mock_state_store.get.return_value = session_data - + # Domain resolver returns domain2 (different from session) async def domain_resolver(context): return "tenant2.auth0.com" - + client = ServerClient( domain=domain_resolver, client_id="test_client", @@ -2759,12 +2759,12 @@ async def domain_resolver(context): async def test_logout_uses_current_domain(mocker): """Test that logout uses current resolved domain (Requirement 7).""" current_domain = "tenant2.auth0.com" - + async def domain_resolver(context): return current_domain - + mock_state_store = AsyncMock() - + client = ServerClient( domain=domain_resolver, client_id="test_client", @@ -2775,7 +2775,7 @@ async def domain_resolver(context): ) logout_url = await client.logout(store_options={"request": {}}) - + # Verify logout URL uses current domain assert current_domain in logout_url assert logout_url.startswith(f"https://{current_domain}") @@ -2785,7 +2785,7 @@ async def domain_resolver(context): async def test_logout_clears_session_for_current_domain(): """Test that logout clears session (Requirement 7).""" mock_state_store = AsyncMock() - + client = ServerClient( domain="tenant.auth0.com", client_id="test_client", @@ -2796,7 +2796,7 @@ async def test_logout_clears_session_for_current_domain(): ) await client.logout() - + # Verify session was deleted mock_state_store.delete.assert_called_once() @@ -2805,7 +2805,7 @@ async def test_logout_clears_session_for_current_domain(): async def test_domain_migration_old_sessions_remain_valid(): """Test that old sessions remain valid with old domain requests (Requirement 8).""" old_domain = "old-tenant.auth0.com" - + # Session from old domain session_data = StateData( user={"sub": "user123"}, @@ -2813,14 +2813,14 @@ async def test_domain_migration_old_sessions_remain_valid(): token_sets=[], internal={"sid": "123", "created_at": int(time.time())} ) - + mock_state_store = AsyncMock() mock_state_store.get.return_value = session_data - + # Domain resolver returns old domain async def domain_resolver(context): return old_domain - + client = ServerClient( domain=domain_resolver, client_id="test_client", @@ -2840,7 +2840,7 @@ async def domain_resolver(context): async def test_domain_migration_new_sessions_use_new_domain(mocker): """Test that new logins create sessions with new domain (Requirement 8).""" new_domain = "new-tenant.auth0.com" - + mock_tx_store = AsyncMock() mock_tx_store.get.return_value = TransactionData( code_verifier="123", @@ -2852,7 +2852,7 @@ async def test_domain_migration_new_sessions_use_new_domain(mocker): async def capture_state(identifier, state_data, options=None): nonlocal captured_state captured_state = state_data - + mock_state_store = AsyncMock() mock_state_store.set = AsyncMock(side_effect=capture_state) @@ -2870,14 +2870,14 @@ async def capture_state(identifier, state_data, options=None): "token_endpoint": f"https://{new_domain}/token" }) mocker.patch.object(client, "_get_jwks_cached", return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]}) - + async_fetch_token = AsyncMock(return_value={ "access_token": "token123", "id_token": "id_token_jwt", "scope": "openid" }) mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) - + # Mock JWT verification mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) mock_signing_key = mocker.MagicMock() @@ -2886,7 +2886,7 @@ async def capture_state(identifier, state_data, options=None): mocker.patch("jwt.decode", return_value={"sub": "user123", "iss": f"https://{new_domain}/"}) await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") - + # Verify new session has new domain assert captured_state.domain == new_domain @@ -2896,7 +2896,7 @@ async def test_domain_migration_sessions_isolated(): """Test that old domain sessions cannot be used with new domain (Requirement 8).""" old_domain = "old-tenant.auth0.com" new_domain = "new-tenant.auth0.com" - + # Session from old domain session_data = StateData( user={"sub": "user123"}, @@ -2904,14 +2904,14 @@ async def test_domain_migration_sessions_isolated(): token_sets=[], internal={"sid": "123", "created_at": int(time.time())} ) - + mock_state_store = AsyncMock() mock_state_store.get.return_value = session_data - + # Domain resolver returns NEW domain (migration happened) async def domain_resolver(context): return new_domain - + client = ServerClient( domain=domain_resolver, client_id="test_client", @@ -2923,4 +2923,4 @@ async def domain_resolver(context): # Should reject old session user = await client.get_user(store_options={"request": {}}) - assert user is None \ No newline at end of file + assert user is None diff --git a/src/auth0_server_python/utils/helpers.py b/src/auth0_server_python/utils/helpers.py index 1a49835..05cb0f8 100644 --- a/src/auth0_server_python/utils/helpers.py +++ b/src/auth0_server_python/utils/helpers.py @@ -5,6 +5,7 @@ import time from typing import Any, Optional from urllib.parse import parse_qs, urlencode, urlparse + from auth0_server_python.auth_types import DomainResolverContext from auth0_server_python.error import DomainResolverError @@ -235,27 +236,27 @@ def create_logout_url(domain: str, client_id: str, return_to: Optional[str] = No def build_domain_resolver_context(store_options: Optional[dict[str, Any]]) -> 'DomainResolverContext': """ Build DomainResolverContext from store_options. - + Extracts request information in a framework-agnostic way using duck typing. - + Args: store_options: Dictionary containing 'request' and 'response' objects - + Returns: DomainResolverContext with extracted request data """ - + if not store_options: return DomainResolverContext() - + request = store_options.get('request') if not request: return DomainResolverContext() - + # Framework-agnostic extraction using duck typing request_url = str(request.url) if hasattr(request, 'url') else None request_headers = dict(request.headers) if hasattr(request, 'headers') else None - + return DomainResolverContext( request_url=request_url, request_headers=request_headers @@ -265,30 +266,30 @@ def build_domain_resolver_context(store_options: Optional[dict[str, Any]]) -> 'D def validate_resolved_domain_value(domain_value: Any) -> str: """ Validate the value returned by domain resolver. - + Args: domain_value: The value returned by the domain resolver - + Returns: The validated domain string - + Raises: DomainResolverError: If the returned value is invalid """ - + if domain_value is None: raise DomainResolverError( "Domain resolver returned None. Must return a valid domain string." ) - + if not isinstance(domain_value, str): raise DomainResolverError( f"Domain resolver must return a string. Got {type(domain_value).__name__} instead." ) - + if not domain_value.strip(): raise DomainResolverError( "Domain resolver returned an empty string. Must return a valid domain." ) - + return domain_value From 307000b8ae54220190f9b8df3f24ab15e961f1c7 Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Tue, 3 Feb 2026 00:03:18 +0530 Subject: [PATCH 04/16] test: improve cache verification in OIDC metadata and JWKS tests --- src/auth0_server_python/tests/test_server_client.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index 7a9665c..62532fb 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -2307,11 +2307,12 @@ async def mock_fetch(domain): result1 = await client._get_oidc_metadata_cached("tenant.auth0.com") assert result1 == mock_metadata assert fetch_count == 1 + first_fetch_count = fetch_count # Second call - should use cache result2 = await client._get_oidc_metadata_cached("tenant.auth0.com") assert result2 == mock_metadata - assert fetch_count == 1 # Should NOT increment + assert fetch_count == first_fetch_count # Should NOT increment # Verify cache contains data assert "tenant.auth0.com" in client._metadata_cache @@ -2394,11 +2395,12 @@ async def mock_fetch_jwks(uri): result1 = await client._get_jwks_cached("tenant.auth0.com", mock_metadata) assert result1 == mock_jwks assert fetch_count == 1 + first_fetch_count = fetch_count # Second call - should use cache result2 = await client._get_jwks_cached("tenant.auth0.com", mock_metadata) assert result2 == mock_jwks - assert fetch_count == 1 # Should NOT increment + assert fetch_count == first_fetch_count # Should NOT increment @pytest.mark.asyncio From 3f68c3c6cfdb1176c295ce6fe89f4e841f885494 Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Tue, 3 Feb 2026 23:24:48 +0530 Subject: [PATCH 05/16] refactor: rename cache size variable and reorganize test comments --- examples/{MCD.md => MultipleCustomDomains.md} | 0 .../auth_server/server_client.py | 6 +++--- .../tests/test_server_client.py | 17 ++++++++--------- 3 files changed, 11 insertions(+), 12 deletions(-) rename examples/{MCD.md => MultipleCustomDomains.md} (100%) diff --git a/examples/MCD.md b/examples/MultipleCustomDomains.md similarity index 100% rename from examples/MCD.md rename to examples/MultipleCustomDomains.md diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 56cd636..f3940ab 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -145,7 +145,7 @@ def __init__( self._metadata_cache = {} # {domain: {"data": {...}, "expires_at": timestamp}} self._jwks_cache = {} # {domain: {"data": {...}, "expires_at": timestamp}} self._cache_ttl = 3600 # 1 hour TTL - self._cache_max_size = 100 # Max 100 domains to prevent memory bloat + self._cache_max_entries = 100 # Max 100 domains to prevent memory bloat # ========================================== # Interactive Login Flow @@ -1090,7 +1090,7 @@ async def _get_oidc_metadata_cached(self, domain: str) -> dict: metadata = await self._fetch_oidc_metadata(domain) # Enforce cache size limit (FIFO eviction) - if len(self._metadata_cache) >= self._cache_max_size: + if len(self._metadata_cache) >= self._cache_max_entries: oldest_key = next(iter(self._metadata_cache)) del self._metadata_cache[oldest_key] @@ -1160,7 +1160,7 @@ async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: jwks = await self._fetch_jwks(jwks_uri) # Enforce cache size limit (FIFO eviction) - if len(self._jwks_cache) >= self._cache_max_size: + if len(self._jwks_cache) >= self._cache_max_entries: oldest_key = next(iter(self._jwks_cache)) del self._jwks_cache[oldest_key] diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index 62532fb..6f14b52 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -1978,7 +1978,7 @@ async def test_complete_connect_account_no_transactions(mocker): # ============================================================================= -# Requirement 1: Multiple Issuer Configuration Methods Tests +# MCD Tests : Multiple Issuer Configuration Methods Tests # ============================================================================= @pytest.mark.asyncio @@ -2049,7 +2049,7 @@ async def test_empty_domain_string(): # ============================================================================= -# Requirement 2: Domain Resolver Context Tests +# MCD Tests : Domain Resolver Context Tests # ============================================================================= @pytest.mark.asyncio @@ -2179,8 +2179,7 @@ async def domain_resolver(context): try: await client.start_interactive_login(store_options=None) except Exception: # noqa: S110 - pass # Intentionally ignore - testing context only - + pass # We only care about context being passed assert received_context is not None assert received_context.request_url is None assert received_context.request_headers is None @@ -2206,7 +2205,7 @@ async def bad_resolver(context): # ============================================================================= -# Requirement 3: OIDC Metadata and JWKS Fetching Tests +# OIDC Metadata and JWKS Fetching Tests # ============================================================================= @@ -2416,7 +2415,7 @@ async def test_jwks_cache_size_limit(): ) # Set small cache size for testing - client._cache_max_size = 3 + client._cache_max_entries = 3 mock_jwks = {"keys": [{"kty": "RSA"}]} @@ -2487,7 +2486,7 @@ async def test_metadata_cache_size_limit(): state_store=AsyncMock() ) - client._cache_max_size = 2 + client._cache_max_entries = 2 async def mock_fetch(domain): return {"issuer": f"https://{domain}/"} @@ -2509,7 +2508,7 @@ async def mock_fetch(domain): # ============================================================================= -# Requirement 4: Issuer Validation Tests +# Issuer Validation Tests # ============================================================================= @@ -2668,7 +2667,7 @@ async def test_normalize_domain_handles_different_schemes(): # ============================================================================= -# Requirements 5-8: Domain-specific Session Management Tests +# MCD Tests : Domain-specific Session Management Tests # ============================================================================= From 2e0ae17b711fb7dec48f773a51ad69639148cf95 Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Tue, 3 Feb 2026 23:27:58 +0530 Subject: [PATCH 06/16] chore: add cryptography package to Snyk license ignore list --- .snyk | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.snyk b/.snyk index 4eaa56f..7d0fc1c 100644 --- a/.snyk +++ b/.snyk @@ -21,4 +21,8 @@ ignore: - '*': reason: "Accepting the Unknown license for now" expires: "2030-12-31T23:59:59Z" + "snyk:lic:pip:cryptography:Unknown": + - '*': + reason: "Accepting the Unknown license for now" + expires: "2030-12-31T23:59:59Z" patch: {} From daa6d350f10b3a75cc3ce2c5d9c781b00ebf49d0 Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Mon, 9 Feb 2026 02:15:36 +0530 Subject: [PATCH 07/16] Implement normalized issuer validation and add related tests --- .gitignore | 3 +- .../auth_server/server_client.py | 56 ++++++++++++++++--- .../tests/test_server_client.py | 54 +++++++++++++++++- 3 files changed, 101 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 3d5c66a..9814158 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,5 @@ test-script.py coverage.xml examples/mcd-poc IMPLEMENTATION_NOTES.md -examples/MCD_DEVELOPER_GUIDE.md \ No newline at end of file +examples/MCD_DEVELOPER_GUIDE.md +DESIGN_DOC_REVISED.md \ No newline at end of file diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index ddc4543..627549e 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -169,6 +169,31 @@ def _normalize_domain(self, domain: str) -> str: else: return f'https://{domain}' + def _normalize_issuer(self, issuer: str) -> str: + """ + Normalize issuer URL for comparison. + + Args: + issuer: The issuer URL to normalize + + Returns: + Normalized issuer URL (lowercase) + """ + if not issuer: + return issuer + + # Lowercase first for case-insensitive comparison and scheme detection + issuer = issuer.lower() + + # Ensure https:// prefix + if issuer.startswith('http://'): + issuer = issuer.replace('http://', 'https://', 1) + elif not issuer.startswith('https://'): + issuer = f'https://{issuer}' + + # Remove trailing slash + return issuer.rstrip('/') + async def _fetch_oidc_metadata(self, domain: str) -> dict: """Fetch OIDC metadata from domain.""" normalized_domain = self._normalize_domain(domain) @@ -528,14 +553,24 @@ async def complete_interactive_login( f"No matching key found in JWKS for kid: {kid}" ) + # Decode with signature claims = jwt.decode( id_token, signing_key.key, algorithms=["RS256"], audience=self._client_id, - issuer=origin_issuer, - options={"verify_signature": True} + options={"verify_signature": True, "verify_iss": False} ) + + # Custom normalized issuer validation + token_issuer = claims.get("iss", "") + if self._normalize_issuer(token_issuer) != self._normalize_issuer(origin_issuer): + raise ApiError( + "invalid_issuer", + f"ID token issuer mismatch. Token issuer: {token_issuer}, Expected: {origin_issuer}. " + f"Ensure your Auth0 domain is configured correctly." + ) + user_claims = UserClaims.parse_obj(claims) except jwt.InvalidSignatureError as e: raise ApiError( @@ -549,12 +584,6 @@ async def complete_interactive_login( f"ID token audience mismatch. Expected: {self._client_id}. Ensure your client_id is configured correctly: {str(e)}", e ) - except jwt.InvalidIssuerError as e: - raise ApiError( - "invalid_issuer", - f"ID token issuer mismatch. Expected: {origin_issuer}. Ensure your Auth0 domain is configured correctly: {str(e)}", - e - ) except jwt.ExpiredSignatureError as e: raise ApiError( "token_expired", @@ -767,8 +796,17 @@ async def handle_backchannel_logout( logout_token, signing_key.key, algorithms=["RS256"], - options={"verify_signature": True} + options={"verify_signature": True, "verify_iss": False} ) + + # Normalized issuer validation + token_issuer = claims.get("iss", "") + expected_issuer = self._normalize_domain(self._domain) + if self._normalize_issuer(token_issuer) != self._normalize_issuer(expected_issuer): + raise BackchannelLogoutError( + f"Logout token issuer mismatch. Token issuer: {token_issuer}, Expected: {expected_issuer}. " + f"Ensure your Auth0 domain is configured correctly." + ) except jwt.InvalidSignatureError as e: raise BackchannelLogoutError( f"Logout token signature verification failed: {str(e)}" diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index e296cc5..b41cb6a 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -903,6 +903,7 @@ async def test_handle_backchannel_logout_ok(mocker): mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) mocker.patch("jwt.decode", return_value={ "events": {"http://schemas.openid.net/event/backchannel-logout": {}}, + "iss": "https://auth0.local", "sub": "user_sub", "sid": "session_id_123" }) @@ -3292,8 +3293,14 @@ async def test_complete_login_issuer_mismatch_raises_error(mocker): mock_signing_key.key = "mock_pem_key" mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) - # Mock jwt.decode to raise InvalidIssuerError - mocker.patch("jwt.decode", side_effect=jwt.InvalidIssuerError("Invalid issuer")) + # Mock jwt.decode to return claims with a WRONG issuer + # Our custom normalized issuer validation should catch this mismatch + mocker.patch("jwt.decode", return_value={ + "sub": "user123", + "iss": "https://wrong-issuer.auth0.com/", # Different from expected: https://tenant.auth0.com/ + "aud": "test_client", + "exp": 9999999999 + }) # Should raise ApiError with invalid_issuer code with pytest.raises(ApiError) as exc_info: @@ -3328,6 +3335,49 @@ async def test_normalize_domain_handles_different_schemes(): assert client._normalize_domain("https://auth0.com/") == "https://auth0.com/" +@pytest.mark.asyncio +async def test_normalize_issuer_handles_edge_cases(): + """Test that _normalize_issuer handles edge cases for robust issuer comparison. + + This test documents the edge cases that could cause issuer validation failures + with PyJWT's strict string comparison: + - Trailing slash differences + - Case sensitivity + - HTTP vs HTTPS schemes + - Missing scheme + """ + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Test trailing slash normalization + assert client._normalize_issuer("https://auth0.com/") == "https://auth0.com" + assert client._normalize_issuer("https://auth0.com") == "https://auth0.com" + assert client._normalize_issuer("https://auth0.com/") == client._normalize_issuer("https://auth0.com") + + # Test case insensitivity + assert client._normalize_issuer("HTTPS://AUTH0.COM/") == "https://auth0.com" + assert client._normalize_issuer("Https://Auth0.Com") == "https://auth0.com" + assert client._normalize_issuer("HTTPS://AUTH0.COM/") == client._normalize_issuer("https://auth0.com") + + # Test HTTP to HTTPS conversion + assert client._normalize_issuer("http://auth0.com") == "https://auth0.com" + assert client._normalize_issuer("HTTP://AUTH0.COM/") == "https://auth0.com" + + # Test missing scheme + assert client._normalize_issuer("auth0.com") == "https://auth0.com" + assert client._normalize_issuer("AUTH0.COM/") == "https://auth0.com" + + # Test empty/None handling + assert client._normalize_issuer("") == "" + assert client._normalize_issuer(None) is None + + # ============================================================================= # MCD Tests : Multiple Issuer Configuration Methods Tests # ============================================================================= From 74329ab92b4d23a27a6601b7124f5d47b3b70453 Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Mon, 9 Feb 2026 14:14:33 +0530 Subject: [PATCH 08/16] refactor: remove unused jwt import from test_server_client.py --- src/auth0_server_python/tests/test_server_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index b41cb6a..dd29d34 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -3,7 +3,6 @@ from unittest.mock import ANY, AsyncMock, MagicMock, patch from urllib.parse import parse_qs, urlparse -import jwt import pytest from auth0_server_python.auth_server.my_account_client import MyAccountClient from auth0_server_python.auth_server.server_client import ServerClient From 073c38beaead5f0bbf5b5f547644c72510858b40 Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Wed, 11 Feb 2026 19:29:52 +0530 Subject: [PATCH 09/16] refactor: remove requirement comments for OIDC metadata and JWKS caching --- src/auth0_server_python/auth_server/server_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 627549e..0dc64a9 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -151,7 +151,7 @@ def __init__( self._my_account_client = MyAccountClient(domain=domain) - # Cache for OIDC metadata and JWKS (Requirement 3: MCD Support) + # Cache for OIDC metadata and JWKS self._metadata_cache = {} # {domain: {"data": {...}, "expires_at": timestamp}} self._jwks_cache = {} # {domain: {"data": {...}, "expires_at": timestamp}} self._cache_ttl = 3600 # 1 hour TTL @@ -531,7 +531,7 @@ async def complete_interactive_login( if user_info: user_claims = UserClaims.parse_obj(user_info) elif id_token: - # Fetch JWKS for signature verification (Requirement 3) + # Fetch JWKS for signature verification jwks = await self._get_jwks_cached(origin_domain, metadata) # Decode and verify ID token with signature verification enabled @@ -771,7 +771,7 @@ async def handle_backchannel_logout( raise BackchannelLogoutError("Missing logout token") try: - # Fetch JWKS for signature verification (Requirement 3) + # Fetch JWKS for signature verification jwks = await self._get_jwks_cached(self._domain) # Decode and verify logout token with signature verification enabled From 75a8531e9a79db6118025b719fa63ba2b63d38a6 Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Sat, 14 Feb 2026 01:26:35 +0530 Subject: [PATCH 10/16] feat: improve caching mechanism and add more examples to the example doc --- .gitignore | 6 +- README.md | 28 ++++ examples/MultipleCustomDomains.md | 151 ++++++++++++++++-- .../auth_server/server_client.py | 97 +++++++---- .../tests/test_server_client.py | 22 +-- 5 files changed, 246 insertions(+), 58 deletions(-) diff --git a/.gitignore b/.gitignore index 9814158..c30eceb 100644 --- a/.gitignore +++ b/.gitignore @@ -23,8 +23,4 @@ setup.py test.py test-script.py .coverage -coverage.xml -examples/mcd-poc -IMPLEMENTATION_NOTES.md -examples/MCD_DEVELOPER_GUIDE.md -DESIGN_DOC_REVISED.md \ No newline at end of file +coverage.xml \ No newline at end of file diff --git a/README.md b/README.md index f81b537..ff5e46c 100644 --- a/README.md +++ b/README.md @@ -145,6 +145,34 @@ print(response.access_token) For more details and examples, see [examples/CustomTokenExchange.md](examples/CustomTokenExchange.md). +### 5. Multiple Custom Domains (MCD) + +For applications that use multiple custom domains on the same Auth0 tenant, pass a domain resolver function instead of a static domain string: + +```python +from auth0_server_python.auth_server.server_client import ServerClient +from auth0_server_python.auth_types import DomainResolverContext + +async def domain_resolver(context: DomainResolverContext) -> str: + host = context.request_headers.get('host', '').split(':')[0] + domain_map = { + "acme.yourapp.com": "acme.auth0.com", + "globex.yourapp.com": "globex.auth0.com", + } + return domain_map.get(host, "default.auth0.com") + +auth0 = ServerClient( + domain=domain_resolver, # Callable enables MCD mode + client_id='', + client_secret='', + secret='', +) +``` + +The SDK handles per-domain OIDC discovery, JWKS fetching, issuer validation, and session isolation automatically. Static string domains continue to work unchanged. + +For more details and examples, see [examples/MultipleCustomDomains.md](examples/MultipleCustomDomains.md). + ## Feedback ### Contributing diff --git a/examples/MultipleCustomDomains.md b/examples/MultipleCustomDomains.md index 85d188d..a33e896 100644 --- a/examples/MultipleCustomDomains.md +++ b/examples/MultipleCustomDomains.md @@ -4,17 +4,15 @@ This guide explains how to implement Multiple Custom Domain (MCD) support using ## What is MCD? -Multiple Custom Domains (MCD) allows your application to serve different organizations or tenants from different hostnames, each mapping to a different Auth0 tenant/domain. +Multiple Custom Domains (MCD) allows your application to use multiple custom domains configured on the same Auth0 tenant, each serving a different branded experience from a single application codebase. **Example:** -- `https://acme.yourapp.com` → Auth0 tenant: `acme.auth0.com` -- `https://globex.yourapp.com` → Auth0 tenant: `globex.auth0.com` - -Each tenant gets its own branded login experience while using a single application codebase. +- `https://acme.yourapp.com` → Custom domain: `auth.acme.com` +- `https://globex.yourapp.com` → Custom domain: `auth.globex.com` ## Configuration Methods -### Method 1: Static Domain (Single Tenant) +### Method 1: Static Domain (Single Domain) For applications with a single Auth0 domain: @@ -104,6 +102,141 @@ async def domain_resolver(context: DomainResolverContext) -> str: return DOMAIN_MAP.get(hostname, DEFAULT_DOMAIN) ``` +## Resolver Patterns + +### Database Lookup (SQLAlchemy) + +Resolve domains from a database using async SQLAlchemy: + +```python +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy import text + +engine = create_async_engine("postgresql+asyncpg://user:pass@localhost/mydb") + +async def domain_resolver(context: DomainResolverContext) -> str: + host = context.request_headers.get("host", "").split(":")[0] + tenant = host.split(".")[0] + + async with AsyncSession(engine) as session: + result = await session.execute( + text("SELECT auth0_domain FROM tenants WHERE slug = :slug"), + {"slug": tenant} + ) + row = result.fetchone() + if row: + return row[0] + + return DEFAULT_DOMAIN +``` + +### Database Lookup with In-Memory Cache + +Avoid hitting the database on every request by caching the tenant map: + +```python +import time + +_tenant_cache = {} +_cache_ttl = 300 # 5 minutes + +async def domain_resolver(context: DomainResolverContext) -> str: + host = context.request_headers.get("host", "").split(":")[0] + tenant = host.split(".")[0] + + now = time.time() + cached = _tenant_cache.get(tenant) + if cached and cached["expires_at"] > now: + return cached["domain"] + + # Cache miss - fetch from database + async with AsyncSession(engine) as session: + result = await session.execute( + text("SELECT auth0_domain FROM tenants WHERE slug = :slug"), + {"slug": tenant} + ) + row = result.fetchone() + domain = row[0] if row else DEFAULT_DOMAIN + + _tenant_cache[tenant] = {"domain": domain, "expires_at": now + _cache_ttl} + return domain +``` + +### Redis Lookup + +Use Redis for shared tenant configuration across multiple app instances: + +```python +import redis.asyncio as redis + +redis_client = redis.Redis(host="localhost", port=6379, decode_responses=True) + +async def domain_resolver(context: DomainResolverContext) -> str: + host = context.request_headers.get("host", "").split(":")[0] + tenant = host.split(".")[0] + + # Key format: "tenant:acme" -> "acme.auth0.com" + domain = await redis_client.get(f"tenant:{tenant}") + if domain: + return domain + + return DEFAULT_DOMAIN +``` + +### Redis with Hash Map + +Store all tenant mappings in a single Redis hash: + +```python +async def domain_resolver(context: DomainResolverContext) -> str: + host = context.request_headers.get("host", "").split(":")[0] + tenant = host.split(".")[0] + + # All tenants in one hash: HGET tenant_domains acme -> "acme.auth0.com" + domain = await redis_client.hget("tenant_domains", tenant) + if domain: + return domain + + return DEFAULT_DOMAIN +``` + +### Path-Based Resolution + +Resolve tenant from URL path instead of hostname: + +```python +from urllib.parse import urlparse + +async def domain_resolver(context: DomainResolverContext) -> str: + if context.request_url: + path = urlparse(context.request_url).path + # URL pattern: /tenant/acme/auth/login + parts = path.strip("/").split("/") + if len(parts) >= 2 and parts[0] == "tenant": + tenant = parts[1] + return DOMAIN_MAP.get(tenant, DEFAULT_DOMAIN) + + return DEFAULT_DOMAIN +``` + +### Custom Header Resolution + +Use a custom header set by your API gateway or load balancer: + +```python +async def domain_resolver(context: DomainResolverContext) -> str: + headers = context.request_headers or {} + + # API gateway sets X-Tenant-Id header + tenant_id = headers.get("x-tenant-id") + if tenant_id: + return DOMAIN_MAP.get(tenant_id, DEFAULT_DOMAIN) + + # Fallback to host header + host = headers.get("host", "").split(":")[0] + return DOMAIN_MAP.get(host, DEFAULT_DOMAIN) +``` + ## Error Handling ### DomainResolverError @@ -116,13 +249,13 @@ from auth0_server_python.error import DomainResolverError async def domain_resolver(context: DomainResolverContext) -> str: try: domain = lookup_domain_from_db(context) - + if not domain: # Return default instead of None return DEFAULT_DOMAIN - + return domain # Must be a non-empty string - + except Exception as e: # Log error and return default logger.error(f"Domain resolution failed: {e}") diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 0dc64a9..32be1da 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -6,6 +6,7 @@ import asyncio import json import time +from collections import OrderedDict from typing import Any, Callable, Generic, Optional, TypeVar, Union from urllib.parse import parse_qs, urlencode, urlparse, urlunparse @@ -151,11 +152,10 @@ def __init__( self._my_account_client = MyAccountClient(domain=domain) - # Cache for OIDC metadata and JWKS - self._metadata_cache = {} # {domain: {"data": {...}, "expires_at": timestamp}} - self._jwks_cache = {} # {domain: {"data": {...}, "expires_at": timestamp}} - self._cache_ttl = 3600 # 1 hour TTL - self._cache_max_entries = 100 # Max 100 domains to prevent memory bloat + # Unified cache for OIDC metadata and JWKS per domain (LRU eviction + TTL) + self._discovery_cache: OrderedDict[str, dict] = OrderedDict() + self._cache_ttl = 600 # 10 mins. TTL + self._cache_max_entries = 100 # Max 100 domains def _normalize_domain(self, domain: str) -> str: """ @@ -203,9 +203,24 @@ async def _fetch_oidc_metadata(self, domain: str) -> dict: response.raise_for_status() return response.json() + def _purge_expired_cache_entries(self): + """Purge all expired entries from the discovery cache.""" + now = time.time() + expired = [k for k, v in self._discovery_cache.items() if v["expires_at"] <= now] + for k in expired: + del self._discovery_cache[k] + + def _ensure_cache_capacity(self): + """Evict least recently used entry if cache is at capacity.""" + if len(self._discovery_cache) >= self._cache_max_entries: + self._discovery_cache.popitem(last=False) + async def _get_oidc_metadata_cached(self, domain: str) -> dict: """ - Get OIDC metadata with caching. + Get OIDC metadata with caching (LRU eviction + TTL). + + Uses a unified cache shared with JWKS when metadata expires, + the corresponding JWKS is also invalidated. Args: domain: Auth0 domain @@ -216,22 +231,25 @@ async def _get_oidc_metadata_cached(self, domain: str) -> dict: now = time.time() # Check cache - if domain in self._metadata_cache: - cached = self._metadata_cache[domain] + if domain in self._discovery_cache: + cached = self._discovery_cache[domain] if cached["expires_at"] > now: - return cached["data"] + self._discovery_cache.move_to_end(domain) + return cached["metadata"] + # Expired — remove entire entry (metadata + jwks) + del self._discovery_cache[domain] # Cache miss/expired - fetch fresh metadata = await self._fetch_oidc_metadata(domain) - # Enforce cache size limit (FIFO eviction) - if len(self._metadata_cache) >= self._cache_max_entries: - oldest_key = next(iter(self._metadata_cache)) - del self._metadata_cache[oldest_key] + # Purge expired entries and ensure capacity + self._purge_expired_cache_entries() + self._ensure_cache_capacity() - # Store in cache - self._metadata_cache[domain] = { - "data": metadata, + # Store in cache with jwks=None (lazily populated when needed) + self._discovery_cache[domain] = { + "metadata": metadata, + "jwks": None, "expires_at": now + self._cache_ttl } @@ -260,7 +278,10 @@ async def _fetch_jwks(self, jwks_uri: str) -> dict: async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: """ - Get JWKS with caching usingOIDC discovery. + Get JWKS with caching using OIDC discovery (LRU eviction + TTL). + + Uses a unified cache shared with metadata — JWKS is lazily populated + on first access and invalidated when the cache entry expires. Args: domain: Auth0 domain @@ -274,13 +295,20 @@ async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: """ now = time.time() - # Check cache - if domain in self._jwks_cache: - cached = self._jwks_cache[domain] + # Check cache — entry exists, not expired, and jwks already fetched + if domain in self._discovery_cache: + cached = self._discovery_cache[domain] if cached["expires_at"] > now: - return cached["data"] - - # Get jwks_uri from OIDC metadata + if cached["jwks"] is not None: + self._discovery_cache.move_to_end(domain) + return cached["jwks"] + # Entry valid but jwks not yet fetched — use cached metadata + metadata = cached["metadata"] + else: + # Expired — remove entire entry + del self._discovery_cache[domain] + + # Get metadata if not available from cache or parameter if not metadata: metadata = await self._get_oidc_metadata_cached(domain) @@ -294,16 +322,19 @@ async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: # Fetch JWKS jwks = await self._fetch_jwks(jwks_uri) - # Enforce cache size limit (FIFO eviction) - if len(self._jwks_cache) >= self._cache_max_entries: - oldest_key = next(iter(self._jwks_cache)) - del self._jwks_cache[oldest_key] - - # Store in cache - self._jwks_cache[domain] = { - "data": jwks, - "expires_at": now + self._cache_ttl - } + # Update existing cache entry with jwks (entry created by _get_oidc_metadata_cached) + if domain in self._discovery_cache: + self._discovery_cache[domain]["jwks"] = jwks + self._discovery_cache.move_to_end(domain) + else: + # Edge case: entry was evicted between metadata and jwks fetch + self._purge_expired_cache_entries() + self._ensure_cache_capacity() + self._discovery_cache[domain] = { + "metadata": metadata, + "jwks": jwks, + "expires_at": now + self._cache_ttl + } return jwks diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index dd29d34..f63c232 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -2976,8 +2976,8 @@ async def mock_fetch(domain): assert fetch_count == first_fetch_count # Should NOT increment # Verify cache contains data - assert "tenant.auth0.com" in client._metadata_cache - assert client._metadata_cache["tenant.auth0.com"]["data"] == mock_metadata + assert "tenant.auth0.com" in client._discovery_cache + assert client._discovery_cache["tenant.auth0.com"]["metadata"] == mock_metadata @pytest.mark.asyncio @@ -3096,15 +3096,15 @@ async def mock_fetch_jwks(uri): await client._get_jwks_cached("domain2.auth0.com") await client._get_jwks_cached("domain3.auth0.com") - assert len(client._jwks_cache) == 3 - assert "domain1.auth0.com" in client._jwks_cache + assert len(client._discovery_cache) == 3 + assert "domain1.auth0.com" in client._discovery_cache # Add one more - should evict oldest (domain1) await client._get_jwks_cached("domain4.auth0.com") - assert len(client._jwks_cache) == 3 - assert "domain1.auth0.com" not in client._jwks_cache # Evicted - assert "domain4.auth0.com" in client._jwks_cache + assert len(client._discovery_cache) == 3 + assert "domain1.auth0.com" not in client._discovery_cache # Evicted + assert "domain4.auth0.com" in client._discovery_cache @pytest.mark.asyncio @@ -3159,14 +3159,14 @@ async def mock_fetch(domain): await client._get_oidc_metadata_cached("domain1.auth0.com") await client._get_oidc_metadata_cached("domain2.auth0.com") - assert len(client._metadata_cache) == 2 + assert len(client._discovery_cache) == 2 # Add third - should evict first await client._get_oidc_metadata_cached("domain3.auth0.com") - assert len(client._metadata_cache) == 2 - assert "domain1.auth0.com" not in client._metadata_cache - assert "domain3.auth0.com" in client._metadata_cache + assert len(client._discovery_cache) == 2 + assert "domain1.auth0.com" not in client._discovery_cache + assert "domain3.auth0.com" in client._discovery_cache # ============================================================================= From 0eb64b44ed879e8db84636b54d4198261a3c6beb Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Wed, 25 Feb 2026 22:37:13 +0530 Subject: [PATCH 11/16] feat: enhance Multiple Custom Domains (MCD) support with dynamic domain resolution and improved session handling --- examples/MultipleCustomDomains.md | 78 +++++- .../auth_server/server_client.py | 235 ++++++++---------- .../tests/test_server_client.py | 44 ++-- 3 files changed, 206 insertions(+), 151 deletions(-) diff --git a/examples/MultipleCustomDomains.md b/examples/MultipleCustomDomains.md index a33e896..5afafda 100644 --- a/examples/MultipleCustomDomains.md +++ b/examples/MultipleCustomDomains.md @@ -1,15 +1,13 @@ -# Multiple Custom Domains (MCD) Guide +# Multiple Custom Domains (MCD) -This guide explains how to implement Multiple Custom Domain (MCD) support using the Auth0 Python SDKs. - -## What is MCD? - -Multiple Custom Domains (MCD) allows your application to use multiple custom domains configured on the same Auth0 tenant, each serving a different branded experience from a single application codebase. +MCD lets you resolve the Auth0 domain per request while keeping a single `ServerClient` instance. This is useful when your application uses multiple custom domains configured on the same Auth0 tenant. **Example:** - `https://acme.yourapp.com` → Custom domain: `auth.acme.com` - `https://globex.yourapp.com` → Custom domain: `auth.globex.com` +MCD is enabled by providing a **domain resolver function** instead of a static domain string. + ## Configuration Methods ### Method 1: Static Domain (Single Domain) @@ -102,6 +100,8 @@ async def domain_resolver(context: DomainResolverContext) -> str: return DOMAIN_MAP.get(hostname, DEFAULT_DOMAIN) ``` +> **Note:** In resolver mode, the SDK builds the `redirect_uri` dynamically from the resolved domain. You do not need to set it per request. If you override `redirect_uri` in `authorization_params`, the SDK uses your value as-is. + ## Resolver Patterns ### Database Lookup (SQLAlchemy) @@ -269,4 +269,68 @@ async def domain_resolver(context: DomainResolverContext) -> str: **Exceptions raised by your resolver:** - Automatically wrapped in `DomainResolverError` -- Original exception accessible via `.original_error` \ No newline at end of file +- Original exception accessible via `.original_error` + +## Session Behavior in Resolver Mode + +In resolver mode, sessions are bound to the domain that created them. On each request, the SDK compares the session's stored domain against the current resolved domain: + +- `get_user()` and `get_session()` return `None` on domain mismatch. +- `get_access_token()` raises `AccessTokenError` on domain mismatch. +- Token refresh uses the session's stored domain, not the current request domain. + +> **Warning:** If you switch from a static domain string to a resolver function, existing sessions that do not include a stored domain continue to work — the SDK treats the absent domain field as valid. New sessions will store the resolved domain automatically. Once old sessions expire, all sessions will be domain-aware. + +## Discovery Cache + +The SDK caches OIDC metadata and JWKS per domain in memory (LRU eviction, 600-second TTL, up to 100 domains). This avoids repeated network calls when serving multiple domains. The cache is shared across all requests to the same `ServerClient` instance. + +## Security Best Practices + +### Use an Allowlist in Your Resolver + +The SDK passes request headers to your domain resolver via `DomainResolverContext`. These headers come directly from the HTTP request and can be spoofed by an attacker (e.g., `Host: evil.com` or `X-Forwarded-Host: evil.com`). + +The SDK uses the resolved domain to fetch OIDC metadata and JWKS. If an attacker can influence the resolved domain, they could point the SDK at an OIDC provider they control. + +**Always use a mapping or allowlist — never construct domains from raw header values:** + +```python +# Safe: allowlist lookup — unknown hosts fall back to default +DOMAIN_MAP = { + "acme.myapp.com": "auth.acme.com", + "globex.myapp.com": "auth.globex.com", +} + +async def domain_resolver(context: DomainResolverContext) -> str: + host = context.request_headers.get("host", "").split(":")[0] + return DOMAIN_MAP.get(host, DEFAULT_DOMAIN) +``` + +```python +# Risky: constructs domain from raw input — attacker can influence resolved domain +async def domain_resolver(context: DomainResolverContext) -> str: + host = context.request_headers.get("host", "").split(":")[0] + tenant = host.split(".")[0] + return f"{tenant}.auth0.com" # attacker sends Host: evil.myapp.com → evil.auth0.com +``` + +### Trust Forwarded Headers Only Behind a Proxy + +If your application is directly exposed to the internet (not behind a reverse proxy), do not trust `x-forwarded-host` or `x-forwarded-proto` — any client can set these headers. + +Only use forwarded headers when your application runs behind a trusted reverse proxy (nginx, AWS ALB, Cloudflare, etc.) that sets these headers and strips any client-provided values. + +```python +# Only trust x-forwarded-host if behind a trusted proxy +async def domain_resolver(context: DomainResolverContext) -> str: + headers = context.request_headers or {} + + if BEHIND_TRUSTED_PROXY: + host = headers.get("x-forwarded-host") or headers.get("host", "") + else: + host = headers.get("host", "") + + host = host.split(":")[0] + return DOMAIN_MAP.get(host, DEFAULT_DOMAIN) +``` \ No newline at end of file diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 32be1da..da1a302 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -60,6 +60,8 @@ # Generic type for store options TStoreOptions = TypeVar('TStoreOptions') +# redirect_uri is intentionally excluded — in MCD mode it is built +# dynamically from the resolved domain at login time. INTERNAL_AUTHORIZE_PARAMS = ["client_id", "response_type", "code_challenge", "code_challenge_method", "state", "nonce", "scope"] @@ -162,6 +164,7 @@ def _normalize_domain(self, domain: str) -> str: Normalize domain for comparison and URL construction. Handles cases with/without https:// scheme. """ + domain = domain.lower() if domain.startswith('https://'): return domain elif domain.startswith('http://'): @@ -194,6 +197,61 @@ def _normalize_issuer(self, issuer: str) -> str: # Remove trailing slash return issuer.rstrip('/') + async def _resolve_current_domain(self, store_options=None) -> str: + """Resolve domain from resolver function or return static domain.""" + if self._domain_resolver: + context = build_domain_resolver_context(store_options) + try: + resolved = await self._domain_resolver(context) + return validate_resolved_domain_value(resolved) + except DomainResolverError: + raise + except Exception as e: + raise DomainResolverError( + f"Domain resolver function raised an exception: {str(e)}", + original_error=e + ) + return self._domain + + async def _verify_and_decode_jwt( + self, token: str, jwks: dict, audience: str | None = None + ) -> dict: + """ + Find signing key in JWKS by kid and decode JWT. + + Verifies signature but disables built-in issuer validation + (callers perform normalized issuer comparison separately). + + Args: + token: The JWT to verify and decode + jwks: The JWKS dict containing signing keys + audience: Optional expected audience claim + + Returns: + Decoded claims dictionary + + Raises: + ValueError: If no matching key found in JWKS for the token's kid + """ + unverified_header = jwt.get_unverified_header(token) + kid = unverified_header.get("kid") + + signing_key = None + for key in jwks.get("keys", []): + if key.get("kid") == kid: + signing_key = jwt.PyJWK.from_dict(key) + break + + if not signing_key: + raise ValueError(f"No matching key found in JWKS for kid: {kid}") + + decode_options = {"verify_signature": True, "verify_iss": False} + kwargs = {"algorithms": ["RS256"], "options": decode_options} + if audience: + kwargs["audience"] = audience + + return jwt.decode(token, signing_key.key, **kwargs) + async def _fetch_oidc_metadata(self, domain: str) -> dict: """Fetch OIDC metadata from domain.""" normalized_domain = self._normalize_domain(domain) @@ -362,21 +420,7 @@ async def start_interactive_login( options = options or StartInteractiveLoginOptions() # Resolve domain (static or dynamic) - if self._domain_resolver: - # Build context and call developer's resolver - context = build_domain_resolver_context(store_options) - try: - resolved = await self._domain_resolver(context) - origin_domain = validate_resolved_domain_value(resolved) - except DomainResolverError: - raise - except Exception as e: - raise DomainResolverError( - f"Domain resolver function raised an exception: {str(e)}", - original_error=e - ) - else: - origin_domain = self._domain + origin_domain = await self._resolve_current_domain(store_options) # Fetch OIDC metadata from resolved domain try: @@ -552,7 +596,6 @@ async def complete_interactive_login( # Raise a custom error (or handle it as appropriate) raise ApiError( "token_error", f"Token exchange failed: {str(e)}", e) - print(f"Token Response : {token_response}") # Use the userinfo field from the token_response for user claims user_info = token_response.get("userinfo") @@ -567,30 +610,8 @@ async def complete_interactive_login( # Decode and verify ID token with signature verification enabled try: - # Get the signing key from JWKS - unverified_header = jwt.get_unverified_header(id_token) - kid = unverified_header.get("kid") - - # Find the key with matching kid - signing_key = None - for key in jwks.get("keys", []): - if key.get("kid") == kid: - signing_key = jwt.PyJWK.from_dict(key) - break - - if not signing_key: - raise ApiError( - "jwks_key_not_found", - f"No matching key found in JWKS for kid: {kid}" - ) - - # Decode with signature - claims = jwt.decode( - id_token, - signing_key.key, - algorithms=["RS256"], - audience=self._client_id, - options={"verify_signature": True, "verify_iss": False} + claims = await self._verify_and_decode_jwt( + id_token, jwks, audience=self._client_id ) # Custom normalized issuer validation @@ -603,6 +624,8 @@ async def complete_interactive_login( ) user_claims = UserClaims.parse_obj(claims) + except ValueError as e: + raise ApiError("jwks_key_not_found", str(e)) except jwt.InvalidSignatureError as e: raise ApiError( "invalid_signature", @@ -693,17 +716,7 @@ async def get_user(self, store_options: Optional[dict[str, Any]] = None) -> Opti if state_data: # Validate session domain matches current request domain if self._domain_resolver: - context = build_domain_resolver_context(store_options) - try: - resolved = await self._domain_resolver(context) - current_domain = validate_resolved_domain_value(resolved) - except DomainResolverError: - raise - except Exception as e: - raise DomainResolverError( - f"Domain resolver function raised an exception: {str(e)}", - original_error=e - ) + current_domain = await self._resolve_current_domain(store_options) session_domain = getattr(state_data, 'domain', None) if session_domain and session_domain != current_domain: @@ -730,17 +743,7 @@ async def get_session(self, store_options: Optional[dict[str, Any]] = None) -> O if state_data: # Validate session domain matches current request domain if self._domain_resolver: - context = build_domain_resolver_context(store_options) - try: - resolved = await self._domain_resolver(context) - current_domain = validate_resolved_domain_value(resolved) - except DomainResolverError: - raise - except Exception as e: - raise DomainResolverError( - f"Domain resolver function raised an exception: {str(e)}", - original_error=e - ) + current_domain = await self._resolve_current_domain(store_options) session_domain = getattr(state_data, 'domain', None) if session_domain and session_domain != current_domain: @@ -765,20 +768,7 @@ async def logout( await self._state_store.delete(self._state_identifier, store_options) # Resolve domain dynamically for MCD support - if self._domain_resolver: - context = build_domain_resolver_context(store_options) - try: - resolved = await self._domain_resolver(context) - domain = validate_resolved_domain_value(resolved) - except DomainResolverError: - raise - except Exception as e: - raise DomainResolverError( - f"Domain resolver function raised an exception: {str(e)}", - original_error=e - ) - else: - domain = self._domain + domain = await self._resolve_current_domain(store_options) # Use the URL helper to create the logout URL. logout_url = URL.create_logout_url( @@ -802,42 +792,45 @@ async def handle_backchannel_logout( raise BackchannelLogoutError("Missing logout token") try: - # Fetch JWKS for signature verification - jwks = await self._get_jwks_cached(self._domain) - - # Decode and verify logout token with signature verification enabled - try: - # Get the signing key from JWKS - unverified_header = jwt.get_unverified_header(logout_token) - kid = unverified_header.get("kid") - - # Find the key with matching kid - signing_key = None - for key in jwks.get("keys", []): - if key.get("kid") == kid: - signing_key = jwt.PyJWK.from_dict(key) - break - - if not signing_key: + # Determine domain for JWKS fetch and issuer validation + if self._domain_resolver: + try: + unverified = jwt.decode( + logout_token, algorithms=["RS256"], + options={"verify_signature": False} + ) + token_issuer = unverified.get("iss", "") + parsed = urlparse(token_issuer) + domain = parsed.hostname + if not domain: + raise BackchannelLogoutError( + "Cannot determine domain: logout token has no valid issuer" + ) + except BackchannelLogoutError: + raise + except Exception as e: raise BackchannelLogoutError( - f"No matching key found in JWKS for kid: {kid}" + f"Failed to extract domain from logout token: {str(e)}" ) + else: + domain = self._domain - claims = jwt.decode( - logout_token, - signing_key.key, - algorithms=["RS256"], - options={"verify_signature": True, "verify_iss": False} - ) + # Fetch JWKS and verify logout token + jwks = await self._get_jwks_cached(domain) + + try: + claims = await self._verify_and_decode_jwt(logout_token, jwks) # Normalized issuer validation token_issuer = claims.get("iss", "") - expected_issuer = self._normalize_domain(self._domain) + expected_issuer = self._normalize_domain(domain) if self._normalize_issuer(token_issuer) != self._normalize_issuer(expected_issuer): raise BackchannelLogoutError( f"Logout token issuer mismatch. Token issuer: {token_issuer}, Expected: {expected_issuer}. " f"Ensure your Auth0 domain is configured correctly." ) + except ValueError as e: + raise BackchannelLogoutError(str(e)) except jwt.InvalidSignatureError as e: raise BackchannelLogoutError( f"Logout token signature verification failed: {str(e)}" @@ -894,17 +887,7 @@ async def get_access_token( # Validate session domain matches current request domain if state_data and self._domain_resolver: - context = build_domain_resolver_context(store_options) - try: - resolved = await self._domain_resolver(context) - current_domain = validate_resolved_domain_value(resolved) - except DomainResolverError: - raise - except Exception as e: - raise DomainResolverError( - f"Domain resolver function raised an exception: {str(e)}", - original_error=e - ) + current_domain = await self._resolve_current_domain(store_options) session_domain = getattr(state_data, 'domain', None) if session_domain and session_domain != current_domain: @@ -999,11 +982,10 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, # Use session domain if provided, otherwise fallback to static domain domain = options.get("domain") or self._domain - # Ensure we have the OIDC metadata from the correct domain - if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: - self._oauth.metadata = await self._get_oidc_metadata_cached(domain) + # Fetch OIDC metadata from the correct domain + metadata = await self._get_oidc_metadata_cached(domain) - token_endpoint = self._oauth.metadata.get("token_endpoint") + token_endpoint = metadata.get("token_endpoint") if not token_endpoint: raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") @@ -1374,11 +1356,10 @@ async def backchannel_authentication_grant(self, auth_req_id: str) -> dict[str, raise MissingRequiredArgumentError("auth_req_id") try: - # Ensure we have the OIDC metadata - if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: - self._oauth.metadata = await self._get_oidc_metadata_cached(self._domain) + # Fetch OIDC metadata + metadata = await self._get_oidc_metadata_cached(self._domain) - token_endpoint = self._oauth.metadata.get("token_endpoint") + token_endpoint = metadata.get("token_endpoint") if not token_endpoint: raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") @@ -1766,11 +1747,10 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A # Use session domain if provided, otherwise fallback to static domain domain = options.get("domain") or self._domain - # Ensure we have OIDC metadata from the correct domain - if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: - self._oauth.metadata = await self._get_oidc_metadata_cached(domain) + # Fetch OIDC metadata from the correct domain + metadata = await self._get_oidc_metadata_cached(domain) - token_endpoint = self._oauth.metadata.get("token_endpoint") + token_endpoint = metadata.get("token_endpoint") if not token_endpoint: raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") @@ -2103,11 +2083,10 @@ async def custom_token_exchange( if not isinstance(options, CustomTokenExchangeOptions): options = CustomTokenExchangeOptions(**options) - # Ensure we have OIDC metadata - if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) + # Fetch OIDC metadata + metadata = await self._get_oidc_metadata_cached(self._domain) - token_endpoint = self._oauth.metadata.get("token_endpoint") + token_endpoint = metadata.get("token_endpoint") if not token_endpoint: raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index f63c232..6ee89d3 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -1416,7 +1416,11 @@ async def test_backchannel_authentication_grant_success(mocker): secret="some-secret" ) # Mock OIDC metadata - client._oauth.metadata = {"token_endpoint": "https://auth0.local/token"} + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/token"} + ) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) mock_response = AsyncMock() @@ -1450,7 +1454,11 @@ async def test_backchannel_authentication_grant_error_response(mocker): client_secret="client_secret", secret="some-secret" ) - client._oauth.metadata = {"token_endpoint": "https://auth0.local/token"} + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/token"} + ) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) mock_response = AsyncMock() @@ -1477,7 +1485,11 @@ async def test_backchannel_authentication_grant_json_decode_error(mocker): client_secret="client_secret", secret="some-secret" ) - client._oauth.metadata = {"token_endpoint": "https://auth0.local/token"} + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/token"} + ) # Mock httpx.AsyncClient.post to return a response whose .json() raises JSONDecodeError mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) @@ -1502,9 +1514,9 @@ async def test_get_token_for_connection_success(mocker): ) mocker.patch.object( - client._oauth, - "metadata", - {"token_endpoint": "https://auth0.local/token"} + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/token"} ) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) @@ -1549,9 +1561,9 @@ async def test_get_token_for_connection_exchange_failed(mocker): ) mocker.patch.object( - client._oauth, - "metadata", - {"token_endpoint": "https://auth0.local/token"} + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/token"} ) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) @@ -1587,9 +1599,9 @@ async def test_get_token_by_refresh_token_success(mocker): ) mocker.patch.object( - client._oauth, - "metadata", - {"token_endpoint": "https://auth0.local/token"} + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/token"} ) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) @@ -1631,9 +1643,9 @@ async def test_get_token_by_refresh_token_exchange_failed(mocker): ) mocker.patch.object( - client._oauth, - "metadata", - {"token_endpoint": "https://auth0.local/token"} + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/token"} ) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) @@ -3066,7 +3078,7 @@ async def mock_fetch_jwks(uri): @pytest.mark.asyncio async def test_jwks_cache_size_limit(): - """Test JWKS cache enforces max size limit with FIFO eviction.""" + """Test JWKS cache enforces max size limit with LRU eviction.""" client = ServerClient( domain="tenant.auth0.com", client_id="test_client", From f8d847005584e8dd29dd23e60237ecec43f8cd00 Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Wed, 25 Feb 2026 22:44:01 +0530 Subject: [PATCH 12/16] fix: update type hint for audience parameter in _verify_and_decode_jwt method --- src/auth0_server_python/auth_server/server_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index da1a302..e31a095 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -214,7 +214,7 @@ async def _resolve_current_domain(self, store_options=None) -> str: return self._domain async def _verify_and_decode_jwt( - self, token: str, jwks: dict, audience: str | None = None + self, token: str, jwks: dict, audience: Optional[str] = None ) -> dict: """ Find signing key in JWKS by kid and decode JWT. From 283377f92f8231f1afb3907ad2c5fca0408c30f6 Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Thu, 26 Feb 2026 20:59:45 +0530 Subject: [PATCH 13/16] feat: add DOMAIN_MISMATCH error code for session domain validation in MCD support --- .../auth_server/server_client.py | 30 +++++++++++++++++-- src/auth0_server_python/error/__init__.py | 2 ++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index e31a095..2e78990 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -893,7 +893,7 @@ async def get_access_token( if session_domain and session_domain != current_domain: # Session created with different domain - reject for security raise AccessTokenError( - AccessTokenErrorCode.MISSING_REFRESH_TOKEN, + AccessTokenErrorCode.DOMAIN_MISMATCH, "Session domain mismatch. User needs to re-authenticate with the current domain." ) @@ -1443,6 +1443,11 @@ async def start_link_user( "Unable to start the user linking process without a logged in user. Ensure to login using the SDK before starting the user linking process." ) + # Resolve domain for MCD + origin_domain = await self._resolve_current_domain(store_options) + metadata = await self._get_oidc_metadata_cached(origin_domain) + origin_issuer = metadata.get('issuer') + # Generate PKCE and state for security code_verifier = PKCE.generate_code_verifier() state = PKCE.generate_random_string(32) @@ -1460,7 +1465,9 @@ async def start_link_user( # Store transaction data transaction_data = TransactionData( code_verifier=code_verifier, - app_state=options.get("app_state") + app_state=options.get("app_state"), + origin_domain=origin_domain, + origin_issuer=origin_issuer, ) await self._transaction_store.set( @@ -1517,6 +1524,11 @@ async def start_unlink_user( "Unable to start the user linking process without a logged in user. Ensure to login using the SDK before starting the user linking process." ) + # Resolve domain for MCD + origin_domain = await self._resolve_current_domain(store_options) + metadata = await self._get_oidc_metadata_cached(origin_domain) + origin_issuer = metadata.get('issuer') + # Generate PKCE and state for security code_verifier = PKCE.generate_code_verifier() state = PKCE.generate_random_string(32) @@ -1533,7 +1545,9 @@ async def start_unlink_user( # Store transaction data transaction_data = TransactionData( code_verifier=code_verifier, - app_state=options.get("app_state") + app_state=options.get("app_state"), + origin_domain=origin_domain, + origin_issuer=origin_issuer, ) await self._transaction_store.set( @@ -1687,6 +1701,16 @@ async def get_access_token_for_connection( else: state_data_dict = state_data or {} + # In MCD mode, verify session domain matches current domain + if self._domain_resolver: + current_domain = await self._resolve_current_domain(store_options) + session_domain = state_data_dict.get("domain") + if session_domain and session_domain != current_domain: + raise AccessTokenForConnectionError( + AccessTokenForConnectionErrorCode.DOMAIN_MISMATCH, + "Session domain mismatch. User needs to re-authenticate with the current domain." + ) + # Find existing connection token connection_token_set = None if state_data_dict and len(state_data_dict["connection_token_sets"]) > 0: diff --git a/src/auth0_server_python/error/__init__.py b/src/auth0_server_python/error/__init__.py index 59f1c66..8af2c63 100644 --- a/src/auth0_server_python/error/__init__.py +++ b/src/auth0_server_python/error/__init__.py @@ -183,6 +183,7 @@ class AccessTokenErrorCode: REFRESH_TOKEN_ERROR = "refresh_token_error" AUTH_REQ_ID_ERROR = "auth_req_id_error" INCORRECT_AUDIENCE = "incorrect_audience" + DOMAIN_MISMATCH = "domain_mismatch" class AccessTokenForConnectionErrorCode: @@ -191,6 +192,7 @@ class AccessTokenForConnectionErrorCode: FAILED_TO_RETRIEVE = "failed_to_retrieve" API_ERROR = "api_error" FETCH_ERROR = "retrieval_error" + DOMAIN_MISMATCH = "domain_mismatch" class CustomTokenExchangeError(Auth0Error): From e40942dd3423cf22f03ba791121d83714f6873ae Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Fri, 27 Feb 2026 17:42:37 +0530 Subject: [PATCH 14/16] feat: implement domain validation for backchannel logout and enhance session handling in MCD mode --- .../auth_server/server_client.py | 5 + .../tests/test_server_client.py | 347 ++++++++++++++++++ 2 files changed, 352 insertions(+) diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 2e78990..906d22d 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -806,6 +806,11 @@ async def handle_backchannel_logout( raise BackchannelLogoutError( "Cannot determine domain: logout token has no valid issuer" ) + if domain not in self._discovery_cache: + raise BackchannelLogoutError( + f"Unknown domain in logout token issuer: {domain}. " + f"Only domains from active sessions are accepted." + ) except BackchannelLogoutError: raise except Exception as e: diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index 6ee89d3..67bc0bd 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -24,7 +24,10 @@ TransactionData, ) from auth0_server_python.error import ( + AccessTokenError, + AccessTokenErrorCode, AccessTokenForConnectionError, + AccessTokenForConnectionErrorCode, ApiError, BackchannelLogoutError, ConfigurationError, @@ -250,6 +253,98 @@ async def test_complete_link_user_returns_app_state(mocker): mock_tx_store.delete.assert_awaited_once() +@pytest.mark.asyncio +async def test_start_link_user_stores_origin_domain_in_mcd(mocker): + """Test that start_link_user stores origin_domain in transaction in MCD mode.""" + async def domain_resolver(context): + return "tenant1.auth0.com" + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = { + "id_token": "existing_id_token", + "user": {"sub": "user123"} + } + + captured_transaction = None + async def capture_tx(identifier, transaction_data, options=None): + nonlocal captured_transaction + captured_transaction = transaction_data + + mock_tx_store = AsyncMock() + mock_tx_store.set = AsyncMock(side_effect=capture_tx) + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={ + "issuer": "https://tenant1.auth0.com/", + "authorization_endpoint": "https://tenant1.auth0.com/authorize", + "token_endpoint": "https://tenant1.auth0.com/oauth/token" + }) + mocker.patch.object(client, "_build_link_user_url", return_value="https://tenant1.auth0.com/authorize?...") + + await client.start_link_user( + options={"connection": "google-oauth2"}, + store_options={"request": {}} + ) + + assert captured_transaction is not None + assert captured_transaction.origin_domain == "tenant1.auth0.com" + assert captured_transaction.origin_issuer == "https://tenant1.auth0.com/" + + +@pytest.mark.asyncio +async def test_start_unlink_user_stores_origin_domain_in_mcd(mocker): + """Test that start_unlink_user stores origin_domain in transaction in MCD mode.""" + async def domain_resolver(context): + return "tenant1.auth0.com" + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = { + "id_token": "existing_id_token", + "user": {"sub": "user123"} + } + + captured_transaction = None + async def capture_tx(identifier, transaction_data, options=None): + nonlocal captured_transaction + captured_transaction = transaction_data + + mock_tx_store = AsyncMock() + mock_tx_store.set = AsyncMock(side_effect=capture_tx) + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={ + "issuer": "https://tenant1.auth0.com/", + "authorization_endpoint": "https://tenant1.auth0.com/authorize", + "token_endpoint": "https://tenant1.auth0.com/oauth/token" + }) + mocker.patch.object(client, "_build_unlink_user_url", return_value="https://tenant1.auth0.com/authorize?...") + + await client.start_unlink_user( + options={"connection": "google-oauth2"}, + store_options={"request": {}} + ) + + assert captured_transaction is not None + assert captured_transaction.origin_domain == "tenant1.auth0.com" + assert captured_transaction.origin_issuer == "https://tenant1.auth0.com/" + + @pytest.mark.asyncio async def test_login_backchannel_stores_access_token(mocker): mock_transaction_store = AsyncMock() @@ -375,6 +470,36 @@ async def test_get_session_none(): session_data = await client.get_session() assert session_data is None + +@pytest.mark.asyncio +async def test_get_session_domain_mismatch_returns_none(): + """Test that get_session returns None on domain mismatch in MCD mode.""" + session_data = StateData( + user={"sub": "user123"}, + domain="tenant1.auth0.com", + token_sets=[], + internal={"sid": "123", "created_at": int(time.time())} + ) + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant2.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + session = await client.get_session(store_options={"request": {}}) + assert session is None + + @pytest.mark.asyncio async def test_get_access_token_from_store(): mock_state_store = AsyncMock() @@ -784,6 +909,41 @@ async def test_get_access_token_from_store_returns_minimum_matching_scopes(mocke assert token == "minimum_scope_token" get_refresh_token_mock.assert_not_awaited() + +@pytest.mark.asyncio +async def test_get_access_token_domain_mismatch_raises_error(): + """Test that get_access_token raises AccessTokenError on domain mismatch.""" + session_data = StateData( + user={"sub": "user123"}, + domain="tenant1.auth0.com", + token_sets=[{ + "audience": "default", + "access_token": "token123", + "expires_at": int(time.time()) + 500 + }], + internal={"sid": "123", "created_at": int(time.time())} + ) + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant2.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + with pytest.raises(AccessTokenError) as exc: + await client.get_access_token(store_options={"request": {}}) + assert exc.value.code == AccessTokenErrorCode.DOMAIN_MISMATCH + + @pytest.mark.asyncio async def test_get_access_token_for_connection_cached(): mock_state_store = AsyncMock() @@ -827,6 +987,47 @@ async def test_get_access_token_for_connection_no_refresh(): await client.get_access_token_for_connection({"connection": "my_connection"}) assert "A refresh token was not found" in str(exc.value) + +@pytest.mark.asyncio +async def test_get_access_token_for_connection_domain_mismatch(): + """Test that get_access_token_for_connection raises error on domain mismatch.""" + session_data = StateData( + user={"sub": "user123"}, + domain="tenant1.auth0.com", + token_sets=[], + connection_token_sets=[{ + "connection": "my_connection", + "audience": "default", + "access_token": "conn_token", + "login_hint": "hint", + "expires_at": int(time.time()) + 500 + }], + internal={"sid": "123", "created_at": int(time.time())} + ) + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant2.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + with pytest.raises(AccessTokenForConnectionError) as exc: + await client.get_access_token_for_connection( + {"connection": "my_connection"}, + store_options={"request": {}} + ) + assert exc.value.code == AccessTokenForConnectionErrorCode.DOMAIN_MISMATCH + + @pytest.mark.asyncio async def test_logout(): mock_state_store = AsyncMock() @@ -913,6 +1114,92 @@ async def test_handle_backchannel_logout_ok(mocker): None ) + +@pytest.mark.asyncio +async def test_backchannel_logout_mcd_known_domain(mocker): + """Test that backchannel logout works in MCD mode when domain is in cache.""" + async def domain_resolver(context): + return "tenant1.auth0.com" + + mock_state_store = AsyncMock() + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Pre-populate the discovery cache (simulates a prior login) + client._discovery_cache["tenant1.auth0.com"] = { + "metadata": { + "issuer": "https://tenant1.auth0.com/", + "jwks_uri": "https://tenant1.auth0.com/.well-known/jwks.json" + }, + "jwks": {"keys": [{"kty": "RSA", "kid": "test-key"}]}, + "expires_at": time.time() + 600 + } + + # Mock the unverified decode to extract issuer + mocker.patch("jwt.decode", return_value={ + "iss": "https://tenant1.auth0.com/", + "events": {"http://schemas.openid.net/event/backchannel-logout": {}}, + "sub": "user123", + "sid": "session123" + }) + + mocker.patch.object( + client, "_get_jwks_cached", + return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]} + ) + + mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "iss": "https://tenant1.auth0.com/", + "events": {"http://schemas.openid.net/event/backchannel-logout": {}}, + "sub": "user123", + "sid": "session123" + }) + + await client.handle_backchannel_logout("some_logout_token") + + mock_state_store.delete_by_logout_token.assert_awaited_once_with( + {"sub": "user123", "sid": "session123"}, + None + ) + + +@pytest.mark.asyncio +async def test_backchannel_logout_mcd_unknown_domain_rejected(mocker): + """Test that backchannel logout rejects unknown domains in MCD mode (SSRF protection).""" + async def domain_resolver(context): + return "tenant1.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + state_store=AsyncMock(), + secret="test_secret_key_32_chars_long!!", + ) + + # Discovery cache is empty — no prior logins + assert len(client._discovery_cache) == 0 + + # Mock unverified decode — attacker's token has evil issuer + mocker.patch("jwt.decode", return_value={ + "iss": "https://evil.internal.server/", + "events": {"http://schemas.openid.net/event/backchannel-logout": {}}, + "sub": "user123", + "sid": "session123" + }) + + with pytest.raises(BackchannelLogoutError) as exc: + await client.handle_backchannel_logout("crafted_logout_token") + assert "Unknown domain" in str(exc.value) + assert "evil.internal.server" in str(exc.value) + + # Test For AuthLib Helpers @pytest.mark.asyncio @@ -3615,6 +3902,66 @@ async def bad_resolver(context): with pytest.raises(DomainResolverError, match="must return a string"): await client.start_interactive_login(store_options={"request": MagicMock()}) + +@pytest.mark.asyncio +async def test_sync_callable_as_domain_resolver_raises_error(): + """Test that a non-async (sync) callable raises DomainResolverError. + + The SDK always awaits the resolver, so sync callables are not supported. + Domain resolvers must be async functions. + """ + def sync_resolver(context): + return "tenant1.auth0.com" + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = StateData( + user={"sub": "user123"}, + domain="tenant1.auth0.com", + token_sets=[], + internal={"sid": "123", "created_at": int(time.time())} + ) + + client = ServerClient( + domain=sync_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + with pytest.raises(DomainResolverError): + await client.get_user(store_options={"request": {}}) + + +@pytest.mark.asyncio +async def test_resolver_returns_domain_with_scheme_prefix(): + """Test that domain resolver returning 'https://domain' works with session matching.""" + async def resolver_with_scheme(context): + return "https://tenant1.auth0.com" + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = StateData( + user={"sub": "user123"}, + domain="https://tenant1.auth0.com", + token_sets=[], + internal={"sid": "123", "created_at": int(time.time())} + ) + + client = ServerClient( + domain=resolver_with_scheme, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + user = await client.get_user(store_options={"request": {}}) + assert user is not None + assert user["sub"] == "user123" + + # ============================================================================= # MCD Tests : Domain-specific Session Management Tests # ============================================================================= From 1c5f2d8cd8de360e682895f63589d24c460629f9 Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Tue, 3 Mar 2026 19:38:13 +0530 Subject: [PATCH 15/16] chore: fix for review comments (set-1) --- examples/MultipleCustomDomains.md | 44 ++- .../auth_server/server_client.py | 184 ++++++--- src/auth0_server_python/error/__init__.py | 2 + .../tests/test_server_client.py | 348 +++++++++++++++++- 4 files changed, 510 insertions(+), 68 deletions(-) diff --git a/examples/MultipleCustomDomains.md b/examples/MultipleCustomDomains.md index 5afafda..9b272ac 100644 --- a/examples/MultipleCustomDomains.md +++ b/examples/MultipleCustomDomains.md @@ -100,7 +100,37 @@ async def domain_resolver(context: DomainResolverContext) -> str: return DOMAIN_MAP.get(hostname, DEFAULT_DOMAIN) ``` -> **Note:** In resolver mode, the SDK builds the `redirect_uri` dynamically from the resolved domain. You do not need to set it per request. If you override `redirect_uri` in `authorization_params`, the SDK uses your value as-is. +## Passing store_options + +In resolver mode, pass `store_options` to each SDK call so the resolver can inspect the +current request and select the correct domain. If `store_options` are omitted, the resolver +receives empty context (`request_url=None`, `request_headers=None`). + +All public SDK methods that interact with sessions or Auth0 endpoints accept `store_options`. +Here is an example using `get_user()`: + +```python +# In your route handler, pass the framework request via store_options +store_options = {"request": request, "response": response} + +# The SDK calls your domain_resolver with a DomainResolverContext +# built from the request in store_options +user = await client.get_user(store_options=store_options) +``` + +The same pattern applies to `get_session()`, `get_access_token()`, `start_interactive_login()`, +`logout()`, and all other session-aware methods. + +## Redirect URI Requirements + +In resolver mode, the SDK does not infer `redirect_uri` from the request. You must provide it +explicitly: + +- Set a default `redirect_uri` when constructing `ServerClient`, or +- Pass `redirect_uri` in `authorization_params` for each login call. + +Framework wrappers like `auth0-fastapi` handle this automatically by constructing the +`redirect_uri` from the incoming request's host and scheme. ## Resolver Patterns @@ -273,13 +303,17 @@ async def domain_resolver(context: DomainResolverContext) -> str: ## Session Behavior in Resolver Mode -In resolver mode, sessions are bound to the domain that created them. On each request, the SDK compares the session's stored domain against the current resolved domain: +In resolver mode, sessions are bound to the domain that created them. On each request, the SDK compares the session's stored domain against the current resolved domain. If the domain is missing or does not match: -- `get_user()` and `get_session()` return `None` on domain mismatch. -- `get_access_token()` raises `AccessTokenError` on domain mismatch. +- `get_user()` and `get_session()` return `None`. +- `get_access_token()` raises `AccessTokenError` (code `MISSING_SESSION_DOMAIN` if the session has no stored domain, `DOMAIN_MISMATCH` if the domains differ). +- `get_access_token_for_connection()` raises `AccessTokenForConnectionError` (same codes as above). +- `start_link_user()` and `start_unlink_user()` raise `StartLinkUserError`. - Token refresh uses the session's stored domain, not the current request domain. -> **Warning:** If you switch from a static domain string to a resolver function, existing sessions that do not include a stored domain continue to work — the SDK treats the absent domain field as valid. New sessions will store the resolved domain automatically. Once old sessions expire, all sessions will be domain-aware. +> **Warning:** If you switch from a static domain string to a resolver function, existing sessions that do not include a stored domain are treated as **missing sessions**. The SDK cannot verify which domain originally created the session, so users will need to re-authenticate. New sessions store the resolved domain automatically. + +> **Note:** If a login was started before the switch to resolver mode and completes after, the SDK falls back to the current resolved domain for token exchange. The resulting session will store the resolved domain and work normally going forward. ## Discovery Cache diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 906d22d..08e87c7 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -573,8 +573,8 @@ async def complete_interactive_login( if not code: raise MissingRequiredArgumentError("code") - # Get origin domain and issuer from transaction - origin_domain = transaction_data.origin_domain + # Get origin domain and issuer from transactiondata, or resolve domain if not present (resolver mode) + origin_domain = transaction_data.origin_domain or await self._resolve_current_domain(store_options) origin_issuer = transaction_data.origin_issuer # Fetch metadata from the origin domain @@ -714,17 +714,19 @@ async def get_user(self, store_options: Optional[dict[str, Any]] = None) -> Opti state_data = await self._state_store.get(self._state_identifier, store_options) if state_data: - # Validate session domain matches current request domain + # Domain check should work for both Pydantic models and plain dicts + if hasattr(state_data, "dict") and callable(state_data.dict): + state_data = state_data.dict() + + # In resolver mode, reject sessions without domain or with mismatched domain if self._domain_resolver: + session_domain = state_data.get('domain') + if not session_domain: + return None current_domain = await self._resolve_current_domain(store_options) - session_domain = getattr(state_data, 'domain', None) - - if session_domain and session_domain != current_domain: - # Session created with different domain - reject for security + if self._normalize_issuer(session_domain) != self._normalize_issuer(current_domain): return None - if hasattr(state_data, "dict") and callable(state_data.dict): - state_data = state_data.dict() return state_data.get("user") return None @@ -741,17 +743,19 @@ async def get_session(self, store_options: Optional[dict[str, Any]] = None) -> O state_data = await self._state_store.get(self._state_identifier, store_options) if state_data: - # Validate session domain matches current request domain + # Domain check should work for both Pydantic models and plain dicts + if hasattr(state_data, "dict") and callable(state_data.dict): + state_data = state_data.dict() + + # In resolver mode, reject sessions without domain or with mismatched domain if self._domain_resolver: + session_domain = state_data.get('domain') + if not session_domain: + return None current_domain = await self._resolve_current_domain(store_options) - session_domain = getattr(state_data, 'domain', None) - - if session_domain and session_domain != current_domain: - # Session created with different domain - reject for security + if self._normalize_issuer(session_domain) != self._normalize_issuer(current_domain): return None - if hasattr(state_data, "dict") and callable(state_data.dict): - state_data = state_data.dict() session_data = {k: v for k, v in state_data.items() if k != "internal"} return session_data @@ -890,13 +894,22 @@ async def get_access_token( """ state_data = await self._state_store.get(self._state_identifier, store_options) - # Validate session domain matches current request domain + # Domain check should work for both Pydantic models and plain dicts + if state_data and hasattr(state_data, "dict") and callable(state_data.dict): + state_data_dict = state_data.dict() + else: + state_data_dict = state_data or {} + + # In resolver mode, reject sessions without domain or with mismatched domain if state_data and self._domain_resolver: + session_domain = state_data_dict.get('domain') + if not session_domain: + raise AccessTokenError( + AccessTokenErrorCode.MISSING_SESSION_DOMAIN, + "Session is missing domain. User needs to re-authenticate." + ) current_domain = await self._resolve_current_domain(store_options) - session_domain = getattr(state_data, 'domain', None) - - if session_domain and session_domain != current_domain: - # Session created with different domain - reject for security + if self._normalize_issuer(session_domain) != self._normalize_issuer(current_domain): raise AccessTokenError( AccessTokenErrorCode.DOMAIN_MISMATCH, "Session domain mismatch. User needs to re-authenticate with the current domain." @@ -910,11 +923,6 @@ async def get_access_token( merged_scope = self._merge_scope_with_defaults(scope, audience) - if state_data and hasattr(state_data, "dict") and callable(state_data.dict): - state_data_dict = state_data.dict() - else: - state_data_dict = state_data or {} - # Find matching token set token_set = None if state_data_dict and "token_sets" in state_data_dict: @@ -1127,7 +1135,7 @@ async def login_backchannel( "binding_message": options.get("binding_message"), "login_hint": options.get("login_hint"), "authorization_params": options.get("authorization_params"), - }) + }, store_options=store_options) existing_state_data = await self._state_store.get(self._state_identifier, store_options) @@ -1140,6 +1148,10 @@ async def login_backchannel( token_endpoint_response ) + # Store domain for MCD session + domain = await self._resolve_current_domain(store_options) + state_data["domain"] = domain + await self._state_store.set(self._state_identifier, state_data, store_options) result = { @@ -1149,7 +1161,8 @@ async def login_backchannel( async def backchannel_authentication( self, - options: dict[str, Any] + options: dict[str, Any], + store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Performs backchannel authentication with Auth0. @@ -1174,7 +1187,7 @@ async def backchannel_authentication( Raises: ApiError: If the backchannel authentication fails """ - backchannel_data = await self.initiate_backchannel_authentication(options) + backchannel_data = await self.initiate_backchannel_authentication(options, store_options=store_options) auth_req_id = backchannel_data.get("auth_req_id") expires_in = backchannel_data.get( "expires_in", 120) # Default to 2 minutes @@ -1188,7 +1201,7 @@ async def backchannel_authentication( while time.time() < end_time: # Make token request try: - token_response = await self.backchannel_authentication_grant(auth_req_id) + token_response = await self.backchannel_authentication_grant(auth_req_id, store_options=store_options) return token_response except Exception as e: @@ -1213,7 +1226,8 @@ async def backchannel_authentication( async def initiate_backchannel_authentication( self, - options: dict[str, Any] + options: dict[str, Any], + store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Start backchannel authentication with Auth0. @@ -1266,16 +1280,16 @@ async def initiate_backchannel_authentication( ) try: - # Fetch OpenID Connect metadata if not already fetched - if not hasattr(self, '_oauth_metadata'): - self._oauth_metadata = await self._get_oidc_metadata_cached(self._domain) + # Resolve domain + domain = await self._resolve_current_domain(store_options) + metadata = await self._get_oidc_metadata_cached(domain) # Get the issuer from metadata - issuer = self._oauth_metadata.get( - "issuer") or f"https://{self._domain}/" + issuer = metadata.get( + "issuer") or f"https://{domain}/" # Get backchannel authentication endpoint - backchannel_endpoint = self._oauth_metadata.get( + backchannel_endpoint = metadata.get( "backchannel_authentication_endpoint") if not backchannel_endpoint: raise ApiError( @@ -1344,12 +1358,17 @@ async def initiate_backchannel_authentication( e ) - async def backchannel_authentication_grant(self, auth_req_id: str) -> dict[str, Any]: + async def backchannel_authentication_grant( + self, + auth_req_id: str, + store_options: Optional[dict[str, Any]] = None + ) -> dict[str, Any]: """ Retrieves a token by exchanging an auth_req_id. Args: auth_req_id (str): The authentication request ID obtained from bc-authorize + store_options: Optional options used to pass to the Transaction and State Store. Raises: AccessTokenError: If there was an issue requesting the access token. @@ -1361,8 +1380,9 @@ async def backchannel_authentication_grant(self, auth_req_id: str) -> dict[str, raise MissingRequiredArgumentError("auth_req_id") try: - # Fetch OIDC metadata - metadata = await self._get_oidc_metadata_cached(self._domain) + # Resolve domain (supports MCD mode) + domain = await self._resolve_current_domain(store_options) + metadata = await self._get_oidc_metadata_cached(domain) token_endpoint = metadata.get("token_endpoint") if not token_endpoint: @@ -1450,6 +1470,19 @@ async def start_link_user( # Resolve domain for MCD origin_domain = await self._resolve_current_domain(store_options) + + # In resolver mode, reject sessions without domain or with mismatched domain + if self._domain_resolver: + session_domain = state_data.get('domain') if isinstance(state_data, dict) else getattr(state_data, 'domain', None) + if not session_domain: + raise StartLinkUserError( + "Session is missing domain. User needs to re-authenticate." + ) + if self._normalize_issuer(session_domain) != self._normalize_issuer(origin_domain): + raise StartLinkUserError( + "Session domain mismatch. User needs to re-authenticate with the current domain." + ) + metadata = await self._get_oidc_metadata_cached(origin_domain) origin_issuer = metadata.get('issuer') @@ -1464,7 +1497,8 @@ async def start_link_user( id_token=state_data["id_token"], code_verifier=code_verifier, state=state, - authorization_params=options.get("authorization_params") + authorization_params=options.get("authorization_params"), + domain=origin_domain ) # Store transaction data @@ -1531,6 +1565,19 @@ async def start_unlink_user( # Resolve domain for MCD origin_domain = await self._resolve_current_domain(store_options) + + # In resolver mode, reject sessions without domain or with mismatched domain + if self._domain_resolver: + session_domain = state_data.get('domain') if isinstance(state_data, dict) else getattr(state_data, 'domain', None) + if not session_domain: + raise StartLinkUserError( + "Session is missing domain. User needs to re-authenticate." + ) + if self._normalize_issuer(session_domain) != self._normalize_issuer(origin_domain): + raise StartLinkUserError( + "Session domain mismatch. User needs to re-authenticate with the current domain." + ) + metadata = await self._get_oidc_metadata_cached(origin_domain) origin_issuer = metadata.get('issuer') @@ -1544,7 +1591,8 @@ async def start_unlink_user( id_token=state_data["id_token"], code_verifier=code_verifier, state=state, - authorization_params=options.get("authorization_params") + authorization_params=options.get("authorization_params"), + domain=origin_domain ) # Store transaction data @@ -1594,19 +1642,20 @@ async def _build_link_user_url( code_verifier: str, state: str, connection_scope: Optional[str] = None, - authorization_params: Optional[dict[str, Any]] = None + authorization_params: Optional[dict[str, Any]] = None, + domain: Optional[str] = None ) -> str: """Build a URL for linking user accounts""" # Generate code challenge from verifier code_challenge = PKCE.generate_code_challenge(code_verifier) - # Get metadata if not already fetched - if not hasattr(self, '_oauth_metadata'): - self._oauth_metadata = await self._get_oidc_metadata_cached(self._domain) + # Use provided domain or fall back to static domain + resolved_domain = domain or self._domain + metadata = await self._get_oidc_metadata_cached(resolved_domain) # Get authorization endpoint - auth_endpoint = self._oauth_metadata.get("authorization_endpoint", - f"https://{self._domain}/authorize") + auth_endpoint = metadata.get("authorization_endpoint", + f"https://{resolved_domain}/authorize") # Build params params = { @@ -1637,19 +1686,20 @@ async def _build_unlink_user_url( id_token: str, code_verifier: str, state: str, - authorization_params: Optional[dict[str, Any]] = None + authorization_params: Optional[dict[str, Any]] = None, + domain: Optional[str] = None ) -> str: """Build a URL for unlinking user accounts""" # Generate code challenge from verifier code_challenge = PKCE.generate_code_challenge(code_verifier) - # Get metadata if not already fetched - if not hasattr(self, '_oauth_metadata'): - self._oauth_metadata = await self._get_oidc_metadata_cached(self._domain) + # Use provided domain or fall back to static domain + resolved_domain = domain or self._domain + metadata = await self._get_oidc_metadata_cached(resolved_domain) # Get authorization endpoint - auth_endpoint = self._oauth_metadata.get("authorization_endpoint", - f"https://{self._domain}/authorize") + auth_endpoint = metadata.get("authorization_endpoint", + f"https://{resolved_domain}/authorize") # Build params params = { @@ -1706,11 +1756,16 @@ async def get_access_token_for_connection( else: state_data_dict = state_data or {} - # In MCD mode, verify session domain matches current domain + # In resolver mode, reject sessions without domain or with mismatched domain if self._domain_resolver: - current_domain = await self._resolve_current_domain(store_options) session_domain = state_data_dict.get("domain") - if session_domain and session_domain != current_domain: + if not session_domain: + raise AccessTokenForConnectionError( + AccessTokenForConnectionErrorCode.MISSING_SESSION_DOMAIN, + "Session is missing domain. User needs to re-authenticate." + ) + current_domain = await self._resolve_current_domain(store_options) + if self._normalize_issuer(session_domain) != self._normalize_issuer(current_domain): raise AccessTokenForConnectionError( AccessTokenForConnectionErrorCode.DOMAIN_MISMATCH, "Session domain mismatch. User needs to re-authenticate with the current domain." @@ -2073,7 +2128,8 @@ async def list_connected_account_connections( async def custom_token_exchange( self, - options: CustomTokenExchangeOptions + options: CustomTokenExchangeOptions, + store_options: Optional[dict[str, Any]] = None ) -> TokenExchangeResponse: """ Exchanges a custom token for Auth0 tokens using RFC 8693. @@ -2083,6 +2139,7 @@ async def custom_token_exchange( Args: options: Configuration for the token exchange + store_options: Optional options used to pass to the Transaction and State Store. Returns: TokenExchangeResponse containing access_token and metadata @@ -2112,8 +2169,9 @@ async def custom_token_exchange( if not isinstance(options, CustomTokenExchangeOptions): options = CustomTokenExchangeOptions(**options) - # Fetch OIDC metadata - metadata = await self._get_oidc_metadata_cached(self._domain) + # Resolve domain + domain = await self._resolve_current_domain(store_options) + metadata = await self._get_oidc_metadata_cached(domain) token_endpoint = metadata.get("token_endpoint") if not token_endpoint: @@ -2241,7 +2299,7 @@ async def login_with_custom_token_exchange( authorization_params=options.authorization_params ) - token_response = await self.custom_token_exchange(exchange_options) + token_response = await self.custom_token_exchange(exchange_options, store_options=store_options) # Extract user claims from ID token if present user_claims = None @@ -2263,12 +2321,16 @@ async def login_with_custom_token_exchange( expires_at=int(time.time()) + token_response.expires_in ) + # Resolve domain for session storage + domain = await self._resolve_current_domain(store_options) + # Construct state data state_data = StateData( user=user_claims, id_token=token_response.id_token, refresh_token=token_response.refresh_token, token_sets=[token_set], + domain=domain, internal={ "sid": sid, "created_at": int(time.time()) diff --git a/src/auth0_server_python/error/__init__.py b/src/auth0_server_python/error/__init__.py index 8af2c63..d3ff871 100644 --- a/src/auth0_server_python/error/__init__.py +++ b/src/auth0_server_python/error/__init__.py @@ -183,6 +183,7 @@ class AccessTokenErrorCode: REFRESH_TOKEN_ERROR = "refresh_token_error" AUTH_REQ_ID_ERROR = "auth_req_id_error" INCORRECT_AUDIENCE = "incorrect_audience" + MISSING_SESSION_DOMAIN = "missing_session_domain" DOMAIN_MISMATCH = "domain_mismatch" @@ -192,6 +193,7 @@ class AccessTokenForConnectionErrorCode: FAILED_TO_RETRIEVE = "failed_to_retrieve" API_ERROR = "api_error" FETCH_ERROR = "retrieval_error" + MISSING_SESSION_DOMAIN = "missing_session_domain" DOMAIN_MISMATCH = "domain_mismatch" diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index 67bc0bd..3b4d7e0 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -262,7 +262,8 @@ async def domain_resolver(context): mock_state_store = AsyncMock() mock_state_store.get.return_value = { "id_token": "existing_id_token", - "user": {"sub": "user123"} + "user": {"sub": "user123"}, + "domain": "tenant1.auth0.com" } captured_transaction = None @@ -308,7 +309,8 @@ async def domain_resolver(context): mock_state_store = AsyncMock() mock_state_store.get.return_value = { "id_token": "existing_id_token", - "user": {"sub": "user123"} + "user": {"sub": "user123"}, + "domain": "tenant1.auth0.com" } captured_transaction = None @@ -500,6 +502,148 @@ async def domain_resolver(context): assert session is None +@pytest.mark.asyncio +async def test_get_session_domain_mismatch_with_dict_state(): + """Test domain mismatch works when state store returns plain dict (stateless cookie store).""" + session_data = { + "user": {"sub": "user123"}, + "domain": "tenant1.auth0.com", + "token_sets": [], + "internal": {"sid": "123", "created_at": int(time.time())} + } + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant2.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + session = await client.get_session(store_options={"request": {}}) + assert session is None + + +@pytest.mark.asyncio +async def test_get_user_domain_mismatch_with_dict_state(): + """Test domain mismatch works when state store returns plain dict (stateless cookie store).""" + session_data = { + "user": {"sub": "user123", "name": "Test User"}, + "domain": "tenant1.auth0.com", + "token_sets": [], + "internal": {"sid": "123", "created_at": int(time.time())} + } + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant2.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + user = await client.get_user(store_options={"request": {}}) + assert user is None + + +@pytest.mark.asyncio +async def test_get_user_legacy_session_rejected_in_resolver_mode(): + """Test that sessions without domain field are rejected in resolver mode.""" + session_data = { + "user": {"sub": "user123", "name": "Test User"}, + # No "domain" field — legacy session created before MCD was enabled + } + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant1.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + user = await client.get_user(store_options={"request": {}}) + assert user is None + + +@pytest.mark.asyncio +async def test_get_session_legacy_session_rejected_in_resolver_mode(): + """Test that sessions without domain field are rejected in resolver mode.""" + session_data = { + "user": {"sub": "user123"}, + "token_sets": [], + "internal": {"sid": "123", "created_at": int(time.time())} + # No "domain" field — legacy session + } + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant1.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + session = await client.get_session(store_options={"request": {}}) + assert session is None + + +@pytest.mark.asyncio +async def test_get_user_domain_normalization(): + """Test that domain comparison is case-insensitive and normalizes schemes.""" + session_data = { + "user": {"sub": "user123", "name": "Test User"}, + "domain": "Tenant1.Auth0.Com" # Mixed case + } + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant1.auth0.com" # Lowercase + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + user = await client.get_user(store_options={"request": {}}) + assert user is not None + assert user["sub"] == "user123" + + @pytest.mark.asyncio async def test_get_access_token_from_store(): mock_state_store = AsyncMock() @@ -944,6 +1088,40 @@ async def domain_resolver(context): assert exc.value.code == AccessTokenErrorCode.DOMAIN_MISMATCH +@pytest.mark.asyncio +async def test_get_access_token_domain_mismatch_with_dict_state(): + """Test domain mismatch works when state store returns plain dict (stateless cookie store).""" + session_data = { + "user": {"sub": "user123"}, + "domain": "tenant1.auth0.com", + "token_sets": [{ + "audience": "default", + "access_token": "token123", + "expires_at": int(time.time()) + 500 + }], + "internal": {"sid": "123", "created_at": int(time.time())} + } + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant2.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + with pytest.raises(AccessTokenError) as exc: + await client.get_access_token(store_options={"request": {}}) + assert exc.value.code == AccessTokenErrorCode.DOMAIN_MISMATCH + + @pytest.mark.asyncio async def test_get_access_token_for_connection_cached(): mock_state_store = AsyncMock() @@ -1028,6 +1206,172 @@ async def domain_resolver(context): assert exc.value.code == AccessTokenForConnectionErrorCode.DOMAIN_MISMATCH +@pytest.mark.asyncio +async def test_get_access_token_legacy_session_rejected_in_resolver_mode(): + """Test that sessions without domain field raise MISSING_SESSION_DOMAIN in resolver mode.""" + session_data = { + "user": {"sub": "user123"}, + "token_sets": [{ + "audience": "default", + "access_token": "token123", + "expires_at": int(time.time()) + 500 + }], + "internal": {"sid": "123", "created_at": int(time.time())} + # No "domain" field — legacy session + } + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant1.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + with pytest.raises(AccessTokenError) as exc: + await client.get_access_token(store_options={"request": {}}) + assert exc.value.code == AccessTokenErrorCode.MISSING_SESSION_DOMAIN + + +@pytest.mark.asyncio +async def test_get_access_token_for_connection_legacy_session_rejected(): + """Test that sessions without domain field raise MISSING_SESSION_DOMAIN in resolver mode.""" + session_data = { + "user": {"sub": "user123"}, + "connection_token_sets": [{ + "connection": "my_connection", + "access_token": "conn_token", + "expires_at": int(time.time()) + 500 + }], + "internal": {"sid": "123", "created_at": int(time.time())} + # No "domain" field — legacy session + } + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant1.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + with pytest.raises(AccessTokenForConnectionError) as exc: + await client.get_access_token_for_connection( + {"connection": "my_connection"}, + store_options={"request": {}} + ) + assert exc.value.code == AccessTokenForConnectionErrorCode.MISSING_SESSION_DOMAIN + + +@pytest.mark.asyncio +async def test_get_access_token_for_connection_domain_mismatch_with_dict_state(): + """Test domain mismatch works when state store returns plain dict (stateless cookie store).""" + session_data = { + "user": {"sub": "user123"}, + "domain": "tenant1.auth0.com", + "connection_token_sets": [{ + "connection": "my_connection", + "access_token": "conn_token", + "expires_at": int(time.time()) + 500 + }], + "internal": {"sid": "123", "created_at": int(time.time())} + } + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant2.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + with pytest.raises(AccessTokenForConnectionError) as exc: + await client.get_access_token_for_connection( + {"connection": "my_connection"}, + store_options={"request": {}} + ) + assert exc.value.code == AccessTokenForConnectionErrorCode.DOMAIN_MISMATCH + + +@pytest.mark.asyncio +async def test_start_link_user_rejects_legacy_session_in_resolver_mode(mocker): + """Test that start_link_user rejects sessions without domain in resolver mode.""" + async def domain_resolver(context): + return "tenant1.auth0.com" + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = { + "id_token": "existing_id_token", + "user": {"sub": "user123"} + # No "domain" field — legacy session + } + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + with pytest.raises(StartLinkUserError): + await client.start_link_user( + options={"connection": "google-oauth2"}, + store_options={"request": {}} + ) + + +@pytest.mark.asyncio +async def test_start_unlink_user_rejects_legacy_session_in_resolver_mode(mocker): + """Test that start_unlink_user rejects sessions without domain in resolver mode.""" + async def domain_resolver(context): + return "tenant1.auth0.com" + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = { + "id_token": "existing_id_token", + "user": {"sub": "user123"} + # No "domain" field — legacy session + } + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + with pytest.raises(StartLinkUserError): + await client.start_unlink_user( + options={"connection": "google-oauth2"}, + store_options={"request": {}} + ) + + @pytest.mark.asyncio async def test_logout(): mock_state_store = AsyncMock() From a9debdd98df5b4d466f0e9af1e220780238eee6f Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Wed, 4 Mar 2026 08:59:16 +0530 Subject: [PATCH 16/16] fix: scoped logout, backchannel logout --- .../auth_server/server_client.py | 36 ++++++++++++++----- src/auth0_server_python/store/abstract.py | 3 +- .../tests/test_server_client.py | 21 +++++++++-- 3 files changed, 47 insertions(+), 13 deletions(-) diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 08e87c7..e37761a 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -768,13 +768,22 @@ async def logout( ) -> str: options = options or LogoutOptions() - # Delete the session from the state store - await self._state_store.delete(self._state_identifier, store_options) + if not self._domain_resolver: + await self._state_store.delete(self._state_identifier, store_options) + domain = self._domain + else: + # Resolver mode: delete session if domains match + domain = await self._resolve_current_domain(store_options) + state_data = await self._state_store.get(self._state_identifier, store_options) - # Resolve domain dynamically for MCD support - domain = await self._resolve_current_domain(store_options) + if state_data: + if hasattr(state_data, "dict") and callable(state_data.dict): + state_data = state_data.dict() + session_domain = state_data.get("domain") + if session_domain and self._normalize_issuer(session_domain) == self._normalize_issuer(domain): + await self._state_store.delete(self._state_identifier, store_options) - # Use the URL helper to create the logout URL. + # Return logout URL for the current resolved domain logout_url = URL.create_logout_url( domain, self._client_id, options.return_to) @@ -828,7 +837,7 @@ async def handle_backchannel_logout( jwks = await self._get_jwks_cached(domain) try: - claims = await self._verify_and_decode_jwt(logout_token, jwks) + claims = await self._verify_and_decode_jwt(logout_token, jwks, audience=self._client_id) # Normalized issuer validation token_issuer = claims.get("iss", "") @@ -861,7 +870,12 @@ async def handle_backchannel_logout( sid=claims.get("sid") ) - await self._state_store.delete_by_logout_token(logout_claims.dict(), store_options) + # In resolver mode, include iss for issuer-scoped deletion + claims_dict = logout_claims.dict() + if self._domain_resolver: + claims_dict["iss"] = claims.get("iss") + + await self._state_store.delete_by_logout_token(claims_dict, store_options) except (jwt.PyJWTError, ValidationError) as e: raise BackchannelLogoutError( @@ -1473,7 +1487,9 @@ async def start_link_user( # In resolver mode, reject sessions without domain or with mismatched domain if self._domain_resolver: - session_domain = state_data.get('domain') if isinstance(state_data, dict) else getattr(state_data, 'domain', None) + if hasattr(state_data, "dict") and callable(state_data.dict): + state_data = state_data.dict() + session_domain = state_data.get('domain') if not session_domain: raise StartLinkUserError( "Session is missing domain. User needs to re-authenticate." @@ -1568,7 +1584,9 @@ async def start_unlink_user( # In resolver mode, reject sessions without domain or with mismatched domain if self._domain_resolver: - session_domain = state_data.get('domain') if isinstance(state_data, dict) else getattr(state_data, 'domain', None) + if hasattr(state_data, "dict") and callable(state_data.dict): + state_data = state_data.dict() + session_domain = state_data.get('domain') if not session_domain: raise StartLinkUserError( "Session is missing domain. User needs to re-authenticate." diff --git a/src/auth0_server_python/store/abstract.py b/src/auth0_server_python/store/abstract.py index 45e433f..178757e 100644 --- a/src/auth0_server_python/store/abstract.py +++ b/src/auth0_server_python/store/abstract.py @@ -96,7 +96,8 @@ async def delete_by_logout_token(self, claims: dict[str, Any], options: Optional Delete sessions based on logout token claims. Args: - claims: Claims from the logout token + claims: Claims from the logout token (sub, sid, and optionally iss + in MCD mode for issuer-scoped deletion) options: Additional operation-specific options Note: diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index 3b4d7e0..6c9b250 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -1445,7 +1445,7 @@ async def test_handle_backchannel_logout_ok(mocker): mock_signing_key = mocker.MagicMock() mock_signing_key.key = "mock_pem_key" mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) - mocker.patch("jwt.decode", return_value={ + mock_jwt_decode = mocker.patch("jwt.decode", return_value={ "events": {"http://schemas.openid.net/event/backchannel-logout": {}}, "iss": "https://auth0.local", "sub": "user_sub", @@ -1453,6 +1453,12 @@ async def test_handle_backchannel_logout_ok(mocker): }) await client.handle_backchannel_logout("some_logout_token") + + # Verify audience is passed to jwt.decode + call_kwargs = mock_jwt_decode.call_args[1] + assert call_kwargs["audience"] == "client_id" + + # In static mode, iss should NOT be included mock_state_store.delete_by_logout_token.assert_awaited_once_with( {"sub": "user_sub", "sid": "session_id_123"}, None @@ -1498,7 +1504,7 @@ async def domain_resolver(context): return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]} ) - mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + mock_verify = mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ "iss": "https://tenant1.auth0.com/", "events": {"http://schemas.openid.net/event/backchannel-logout": {}}, "sub": "user123", @@ -1507,8 +1513,14 @@ async def domain_resolver(context): await client.handle_backchannel_logout("some_logout_token") + # Verify audience is passed to JWT verification + mock_verify.assert_awaited_once() + call_kwargs = mock_verify.call_args[1] + assert call_kwargs["audience"] == "test_client" + + # In resolver mode, iss should be included for issuer-scoped deletion mock_state_store.delete_by_logout_token.assert_awaited_once_with( - {"sub": "user123", "sid": "session123"}, + {"sub": "user123", "sid": "session123", "iss": "https://tenant1.auth0.com/"}, None ) @@ -4405,6 +4417,7 @@ async def domain_resolver(context): return current_domain mock_state_store = AsyncMock() + mock_state_store.get.return_value = {"domain": current_domain, "user": {"sub": "user1"}} client = ServerClient( domain=domain_resolver, @@ -4420,6 +4433,8 @@ async def domain_resolver(context): # Verify logout URL uses current domain assert current_domain in logout_url assert logout_url.startswith(f"https://{current_domain}") + # Verify session was deleted (domains match) + mock_state_store.delete.assert_called_once() @pytest.mark.asyncio