Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/mcp/server/auth/provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Generic, Literal, Protocol, TypeVar
from typing import Any, Generic, Literal, Protocol, TypeVar
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse

from pydantic import AnyUrl, BaseModel
Expand Down Expand Up @@ -40,6 +40,8 @@ class AccessToken(BaseModel):
scopes: list[str]
expires_at: int | None = None
resource: str | None = None # RFC 8707 resource indicator
subject: str | None = None # JWT sub claim (user ID)
claims: dict[str, Any] | None = None # Additional JWT claims beyond reserved fields


RegistrationErrorCode = Literal[
Expand Down
21 changes: 21 additions & 0 deletions src/mcp/server/mcpserver/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pydantic import AnyUrl, BaseModel

from mcp.server.auth.middleware.auth_context import get_access_token
from mcp.server.context import LifespanContextT, RequestT, ServerRequestContext
from mcp.server.elicitation import (
ElicitationResult,
Expand Down Expand Up @@ -218,6 +219,26 @@ def client_id(self) -> str | None:
"""Get the client ID if available."""
return self.request_context.meta.get("client_id") if self.request_context.meta else None # pragma: no cover

@property
def subject(self) -> str | None:
"""Get the authenticated user's subject (JWT sub claim), if available.

Returns the ``subject`` field from the current request's access token.
This is typically the user ID set by the token verifier when the token
is validated. Returns ``None`` when the request is unauthenticated or
the token verifier did not populate the field.

Example::

@server.tool()
async def my_tool(ctx: Context) -> str:
if ctx.subject is None:
return "unauthenticated"
return f"Hello, {ctx.subject}"
"""
token = get_access_token()
return token.subject if token else None

@property
def request_id(self) -> str:
"""Get the unique ID for this request."""
Expand Down
58 changes: 58 additions & 0 deletions tests/server/auth/middleware/test_bearer_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,64 @@ def no_expiry_access_token() -> AccessToken:
)


class TestAccessTokenFields:
"""Tests for AccessToken model fields including subject and claims."""

def test_backward_compat_without_subject_and_claims(self):
"""Existing code that omits subject/claims should still work."""
token = AccessToken(
token="tok",
client_id="client",
scopes=["read"],
)
assert token.subject is None
assert token.claims is None

def test_subject_field(self):
"""subject stores the JWT sub claim."""
token = AccessToken(
token="tok",
client_id="client",
scopes=["read"],
subject="user-123",
)
assert token.subject == "user-123"

def test_claims_field(self):
"""claims stores arbitrary additional JWT claims."""
custom_claims = {"org": "acme", "role": "admin", "tier": 2}
token = AccessToken(
token="tok",
client_id="client",
scopes=["read"],
claims=custom_claims,
)
assert token.claims == custom_claims

def test_subject_and_claims_together(self):
"""subject and claims can both be set simultaneously."""
token = AccessToken(
token="tok",
client_id="client",
scopes=["read"],
subject="user-456",
claims={"org": "acme"},
)
assert token.subject == "user-456"
assert token.claims == {"org": "acme"}

def test_subject_flows_through_authenticated_user(self):
"""AuthenticatedUser carries the subject via its access_token attribute."""
token = AccessToken(
token="tok",
client_id="client",
scopes=["read"],
subject="user-789",
)
user = AuthenticatedUser(token)
assert user.access_token.subject == "user-789"


@pytest.mark.anyio
class TestBearerAuthBackend:
"""Tests for the BearerAuthBackend class."""
Expand Down
45 changes: 45 additions & 0 deletions tests/server/mcpserver/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Tests for the mcpserver Context class."""

from mcp.server.auth.middleware.auth_context import auth_context_var
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
from mcp.server.auth.provider import AccessToken
from mcp.server.mcpserver import Context


class TestContextSubject:
"""Tests for Context.subject property."""

def test_subject_returns_none_when_unauthenticated(self):
ctx = Context()
assert ctx.subject is None

def test_subject_returns_none_when_token_has_no_subject(self):
user = AuthenticatedUser(AccessToken(token="tok", client_id="client", scopes=["read"]))
token = auth_context_var.set(user)
try:
ctx = Context()
assert ctx.subject is None
finally:
auth_context_var.reset(token)

def test_subject_returns_value_from_access_token(self):
user = AuthenticatedUser(AccessToken(token="tok", client_id="client", scopes=["read"], subject="user-123"))
token = auth_context_var.set(user)
try:
ctx = Context()
assert ctx.subject == "user-123"
finally:
auth_context_var.reset(token)

def test_subject_reflects_current_context(self):
ctx = Context()
assert ctx.subject is None

user = AuthenticatedUser(AccessToken(token="a", client_id="c", scopes=[], subject="alice"))
cv_token = auth_context_var.set(user)
try:
assert ctx.subject == "alice"
finally:
auth_context_var.reset(cv_token)

assert ctx.subject is None