Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion .github/actions/conformance/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,27 @@ async def run_client_credentials_basic(server_url: str) -> None:
async def run_auth_code_client(server_url: str) -> None:
"""Authorization code flow (default for auth/* scenarios)."""
callback_handler = ConformanceOAuthCallbackHandler()
storage = InMemoryTokenStorage()

# Check for pre-registered client credentials from context
context_json = os.environ.get("MCP_CONFORMANCE_CONTEXT")
if context_json:
try:
context = json.loads(context_json)
client_id = context.get("client_id")
client_secret = context.get("client_secret")
if client_id:
await storage.set_client_info(
OAuthClientInformationFull(
client_id=client_id,
client_secret=client_secret,
redirect_uris=[AnyUrl("http://localhost:3000/callback")],
token_endpoint_auth_method="client_secret_basic" if client_secret else "none",
)
)
logger.debug(f"Pre-loaded client credentials: client_id={client_id}")
except json.JSONDecodeError:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont you want to see the exception if the json is wrong?

pass

oauth_auth = OAuthClientProvider(
server_url=server_url,
Expand All @@ -284,7 +305,7 @@ async def run_auth_code_client(server_url: str) -> None:
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
),
storage=InMemoryTokenStorage(),
storage=storage,
redirect_handler=callback_handler.handle_redirect,
callback_handler=callback_handler.handle_callback,
client_metadata_url="https://conformance-test.local/client-metadata.json",
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/conformance.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ jobs:
with:
node-version: 24
- run: uv sync --frozen --all-extras --package mcp
- run: npx @modelcontextprotocol/conformance@0.1.10 client --command 'uv run --frozen python .github/actions/conformance/client.py' --suite all
- run: npx @modelcontextprotocol/conformance@0.1.13 client --command 'uv run --frozen python .github/actions/conformance/client.py' --suite all
29 changes: 29 additions & 0 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def __init__(
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
timeout: float = 300.0,
client_metadata_url: str | None = None,
validate_resource_url: Callable[[str, str | None], Awaitable[str | None]] | None = None,
):
"""Initialize OAuth2 authentication.

Expand All @@ -243,6 +244,11 @@ def __init__(
advertises client_id_metadata_document_supported=true, this URL will be
used as the client_id instead of performing dynamic client registration.
Must be a valid HTTPS URL with a non-root pathname.
validate_resource_url: Optional callback to override resource URL validation.
Called with (server_url, prm_resource) where prm_resource is the resource
from Protected Resource Metadata (or None if not present). Must return the
resource URL to use, or None to omit it. If not provided, default validation
rejects mismatched resources per RFC 8707.

Raises:
ValueError: If client_metadata_url is provided but not a valid HTTPS URL
Expand All @@ -263,6 +269,7 @@ def __init__(
timeout=timeout,
client_metadata_url=client_metadata_url,
)
self._validate_resource_url_callback = validate_resource_url
self._initialized = False

async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
Expand Down Expand Up @@ -476,6 +483,26 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response) -> Non
metadata = OAuthMetadata.model_validate_json(content)
self.context.oauth_metadata = metadata

async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None:
"""Validate that PRM resource matches the server URL per RFC 8707."""
prm_resource = str(prm.resource) if prm.resource else None

if self._validate_resource_url_callback is not None:
await self._validate_resource_url_callback(self.context.server_url, prm_resource)
return

if not prm_resource:
return # pragma: no cover
default_resource = resource_url_from_server_url(self.context.server_url)
# Normalize: Pydantic AnyHttpUrl adds trailing slash to root URLs
# (e.g. "https://example.com/") while resource_url_from_server_url may not.
if not default_resource.endswith("/"):
default_resource += "/"
if not prm_resource.endswith("/"):
prm_resource += "/"
if not check_resource_allowed(requested_resource=default_resource, configured_resource=prm_resource):
raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}")

async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
"""HTTPX auth flow integration."""
async with self.context.lock:
Expand Down Expand Up @@ -517,6 +544,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.

prm = await handle_protected_resource_response(discovery_response)
if prm:
# Validate PRM resource matches server URL (RFC 8707)
await self._validate_resource_match(prm)
self.context.protected_resource_metadata = prm

# todo: try all authorization_servers to find the OASM
Expand Down
87 changes: 85 additions & 2 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pydantic import AnyHttpUrl, AnyUrl

from mcp.client.auth import OAuthClientProvider, PKCEParameters
from mcp.client.auth.exceptions import OAuthFlowError
from mcp.client.auth.utils import (
build_oauth_authorization_server_metadata_discovery_urls,
build_protected_resource_metadata_discovery_urls,
Expand Down Expand Up @@ -818,6 +819,88 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa
assert "resource=" in content


class TestResourceValidation:
"""Test PRM resource validation in OAuthClientProvider."""

@pytest.mark.anyio
async def test_rejects_mismatched_resource(self, client_metadata, mock_storage):
"""Client must reject PRM resource that doesn't match server URL."""
provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
)
provider._initialized = True

prm = ProtectedResourceMetadata(
resource=AnyHttpUrl("https://evil.example.com/mcp"),
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
)
with pytest.raises(OAuthFlowError, match="does not match expected"):
await provider._validate_resource_match(prm)

@pytest.mark.anyio
async def test_accepts_matching_resource(self, client_metadata, mock_storage):
"""Client must accept PRM resource that matches server URL."""
provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
)
provider._initialized = True

prm = ProtectedResourceMetadata(
resource=AnyHttpUrl("https://api.example.com/v1/mcp"),
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
)
# Should not raise
await provider._validate_resource_match(prm)

@pytest.mark.anyio
async def test_custom_validate_resource_url_callback(self, client_metadata, mock_storage):
"""Custom callback overrides default validation."""
callback_called_with: list[tuple[str, str | None]] = []

async def custom_validate(server_url: str, prm_resource: str | None) -> None:
callback_called_with.append((server_url, prm_resource))

provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
validate_resource_url=custom_validate,
)
provider._initialized = True

# This would normally fail default validation (different origin),
# but custom callback accepts it
prm = ProtectedResourceMetadata(
resource=AnyHttpUrl("https://evil.example.com/mcp"),
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
)
await provider._validate_resource_match(prm)
assert len(callback_called_with) == 1
assert callback_called_with[0][0] == "https://api.example.com/v1/mcp"
assert callback_called_with[0][1] == "https://evil.example.com/mcp"

@pytest.mark.anyio
async def test_accepts_root_url_with_trailing_slash(self, client_metadata, mock_storage):
"""Root URLs with trailing slash normalization should match."""
provider = OAuthClientProvider(
server_url="https://api.example.com",
client_metadata=client_metadata,
storage=mock_storage,
)
provider._initialized = True

prm = ProtectedResourceMetadata(
resource=AnyHttpUrl("https://api.example.com/"),
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
)
# Should not raise despite trailing slash difference
await provider._validate_resource_match(prm)


class TestRegistrationResponse:
"""Test client registration response handling."""

Expand Down Expand Up @@ -963,7 +1046,7 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide
# Send a successful discovery response with minimal protected resource metadata
discovery_response = httpx.Response(
200,
content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}',
content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}',
request=discovery_request,
)

Expand Down Expand Up @@ -1116,7 +1199,7 @@ async def test_token_exchange_accepts_201_status(
# Send a successful discovery response with minimal protected resource metadata
discovery_response = httpx.Response(
200,
content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}',
content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}',
request=discovery_request,
)

Expand Down
Loading