diff --git a/.gitignore b/.gitignore index fe90143..c30eceb 100644 --- a/.gitignore +++ b/.gitignore @@ -23,5 +23,4 @@ setup.py test.py test-script.py .coverage -coverage.xml - +coverage.xml \ No newline at end of file diff --git a/README.md b/README.md index f81b537..ff5e46c 100644 --- a/README.md +++ b/README.md @@ -145,6 +145,34 @@ print(response.access_token) For more details and examples, see [examples/CustomTokenExchange.md](examples/CustomTokenExchange.md). +### 5. Multiple Custom Domains (MCD) + +For applications that use multiple custom domains on the same Auth0 tenant, pass a domain resolver function instead of a static domain string: + +```python +from auth0_server_python.auth_server.server_client import ServerClient +from auth0_server_python.auth_types import DomainResolverContext + +async def domain_resolver(context: DomainResolverContext) -> str: + host = context.request_headers.get('host', '').split(':')[0] + domain_map = { + "acme.yourapp.com": "acme.auth0.com", + "globex.yourapp.com": "globex.auth0.com", + } + return domain_map.get(host, "default.auth0.com") + +auth0 = ServerClient( + domain=domain_resolver, # Callable enables MCD mode + client_id='', + client_secret='', + secret='', +) +``` + +The SDK handles per-domain OIDC discovery, JWKS fetching, issuer validation, and session isolation automatically. Static string domains continue to work unchanged. + +For more details and examples, see [examples/MultipleCustomDomains.md](examples/MultipleCustomDomains.md). + ## Feedback ### Contributing diff --git a/examples/MultipleCustomDomains.md b/examples/MultipleCustomDomains.md new file mode 100644 index 0000000..9b272ac --- /dev/null +++ b/examples/MultipleCustomDomains.md @@ -0,0 +1,370 @@ +# Multiple Custom Domains (MCD) + +MCD lets you resolve the Auth0 domain per request while keeping a single `ServerClient` instance. This is useful when your application uses multiple custom domains configured on the same Auth0 tenant. + +**Example:** +- `https://acme.yourapp.com` → Custom domain: `auth.acme.com` +- `https://globex.yourapp.com` → Custom domain: `auth.globex.com` + +MCD is enabled by providing a **domain resolver function** instead of a static domain string. + +## Configuration Methods + +### Method 1: Static Domain (Single Domain) + +For applications with a single Auth0 domain: + +```python +from auth0_server_python import ServerClient + +client = ServerClient( + domain="your-tenant.auth0.com", # Static string + client_id="your_client_id", + client_secret="your_client_secret", + secret="your_encryption_secret" +) +``` + +### Method 2: Dynamic Domain Resolver (MCD) + +For MCD support, provide a domain resolver function that receives a `DomainResolverContext`: + +```python +from auth0_server_python import ServerClient +from auth0_server_python.auth_types import DomainResolverContext + +# Map your app hostnames to Auth0 domains +DOMAIN_MAP = { + "acme.yourapp.com": "acme.auth0.com", + "globex.yourapp.com": "globex.auth0.com", +} +DEFAULT_DOMAIN = "default.auth0.com" + +async def domain_resolver(context: DomainResolverContext) -> str: + """ + Resolve Auth0 domain based on request hostname. + + Args: + context: Contains request_url and request_headers + + Returns: + Auth0 domain string (e.g., "acme.auth0.com") + """ + # Extract hostname from request headers + if not context.request_headers: + return DEFAULT_DOMAIN + + host = context.request_headers.get('host', DEFAULT_DOMAIN) + host_without_port = host.split(':')[0] + + # Look up Auth0 domain + return DOMAIN_MAP.get(host_without_port, DEFAULT_DOMAIN) + +client = ServerClient( + domain=domain_resolver, # Callable function + client_id="your_client_id", + client_secret="your_client_secret", + secret="your_encryption_secret" +) +``` + +## DomainResolverContext + +The `DomainResolverContext` object provides request information to your resolver: + +| Property | Type | Description | +|----------|------|-------------| +| `request_url` | `Optional[str]` | Full request URL (e.g., "https://acme.yourapp.com/auth/login") | +| `request_headers` | `Optional[dict[str, str]]` | Request headers dictionary | + +**Common headers:** +- `host`: Request hostname (e.g., "acme.yourapp.com") +- `x-forwarded-host`: Original host when behind proxy/load balancer + +**Example usage:** + +```python +async def domain_resolver(context: DomainResolverContext) -> str: + # Check if we have request headers + if not context.request_headers: + return DEFAULT_DOMAIN + + # Use x-forwarded-host if behind proxy, otherwise use host + host = (context.request_headers.get('x-forwarded-host') or + context.request_headers.get('host', '')) + + # Remove port number if present + hostname = host.split(':')[0].lower() + + # Look up in mapping + return DOMAIN_MAP.get(hostname, DEFAULT_DOMAIN) +``` + +## Passing store_options + +In resolver mode, pass `store_options` to each SDK call so the resolver can inspect the +current request and select the correct domain. If `store_options` are omitted, the resolver +receives empty context (`request_url=None`, `request_headers=None`). + +All public SDK methods that interact with sessions or Auth0 endpoints accept `store_options`. +Here is an example using `get_user()`: + +```python +# In your route handler, pass the framework request via store_options +store_options = {"request": request, "response": response} + +# The SDK calls your domain_resolver with a DomainResolverContext +# built from the request in store_options +user = await client.get_user(store_options=store_options) +``` + +The same pattern applies to `get_session()`, `get_access_token()`, `start_interactive_login()`, +`logout()`, and all other session-aware methods. + +## Redirect URI Requirements + +In resolver mode, the SDK does not infer `redirect_uri` from the request. You must provide it +explicitly: + +- Set a default `redirect_uri` when constructing `ServerClient`, or +- Pass `redirect_uri` in `authorization_params` for each login call. + +Framework wrappers like `auth0-fastapi` handle this automatically by constructing the +`redirect_uri` from the incoming request's host and scheme. + +## Resolver Patterns + +### Database Lookup (SQLAlchemy) + +Resolve domains from a database using async SQLAlchemy: + +```python +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy import text + +engine = create_async_engine("postgresql+asyncpg://user:pass@localhost/mydb") + +async def domain_resolver(context: DomainResolverContext) -> str: + host = context.request_headers.get("host", "").split(":")[0] + tenant = host.split(".")[0] + + async with AsyncSession(engine) as session: + result = await session.execute( + text("SELECT auth0_domain FROM tenants WHERE slug = :slug"), + {"slug": tenant} + ) + row = result.fetchone() + if row: + return row[0] + + return DEFAULT_DOMAIN +``` + +### Database Lookup with In-Memory Cache + +Avoid hitting the database on every request by caching the tenant map: + +```python +import time + +_tenant_cache = {} +_cache_ttl = 300 # 5 minutes + +async def domain_resolver(context: DomainResolverContext) -> str: + host = context.request_headers.get("host", "").split(":")[0] + tenant = host.split(".")[0] + + now = time.time() + cached = _tenant_cache.get(tenant) + if cached and cached["expires_at"] > now: + return cached["domain"] + + # Cache miss - fetch from database + async with AsyncSession(engine) as session: + result = await session.execute( + text("SELECT auth0_domain FROM tenants WHERE slug = :slug"), + {"slug": tenant} + ) + row = result.fetchone() + domain = row[0] if row else DEFAULT_DOMAIN + + _tenant_cache[tenant] = {"domain": domain, "expires_at": now + _cache_ttl} + return domain +``` + +### Redis Lookup + +Use Redis for shared tenant configuration across multiple app instances: + +```python +import redis.asyncio as redis + +redis_client = redis.Redis(host="localhost", port=6379, decode_responses=True) + +async def domain_resolver(context: DomainResolverContext) -> str: + host = context.request_headers.get("host", "").split(":")[0] + tenant = host.split(".")[0] + + # Key format: "tenant:acme" -> "acme.auth0.com" + domain = await redis_client.get(f"tenant:{tenant}") + if domain: + return domain + + return DEFAULT_DOMAIN +``` + +### Redis with Hash Map + +Store all tenant mappings in a single Redis hash: + +```python +async def domain_resolver(context: DomainResolverContext) -> str: + host = context.request_headers.get("host", "").split(":")[0] + tenant = host.split(".")[0] + + # All tenants in one hash: HGET tenant_domains acme -> "acme.auth0.com" + domain = await redis_client.hget("tenant_domains", tenant) + if domain: + return domain + + return DEFAULT_DOMAIN +``` + +### Path-Based Resolution + +Resolve tenant from URL path instead of hostname: + +```python +from urllib.parse import urlparse + +async def domain_resolver(context: DomainResolverContext) -> str: + if context.request_url: + path = urlparse(context.request_url).path + # URL pattern: /tenant/acme/auth/login + parts = path.strip("/").split("/") + if len(parts) >= 2 and parts[0] == "tenant": + tenant = parts[1] + return DOMAIN_MAP.get(tenant, DEFAULT_DOMAIN) + + return DEFAULT_DOMAIN +``` + +### Custom Header Resolution + +Use a custom header set by your API gateway or load balancer: + +```python +async def domain_resolver(context: DomainResolverContext) -> str: + headers = context.request_headers or {} + + # API gateway sets X-Tenant-Id header + tenant_id = headers.get("x-tenant-id") + if tenant_id: + return DOMAIN_MAP.get(tenant_id, DEFAULT_DOMAIN) + + # Fallback to host header + host = headers.get("host", "").split(":")[0] + return DOMAIN_MAP.get(host, DEFAULT_DOMAIN) +``` + +## Error Handling + +### DomainResolverError + +The domain resolver should return a valid Auth0 domain string. Invalid returns will raise `DomainResolverError`: + +```python +from auth0_server_python.error import DomainResolverError + +async def domain_resolver(context: DomainResolverContext) -> str: + try: + domain = lookup_domain_from_db(context) + + if not domain: + # Return default instead of None + return DEFAULT_DOMAIN + + return domain # Must be a non-empty string + + except Exception as e: + # Log error and return default + logger.error(f"Domain resolution failed: {e}") + return DEFAULT_DOMAIN +``` + +**Invalid return values that raise `DomainResolverError`:** +- `None` +- Empty string `""` +- Non-string types (int, list, dict, etc.) + +**Exceptions raised by your resolver:** +- Automatically wrapped in `DomainResolverError` +- Original exception accessible via `.original_error` + +## Session Behavior in Resolver Mode + +In resolver mode, sessions are bound to the domain that created them. On each request, the SDK compares the session's stored domain against the current resolved domain. If the domain is missing or does not match: + +- `get_user()` and `get_session()` return `None`. +- `get_access_token()` raises `AccessTokenError` (code `MISSING_SESSION_DOMAIN` if the session has no stored domain, `DOMAIN_MISMATCH` if the domains differ). +- `get_access_token_for_connection()` raises `AccessTokenForConnectionError` (same codes as above). +- `start_link_user()` and `start_unlink_user()` raise `StartLinkUserError`. +- Token refresh uses the session's stored domain, not the current request domain. + +> **Warning:** If you switch from a static domain string to a resolver function, existing sessions that do not include a stored domain are treated as **missing sessions**. The SDK cannot verify which domain originally created the session, so users will need to re-authenticate. New sessions store the resolved domain automatically. + +> **Note:** If a login was started before the switch to resolver mode and completes after, the SDK falls back to the current resolved domain for token exchange. The resulting session will store the resolved domain and work normally going forward. + +## Discovery Cache + +The SDK caches OIDC metadata and JWKS per domain in memory (LRU eviction, 600-second TTL, up to 100 domains). This avoids repeated network calls when serving multiple domains. The cache is shared across all requests to the same `ServerClient` instance. + +## Security Best Practices + +### Use an Allowlist in Your Resolver + +The SDK passes request headers to your domain resolver via `DomainResolverContext`. These headers come directly from the HTTP request and can be spoofed by an attacker (e.g., `Host: evil.com` or `X-Forwarded-Host: evil.com`). + +The SDK uses the resolved domain to fetch OIDC metadata and JWKS. If an attacker can influence the resolved domain, they could point the SDK at an OIDC provider they control. + +**Always use a mapping or allowlist — never construct domains from raw header values:** + +```python +# Safe: allowlist lookup — unknown hosts fall back to default +DOMAIN_MAP = { + "acme.myapp.com": "auth.acme.com", + "globex.myapp.com": "auth.globex.com", +} + +async def domain_resolver(context: DomainResolverContext) -> str: + host = context.request_headers.get("host", "").split(":")[0] + return DOMAIN_MAP.get(host, DEFAULT_DOMAIN) +``` + +```python +# Risky: constructs domain from raw input — attacker can influence resolved domain +async def domain_resolver(context: DomainResolverContext) -> str: + host = context.request_headers.get("host", "").split(":")[0] + tenant = host.split(".")[0] + return f"{tenant}.auth0.com" # attacker sends Host: evil.myapp.com → evil.auth0.com +``` + +### Trust Forwarded Headers Only Behind a Proxy + +If your application is directly exposed to the internet (not behind a reverse proxy), do not trust `x-forwarded-host` or `x-forwarded-proto` — any client can set these headers. + +Only use forwarded headers when your application runs behind a trusted reverse proxy (nginx, AWS ALB, Cloudflare, etc.) that sets these headers and strips any client-provided values. + +```python +# Only trust x-forwarded-host if behind a trusted proxy +async def domain_resolver(context: DomainResolverContext) -> str: + headers = context.request_headers or {} + + if BEHIND_TRUSTED_PROXY: + host = headers.get("x-forwarded-host") or headers.get("host", "") + else: + host = headers.get("host", "") + + host = host.split(":")[0] + return DOMAIN_MAP.get(host, DEFAULT_DOMAIN) +``` \ No newline at end of file diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 1e62aa5..e37761a 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -6,7 +6,8 @@ import asyncio import json import time -from typing import Any, Generic, Optional, TypeVar +from collections import OrderedDict +from typing import Any, Callable, Generic, Optional, TypeVar, Union from urllib.parse import parse_qs, urlencode, urlparse, urlunparse import httpx @@ -38,8 +39,10 @@ AccessTokenForConnectionErrorCode, ApiError, BackchannelLogoutError, + ConfigurationError, CustomTokenExchangeError, CustomTokenExchangeErrorCode, + DomainResolverError, InvalidArgumentError, MissingRequiredArgumentError, MissingTransactionError, @@ -47,13 +50,19 @@ StartLinkUserError, ) from auth0_server_python.utils import PKCE, URL, State +from auth0_server_python.utils.helpers import ( + build_domain_resolver_context, + validate_resolved_domain_value, +) from authlib.integrations.base_client.errors import OAuthError from authlib.integrations.httpx_client import AsyncOAuth2Client from pydantic import ValidationError # Generic type for store options TStoreOptions = TypeVar('TStoreOptions') -INTERNAL_AUTHORIZE_PARAMS = ["client_id", "redirect_uri", "response_type", +# redirect_uri is intentionally excluded — in MCD mode it is built +# dynamically from the resolved domain at login time. +INTERNAL_AUTHORIZE_PARAMS = ["client_id", "response_type", "code_challenge", "code_challenge_method", "state", "nonce", "scope"] @@ -70,9 +79,9 @@ class ServerClient(Generic[TStoreOptions]): def __init__( self, - domain: str, - client_id: str, - client_secret: str, + domain: Union[str, Callable[[Optional[dict[str, Any]]], str]] = None, + client_id: str = None, + client_secret: str = None, redirect_uri: Optional[str] = None, secret: str = None, transaction_store=None, @@ -80,13 +89,13 @@ def __init__( transaction_identifier: str = "_a0_tx", state_identifier: str = "_a0_session", authorization_params: Optional[dict[str, Any]] = None, - pushed_authorization_requests: bool = False + pushed_authorization_requests: bool = False, ): """ Initialize the Auth0 server client. Args: - domain: Auth0 domain (e.g., 'your-tenant.auth0.com') + domain: Auth0 domain - either a static string (e.g., 'tenant.auth0.com') or a callable that resolves domain dynamically. client_id: Auth0 client ID client_secret: Auth0 client secret redirect_uri: Default redirect URI for authentication @@ -96,12 +105,35 @@ def __init__( transaction_identifier: Identifier for transaction data state_identifier: Identifier for state data authorization_params: Default parameters for authorization requests + pushed_authorization_requests: Whether to use Pushed Authorization Requests """ if not secret: raise MissingRequiredArgumentError("secret") - # Store configuration - self._domain = domain + if domain is None: + raise ConfigurationError( + "Domain is required" + ) + + # Validate domain type + if not isinstance(domain, str) and not callable(domain): + raise ConfigurationError( + f"Domain must be either a string or a callable function. " + f"Got {type(domain).__name__} instead." + ) + + # Determine if domain is static string or dynamic callable + if callable(domain): + self._domain = None + self._domain_resolver = domain + else: + # Validate static domain string + domain_str = str(domain) + if not domain_str or domain_str.strip() == "": + raise ConfigurationError("Domain cannot be empty.") + self._domain = domain_str + self._domain_resolver = None + self._client_id = client_id self._client_secret = client_secret self._redirect_uri = redirect_uri @@ -122,14 +154,248 @@ def __init__( self._my_account_client = MyAccountClient(domain=domain) + # Unified cache for OIDC metadata and JWKS per domain (LRU eviction + TTL) + self._discovery_cache: OrderedDict[str, dict] = OrderedDict() + self._cache_ttl = 600 # 10 mins. TTL + self._cache_max_entries = 100 # Max 100 domains + + def _normalize_domain(self, domain: str) -> str: + """ + Normalize domain for comparison and URL construction. + Handles cases with/without https:// scheme. + """ + domain = domain.lower() + if domain.startswith('https://'): + return domain + elif domain.startswith('http://'): + return domain.replace('http://', 'https://') + else: + return f'https://{domain}' + + def _normalize_issuer(self, issuer: str) -> str: + """ + Normalize issuer URL for comparison. + + Args: + issuer: The issuer URL to normalize + + Returns: + Normalized issuer URL (lowercase) + """ + if not issuer: + return issuer + + # Lowercase first for case-insensitive comparison and scheme detection + issuer = issuer.lower() + + # Ensure https:// prefix + if issuer.startswith('http://'): + issuer = issuer.replace('http://', 'https://', 1) + elif not issuer.startswith('https://'): + issuer = f'https://{issuer}' + + # Remove trailing slash + return issuer.rstrip('/') + + async def _resolve_current_domain(self, store_options=None) -> str: + """Resolve domain from resolver function or return static domain.""" + if self._domain_resolver: + context = build_domain_resolver_context(store_options) + try: + resolved = await self._domain_resolver(context) + return validate_resolved_domain_value(resolved) + except DomainResolverError: + raise + except Exception as e: + raise DomainResolverError( + f"Domain resolver function raised an exception: {str(e)}", + original_error=e + ) + return self._domain + + async def _verify_and_decode_jwt( + self, token: str, jwks: dict, audience: Optional[str] = None + ) -> dict: + """ + Find signing key in JWKS by kid and decode JWT. + + Verifies signature but disables built-in issuer validation + (callers perform normalized issuer comparison separately). + + Args: + token: The JWT to verify and decode + jwks: The JWKS dict containing signing keys + audience: Optional expected audience claim + + Returns: + Decoded claims dictionary + + Raises: + ValueError: If no matching key found in JWKS for the token's kid + """ + unverified_header = jwt.get_unverified_header(token) + kid = unverified_header.get("kid") + + signing_key = None + for key in jwks.get("keys", []): + if key.get("kid") == kid: + signing_key = jwt.PyJWK.from_dict(key) + break + + if not signing_key: + raise ValueError(f"No matching key found in JWKS for kid: {kid}") + + decode_options = {"verify_signature": True, "verify_iss": False} + kwargs = {"algorithms": ["RS256"], "options": decode_options} + if audience: + kwargs["audience"] = audience + + return jwt.decode(token, signing_key.key, **kwargs) + async def _fetch_oidc_metadata(self, domain: str) -> dict: - """Fetch OpenID Connect discovery metadata from the Auth0 domain.""" - metadata_url = f"https://{domain}/.well-known/openid-configuration" + """Fetch OIDC metadata from domain.""" + normalized_domain = self._normalize_domain(domain) + metadata_url = f"{normalized_domain}/.well-known/openid-configuration" async with httpx.AsyncClient() as client: response = await client.get(metadata_url) response.raise_for_status() return response.json() + def _purge_expired_cache_entries(self): + """Purge all expired entries from the discovery cache.""" + now = time.time() + expired = [k for k, v in self._discovery_cache.items() if v["expires_at"] <= now] + for k in expired: + del self._discovery_cache[k] + + def _ensure_cache_capacity(self): + """Evict least recently used entry if cache is at capacity.""" + if len(self._discovery_cache) >= self._cache_max_entries: + self._discovery_cache.popitem(last=False) + + async def _get_oidc_metadata_cached(self, domain: str) -> dict: + """ + Get OIDC metadata with caching (LRU eviction + TTL). + + Uses a unified cache shared with JWKS when metadata expires, + the corresponding JWKS is also invalidated. + + Args: + domain: Auth0 domain + + Returns: + OIDC metadata document + """ + now = time.time() + + # Check cache + if domain in self._discovery_cache: + cached = self._discovery_cache[domain] + if cached["expires_at"] > now: + self._discovery_cache.move_to_end(domain) + return cached["metadata"] + # Expired — remove entire entry (metadata + jwks) + del self._discovery_cache[domain] + + # Cache miss/expired - fetch fresh + metadata = await self._fetch_oidc_metadata(domain) + + # Purge expired entries and ensure capacity + self._purge_expired_cache_entries() + self._ensure_cache_capacity() + + # Store in cache with jwks=None (lazily populated when needed) + self._discovery_cache[domain] = { + "metadata": metadata, + "jwks": None, + "expires_at": now + self._cache_ttl + } + + return metadata + + async def _fetch_jwks(self, jwks_uri: str) -> dict: + """ + Fetch JWKS (JSON Web Key Set) from jwks_uri. + + Args: + jwks_uri: The JWKS endpoint URL + + Returns: + JWKS document containing public keys + + Raises: + ApiError: If JWKS fetch fails + """ + try: + async with httpx.AsyncClient() as client: + response = await client.get(jwks_uri) + response.raise_for_status() + return response.json() + except Exception as e: + raise ApiError("jwks_fetch_error", f"Failed to fetch JWKS from {jwks_uri}", e) + + async def _get_jwks_cached(self, domain: str, metadata: dict = None) -> dict: + """ + Get JWKS with caching using OIDC discovery (LRU eviction + TTL). + + Uses a unified cache shared with metadata — JWKS is lazily populated + on first access and invalidated when the cache entry expires. + + Args: + domain: Auth0 domain + metadata: Optional OIDC metadata (if already fetched) + + Returns: + JWKS document + + Raises: + ApiError: If JWKS fetch fails or jwks_uri missing from metadata + """ + now = time.time() + + # Check cache — entry exists, not expired, and jwks already fetched + if domain in self._discovery_cache: + cached = self._discovery_cache[domain] + if cached["expires_at"] > now: + if cached["jwks"] is not None: + self._discovery_cache.move_to_end(domain) + return cached["jwks"] + # Entry valid but jwks not yet fetched — use cached metadata + metadata = cached["metadata"] + else: + # Expired — remove entire entry + del self._discovery_cache[domain] + + # Get metadata if not available from cache or parameter + if not metadata: + metadata = await self._get_oidc_metadata_cached(domain) + + jwks_uri = metadata.get('jwks_uri') + if not jwks_uri: + raise ApiError( + "missing_jwks_uri", + f"OIDC metadata for {domain} does not contain jwks_uri. Provider may be non-RFC-compliant." + ) + + # Fetch JWKS + jwks = await self._fetch_jwks(jwks_uri) + + # Update existing cache entry with jwks (entry created by _get_oidc_metadata_cached) + if domain in self._discovery_cache: + self._discovery_cache[domain]["jwks"] = jwks + self._discovery_cache.move_to_end(domain) + else: + # Edge case: entry was evicted between metadata and jwks fetch + self._purge_expired_cache_entries() + self._ensure_cache_capacity() + self._discovery_cache[domain] = { + "metadata": metadata, + "jwks": jwks, + "expires_at": now + self._cache_ttl + } + + return jwks + # ============================================================================ # INTERACTIVE LOGIN FLOW # Handles browser-based authentication using the Authorization Code flow @@ -146,12 +412,24 @@ async def start_interactive_login( Args: options: Configuration options for the login process + store_options: Store options containing request/response Returns: Authorization URL to redirect the user to """ options = options or StartInteractiveLoginOptions() + # Resolve domain (static or dynamic) + origin_domain = await self._resolve_current_domain(store_options) + + # Fetch OIDC metadata from resolved domain + try: + metadata = await self._get_oidc_metadata_cached(origin_domain) + origin_issuer = metadata.get('issuer') + except Exception as e: + raise ApiError("metadata_error", + "Failed to fetch OIDC metadata", e) + # Get effective authorization params (merge defaults with provided ones) auth_params = dict(self._default_authorization_params) if options.authorization_params: @@ -180,17 +458,20 @@ async def start_interactive_login( state = PKCE.generate_random_string(32) auth_params["state"] = state - #merge any requested scope with defaults + # Merge any requested scope with defaults requested_scope = options.authorization_params.get("scope", None) if options.authorization_params else None audience = auth_params.get("audience", None) merged_scope = self._merge_scope_with_defaults(requested_scope, audience) auth_params["scope"] = merged_scope - # Build the transaction data to store + # Build the transaction data to store with origin domain and issuer transaction_data = TransactionData( code_verifier=code_verifier, app_state=options.app_state, audience=audience, + origin_domain=origin_domain, + origin_issuer=origin_issuer, + redirect_uri=auth_params.get("redirect_uri"), ) # Store the transaction data @@ -199,11 +480,9 @@ async def start_interactive_login( transaction_data, options=store_options ) - try: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) - except Exception as e: - raise ApiError("metadata_error", - "Failed to fetch OIDC metadata", e) + + # Set metadata for OAuth client + self._oauth.metadata = metadata # If PAR is enabled, use the PAR endpoint if self._pushed_authorization_requests: par_endpoint = self._oauth.metadata.get( @@ -294,34 +573,84 @@ async def complete_interactive_login( if not code: raise MissingRequiredArgumentError("code") - if not self._oauth.metadata or "token_endpoint" not in self._oauth.metadata: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) + # Get origin domain and issuer from transactiondata, or resolve domain if not present (resolver mode) + origin_domain = transaction_data.origin_domain or await self._resolve_current_domain(store_options) + origin_issuer = transaction_data.origin_issuer + + # Fetch metadata from the origin domain + metadata = await self._get_oidc_metadata_cached(origin_domain) + self._oauth.metadata = metadata # Exchange the code for tokens + # Use redirect_uri from transaction if available, otherwise fall back to default + token_redirect_uri = transaction_data.redirect_uri or self._redirect_uri try: token_endpoint = self._oauth.metadata["token_endpoint"] token_response = await self._oauth.fetch_token( token_endpoint, code=code, code_verifier=transaction_data.code_verifier, - redirect_uri=self._redirect_uri, + redirect_uri=token_redirect_uri, ) except OAuthError as e: # Raise a custom error (or handle it as appropriate) raise ApiError( "token_error", f"Token exchange failed: {str(e)}", e) - # Use the userinfo field from the token_response for user claims + # Use the userinfo field from the token_response for user claims user_info = token_response.get("userinfo") user_claims = None + id_token = token_response.get("id_token") + if user_info: user_claims = UserClaims.parse_obj(user_info) - else: - id_token = token_response.get("id_token") - if id_token: - claims = jwt.decode(id_token, options={ - "verify_signature": False}) + elif id_token: + # Fetch JWKS for signature verification + jwks = await self._get_jwks_cached(origin_domain, metadata) + + # Decode and verify ID token with signature verification enabled + try: + claims = await self._verify_and_decode_jwt( + id_token, jwks, audience=self._client_id + ) + + # Custom normalized issuer validation + token_issuer = claims.get("iss", "") + if self._normalize_issuer(token_issuer) != self._normalize_issuer(origin_issuer): + raise ApiError( + "invalid_issuer", + f"ID token issuer mismatch. Token issuer: {token_issuer}, Expected: {origin_issuer}. " + f"Ensure your Auth0 domain is configured correctly." + ) + user_claims = UserClaims.parse_obj(claims) + except ValueError as e: + raise ApiError("jwks_key_not_found", str(e)) + except jwt.InvalidSignatureError as e: + raise ApiError( + "invalid_signature", + f"ID token signature verification failed. The token may have been tampered with or is from an untrusted source: {str(e)}", + e + ) + except jwt.InvalidAudienceError as e: + raise ApiError( + "invalid_audience", + f"ID token audience mismatch. Expected: {self._client_id}. Ensure your client_id is configured correctly: {str(e)}", + e + ) + except jwt.ExpiredSignatureError as e: + raise ApiError( + "token_expired", + f"ID token has expired: {str(e)}", + e + ) + except jwt.InvalidTokenError as e: + raise ApiError( + "invalid_token", + f"ID token verification failed: {str(e)}", + e + ) + # Build a token set using the token response data token_set = TokenSet( @@ -343,6 +672,7 @@ async def complete_interactive_login( # might be None if not provided refresh_token=token_response.get("refresh_token"), token_sets=[token_set], + domain=origin_domain, internal={ "sid": sid, "created_at": int(time.time()) @@ -384,8 +714,19 @@ async def get_user(self, store_options: Optional[dict[str, Any]] = None) -> Opti state_data = await self._state_store.get(self._state_identifier, store_options) if state_data: + # Domain check should work for both Pydantic models and plain dicts if hasattr(state_data, "dict") and callable(state_data.dict): state_data = state_data.dict() + + # In resolver mode, reject sessions without domain or with mismatched domain + if self._domain_resolver: + session_domain = state_data.get('domain') + if not session_domain: + return None + current_domain = await self._resolve_current_domain(store_options) + if self._normalize_issuer(session_domain) != self._normalize_issuer(current_domain): + return None + return state_data.get("user") return None @@ -402,8 +743,19 @@ async def get_session(self, store_options: Optional[dict[str, Any]] = None) -> O state_data = await self._state_store.get(self._state_identifier, store_options) if state_data: + # Domain check should work for both Pydantic models and plain dicts if hasattr(state_data, "dict") and callable(state_data.dict): state_data = state_data.dict() + + # In resolver mode, reject sessions without domain or with mismatched domain + if self._domain_resolver: + session_domain = state_data.get('domain') + if not session_domain: + return None + current_domain = await self._resolve_current_domain(store_options) + if self._normalize_issuer(session_domain) != self._normalize_issuer(current_domain): + return None + session_data = {k: v for k, v in state_data.items() if k != "internal"} return session_data @@ -414,24 +766,26 @@ async def logout( options: Optional[LogoutOptions] = None, store_options: Optional[dict[str, Any]] = None ) -> str: - """ - Logs the user out and returns the Auth0 logout URL. - - Args: - options: Logout options including return_to URL. - store_options: Optional options used to pass to the State Store. - - Returns: - The Auth0 logout URL to redirect the user to. - """ options = options or LogoutOptions() - # Delete the session from the state store - await self._state_store.delete(self._state_identifier, store_options) - - # Use the URL helper to create the logout URL. + if not self._domain_resolver: + await self._state_store.delete(self._state_identifier, store_options) + domain = self._domain + else: + # Resolver mode: delete session if domains match + domain = await self._resolve_current_domain(store_options) + state_data = await self._state_store.get(self._state_identifier, store_options) + + if state_data: + if hasattr(state_data, "dict") and callable(state_data.dict): + state_data = state_data.dict() + session_domain = state_data.get("domain") + if session_domain and self._normalize_issuer(session_domain) == self._normalize_issuer(domain): + await self._state_store.delete(self._state_identifier, store_options) + + # Return logout URL for the current resolved domain logout_url = URL.create_logout_url( - self._domain, self._client_id, options.return_to) + domain, self._client_id, options.return_to) return logout_url @@ -441,7 +795,7 @@ async def handle_backchannel_logout( store_options: Optional[dict[str, Any]] = None ) -> None: """ - Handles backchannel logout requests (OIDC Back-Channel Logout specification). + Handles backchannel logout requests. Args: logout_token: The logout token sent by Auth0 @@ -451,9 +805,58 @@ async def handle_backchannel_logout( raise BackchannelLogoutError("Missing logout token") try: - # Decode the token without verification - claims = jwt.decode(logout_token, options={ - "verify_signature": False}) + # Determine domain for JWKS fetch and issuer validation + if self._domain_resolver: + try: + unverified = jwt.decode( + logout_token, algorithms=["RS256"], + options={"verify_signature": False} + ) + token_issuer = unverified.get("iss", "") + parsed = urlparse(token_issuer) + domain = parsed.hostname + if not domain: + raise BackchannelLogoutError( + "Cannot determine domain: logout token has no valid issuer" + ) + if domain not in self._discovery_cache: + raise BackchannelLogoutError( + f"Unknown domain in logout token issuer: {domain}. " + f"Only domains from active sessions are accepted." + ) + except BackchannelLogoutError: + raise + except Exception as e: + raise BackchannelLogoutError( + f"Failed to extract domain from logout token: {str(e)}" + ) + else: + domain = self._domain + + # Fetch JWKS and verify logout token + jwks = await self._get_jwks_cached(domain) + + try: + claims = await self._verify_and_decode_jwt(logout_token, jwks, audience=self._client_id) + + # Normalized issuer validation + token_issuer = claims.get("iss", "") + expected_issuer = self._normalize_domain(domain) + if self._normalize_issuer(token_issuer) != self._normalize_issuer(expected_issuer): + raise BackchannelLogoutError( + f"Logout token issuer mismatch. Token issuer: {token_issuer}, Expected: {expected_issuer}. " + f"Ensure your Auth0 domain is configured correctly." + ) + except ValueError as e: + raise BackchannelLogoutError(str(e)) + except jwt.InvalidSignatureError as e: + raise BackchannelLogoutError( + f"Logout token signature verification failed: {str(e)}" + ) + except jwt.InvalidTokenError as e: + raise BackchannelLogoutError( + f"Logout token verification failed: {str(e)}" + ) # Validate the token is a logout token events = claims.get("events", {}) @@ -467,9 +870,14 @@ async def handle_backchannel_logout( sid=claims.get("sid") ) - await self._state_store.delete_by_logout_token(logout_claims.dict(), store_options) + # In resolver mode, include iss for issuer-scoped deletion + claims_dict = logout_claims.dict() + if self._domain_resolver: + claims_dict["iss"] = claims.get("iss") - except (jwt.JoseError, ValidationError) as e: + await self._state_store.delete_by_logout_token(claims_dict, store_options) + + except (jwt.PyJWTError, ValidationError) as e: raise BackchannelLogoutError( f"Error processing logout token: {str(e)}") @@ -500,6 +908,27 @@ async def get_access_token( """ state_data = await self._state_store.get(self._state_identifier, store_options) + # Domain check should work for both Pydantic models and plain dicts + if state_data and hasattr(state_data, "dict") and callable(state_data.dict): + state_data_dict = state_data.dict() + else: + state_data_dict = state_data or {} + + # In resolver mode, reject sessions without domain or with mismatched domain + if state_data and self._domain_resolver: + session_domain = state_data_dict.get('domain') + if not session_domain: + raise AccessTokenError( + AccessTokenErrorCode.MISSING_SESSION_DOMAIN, + "Session is missing domain. User needs to re-authenticate." + ) + current_domain = await self._resolve_current_domain(store_options) + if self._normalize_issuer(session_domain) != self._normalize_issuer(current_domain): + raise AccessTokenError( + AccessTokenErrorCode.DOMAIN_MISMATCH, + "Session domain mismatch. User needs to re-authenticate with the current domain." + ) + auth_params = self._default_authorization_params or {} # Get audience passed in on options or use defaults @@ -508,11 +937,6 @@ async def get_access_token( merged_scope = self._merge_scope_with_defaults(scope, audience) - if state_data and hasattr(state_data, "dict") and callable(state_data.dict): - state_data_dict = state_data.dict() - else: - state_data_dict = state_data or {} - # Find matching token set token_set = None if state_data_dict and "token_sets" in state_data_dict: @@ -531,7 +955,12 @@ async def get_access_token( # Get new token with refresh token try: - get_refresh_token_options = {"refresh_token": state_data_dict["refresh_token"]} + # Use session's domain for token refresh + session_domain = state_data_dict.get("domain") or self._domain + get_refresh_token_options = { + "refresh_token": state_data_dict["refresh_token"], + "domain": session_domain + } if audience: get_refresh_token_options["audience"] = audience @@ -557,6 +986,7 @@ async def get_access_token( f"Failed to get token with refresh token: {str(e)}" ) + async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, Any]: """ Retrieves a token by exchanging a refresh token. @@ -576,11 +1006,13 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, raise MissingRequiredArgumentError("refresh_token") try: - # Ensure we have the OIDC metadata - if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) + # Use session domain if provided, otherwise fallback to static domain + domain = options.get("domain") or self._domain + + # Fetch OIDC metadata from the correct domain + metadata = await self._get_oidc_metadata_cached(domain) - token_endpoint = self._oauth.metadata.get("token_endpoint") + token_endpoint = metadata.get("token_endpoint") if not token_endpoint: raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") @@ -717,7 +1149,7 @@ async def login_backchannel( "binding_message": options.get("binding_message"), "login_hint": options.get("login_hint"), "authorization_params": options.get("authorization_params"), - }) + }, store_options=store_options) existing_state_data = await self._state_store.get(self._state_identifier, store_options) @@ -730,6 +1162,10 @@ async def login_backchannel( token_endpoint_response ) + # Store domain for MCD session + domain = await self._resolve_current_domain(store_options) + state_data["domain"] = domain + await self._state_store.set(self._state_identifier, state_data, store_options) result = { @@ -739,7 +1175,8 @@ async def login_backchannel( async def backchannel_authentication( self, - options: dict[str, Any] + options: dict[str, Any], + store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Performs backchannel authentication with Auth0. @@ -764,7 +1201,7 @@ async def backchannel_authentication( Raises: ApiError: If the backchannel authentication fails """ - backchannel_data = await self.initiate_backchannel_authentication(options) + backchannel_data = await self.initiate_backchannel_authentication(options, store_options=store_options) auth_req_id = backchannel_data.get("auth_req_id") expires_in = backchannel_data.get( "expires_in", 120) # Default to 2 minutes @@ -778,7 +1215,7 @@ async def backchannel_authentication( while time.time() < end_time: # Make token request try: - token_response = await self.backchannel_authentication_grant(auth_req_id) + token_response = await self.backchannel_authentication_grant(auth_req_id, store_options=store_options) return token_response except Exception as e: @@ -803,7 +1240,8 @@ async def backchannel_authentication( async def initiate_backchannel_authentication( self, - options: dict[str, Any] + options: dict[str, Any], + store_options: Optional[dict[str, Any]] = None ) -> dict[str, Any]: """ Start backchannel authentication with Auth0. @@ -856,16 +1294,16 @@ async def initiate_backchannel_authentication( ) try: - # Fetch OpenID Connect metadata if not already fetched - if not hasattr(self, '_oauth_metadata'): - self._oauth_metadata = await self._fetch_oidc_metadata(self._domain) + # Resolve domain + domain = await self._resolve_current_domain(store_options) + metadata = await self._get_oidc_metadata_cached(domain) # Get the issuer from metadata - issuer = self._oauth_metadata.get( - "issuer") or f"https://{self._domain}/" + issuer = metadata.get( + "issuer") or f"https://{domain}/" # Get backchannel authentication endpoint - backchannel_endpoint = self._oauth_metadata.get( + backchannel_endpoint = metadata.get( "backchannel_authentication_endpoint") if not backchannel_endpoint: raise ApiError( @@ -934,12 +1372,17 @@ async def initiate_backchannel_authentication( e ) - async def backchannel_authentication_grant(self, auth_req_id: str) -> dict[str, Any]: + async def backchannel_authentication_grant( + self, + auth_req_id: str, + store_options: Optional[dict[str, Any]] = None + ) -> dict[str, Any]: """ Retrieves a token by exchanging an auth_req_id. Args: auth_req_id (str): The authentication request ID obtained from bc-authorize + store_options: Optional options used to pass to the Transaction and State Store. Raises: AccessTokenError: If there was an issue requesting the access token. @@ -951,11 +1394,11 @@ async def backchannel_authentication_grant(self, auth_req_id: str) -> dict[str, raise MissingRequiredArgumentError("auth_req_id") try: - # Ensure we have the OIDC metadata - if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) + # Resolve domain (supports MCD mode) + domain = await self._resolve_current_domain(store_options) + metadata = await self._get_oidc_metadata_cached(domain) - token_endpoint = self._oauth.metadata.get("token_endpoint") + token_endpoint = metadata.get("token_endpoint") if not token_endpoint: raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") @@ -1039,6 +1482,26 @@ async def start_link_user( "Unable to start the user linking process without a logged in user. Ensure to login using the SDK before starting the user linking process." ) + # Resolve domain for MCD + origin_domain = await self._resolve_current_domain(store_options) + + # In resolver mode, reject sessions without domain or with mismatched domain + if self._domain_resolver: + if hasattr(state_data, "dict") and callable(state_data.dict): + state_data = state_data.dict() + session_domain = state_data.get('domain') + if not session_domain: + raise StartLinkUserError( + "Session is missing domain. User needs to re-authenticate." + ) + if self._normalize_issuer(session_domain) != self._normalize_issuer(origin_domain): + raise StartLinkUserError( + "Session domain mismatch. User needs to re-authenticate with the current domain." + ) + + metadata = await self._get_oidc_metadata_cached(origin_domain) + origin_issuer = metadata.get('issuer') + # Generate PKCE and state for security code_verifier = PKCE.generate_code_verifier() state = PKCE.generate_random_string(32) @@ -1050,13 +1513,16 @@ async def start_link_user( id_token=state_data["id_token"], code_verifier=code_verifier, state=state, - authorization_params=options.get("authorization_params") + authorization_params=options.get("authorization_params"), + domain=origin_domain ) # Store transaction data transaction_data = TransactionData( code_verifier=code_verifier, - app_state=options.get("app_state") + app_state=options.get("app_state"), + origin_domain=origin_domain, + origin_issuer=origin_issuer, ) await self._transaction_store.set( @@ -1113,6 +1579,26 @@ async def start_unlink_user( "Unable to start the user linking process without a logged in user. Ensure to login using the SDK before starting the user linking process." ) + # Resolve domain for MCD + origin_domain = await self._resolve_current_domain(store_options) + + # In resolver mode, reject sessions without domain or with mismatched domain + if self._domain_resolver: + if hasattr(state_data, "dict") and callable(state_data.dict): + state_data = state_data.dict() + session_domain = state_data.get('domain') + if not session_domain: + raise StartLinkUserError( + "Session is missing domain. User needs to re-authenticate." + ) + if self._normalize_issuer(session_domain) != self._normalize_issuer(origin_domain): + raise StartLinkUserError( + "Session domain mismatch. User needs to re-authenticate with the current domain." + ) + + metadata = await self._get_oidc_metadata_cached(origin_domain) + origin_issuer = metadata.get('issuer') + # Generate PKCE and state for security code_verifier = PKCE.generate_code_verifier() state = PKCE.generate_random_string(32) @@ -1123,13 +1609,16 @@ async def start_unlink_user( id_token=state_data["id_token"], code_verifier=code_verifier, state=state, - authorization_params=options.get("authorization_params") + authorization_params=options.get("authorization_params"), + domain=origin_domain ) # Store transaction data transaction_data = TransactionData( code_verifier=code_verifier, - app_state=options.get("app_state") + app_state=options.get("app_state"), + origin_domain=origin_domain, + origin_issuer=origin_issuer, ) await self._transaction_store.set( @@ -1171,19 +1660,20 @@ async def _build_link_user_url( code_verifier: str, state: str, connection_scope: Optional[str] = None, - authorization_params: Optional[dict[str, Any]] = None + authorization_params: Optional[dict[str, Any]] = None, + domain: Optional[str] = None ) -> str: - """Helper: Builds the authorization URL for linking user accounts.""" + """Build a URL for linking user accounts""" # Generate code challenge from verifier code_challenge = PKCE.generate_code_challenge(code_verifier) - # Get metadata if not already fetched - if not hasattr(self, '_oauth_metadata'): - self._oauth_metadata = await self._fetch_oidc_metadata(self._domain) + # Use provided domain or fall back to static domain + resolved_domain = domain or self._domain + metadata = await self._get_oidc_metadata_cached(resolved_domain) # Get authorization endpoint - auth_endpoint = self._oauth_metadata.get("authorization_endpoint", - f"https://{self._domain}/authorize") + auth_endpoint = metadata.get("authorization_endpoint", + f"https://{resolved_domain}/authorize") # Build params params = { @@ -1214,19 +1704,20 @@ async def _build_unlink_user_url( id_token: str, code_verifier: str, state: str, - authorization_params: Optional[dict[str, Any]] = None + authorization_params: Optional[dict[str, Any]] = None, + domain: Optional[str] = None ) -> str: - """Helper: Builds the authorization URL for unlinking user accounts.""" + """Build a URL for unlinking user accounts""" # Generate code challenge from verifier code_challenge = PKCE.generate_code_challenge(code_verifier) - # Get metadata if not already fetched - if not hasattr(self, '_oauth_metadata'): - self._oauth_metadata = await self._fetch_oidc_metadata(self._domain) + # Use provided domain or fall back to static domain + resolved_domain = domain or self._domain + metadata = await self._get_oidc_metadata_cached(resolved_domain) # Get authorization endpoint - auth_endpoint = self._oauth_metadata.get("authorization_endpoint", - f"https://{self._domain}/authorize") + auth_endpoint = metadata.get("authorization_endpoint", + f"https://{resolved_domain}/authorize") # Build params params = { @@ -1283,6 +1774,21 @@ async def get_access_token_for_connection( else: state_data_dict = state_data or {} + # In resolver mode, reject sessions without domain or with mismatched domain + if self._domain_resolver: + session_domain = state_data_dict.get("domain") + if not session_domain: + raise AccessTokenForConnectionError( + AccessTokenForConnectionErrorCode.MISSING_SESSION_DOMAIN, + "Session is missing domain. User needs to re-authenticate." + ) + current_domain = await self._resolve_current_domain(store_options) + if self._normalize_issuer(session_domain) != self._normalize_issuer(current_domain): + raise AccessTokenForConnectionError( + AccessTokenForConnectionErrorCode.DOMAIN_MISMATCH, + "Session domain mismatch. User needs to re-authenticate with the current domain." + ) + # Find existing connection token connection_token_set = None if state_data_dict and len(state_data_dict["connection_token_sets"]) > 0: @@ -1302,10 +1808,13 @@ async def get_access_token_for_connection( "A refresh token was not found but is required to be able to retrieve an access token for a connection." ) # Get new token for connection + # Use session's domain for token exchange + session_domain = state_data_dict.get("domain") or self._domain token_endpoint_response = await self.get_token_for_connection({ "connection": options.get("connection"), "login_hint": options.get("login_hint"), - "refresh_token": state_data_dict["refresh_token"] + "refresh_token": state_data_dict["refresh_token"], + "domain": session_domain }) # Update state data with new token @@ -1337,11 +1846,13 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A REQUESTED_TOKEN_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = "http://auth0.com/oauth/token-type/federated-connection-access-token" GRANT_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = "urn:auth0:params:oauth:grant-type:token-exchange:federated-connection-access-token" try: - # Ensure we have OIDC metadata - if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) + # Use session domain if provided, otherwise fallback to static domain + domain = options.get("domain") or self._domain + + # Fetch OIDC metadata from the correct domain + metadata = await self._get_oidc_metadata_cached(domain) - token_endpoint = self._oauth.metadata.get("token_endpoint") + token_endpoint = metadata.get("token_endpoint") if not token_endpoint: raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") @@ -1635,7 +2146,8 @@ async def list_connected_account_connections( async def custom_token_exchange( self, - options: CustomTokenExchangeOptions + options: CustomTokenExchangeOptions, + store_options: Optional[dict[str, Any]] = None ) -> TokenExchangeResponse: """ Exchanges a custom token for Auth0 tokens using RFC 8693. @@ -1645,6 +2157,7 @@ async def custom_token_exchange( Args: options: Configuration for the token exchange + store_options: Optional options used to pass to the Transaction and State Store. Returns: TokenExchangeResponse containing access_token and metadata @@ -1674,11 +2187,11 @@ async def custom_token_exchange( if not isinstance(options, CustomTokenExchangeOptions): options = CustomTokenExchangeOptions(**options) - # Ensure we have OIDC metadata - if not hasattr(self._oauth, "metadata") or not self._oauth.metadata: - self._oauth.metadata = await self._fetch_oidc_metadata(self._domain) + # Resolve domain + domain = await self._resolve_current_domain(store_options) + metadata = await self._get_oidc_metadata_cached(domain) - token_endpoint = self._oauth.metadata.get("token_endpoint") + token_endpoint = metadata.get("token_endpoint") if not token_endpoint: raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata") @@ -1804,7 +2317,7 @@ async def login_with_custom_token_exchange( authorization_params=options.authorization_params ) - token_response = await self.custom_token_exchange(exchange_options) + token_response = await self.custom_token_exchange(exchange_options, store_options=store_options) # Extract user claims from ID token if present user_claims = None @@ -1826,12 +2339,16 @@ async def login_with_custom_token_exchange( expires_at=int(time.time()) + token_response.expires_in ) + # Resolve domain for session storage + domain = await self._resolve_current_domain(store_options) + # Construct state data state_data = StateData( user=user_claims, id_token=token_response.id_token, refresh_token=token_response.refresh_token, token_sets=[token_set], + domain=domain, internal={ "sid": sid, "created_at": int(time.time()) diff --git a/src/auth0_server_python/auth_types/__init__.py b/src/auth0_server_python/auth_types/__init__.py index 4b36ca3..bbd77d4 100644 --- a/src/auth0_server_python/auth_types/__init__.py +++ b/src/auth0_server_python/auth_types/__init__.py @@ -66,6 +66,7 @@ class SessionData(BaseModel): refresh_token: Optional[str] = None token_sets: list[TokenSet] = Field(default_factory=list) connection_token_sets: list[ConnectionTokenSet] = Field(default_factory=list) + domain: Optional[str] = None class Config: extra = "allow" # Allow additional fields not defined in the model @@ -89,6 +90,8 @@ class TransactionData(BaseModel): app_state: Optional[Any] = None auth_session: Optional[str] = None redirect_uri: Optional[str] = None + origin_domain: Optional[str] = None + origin_issuer: Optional[str] = None class Config: extra = "allow" # Allow additional fields not defined in the model @@ -213,6 +216,29 @@ class StartLinkUserOptions(BaseModel): authorization_params: Optional[dict[str, Any]] = None app_state: Optional[Any] = None +# ============================================================================= +# Multiple Custom Domain +# ============================================================================= + +class DomainResolverContext(BaseModel): + """ + Context passed to domain resolver function for MCD support. + + Contains request information needed to determine the correct Auth0 domain + based on the incoming request's hostname or headers. + + Attributes: + request_url: The full request URL (e.g., "https://a.my-app.com/auth/login") + request_headers: Dictionary of request headers (e.g., {"host": "a.my-app.com", "x-forwarded-host": "..."}) + + Example: + async def domain_resolver(context: DomainResolverContext) -> str: + host = context.request_headers.get('host', '').split(':')[0] + return DOMAIN_MAP.get(host, DEFAULT_DOMAIN) + """ + request_url: Optional[str] = None + request_headers: Optional[dict[str, str]] = None + # ============================================================================= # Custom Token Exchange Types # ============================================================================= diff --git a/src/auth0_server_python/error/__init__.py b/src/auth0_server_python/error/__init__.py index c593368..d3ff871 100644 --- a/src/auth0_server_python/error/__init__.py +++ b/src/auth0_server_python/error/__init__.py @@ -101,6 +101,17 @@ def __init__(self, argument: str): self.argument = argument +class ConfigurationError(Auth0Error): + """ + Error raised when SDK configuration is invalid. + This includes invalid combinations of parameters or incorrect configuration values. + """ + code = "configuration_error" + + def __init__(self, message: str): + super().__init__(message) + self.name = "ConfigurationError" + class InvalidArgumentError(Auth0Error): """ Error raised when a given argument is an invalid value. @@ -125,6 +136,21 @@ def __init__(self, message: str): self.name = "BackchannelLogoutError" +class DomainResolverError(Auth0Error): + """ + Error raised when domain resolver function fails or returns invalid value. + + This error indicates an issue with the custom domain resolver function + provided for MCD (Multiple Custom Domains) support. + """ + code = "domain_resolver_error" + + def __init__(self, message: str, original_error: Exception = None): + super().__init__(message) + self.name = "DomainResolverError" + self.original_error = original_error + + class AccessTokenForConnectionError(Auth0Error): """Error when retrieving access tokens for a specific connection fails.""" @@ -157,6 +183,8 @@ class AccessTokenErrorCode: REFRESH_TOKEN_ERROR = "refresh_token_error" AUTH_REQ_ID_ERROR = "auth_req_id_error" INCORRECT_AUDIENCE = "incorrect_audience" + MISSING_SESSION_DOMAIN = "missing_session_domain" + DOMAIN_MISMATCH = "domain_mismatch" class AccessTokenForConnectionErrorCode: @@ -165,6 +193,8 @@ class AccessTokenForConnectionErrorCode: FAILED_TO_RETRIEVE = "failed_to_retrieve" API_ERROR = "api_error" FETCH_ERROR = "retrieval_error" + MISSING_SESSION_DOMAIN = "missing_session_domain" + DOMAIN_MISMATCH = "domain_mismatch" class CustomTokenExchangeError(Auth0Error): diff --git a/src/auth0_server_python/store/abstract.py b/src/auth0_server_python/store/abstract.py index 45e433f..178757e 100644 --- a/src/auth0_server_python/store/abstract.py +++ b/src/auth0_server_python/store/abstract.py @@ -96,7 +96,8 @@ async def delete_by_logout_token(self, claims: dict[str, Any], options: Optional Delete sessions based on logout token claims. Args: - claims: Claims from the logout token + claims: Claims from the logout token (sub, sid, and optionally iss + in MCD mode for issuer-scoped deletion) options: Additional operation-specific options Note: diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index 4f1b90b..6c9b250 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -1,6 +1,6 @@ import json import time -from unittest.mock import ANY, AsyncMock, MagicMock +from unittest.mock import ANY, AsyncMock, MagicMock, patch from urllib.parse import parse_qs, urlparse import pytest @@ -15,18 +15,25 @@ ConnectedAccountConnection, ConnectParams, CustomTokenExchangeOptions, + DomainResolverContext, ListConnectedAccountConnectionsResponse, ListConnectedAccountsResponse, LoginWithCustomTokenExchangeOptions, LogoutOptions, + StateData, TransactionData, ) from auth0_server_python.error import ( + AccessTokenError, + AccessTokenErrorCode, AccessTokenForConnectionError, + AccessTokenForConnectionErrorCode, ApiError, BackchannelLogoutError, + ConfigurationError, CustomTokenExchangeError, CustomTokenExchangeErrorCode, + DomainResolverError, InvalidArgumentError, MissingRequiredArgumentError, MissingTransactionError, @@ -52,7 +59,7 @@ async def test_init_no_secret_raises(): @pytest.mark.asyncio -async def test_start_interactive_login_no_redirect_uri(): +async def test_start_interactive_login_no_redirect_uri(mocker): client = ServerClient( domain="auth0.local", client_id="", @@ -61,6 +68,14 @@ async def test_start_interactive_login_no_redirect_uri(): transaction_store=AsyncMock(), secret="some-secret" ) + + # Mock OIDC metadata fetch + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://auth0.local/", "authorization_endpoint": "https://auth0.local/authorize"} + ) + with pytest.raises(MissingRequiredArgumentError) as exc: await client.start_interactive_login() # Check the error message @@ -84,7 +99,7 @@ async def test_start_interactive_login_builds_auth_url(mocker): # Mock out HTTP calls or the internal methods that create the auth URL mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={"authorization_endpoint": "https://auth0.local/authorize"} ) mock_oauth = mocker.patch.object( @@ -125,8 +140,13 @@ async def test_complete_interactive_login_no_transaction(): @pytest.mark.asyncio async def test_complete_interactive_login_returns_app_state(mocker): mock_tx_store = AsyncMock() - # The stored transaction includes an appState - mock_tx_store.get.return_value = TransactionData(code_verifier="123", app_state={"foo": "bar"}) + # The stored transaction includes an appState with origin_domain and origin_issuer + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + app_state={"foo": "bar"}, + origin_domain="auth0.local", + origin_issuer="https://auth0.local/" + ) mock_state_store = AsyncMock() @@ -139,6 +159,13 @@ async def test_complete_interactive_login_returns_app_state(mocker): secret="some-secret", ) + # Mock OIDC metadata fetch + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://auth0.local/", "token_endpoint": "https://auth0.local/token"} + ) + # Patch token exchange mocker.patch.object(client._oauth, "metadata", {"token_endpoint": "https://auth0.local/token"}) @@ -214,7 +241,7 @@ async def test_complete_link_user_returns_app_state(mocker): ) # Patch token exchange - mocker.patch.object(client, "_fetch_oidc_metadata", return_value={"token_endpoint": "https://auth0.local/token"}) + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={"token_endpoint": "https://auth0.local/token"}) async_fetch_token = AsyncMock() async_fetch_token.return_value = { "access_token": "token123", @@ -226,6 +253,100 @@ async def test_complete_link_user_returns_app_state(mocker): mock_tx_store.delete.assert_awaited_once() +@pytest.mark.asyncio +async def test_start_link_user_stores_origin_domain_in_mcd(mocker): + """Test that start_link_user stores origin_domain in transaction in MCD mode.""" + async def domain_resolver(context): + return "tenant1.auth0.com" + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = { + "id_token": "existing_id_token", + "user": {"sub": "user123"}, + "domain": "tenant1.auth0.com" + } + + captured_transaction = None + async def capture_tx(identifier, transaction_data, options=None): + nonlocal captured_transaction + captured_transaction = transaction_data + + mock_tx_store = AsyncMock() + mock_tx_store.set = AsyncMock(side_effect=capture_tx) + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={ + "issuer": "https://tenant1.auth0.com/", + "authorization_endpoint": "https://tenant1.auth0.com/authorize", + "token_endpoint": "https://tenant1.auth0.com/oauth/token" + }) + mocker.patch.object(client, "_build_link_user_url", return_value="https://tenant1.auth0.com/authorize?...") + + await client.start_link_user( + options={"connection": "google-oauth2"}, + store_options={"request": {}} + ) + + assert captured_transaction is not None + assert captured_transaction.origin_domain == "tenant1.auth0.com" + assert captured_transaction.origin_issuer == "https://tenant1.auth0.com/" + + +@pytest.mark.asyncio +async def test_start_unlink_user_stores_origin_domain_in_mcd(mocker): + """Test that start_unlink_user stores origin_domain in transaction in MCD mode.""" + async def domain_resolver(context): + return "tenant1.auth0.com" + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = { + "id_token": "existing_id_token", + "user": {"sub": "user123"}, + "domain": "tenant1.auth0.com" + } + + captured_transaction = None + async def capture_tx(identifier, transaction_data, options=None): + nonlocal captured_transaction + captured_transaction = transaction_data + + mock_tx_store = AsyncMock() + mock_tx_store.set = AsyncMock(side_effect=capture_tx) + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={ + "issuer": "https://tenant1.auth0.com/", + "authorization_endpoint": "https://tenant1.auth0.com/authorize", + "token_endpoint": "https://tenant1.auth0.com/oauth/token" + }) + mocker.patch.object(client, "_build_unlink_user_url", return_value="https://tenant1.auth0.com/authorize?...") + + await client.start_unlink_user( + options={"connection": "google-oauth2"}, + store_options={"request": {}} + ) + + assert captured_transaction is not None + assert captured_transaction.origin_domain == "tenant1.auth0.com" + assert captured_transaction.origin_issuer == "https://tenant1.auth0.com/" + + @pytest.mark.asyncio async def test_login_backchannel_stores_access_token(mocker): mock_transaction_store = AsyncMock() @@ -351,6 +472,178 @@ async def test_get_session_none(): session_data = await client.get_session() assert session_data is None + +@pytest.mark.asyncio +async def test_get_session_domain_mismatch_returns_none(): + """Test that get_session returns None on domain mismatch in MCD mode.""" + session_data = StateData( + user={"sub": "user123"}, + domain="tenant1.auth0.com", + token_sets=[], + internal={"sid": "123", "created_at": int(time.time())} + ) + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant2.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + session = await client.get_session(store_options={"request": {}}) + assert session is None + + +@pytest.mark.asyncio +async def test_get_session_domain_mismatch_with_dict_state(): + """Test domain mismatch works when state store returns plain dict (stateless cookie store).""" + session_data = { + "user": {"sub": "user123"}, + "domain": "tenant1.auth0.com", + "token_sets": [], + "internal": {"sid": "123", "created_at": int(time.time())} + } + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant2.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + session = await client.get_session(store_options={"request": {}}) + assert session is None + + +@pytest.mark.asyncio +async def test_get_user_domain_mismatch_with_dict_state(): + """Test domain mismatch works when state store returns plain dict (stateless cookie store).""" + session_data = { + "user": {"sub": "user123", "name": "Test User"}, + "domain": "tenant1.auth0.com", + "token_sets": [], + "internal": {"sid": "123", "created_at": int(time.time())} + } + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant2.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + user = await client.get_user(store_options={"request": {}}) + assert user is None + + +@pytest.mark.asyncio +async def test_get_user_legacy_session_rejected_in_resolver_mode(): + """Test that sessions without domain field are rejected in resolver mode.""" + session_data = { + "user": {"sub": "user123", "name": "Test User"}, + # No "domain" field — legacy session created before MCD was enabled + } + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant1.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + user = await client.get_user(store_options={"request": {}}) + assert user is None + + +@pytest.mark.asyncio +async def test_get_session_legacy_session_rejected_in_resolver_mode(): + """Test that sessions without domain field are rejected in resolver mode.""" + session_data = { + "user": {"sub": "user123"}, + "token_sets": [], + "internal": {"sid": "123", "created_at": int(time.time())} + # No "domain" field — legacy session + } + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant1.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + session = await client.get_session(store_options={"request": {}}) + assert session is None + + +@pytest.mark.asyncio +async def test_get_user_domain_normalization(): + """Test that domain comparison is case-insensitive and normalizes schemes.""" + session_data = { + "user": {"sub": "user123", "name": "Test User"}, + "domain": "Tenant1.Auth0.Com" # Mixed case + } + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant1.auth0.com" # Lowercase + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + user = await client.get_user(store_options={"request": {}}) + assert user is not None + assert user["sub"] == "user123" + + @pytest.mark.asyncio async def test_get_access_token_from_store(): mock_state_store = AsyncMock() @@ -410,7 +703,8 @@ async def test_get_access_token_refresh_expired(mocker): assert token == "new_token" mock_state_store.set.assert_awaited_once() get_refresh_token_mock.assert_awaited_with({ - "refresh_token": "refresh_xyz" + "refresh_token": "refresh_xyz", + "domain": "auth0.local" }) @pytest.mark.asyncio @@ -451,6 +745,7 @@ async def test_get_access_token_refresh_merging_default_scope(mocker): mock_state_store.set.assert_awaited_once() get_refresh_token_mock.assert_awaited_with({ "refresh_token": "refresh_xyz", + "domain": "auth0.local", "audience": "default", "scope": "openid profile email foo:bar" }) @@ -492,6 +787,7 @@ async def test_get_access_token_refresh_with_auth_params_scope(mocker): mock_state_store.set.assert_awaited_once() get_refresh_token_mock.assert_awaited_with({ "refresh_token": "refresh_xyz", + "domain": "auth0.local", "scope": "openid profile email" }) @@ -532,6 +828,7 @@ async def test_get_access_token_refresh_with_auth_params_audience(mocker): mock_state_store.set.assert_awaited_once() get_refresh_token_mock.assert_awaited_with({ "refresh_token": "refresh_xyz", + "domain": "auth0.local", "audience": "my_audience" }) @@ -578,6 +875,7 @@ async def test_get_access_token_mrrt(mocker): assert len(stored_state["token_sets"]) == 2 get_refresh_token_mock.assert_awaited_with({ "refresh_token": "refresh_xyz", + "domain": "auth0.local", "audience": "some_audience", "scope": "foo:bar", }) @@ -631,6 +929,7 @@ async def test_get_access_token_mrrt_with_auth_params_scope(mocker): assert len(stored_state["token_sets"]) == 2 get_refresh_token_mock.assert_awaited_with({ "refresh_token": "refresh_xyz", + "domain": "auth0.local", "audience": "some_audience", "scope": "foo:bar", }) @@ -754,6 +1053,75 @@ async def test_get_access_token_from_store_returns_minimum_matching_scopes(mocke assert token == "minimum_scope_token" get_refresh_token_mock.assert_not_awaited() + +@pytest.mark.asyncio +async def test_get_access_token_domain_mismatch_raises_error(): + """Test that get_access_token raises AccessTokenError on domain mismatch.""" + session_data = StateData( + user={"sub": "user123"}, + domain="tenant1.auth0.com", + token_sets=[{ + "audience": "default", + "access_token": "token123", + "expires_at": int(time.time()) + 500 + }], + internal={"sid": "123", "created_at": int(time.time())} + ) + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant2.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + with pytest.raises(AccessTokenError) as exc: + await client.get_access_token(store_options={"request": {}}) + assert exc.value.code == AccessTokenErrorCode.DOMAIN_MISMATCH + + +@pytest.mark.asyncio +async def test_get_access_token_domain_mismatch_with_dict_state(): + """Test domain mismatch works when state store returns plain dict (stateless cookie store).""" + session_data = { + "user": {"sub": "user123"}, + "domain": "tenant1.auth0.com", + "token_sets": [{ + "audience": "default", + "access_token": "token123", + "expires_at": int(time.time()) + 500 + }], + "internal": {"sid": "123", "created_at": int(time.time())} + } + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant2.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + with pytest.raises(AccessTokenError) as exc: + await client.get_access_token(store_options={"request": {}}) + assert exc.value.code == AccessTokenErrorCode.DOMAIN_MISMATCH + + @pytest.mark.asyncio async def test_get_access_token_for_connection_cached(): mock_state_store = AsyncMock() @@ -797,80 +1165,398 @@ async def test_get_access_token_for_connection_no_refresh(): await client.get_access_token_for_connection({"connection": "my_connection"}) assert "A refresh token was not found" in str(exc.value) + @pytest.mark.asyncio -async def test_logout(): +async def test_get_access_token_for_connection_domain_mismatch(): + """Test that get_access_token_for_connection raises error on domain mismatch.""" + session_data = StateData( + user={"sub": "user123"}, + domain="tenant1.auth0.com", + token_sets=[], + connection_token_sets=[{ + "connection": "my_connection", + "audience": "default", + "access_token": "conn_token", + "login_hint": "hint", + "expires_at": int(time.time()) + 500 + }], + internal={"sid": "123", "created_at": int(time.time())} + ) + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant2.auth0.com" client = ServerClient( - domain="auth0.local", - client_id="client_id", - client_secret="client_secret", + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), state_store=mock_state_store, - secret="some-secret" + secret="test_secret_key_32_chars_long!!", ) - url = await client.logout(LogoutOptions(return_to="/after_logout")) - mock_state_store.delete.assert_awaited_once() - # Check returned URL - assert "auth0.local/v2/logout" in url - assert "client_id=" in url - assert "returnTo=%2Fafter_logout" in url + with pytest.raises(AccessTokenForConnectionError) as exc: + await client.get_access_token_for_connection( + {"connection": "my_connection"}, + store_options={"request": {}} + ) + assert exc.value.code == AccessTokenForConnectionErrorCode.DOMAIN_MISMATCH + @pytest.mark.asyncio -async def test_logout_no_session(): +async def test_get_access_token_legacy_session_rejected_in_resolver_mode(): + """Test that sessions without domain field raise MISSING_SESSION_DOMAIN in resolver mode.""" + session_data = { + "user": {"sub": "user123"}, + "token_sets": [{ + "audience": "default", + "access_token": "token123", + "expires_at": int(time.time()) + 500 + }], + "internal": {"sid": "123", "created_at": int(time.time())} + # No "domain" field — legacy session + } + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant1.auth0.com" client = ServerClient( - domain="auth0.local", - client_id="client_id", - client_secret="client_secret", + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), state_store=mock_state_store, - secret="some-secret" + secret="test_secret_key_32_chars_long!!", ) - mock_state_store.delete.side_effect = None # Even if it's empty - url = await client.logout(LogoutOptions(return_to= "/bye")) + with pytest.raises(AccessTokenError) as exc: + await client.get_access_token(store_options={"request": {}}) + assert exc.value.code == AccessTokenErrorCode.MISSING_SESSION_DOMAIN - mock_state_store.delete.assert_awaited_once() # No error if already empty - assert "logout" in url @pytest.mark.asyncio -async def test_handle_backchannel_logout_no_token(): +async def test_get_access_token_for_connection_legacy_session_rejected(): + """Test that sessions without domain field raise MISSING_SESSION_DOMAIN in resolver mode.""" + session_data = { + "user": {"sub": "user123"}, + "connection_token_sets": [{ + "connection": "my_connection", + "access_token": "conn_token", + "expires_at": int(time.time()) + 500 + }], + "internal": {"sid": "123", "created_at": int(time.time())} + # No "domain" field — legacy session + } + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant1.auth0.com" + client = ServerClient( - domain="auth0.local", - client_id="client_id", - client_secret="client_secret", - secret="some-secret" + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", ) - with pytest.raises(BackchannelLogoutError) as exc: - await client.handle_backchannel_logout("") - assert "Missing logout token" in str(exc.value) + with pytest.raises(AccessTokenForConnectionError) as exc: + await client.get_access_token_for_connection( + {"connection": "my_connection"}, + store_options={"request": {}} + ) + assert exc.value.code == AccessTokenForConnectionErrorCode.MISSING_SESSION_DOMAIN + @pytest.mark.asyncio -async def test_handle_backchannel_logout_ok(mocker): +async def test_get_access_token_for_connection_domain_mismatch_with_dict_state(): + """Test domain mismatch works when state store returns plain dict (stateless cookie store).""" + session_data = { + "user": {"sub": "user123"}, + "domain": "tenant1.auth0.com", + "connection_token_sets": [{ + "connection": "my_connection", + "access_token": "conn_token", + "expires_at": int(time.time()) + 500 + }], + "internal": {"sid": "123", "created_at": int(time.time())} + } + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + async def domain_resolver(context): + return "tenant2.auth0.com" + client = ServerClient( - domain="auth0.local", - client_id="client_id", - client_secret="client_secret", + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), state_store=mock_state_store, - secret="some-secret" + secret="test_secret_key_32_chars_long!!", ) - mocker.patch("jwt.decode", return_value={ - "events": {"http://schemas.openid.net/event/backchannel-logout": {}}, - "sub": "user_sub", - "sid": "session_id_123" - }) + with pytest.raises(AccessTokenForConnectionError) as exc: + await client.get_access_token_for_connection( + {"connection": "my_connection"}, + store_options={"request": {}} + ) + assert exc.value.code == AccessTokenForConnectionErrorCode.DOMAIN_MISMATCH - await client.handle_backchannel_logout("some_logout_token") - mock_state_store.delete_by_logout_token.assert_awaited_once_with( - {"sub": "user_sub", "sid": "session_id_123"}, - None + +@pytest.mark.asyncio +async def test_start_link_user_rejects_legacy_session_in_resolver_mode(mocker): + """Test that start_link_user rejects sessions without domain in resolver mode.""" + async def domain_resolver(context): + return "tenant1.auth0.com" + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = { + "id_token": "existing_id_token", + "user": {"sub": "user123"} + # No "domain" field — legacy session + } + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", ) -# Test For AuthLib Helpers + with pytest.raises(StartLinkUserError): + await client.start_link_user( + options={"connection": "google-oauth2"}, + store_options={"request": {}} + ) + + +@pytest.mark.asyncio +async def test_start_unlink_user_rejects_legacy_session_in_resolver_mode(mocker): + """Test that start_unlink_user rejects sessions without domain in resolver mode.""" + async def domain_resolver(context): + return "tenant1.auth0.com" + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = { + "id_token": "existing_id_token", + "user": {"sub": "user123"} + # No "domain" field — legacy session + } + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + with pytest.raises(StartLinkUserError): + await client.start_unlink_user( + options={"connection": "google-oauth2"}, + store_options={"request": {}} + ) + + +@pytest.mark.asyncio +async def test_logout(): + mock_state_store = AsyncMock() + + client = ServerClient( + domain="auth0.local", + client_id="client_id", + client_secret="client_secret", + state_store=mock_state_store, + secret="some-secret" + ) + url = await client.logout(LogoutOptions(return_to="/after_logout")) + + mock_state_store.delete.assert_awaited_once() + # Check returned URL + assert "auth0.local/v2/logout" in url + assert "client_id=" in url + assert "returnTo=%2Fafter_logout" in url + +@pytest.mark.asyncio +async def test_logout_no_session(): + mock_state_store = AsyncMock() + + client = ServerClient( + domain="auth0.local", + client_id="client_id", + client_secret="client_secret", + state_store=mock_state_store, + secret="some-secret" + ) + mock_state_store.delete.side_effect = None # Even if it's empty + + url = await client.logout(LogoutOptions(return_to= "/bye")) + + mock_state_store.delete.assert_awaited_once() # No error if already empty + assert "logout" in url + +@pytest.mark.asyncio +async def test_handle_backchannel_logout_no_token(): + client = ServerClient( + domain="auth0.local", + client_id="client_id", + client_secret="client_secret", + secret="some-secret" + ) + + with pytest.raises(BackchannelLogoutError) as exc: + await client.handle_backchannel_logout("") + assert "Missing logout token" in str(exc.value) + +@pytest.mark.asyncio +async def test_handle_backchannel_logout_ok(mocker): + mock_state_store = AsyncMock() + client = ServerClient( + domain="auth0.local", + client_id="client_id", + client_secret="client_secret", + state_store=mock_state_store, + secret="some-secret" + ) + + # Mock JWKS fetch to prevent network call + mocker.patch.object( + client, + "_get_jwks_cached", + return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]} + ) + + # Mock JWT verification + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + mock_jwt_decode = mocker.patch("jwt.decode", return_value={ + "events": {"http://schemas.openid.net/event/backchannel-logout": {}}, + "iss": "https://auth0.local", + "sub": "user_sub", + "sid": "session_id_123" + }) + + await client.handle_backchannel_logout("some_logout_token") + + # Verify audience is passed to jwt.decode + call_kwargs = mock_jwt_decode.call_args[1] + assert call_kwargs["audience"] == "client_id" + + # In static mode, iss should NOT be included + mock_state_store.delete_by_logout_token.assert_awaited_once_with( + {"sub": "user_sub", "sid": "session_id_123"}, + None + ) + + +@pytest.mark.asyncio +async def test_backchannel_logout_mcd_known_domain(mocker): + """Test that backchannel logout works in MCD mode when domain is in cache.""" + async def domain_resolver(context): + return "tenant1.auth0.com" + + mock_state_store = AsyncMock() + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Pre-populate the discovery cache (simulates a prior login) + client._discovery_cache["tenant1.auth0.com"] = { + "metadata": { + "issuer": "https://tenant1.auth0.com/", + "jwks_uri": "https://tenant1.auth0.com/.well-known/jwks.json" + }, + "jwks": {"keys": [{"kty": "RSA", "kid": "test-key"}]}, + "expires_at": time.time() + 600 + } + + # Mock the unverified decode to extract issuer + mocker.patch("jwt.decode", return_value={ + "iss": "https://tenant1.auth0.com/", + "events": {"http://schemas.openid.net/event/backchannel-logout": {}}, + "sub": "user123", + "sid": "session123" + }) + + mocker.patch.object( + client, "_get_jwks_cached", + return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]} + ) + + mock_verify = mocker.patch.object(client, "_verify_and_decode_jwt", return_value={ + "iss": "https://tenant1.auth0.com/", + "events": {"http://schemas.openid.net/event/backchannel-logout": {}}, + "sub": "user123", + "sid": "session123" + }) + + await client.handle_backchannel_logout("some_logout_token") + + # Verify audience is passed to JWT verification + mock_verify.assert_awaited_once() + call_kwargs = mock_verify.call_args[1] + assert call_kwargs["audience"] == "test_client" + + # In resolver mode, iss should be included for issuer-scoped deletion + mock_state_store.delete_by_logout_token.assert_awaited_once_with( + {"sub": "user123", "sid": "session123", "iss": "https://tenant1.auth0.com/"}, + None + ) + + +@pytest.mark.asyncio +async def test_backchannel_logout_mcd_unknown_domain_rejected(mocker): + """Test that backchannel logout rejects unknown domains in MCD mode (SSRF protection).""" + async def domain_resolver(context): + return "tenant1.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + state_store=AsyncMock(), + secret="test_secret_key_32_chars_long!!", + ) + + # Discovery cache is empty — no prior logins + assert len(client._discovery_cache) == 0 + + # Mock unverified decode — attacker's token has evil issuer + mocker.patch("jwt.decode", return_value={ + "iss": "https://evil.internal.server/", + "events": {"http://schemas.openid.net/event/backchannel-logout": {}}, + "sub": "user123", + "sid": "session123" + }) + + with pytest.raises(BackchannelLogoutError) as exc: + await client.handle_backchannel_logout("crafted_logout_token") + assert "Unknown domain" in str(exc.value) + assert "evil.internal.server" in str(exc.value) + + +# Test For AuthLib Helpers @pytest.mark.asyncio async def test_build_link_user_url_success(mocker): @@ -884,7 +1570,7 @@ async def test_build_link_user_url_success(mocker): # Patch _fetch_oidc_metadata to return an authorization_endpoint mock_fetch = mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={"authorization_endpoint": "https://auth0.local/authorize"} ) @@ -942,7 +1628,7 @@ async def test_build_link_user_url_fallback_authorize(mocker): # Patch _fetch_oidc_metadata to NOT have an authorization_endpoint mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={} # empty dict, triggers fallback ) @@ -979,7 +1665,7 @@ async def test_build_unlink_user_url_success(mocker): # Patch out metadata mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={"authorization_endpoint": "https://auth0.local/authorize"} ) @@ -1012,7 +1698,7 @@ async def test_build_unlink_user_url_fallback_authorize(mocker): ) # No 'authorization_endpoint' - mocker.patch.object(client, "_fetch_oidc_metadata", return_value={}) + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={}) result_url = await client._build_unlink_user_url( connection="", @@ -1043,7 +1729,7 @@ async def test_build_unlink_user_url_with_metadata(mocker): # Patch the metadata fetch to include a valid authorization endpoint mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={"authorization_endpoint": "https://auth0.local/authorize"} ) @@ -1096,7 +1782,7 @@ async def test_build_unlink_user_url_no_authorization_endpoint(mocker): # Patch _fetch_oidc_metadata to return no authorization_endpoint mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={} ) result_url = await client._build_unlink_user_url( @@ -1127,7 +1813,7 @@ async def test_backchannel_auth_with_audience_and_binding_message(mocker): mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={ "issuer": "https://auth0.local/", "backchannel_authentication_endpoint": "https://auth0.local/custom-authorize", @@ -1176,7 +1862,7 @@ async def test_backchannel_auth_rar(mocker): mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={ "issuer": "https://auth0.local/", "backchannel_authentication_endpoint": "https://auth0.local/custom-authorize", @@ -1227,7 +1913,7 @@ async def test_backchannel_auth_token_exchange_failed(mocker): mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={ "issuer": "https://auth0.local/", "backchannel_authentication_endpoint": "https://auth0.local/custom-authorize", @@ -1277,7 +1963,7 @@ async def test_initiate_backchannel_authentication_success(mocker): # Mock OIDC metadata mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={ "issuer": "https://auth0.local/", "backchannel_authentication_endpoint": "https://auth0.local/backchannel" @@ -1325,7 +2011,7 @@ async def test_initiate_backchannel_authentication_error_response(mocker): ) mocker.patch.object( client, - "_fetch_oidc_metadata", + "_get_oidc_metadata_cached", return_value={ "issuer": "https://auth0.local/", "backchannel_authentication_endpoint": "https://auth0.local/backchannel" @@ -1373,7 +2059,11 @@ async def test_backchannel_authentication_grant_success(mocker): secret="some-secret" ) # Mock OIDC metadata - client._oauth.metadata = {"token_endpoint": "https://auth0.local/token"} + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/token"} + ) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) mock_response = AsyncMock() @@ -1407,7 +2097,11 @@ async def test_backchannel_authentication_grant_error_response(mocker): client_secret="client_secret", secret="some-secret" ) - client._oauth.metadata = {"token_endpoint": "https://auth0.local/token"} + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/token"} + ) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) mock_response = AsyncMock() @@ -1434,7 +2128,11 @@ async def test_backchannel_authentication_grant_json_decode_error(mocker): client_secret="client_secret", secret="some-secret" ) - client._oauth.metadata = {"token_endpoint": "https://auth0.local/token"} + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/token"} + ) # Mock httpx.AsyncClient.post to return a response whose .json() raises JSONDecodeError mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) @@ -1459,9 +2157,9 @@ async def test_get_token_for_connection_success(mocker): ) mocker.patch.object( - client._oauth, - "metadata", - {"token_endpoint": "https://auth0.local/token"} + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/token"} ) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) @@ -1506,9 +2204,9 @@ async def test_get_token_for_connection_exchange_failed(mocker): ) mocker.patch.object( - client._oauth, - "metadata", - {"token_endpoint": "https://auth0.local/token"} + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/token"} ) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) @@ -1544,9 +2242,9 @@ async def test_get_token_by_refresh_token_success(mocker): ) mocker.patch.object( - client._oauth, - "metadata", - {"token_endpoint": "https://auth0.local/token"} + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/token"} ) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) @@ -1588,9 +2286,9 @@ async def test_get_token_by_refresh_token_exchange_failed(mocker): ) mocker.patch.object( - client._oauth, - "metadata", - {"token_endpoint": "https://auth0.local/token"} + client, + "_get_oidc_metadata_cached", + return_value={"token_endpoint": "https://auth0.local/token"} ) mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=AsyncMock) @@ -2150,7 +2848,6 @@ async def test_list_connected_account_connections_with_invalid_take_param(mocker assert "The 'take' parameter must be a positive integer." in str(exc.value) mock_my_account_client.list_connected_account_connections.assert_not_awaited() - # ============================================================================= # Custom Token Exchange Tests # ============================================================================= @@ -2177,7 +2874,6 @@ async def test_custom_token_exchange_success(mocker): "_fetch_oidc_metadata", return_value={"token_endpoint": "https://auth0.local/oauth/token"} ) - # Mock httpx response mock_response = MagicMock() mock_response.status_code = 200 @@ -2824,3 +3520,1063 @@ async def test_login_with_custom_token_exchange_failure_propagates(mocker): ) ) assert exc.value.code == "unauthorized" + +# ============================================================================= +# OIDC Metadata and JWKS Fetching Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_fetch_jwks_success(): + """Test successful JWKS fetch from URI.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + mock_jwks = { + "keys": [ + { + "kty": "RSA", + "use": "sig", + "kid": "test-key-id", + "n": "test-modulus", + "e": "AQAB" + } + ] + } + + # Mock httpx client + mock_response = MagicMock() + mock_response.json.return_value = mock_jwks + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.get.return_value = mock_response + + with patch('httpx.AsyncClient', return_value=mock_client): + jwks = await client._fetch_jwks("https://tenant.auth0.com/.well-known/jwks.json") + + assert jwks == mock_jwks + assert "keys" in jwks + mock_client.get.assert_awaited_once_with("https://tenant.auth0.com/.well-known/jwks.json") + + +@pytest.mark.asyncio +async def test_fetch_jwks_failure(): + """Test JWKS fetch failure raises ApiError.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Mock httpx client to raise exception + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.get.side_effect = Exception("Network error") + + with patch('httpx.AsyncClient', return_value=mock_client): + with pytest.raises(ApiError, match="Failed to fetch JWKS"): + await client._fetch_jwks("https://tenant.auth0.com/.well-known/jwks.json") + + +@pytest.mark.asyncio +async def test_oidc_metadata_caching(): + """Test OIDC metadata is cached and reused.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + mock_metadata = { + "issuer": "https://tenant.auth0.com/", + "authorization_endpoint": "https://tenant.auth0.com/authorize", + "token_endpoint": "https://tenant.auth0.com/oauth/token", + "jwks_uri": "https://tenant.auth0.com/.well-known/jwks.json" + } + + # Mock _fetch_oidc_metadata to track calls + fetch_count = 0 + async def mock_fetch(domain): + nonlocal fetch_count + fetch_count += 1 + return mock_metadata + + client._fetch_oidc_metadata = mock_fetch + + # First call - should fetch + result1 = await client._get_oidc_metadata_cached("tenant.auth0.com") + assert result1 == mock_metadata + assert fetch_count == 1 + first_fetch_count = fetch_count + + # Second call - should use cache + result2 = await client._get_oidc_metadata_cached("tenant.auth0.com") + assert result2 == mock_metadata + assert fetch_count == first_fetch_count # Should NOT increment + + # Verify cache contains data + assert "tenant.auth0.com" in client._discovery_cache + assert client._discovery_cache["tenant.auth0.com"]["metadata"] == mock_metadata + + +@pytest.mark.asyncio +async def test_oidc_metadata_cache_expiration(): + """Test OIDC metadata cache expires after TTL.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Set short TTL for testing + client._cache_ttl = 1 # 1 second + + mock_metadata = { + "issuer": "https://tenant.auth0.com/", + "jwks_uri": "https://tenant.auth0.com/.well-known/jwks.json" + } + + fetch_count = 0 + async def mock_fetch(domain): + nonlocal fetch_count + fetch_count += 1 + return mock_metadata + + client._fetch_oidc_metadata = mock_fetch + + # First call + await client._get_oidc_metadata_cached("tenant.auth0.com") + assert fetch_count == 1 + + # Wait for cache to expire + time.sleep(1.1) + + # Second call after expiration - should fetch again + await client._get_oidc_metadata_cached("tenant.auth0.com") + assert fetch_count == 2 + + +@pytest.mark.asyncio +async def test_jwks_caching(): + """Test JWKS is cached and reused.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + mock_metadata = { + "issuer": "https://tenant.auth0.com/", + "jwks_uri": "https://tenant.auth0.com/.well-known/jwks.json" + } + + mock_jwks = { + "keys": [{"kty": "RSA", "kid": "key1"}] + } + + # Mock the fetch methods + client._get_oidc_metadata_cached = AsyncMock(return_value=mock_metadata) + + fetch_count = 0 + async def mock_fetch_jwks(uri): + nonlocal fetch_count + fetch_count += 1 + return mock_jwks + + client._fetch_jwks = mock_fetch_jwks + + # First call - should fetch + result1 = await client._get_jwks_cached("tenant.auth0.com", mock_metadata) + assert result1 == mock_jwks + assert fetch_count == 1 + first_fetch_count = fetch_count + + # Second call - should use cache + result2 = await client._get_jwks_cached("tenant.auth0.com", mock_metadata) + assert result2 == mock_jwks + assert fetch_count == first_fetch_count # Should NOT increment + + +@pytest.mark.asyncio +async def test_jwks_cache_size_limit(): + """Test JWKS cache enforces max size limit with LRU eviction.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Set small cache size for testing + client._cache_max_entries = 3 + + mock_jwks = {"keys": [{"kty": "RSA"}]} + + # Mock methods + async def mock_fetch_metadata(domain): + return {"jwks_uri": f"https://{domain}/.well-known/jwks.json"} + + async def mock_fetch_jwks(uri): + return mock_jwks + + client._fetch_oidc_metadata = mock_fetch_metadata + client._fetch_jwks = mock_fetch_jwks + + # Fill cache to limit + await client._get_jwks_cached("domain1.auth0.com") + await client._get_jwks_cached("domain2.auth0.com") + await client._get_jwks_cached("domain3.auth0.com") + + assert len(client._discovery_cache) == 3 + assert "domain1.auth0.com" in client._discovery_cache + + # Add one more - should evict oldest (domain1) + await client._get_jwks_cached("domain4.auth0.com") + + assert len(client._discovery_cache) == 3 + assert "domain1.auth0.com" not in client._discovery_cache # Evicted + assert "domain4.auth0.com" in client._discovery_cache + + +@pytest.mark.asyncio +async def test_jwks_missing_uri_raises_error(): + """Test that missing jwks_uri in metadata raises ApiError.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Metadata WITHOUT jwks_uri + mock_metadata_no_jwks_uri = { + "issuer": "https://tenant.auth0.com/", + "authorization_endpoint": "https://tenant.auth0.com/authorize" + # No jwks_uri + } + + client._get_oidc_metadata_cached = AsyncMock(return_value=mock_metadata_no_jwks_uri) + + # Should raise ApiError when jwks_uri is missing + with pytest.raises(ApiError) as exc_info: + await client._get_jwks_cached("tenant.auth0.com") + + assert exc_info.value.code == "missing_jwks_uri" + assert "non-RFC-compliant" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_metadata_cache_size_limit(): + """Test OIDC metadata cache enforces max size limit.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + client._cache_max_entries = 2 + + async def mock_fetch(domain): + return {"issuer": f"https://{domain}/"} + + client._fetch_oidc_metadata = mock_fetch + + # Fill cache + await client._get_oidc_metadata_cached("domain1.auth0.com") + await client._get_oidc_metadata_cached("domain2.auth0.com") + + assert len(client._discovery_cache) == 2 + + # Add third - should evict first + await client._get_oidc_metadata_cached("domain3.auth0.com") + + assert len(client._discovery_cache) == 2 + assert "domain1.auth0.com" not in client._discovery_cache + assert "domain3.auth0.com" in client._discovery_cache + + +# ============================================================================= +# Issuer Validation Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_complete_login_issuer_validation_success(mocker): + """Test complete login with valid issuer in ID token.""" + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + origin_domain="tenant.auth0.com", + origin_issuer="https://tenant.auth0.com/" + ) + + mock_state_store = AsyncMock() + + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Mock OIDC metadata + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "token_endpoint": "https://tenant.auth0.com/token"} + ) + + # Mock JWKS fetch + mocker.patch.object( + client, + "_get_jwks_cached", + return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]} + ) + + # Mock OAuth fetch_token + async_fetch_token = AsyncMock() + async_fetch_token.return_value = { + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid profile" + } + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + + # Mock jwt.get_unverified_header + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + + # Mock PyJWK.from_dict + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + + # Mock jwt.decode with valid issuer + mocker.patch("jwt.decode", return_value={ + "sub": "user123", + "iss": "https://tenant.auth0.com/", # Matches origin_issuer + "aud": "test_client" + }) + + # Should succeed without raising error + result = await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + assert result is not None + assert "state_data" in result + + +@pytest.mark.asyncio +async def test_complete_login_issuer_mismatch_raises_error(mocker): + """Test that issuer mismatch in ID token raises ApiError.""" + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + origin_domain="tenant.auth0.com", + origin_issuer="https://tenant.auth0.com/" + ) + + mock_state_store = AsyncMock() + + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Mock OIDC metadata + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "token_endpoint": "https://tenant.auth0.com/token"} + ) + + # Mock JWKS fetch + mocker.patch.object( + client, + "_get_jwks_cached", + return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]} + ) + + # Mock OAuth fetch_token + async_fetch_token = AsyncMock() + async_fetch_token.return_value = { + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid profile" + } + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + + # Mock jwt.get_unverified_header + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + + # Mock PyJWK.from_dict + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + + # Mock jwt.decode to return claims with a WRONG issuer + # Our custom normalized issuer validation should catch this mismatch + mocker.patch("jwt.decode", return_value={ + "sub": "user123", + "iss": "https://wrong-issuer.auth0.com/", # Different from expected: https://tenant.auth0.com/ + "aud": "test_client", + "exp": 9999999999 + }) + + # Should raise ApiError with invalid_issuer code + with pytest.raises(ApiError) as exc_info: + await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + assert exc_info.value.code == "invalid_issuer" + assert "issuer mismatch" in str(exc_info.value).lower() + + +@pytest.mark.asyncio +async def test_normalize_domain_handles_different_schemes(): + """Test that _normalize_domain handles various URL schemes correctly.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Test domain without scheme + assert client._normalize_domain("auth0.com") == "https://auth0.com" + + # Test domain with https scheme (should remain unchanged) + assert client._normalize_domain("https://auth0.com") == "https://auth0.com" + + # Test domain with http scheme (should convert to https) + assert client._normalize_domain("http://auth0.com") == "https://auth0.com" + + # Test domain with trailing slash + assert client._normalize_domain("https://auth0.com/") == "https://auth0.com/" + + +@pytest.mark.asyncio +async def test_normalize_issuer_handles_edge_cases(): + """Test that _normalize_issuer handles edge cases for robust issuer comparison. + + This test documents the edge cases that could cause issuer validation failures + with PyJWT's strict string comparison: + - Trailing slash differences + - Case sensitivity + - HTTP vs HTTPS schemes + - Missing scheme + """ + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Test trailing slash normalization + assert client._normalize_issuer("https://auth0.com/") == "https://auth0.com" + assert client._normalize_issuer("https://auth0.com") == "https://auth0.com" + assert client._normalize_issuer("https://auth0.com/") == client._normalize_issuer("https://auth0.com") + + # Test case insensitivity + assert client._normalize_issuer("HTTPS://AUTH0.COM/") == "https://auth0.com" + assert client._normalize_issuer("Https://Auth0.Com") == "https://auth0.com" + assert client._normalize_issuer("HTTPS://AUTH0.COM/") == client._normalize_issuer("https://auth0.com") + + # Test HTTP to HTTPS conversion + assert client._normalize_issuer("http://auth0.com") == "https://auth0.com" + assert client._normalize_issuer("HTTP://AUTH0.COM/") == "https://auth0.com" + + # Test missing scheme + assert client._normalize_issuer("auth0.com") == "https://auth0.com" + assert client._normalize_issuer("AUTH0.COM/") == "https://auth0.com" + + # Test empty/None handling + assert client._normalize_issuer("") == "" + assert client._normalize_issuer(None) is None + + +# ============================================================================= +# MCD Tests : Multiple Issuer Configuration Methods Tests +# ============================================================================= + +@pytest.mark.asyncio +async def test_domain_as_static_string(): + """Test Method 1: Static domain string configuration.""" + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client_id", + client_secret="test_client_secret", + secret="test_secret_key_32_chars_long!!" + ) + + assert client._domain == "tenant.auth0.com" + assert client._domain_resolver is None + + +@pytest.mark.asyncio +async def test_domain_as_callable_function(): + """Test Method 2: Domain resolver function configuration.""" + async def domain_resolver(store_options): + return "tenant.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client_id", + client_secret="test_client_secret", + secret="test_secret_key_32_chars_long!!" + ) + + assert client._domain is None + assert client._domain_resolver == domain_resolver + + +@pytest.mark.asyncio +async def test_missing_domain_raises_configuration_error(): + """Test that missing domain parameter raises ConfigurationError.""" + with pytest.raises(ConfigurationError, match="Domain is required"): + ServerClient( + domain=None, + client_id="test_client_id", + client_secret="test_client_secret", + secret="test_secret_key_32_chars_long!!" + ) + + +@pytest.mark.asyncio +async def test_invalid_domain_type_list(): + """Test that list domain raises ConfigurationError.""" + with pytest.raises(ConfigurationError, match="must be either a string or a callable"): + ServerClient( + domain=["tenant.auth0.com"], + client_id="test_client_id", + client_secret="test_client_secret", + secret="test_secret_key_32_chars_long!!" + ) + + +@pytest.mark.asyncio +async def test_empty_domain_string(): + """Test that empty domain string raises ConfigurationError.""" + with pytest.raises(ConfigurationError, match="Domain cannot be empty"): + ServerClient( + domain="", + client_id="test_client_id", + client_secret="test_client_secret", + secret="test_secret_key_32_chars_long!!" + ) + + +# ============================================================================= +# MCD Tests : Domain Resolver Context Tests +# ============================================================================= + +@pytest.mark.asyncio +async def test_domain_resolver_receives_context(mocker): + """Test that domain resolver receives DomainResolverContext with request data.""" + received_context = None + + async def domain_resolver(context): + nonlocal received_context + received_context = context + return "tenant.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + # Mock request with headers + mock_request = MagicMock() + mock_request.url = "https://a.my-app.com/auth/login" + mock_request.headers = {"host": "a.my-app.com", "x-forwarded-host": "a.my-app.com"} + + # Mock OIDC metadata fetch + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "authorization_endpoint": "https://tenant.auth0.com/authorize"} + ) + + try: + await client.start_interactive_login(store_options={"request": mock_request}) + except Exception: # noqa: S110 + pass # We only care about context being passed + + assert received_context is not None + assert isinstance(received_context, DomainResolverContext) + assert received_context.request_url == "https://a.my-app.com/auth/login" + assert received_context.request_headers.get("host") == "a.my-app.com" + + +@pytest.mark.asyncio +async def test_domain_resolver_error_on_none(): + """Test that domain resolver returning None raises DomainResolverError.""" + async def bad_resolver(context): + return None + + client = ServerClient( + domain=bad_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + with pytest.raises(DomainResolverError, match="returned None"): + await client.start_interactive_login(store_options={"request": MagicMock()}) + + +@pytest.mark.asyncio +async def test_domain_resolver_error_on_empty_string(): + """Test that domain resolver returning empty string raises DomainResolverError.""" + async def bad_resolver(context): + return "" + + client = ServerClient( + domain=bad_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + with pytest.raises(DomainResolverError, match="empty string"): + await client.start_interactive_login(store_options={"request": MagicMock()}) + + +@pytest.mark.asyncio +async def test_domain_resolver_error_on_exception(): + """Test that domain resolver exceptions are wrapped in DomainResolverError.""" + async def bad_resolver(context): + raise ValueError("Something went wrong") + + client = ServerClient( + domain=bad_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + with pytest.raises(DomainResolverError, match="raised an exception"): + await client.start_interactive_login(store_options={"request": MagicMock()}) + + +@pytest.mark.asyncio +async def test_domain_resolver_with_no_request(mocker): + """Test that domain resolver works with empty context when no request.""" + received_context = None + + async def domain_resolver(context): + nonlocal received_context + received_context = context + return "tenant.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "authorization_endpoint": "https://tenant.auth0.com/authorize"} + ) + + try: + await client.start_interactive_login(store_options=None) + except Exception: # noqa: S110 + pass # We only care about context being passed + assert received_context is not None + assert received_context.request_url is None + assert received_context.request_headers is None + + +@pytest.mark.asyncio +async def test_domain_resolver_error_on_non_string_type(): + """Test that domain resolver returning non-string raises DomainResolverError.""" + async def bad_resolver(context): + return 12345 + + client = ServerClient( + domain=bad_resolver, + client_id="test_client", + client_secret="test_secret", + secret="test_secret_key_32_chars_long!!", + transaction_store=AsyncMock(), + state_store=AsyncMock() + ) + + with pytest.raises(DomainResolverError, match="must return a string"): + await client.start_interactive_login(store_options={"request": MagicMock()}) + + +@pytest.mark.asyncio +async def test_sync_callable_as_domain_resolver_raises_error(): + """Test that a non-async (sync) callable raises DomainResolverError. + + The SDK always awaits the resolver, so sync callables are not supported. + Domain resolvers must be async functions. + """ + def sync_resolver(context): + return "tenant1.auth0.com" + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = StateData( + user={"sub": "user123"}, + domain="tenant1.auth0.com", + token_sets=[], + internal={"sid": "123", "created_at": int(time.time())} + ) + + client = ServerClient( + domain=sync_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + with pytest.raises(DomainResolverError): + await client.get_user(store_options={"request": {}}) + + +@pytest.mark.asyncio +async def test_resolver_returns_domain_with_scheme_prefix(): + """Test that domain resolver returning 'https://domain' works with session matching.""" + async def resolver_with_scheme(context): + return "https://tenant1.auth0.com" + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = StateData( + user={"sub": "user123"}, + domain="https://tenant1.auth0.com", + token_sets=[], + internal={"sid": "123", "created_at": int(time.time())} + ) + + client = ServerClient( + domain=resolver_with_scheme, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + user = await client.get_user(store_options={"request": {}}) + assert user is not None + assert user["sub"] == "user123" + + +# ============================================================================= +# MCD Tests : Domain-specific Session Management Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_session_stores_origin_domain(mocker): + """Test that session stores origin domain from login (Requirement 5).""" + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + origin_domain="tenant1.auth0.com", + origin_issuer="https://tenant1.auth0.com/" + ) + + captured_state = None + async def capture_state(identifier, state_data, options=None): + nonlocal captured_state + captured_state = state_data + + mock_state_store = AsyncMock() + mock_state_store.set = AsyncMock(side_effect=capture_state) + + client = ServerClient( + domain="tenant1.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={ + "issuer": "https://tenant1.auth0.com/", + "token_endpoint": "https://tenant1.auth0.com/token" + }) + mocker.patch.object(client, "_get_jwks_cached", return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]}) + + async_fetch_token = AsyncMock(return_value={ + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid" + }) + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + + # Mock JWT verification + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + mocker.patch("jwt.decode", return_value={"sub": "user123", "iss": "https://tenant1.auth0.com/"}) + + await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + # Verify session has domain field set + assert captured_state.domain == "tenant1.auth0.com" + + +@pytest.mark.asyncio +async def test_cross_domain_session_rejected(): + """Test that session from domain1 cannot be used with domain2 (Requirement 5).""" + # Create session with domain1 + session_data = StateData( + user={"sub": "user123"}, + domain="tenant1.auth0.com", + token_sets=[], + internal={"sid": "123", "created_at": int(time.time())} + ) + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + # Domain resolver returns domain2 (different from session) + async def domain_resolver(context): + return "tenant2.auth0.com" + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # get_user should return None (session rejected) + user = await client.get_user(store_options={"request": {}}) + assert user is None + + +@pytest.mark.asyncio +async def test_logout_uses_current_domain(mocker): + """Test that logout uses current resolved domain (Requirement 7).""" + current_domain = "tenant2.auth0.com" + + async def domain_resolver(context): + return current_domain + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = {"domain": current_domain, "user": {"sub": "user1"}} + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + logout_url = await client.logout(store_options={"request": {}}) + + # Verify logout URL uses current domain + assert current_domain in logout_url + assert logout_url.startswith(f"https://{current_domain}") + # Verify session was deleted (domains match) + mock_state_store.delete.assert_called_once() + + +@pytest.mark.asyncio +async def test_logout_clears_session_for_current_domain(): + """Test that logout clears session (Requirement 7).""" + mock_state_store = AsyncMock() + + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + await client.logout() + + # Verify session was deleted + mock_state_store.delete.assert_called_once() + + +@pytest.mark.asyncio +async def test_domain_migration_old_sessions_remain_valid(): + """Test that old sessions remain valid with old domain requests (Requirement 8).""" + old_domain = "old-tenant.auth0.com" + + # Session from old domain + session_data = StateData( + user={"sub": "user123"}, + domain=old_domain, + token_sets=[], + internal={"sid": "123", "created_at": int(time.time())} + ) + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + # Domain resolver returns old domain + async def domain_resolver(context): + return old_domain + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Should successfully retrieve user + user = await client.get_user(store_options={"request": {}}) + assert user is not None + assert user["sub"] == "user123" + + +@pytest.mark.asyncio +async def test_domain_migration_new_sessions_use_new_domain(mocker): + """Test that new logins create sessions with new domain (Requirement 8).""" + new_domain = "new-tenant.auth0.com" + + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + origin_domain=new_domain, + origin_issuer=f"https://{new_domain}/" + ) + + captured_state = None + async def capture_state(identifier, state_data, options=None): + nonlocal captured_state + captured_state = state_data + + mock_state_store = AsyncMock() + mock_state_store.set = AsyncMock(side_effect=capture_state) + + client = ServerClient( + domain=new_domain, + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + mocker.patch.object(client, "_get_oidc_metadata_cached", return_value={ + "issuer": f"https://{new_domain}/", + "token_endpoint": f"https://{new_domain}/token" + }) + mocker.patch.object(client, "_get_jwks_cached", return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]}) + + async_fetch_token = AsyncMock(return_value={ + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid" + }) + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + + # Mock JWT verification + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + mocker.patch("jwt.decode", return_value={"sub": "user123", "iss": f"https://{new_domain}/"}) + + await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + # Verify new session has new domain + assert captured_state.domain == new_domain + + +@pytest.mark.asyncio +async def test_domain_migration_sessions_isolated(): + """Test that old domain sessions cannot be used with new domain (Requirement 8).""" + old_domain = "old-tenant.auth0.com" + new_domain = "new-tenant.auth0.com" + + # Session from old domain + session_data = StateData( + user={"sub": "user123"}, + domain=old_domain, + token_sets=[], + internal={"sid": "123", "created_at": int(time.time())} + ) + + mock_state_store = AsyncMock() + mock_state_store.get.return_value = session_data + + # Domain resolver returns NEW domain (migration happened) + async def domain_resolver(context): + return new_domain + + client = ServerClient( + domain=domain_resolver, + client_id="test_client", + client_secret="test_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Should reject old session + user = await client.get_user(store_options={"request": {}}) + assert user is None diff --git a/src/auth0_server_python/utils/helpers.py b/src/auth0_server_python/utils/helpers.py index c57ab18..05cb0f8 100644 --- a/src/auth0_server_python/utils/helpers.py +++ b/src/auth0_server_python/utils/helpers.py @@ -6,6 +6,9 @@ from typing import Any, Optional from urllib.parse import parse_qs, urlencode, urlparse +from auth0_server_python.auth_types import DomainResolverContext +from auth0_server_python.error import DomainResolverError + class PKCE: @classmethod @@ -224,3 +227,69 @@ def create_logout_url(domain: str, client_id: str, return_to: Optional[str] = No if return_to: params["returnTo"] = return_to return URL.build_url(base_url, params) + + +# ============================================================================= +# Domain Resolver Utilities +# ============================================================================= + +def build_domain_resolver_context(store_options: Optional[dict[str, Any]]) -> 'DomainResolverContext': + """ + Build DomainResolverContext from store_options. + + Extracts request information in a framework-agnostic way using duck typing. + + Args: + store_options: Dictionary containing 'request' and 'response' objects + + Returns: + DomainResolverContext with extracted request data + """ + + if not store_options: + return DomainResolverContext() + + request = store_options.get('request') + if not request: + return DomainResolverContext() + + # Framework-agnostic extraction using duck typing + request_url = str(request.url) if hasattr(request, 'url') else None + request_headers = dict(request.headers) if hasattr(request, 'headers') else None + + return DomainResolverContext( + request_url=request_url, + request_headers=request_headers + ) + + +def validate_resolved_domain_value(domain_value: Any) -> str: + """ + Validate the value returned by domain resolver. + + Args: + domain_value: The value returned by the domain resolver + + Returns: + The validated domain string + + Raises: + DomainResolverError: If the returned value is invalid + """ + + if domain_value is None: + raise DomainResolverError( + "Domain resolver returned None. Must return a valid domain string." + ) + + if not isinstance(domain_value, str): + raise DomainResolverError( + f"Domain resolver must return a string. Got {type(domain_value).__name__} instead." + ) + + if not domain_value.strip(): + raise DomainResolverError( + "Domain resolver returned an empty string. Must return a valid domain." + ) + + return domain_value