diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 957082a85..a5cd9feca 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Generic, Literal, Protocol, TypeVar +from typing import Any, Generic, Literal, Protocol, TypeVar from urllib.parse import parse_qs, urlencode, urlparse, urlunparse from pydantic import AnyUrl, BaseModel @@ -40,6 +40,8 @@ class AccessToken(BaseModel): scopes: list[str] expires_at: int | None = None resource: str | None = None # RFC 8707 resource indicator + subject: str | None = None # JWT sub claim (user ID) + claims: dict[str, Any] | None = None # Additional JWT claims beyond reserved fields RegistrationErrorCode = Literal[ diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index 1538adc7c..3da3388d6 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -5,6 +5,7 @@ from pydantic import AnyUrl, BaseModel +from mcp.server.auth.middleware.auth_context import get_access_token from mcp.server.context import LifespanContextT, RequestT, ServerRequestContext from mcp.server.elicitation import ( ElicitationResult, @@ -218,6 +219,26 @@ def client_id(self) -> str | None: """Get the client ID if available.""" return self.request_context.meta.get("client_id") if self.request_context.meta else None # pragma: no cover + @property + def subject(self) -> str | None: + """Get the authenticated user's subject (JWT sub claim), if available. + + Returns the ``subject`` field from the current request's access token. + This is typically the user ID set by the token verifier when the token + is validated. Returns ``None`` when the request is unauthenticated or + the token verifier did not populate the field. + + Example:: + + @server.tool() + async def my_tool(ctx: Context) -> str: + if ctx.subject is None: + return "unauthenticated" + return f"Hello, {ctx.subject}" + """ + token = get_access_token() + return token.subject if token else None + @property def request_id(self) -> str: """Get the unique ID for this request.""" diff --git a/tests/server/auth/middleware/test_bearer_auth.py b/tests/server/auth/middleware/test_bearer_auth.py index bd14e294c..3e471a8a1 100644 --- a/tests/server/auth/middleware/test_bearer_auth.py +++ b/tests/server/auth/middleware/test_bearer_auth.py @@ -102,6 +102,64 @@ def no_expiry_access_token() -> AccessToken: ) +class TestAccessTokenFields: + """Tests for AccessToken model fields including subject and claims.""" + + def test_backward_compat_without_subject_and_claims(self): + """Existing code that omits subject/claims should still work.""" + token = AccessToken( + token="tok", + client_id="client", + scopes=["read"], + ) + assert token.subject is None + assert token.claims is None + + def test_subject_field(self): + """subject stores the JWT sub claim.""" + token = AccessToken( + token="tok", + client_id="client", + scopes=["read"], + subject="user-123", + ) + assert token.subject == "user-123" + + def test_claims_field(self): + """claims stores arbitrary additional JWT claims.""" + custom_claims = {"org": "acme", "role": "admin", "tier": 2} + token = AccessToken( + token="tok", + client_id="client", + scopes=["read"], + claims=custom_claims, + ) + assert token.claims == custom_claims + + def test_subject_and_claims_together(self): + """subject and claims can both be set simultaneously.""" + token = AccessToken( + token="tok", + client_id="client", + scopes=["read"], + subject="user-456", + claims={"org": "acme"}, + ) + assert token.subject == "user-456" + assert token.claims == {"org": "acme"} + + def test_subject_flows_through_authenticated_user(self): + """AuthenticatedUser carries the subject via its access_token attribute.""" + token = AccessToken( + token="tok", + client_id="client", + scopes=["read"], + subject="user-789", + ) + user = AuthenticatedUser(token) + assert user.access_token.subject == "user-789" + + @pytest.mark.anyio class TestBearerAuthBackend: """Tests for the BearerAuthBackend class.""" diff --git a/tests/server/mcpserver/test_context.py b/tests/server/mcpserver/test_context.py new file mode 100644 index 000000000..51b79c890 --- /dev/null +++ b/tests/server/mcpserver/test_context.py @@ -0,0 +1,45 @@ +"""Tests for the mcpserver Context class.""" + +from mcp.server.auth.middleware.auth_context import auth_context_var +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken +from mcp.server.mcpserver import Context + + +class TestContextSubject: + """Tests for Context.subject property.""" + + def test_subject_returns_none_when_unauthenticated(self): + ctx = Context() + assert ctx.subject is None + + def test_subject_returns_none_when_token_has_no_subject(self): + user = AuthenticatedUser(AccessToken(token="tok", client_id="client", scopes=["read"])) + token = auth_context_var.set(user) + try: + ctx = Context() + assert ctx.subject is None + finally: + auth_context_var.reset(token) + + def test_subject_returns_value_from_access_token(self): + user = AuthenticatedUser(AccessToken(token="tok", client_id="client", scopes=["read"], subject="user-123")) + token = auth_context_var.set(user) + try: + ctx = Context() + assert ctx.subject == "user-123" + finally: + auth_context_var.reset(token) + + def test_subject_reflects_current_context(self): + ctx = Context() + assert ctx.subject is None + + user = AuthenticatedUser(AccessToken(token="a", client_id="c", scopes=[], subject="alice")) + cv_token = auth_context_var.set(user) + try: + assert ctx.subject == "alice" + finally: + auth_context_var.reset(cv_token) + + assert ctx.subject is None