diff --git a/.github/actions/conformance/client.py b/.github/actions/conformance/client.py index 2e1e7788b..b39e1cf39 100644 --- a/.github/actions/conformance/client.py +++ b/.github/actions/conformance/client.py @@ -275,6 +275,27 @@ async def run_client_credentials_basic(server_url: str) -> None: async def run_auth_code_client(server_url: str) -> None: """Authorization code flow (default for auth/* scenarios).""" callback_handler = ConformanceOAuthCallbackHandler() + storage = InMemoryTokenStorage() + + # Check for pre-registered client credentials from context + context_json = os.environ.get("MCP_CONFORMANCE_CONTEXT") + if context_json: + try: + context = json.loads(context_json) + client_id = context.get("client_id") + client_secret = context.get("client_secret") + if client_id: + await storage.set_client_info( + OAuthClientInformationFull( + client_id=client_id, + client_secret=client_secret, + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + token_endpoint_auth_method="client_secret_basic" if client_secret else "none", + ) + ) + logger.debug(f"Pre-loaded client credentials: client_id={client_id}") + except json.JSONDecodeError: + pass oauth_auth = OAuthClientProvider( server_url=server_url, @@ -284,7 +305,7 @@ async def run_auth_code_client(server_url: str) -> None: grant_types=["authorization_code", "refresh_token"], response_types=["code"], ), - storage=InMemoryTokenStorage(), + storage=storage, redirect_handler=callback_handler.handle_redirect, callback_handler=callback_handler.handle_callback, client_metadata_url="https://conformance-test.local/client-metadata.json", diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml index cd9c4b01a..d876da00b 100644 --- a/.github/workflows/conformance.yml +++ b/.github/workflows/conformance.yml @@ -42,4 +42,4 @@ jobs: with: node-version: 24 - run: uv sync --frozen --all-extras --package mcp - - run: npx @modelcontextprotocol/conformance@0.1.10 client --command 'uv run --frozen python .github/actions/conformance/client.py' --suite all + - run: npx @modelcontextprotocol/conformance@0.1.13 client --command 'uv run --frozen python .github/actions/conformance/client.py' --suite all diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 98df4d25d..1ce698f06 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -229,6 +229,7 @@ def __init__( callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None, timeout: float = 300.0, client_metadata_url: str | None = None, + validate_resource_url: Callable[[str, str | None], Awaitable[str | None]] | None = None, ): """Initialize OAuth2 authentication. @@ -243,6 +244,11 @@ def __init__( advertises client_id_metadata_document_supported=true, this URL will be used as the client_id instead of performing dynamic client registration. Must be a valid HTTPS URL with a non-root pathname. + validate_resource_url: Optional callback to override resource URL validation. + Called with (server_url, prm_resource) where prm_resource is the resource + from Protected Resource Metadata (or None if not present). Must return the + resource URL to use, or None to omit it. If not provided, default validation + rejects mismatched resources per RFC 8707. Raises: ValueError: If client_metadata_url is provided but not a valid HTTPS URL @@ -263,6 +269,7 @@ def __init__( timeout=timeout, client_metadata_url=client_metadata_url, ) + self._validate_resource_url_callback = validate_resource_url self._initialized = False async def _handle_protected_resource_response(self, response: httpx.Response) -> bool: @@ -476,6 +483,26 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response) -> Non metadata = OAuthMetadata.model_validate_json(content) self.context.oauth_metadata = metadata + async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None: + """Validate that PRM resource matches the server URL per RFC 8707.""" + prm_resource = str(prm.resource) if prm.resource else None + + if self._validate_resource_url_callback is not None: + await self._validate_resource_url_callback(self.context.server_url, prm_resource) + return + + if not prm_resource: + return # pragma: no cover + default_resource = resource_url_from_server_url(self.context.server_url) + # Normalize: Pydantic AnyHttpUrl adds trailing slash to root URLs + # (e.g. "https://example.com/") while resource_url_from_server_url may not. + if not default_resource.endswith("/"): + default_resource += "/" + if not prm_resource.endswith("/"): + prm_resource += "/" + if not check_resource_allowed(requested_resource=default_resource, configured_resource=prm_resource): + raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}") + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: """HTTPX auth flow integration.""" async with self.context.lock: @@ -517,6 +544,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. prm = await handle_protected_resource_response(discovery_response) if prm: + # Validate PRM resource matches server URL (RFC 8707) + await self._validate_resource_match(prm) self.context.protected_resource_metadata = prm # todo: try all authorization_servers to find the OASM diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 7ad24f2df..bd6b95d0c 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -11,6 +11,7 @@ from pydantic import AnyHttpUrl, AnyUrl from mcp.client.auth import OAuthClientProvider, PKCEParameters +from mcp.client.auth.exceptions import OAuthFlowError from mcp.client.auth.utils import ( build_oauth_authorization_server_metadata_discovery_urls, build_protected_resource_metadata_discovery_urls, @@ -818,6 +819,88 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa assert "resource=" in content +class TestResourceValidation: + """Test PRM resource validation in OAuthClientProvider.""" + + @pytest.mark.anyio + async def test_rejects_mismatched_resource(self, client_metadata, mock_storage): + """Client must reject PRM resource that doesn't match server URL.""" + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + ) + provider._initialized = True + + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl("https://evil.example.com/mcp"), + authorization_servers=[AnyHttpUrl("https://auth.example.com")], + ) + with pytest.raises(OAuthFlowError, match="does not match expected"): + await provider._validate_resource_match(prm) + + @pytest.mark.anyio + async def test_accepts_matching_resource(self, client_metadata, mock_storage): + """Client must accept PRM resource that matches server URL.""" + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + ) + provider._initialized = True + + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl("https://api.example.com/v1/mcp"), + authorization_servers=[AnyHttpUrl("https://auth.example.com")], + ) + # Should not raise + await provider._validate_resource_match(prm) + + @pytest.mark.anyio + async def test_custom_validate_resource_url_callback(self, client_metadata, mock_storage): + """Custom callback overrides default validation.""" + callback_called_with: list[tuple[str, str | None]] = [] + + async def custom_validate(server_url: str, prm_resource: str | None) -> None: + callback_called_with.append((server_url, prm_resource)) + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + validate_resource_url=custom_validate, + ) + provider._initialized = True + + # This would normally fail default validation (different origin), + # but custom callback accepts it + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl("https://evil.example.com/mcp"), + authorization_servers=[AnyHttpUrl("https://auth.example.com")], + ) + await provider._validate_resource_match(prm) + assert len(callback_called_with) == 1 + assert callback_called_with[0][0] == "https://api.example.com/v1/mcp" + assert callback_called_with[0][1] == "https://evil.example.com/mcp" + + @pytest.mark.anyio + async def test_accepts_root_url_with_trailing_slash(self, client_metadata, mock_storage): + """Root URLs with trailing slash normalization should match.""" + provider = OAuthClientProvider( + server_url="https://api.example.com", + client_metadata=client_metadata, + storage=mock_storage, + ) + provider._initialized = True + + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl("https://api.example.com/"), + authorization_servers=[AnyHttpUrl("https://auth.example.com")], + ) + # Should not raise despite trailing slash difference + await provider._validate_resource_match(prm) + + class TestRegistrationResponse: """Test client registration response handling.""" @@ -963,7 +1046,7 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide # Send a successful discovery response with minimal protected resource metadata discovery_response = httpx.Response( 200, - content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}', + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', request=discovery_request, ) @@ -1116,7 +1199,7 @@ async def test_token_exchange_accepts_201_status( # Send a successful discovery response with minimal protected resource metadata discovery_response = httpx.Response( 200, - content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}', + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', request=discovery_request, )