diff --git a/.vscode/launch.json b/.vscode/launch.json index cea7fd917..0d90e162a 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -23,6 +23,16 @@ "program": "cmd/github-mcp-server/main.go", "args": ["stdio", "--read-only"], "console": "integratedTerminal", + }, + { + "name": "Launch http server", + "type": "go", + "request": "launch", + "mode": "auto", + "cwd": "${workspaceFolder}", + "program": "cmd/github-mcp-server/main.go", + "args": ["http", "--port", "8082"], + "console": "integratedTerminal", } ] } \ No newline at end of file diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index c361a4d5a..b8002d456 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -9,6 +9,7 @@ import ( "github.com/github/github-mcp-server/internal/ghmcp" "github.com/github/github-mcp-server/pkg/github" + ghhttp "github.com/github/github-mcp-server/pkg/http" "github.com/spf13/cobra" "github.com/spf13/pflag" "github.com/spf13/viper" @@ -89,6 +90,31 @@ var ( return ghmcp.RunStdioServer(stdioServerConfig) }, } + + httpCmd = &cobra.Command{ + Use: "http", + Short: "Start HTTP server", + Long: `Start an HTTP server that listens for MCP requests over HTTP.`, + RunE: func(_ *cobra.Command, _ []string) error { + ttl := viper.GetDuration("repo-access-cache-ttl") + httpConfig := ghhttp.ServerConfig{ + Version: version, + Host: viper.GetString("host"), + Port: viper.GetInt("port"), + BaseURL: viper.GetString("base-url"), + ResourcePath: viper.GetString("base-path"), + ExportTranslations: viper.GetBool("export-translations"), + EnableCommandLogging: viper.GetBool("enable-command-logging"), + LogFilePath: viper.GetString("log-file"), + ContentWindowSize: viper.GetInt("content-window-size"), + LockdownMode: viper.GetBool("lockdown-mode"), + RepoAccessCacheTTL: &ttl, + ScopeChallenge: viper.GetBool("scope-challenge"), + } + + return ghhttp.RunHTTPServer(httpConfig) + }, + } ) func init() { @@ -112,6 +138,12 @@ func init() { rootCmd.PersistentFlags().Bool("insiders", false, "Enable insiders features") rootCmd.PersistentFlags().Duration("repo-access-cache-ttl", 5*time.Minute, "Override the repo access cache TTL (e.g. 1m, 0s to disable)") + // HTTP-specific flags + httpCmd.Flags().Int("port", 8082, "HTTP server port") + httpCmd.Flags().String("base-url", "", "Base URL where this server is publicly accessible (for OAuth resource metadata)") + httpCmd.Flags().String("base-path", "", "Externally visible base path for the HTTP server (for OAuth resource metadata)") + httpCmd.Flags().Bool("scope-challenge", false, "Enable OAuth scope challenge responses") + // Bind flag to viper _ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets")) _ = viper.BindPFlag("tools", rootCmd.PersistentFlags().Lookup("tools")) @@ -126,9 +158,13 @@ func init() { _ = viper.BindPFlag("lockdown-mode", rootCmd.PersistentFlags().Lookup("lockdown-mode")) _ = viper.BindPFlag("insiders", rootCmd.PersistentFlags().Lookup("insiders")) _ = viper.BindPFlag("repo-access-cache-ttl", rootCmd.PersistentFlags().Lookup("repo-access-cache-ttl")) - + _ = viper.BindPFlag("port", httpCmd.Flags().Lookup("port")) + _ = viper.BindPFlag("base-url", httpCmd.Flags().Lookup("base-url")) + _ = viper.BindPFlag("base-path", httpCmd.Flags().Lookup("base-path")) + _ = viper.BindPFlag("scope-challenge", httpCmd.Flags().Lookup("scope-challenge")) // Add subcommands rootCmd.AddCommand(stdioCmd) + rootCmd.AddCommand(httpCmd) } func initConfig() { diff --git a/docs/remote-server.md b/docs/remote-server.md index 149667393..cad9ed604 100644 --- a/docs/remote-server.md +++ b/docs/remote-server.md @@ -121,13 +121,15 @@ The Remote GitHub MCP server supports the following URL path patterns: - `/` - Default toolset (see ["default" toolset](../README.md#default-toolset)) - `/readonly` - Default toolset in read-only mode - `/insiders` - Default toolset with insiders mode enabled -- `/insiders/readonly` - Default toolset with insiders mode in read-only mode +- `/readonly/insiders` - Default toolset in read-only mode with insiders mode enabled - `/x/all` - All available toolsets - `/x/all/readonly` - All available toolsets in read-only mode - `/x/all/insiders` - All available toolsets with insiders mode enabled +- `/x/all/readonly/insiders` - All available toolsets in read-only mode with insiders mode enabled - `/x/{toolset}` - Single specific toolset - `/x/{toolset}/readonly` - Single specific toolset in read-only mode - `/x/{toolset}/insiders` - Single specific toolset with insiders mode enabled +- `/x/{toolset}/readonly/insiders` - Single specific toolset in read-only mode with insiders mode enabled Note: `{toolset}` can only be a single toolset, not a comma-separated list. To combine multiple toolsets, use the `X-MCP-Toolsets` header instead. Path modifiers like `/readonly` and `/insiders` can be combined with the `X-MCP-Insiders` or `X-MCP-Readonly` headers. diff --git a/go.mod b/go.mod index 10bbde9d1..c6c6e2967 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/aymerick/douceur v0.2.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-chi/chi/v5 v5.2.3 github.com/go-openapi/jsonpointer v0.19.5 // indirect github.com/go-openapi/swag v0.21.1 // indirect github.com/go-viper/mapstructure/v2 v2.5.0 diff --git a/go.sum b/go.sum index b364f2ef3..d525cb0a1 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= +github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index b6e744d3a..1fd56b7ab 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -6,7 +6,6 @@ import ( "io" "log/slog" "net/http" - "net/url" "os" "os/signal" "strings" @@ -15,69 +14,19 @@ import ( "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/http/transport" "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/lockdown" mcplog "github.com/github/github-mcp-server/pkg/log" "github.com/github/github-mcp-server/pkg/raw" "github.com/github/github-mcp-server/pkg/scopes" "github.com/github/github-mcp-server/pkg/translations" + "github.com/github/github-mcp-server/pkg/utils" gogithub "github.com/google/go-github/v79/github" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/shurcooL/githubv4" ) -type MCPServerConfig struct { - // Version of the server - Version string - - // GitHub Host to target for API requests (e.g. github.com or github.enterprise.com) - Host string - - // GitHub Token to authenticate with the GitHub API - Token string - - // EnabledToolsets is a list of toolsets to enable - // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#tool-configuration - EnabledToolsets []string - - // EnabledTools is a list of specific tools to enable (additive to toolsets) - // When specified, these tools are registered in addition to any specified toolset tools - EnabledTools []string - - // EnabledFeatures is a list of feature flags that are enabled - // Items with FeatureFlagEnable matching an entry in this list will be available - EnabledFeatures []string - - // Whether to enable dynamic toolsets - // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#dynamic-tool-discovery - DynamicToolsets bool - - // ReadOnly indicates if we should only offer read-only tools - ReadOnly bool - - // Translator provides translated text for the server tooling - Translator translations.TranslationHelperFunc - - // Content window size - ContentWindowSize int - - // LockdownMode indicates if we should enable lockdown mode - LockdownMode bool - - // InsidersMode indicates if we should enable experimental features - InsidersMode bool - - // Logger is used for logging within the server - Logger *slog.Logger - // RepoAccessTTL overrides the default TTL for repository access cache entries. - RepoAccessTTL *time.Duration - - // TokenScopes contains the OAuth scopes available to the token. - // When non-nil, tools requiring scopes not in this list will be hidden. - // This is used for PAT scope filtering where we can't issue scope challenges. - TokenScopes []string -} - // githubClients holds all the GitHub API clients created for a server instance. type githubClients struct { rest *gogithub.Client @@ -88,27 +37,48 @@ type githubClients struct { } // createGitHubClients creates all the GitHub API clients needed by the server. -func createGitHubClients(cfg MCPServerConfig, apiHost apiHost) (*githubClients, error) { +func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolver) (*githubClients, error) { + restURL, err := apiHost.BaseRESTURL(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get base REST URL: %w", err) + } + + uploadURL, err := apiHost.UploadURL(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get upload URL: %w", err) + } + + graphQLURL, err := apiHost.GraphqlURL(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get GraphQL URL: %w", err) + } + + rawURL, err := apiHost.RawURL(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get Raw URL: %w", err) + } + // Construct REST client restClient := gogithub.NewClient(nil).WithAuthToken(cfg.Token) restClient.UserAgent = fmt.Sprintf("github-mcp-server/%s", cfg.Version) - restClient.BaseURL = apiHost.baseRESTURL - restClient.UploadURL = apiHost.uploadURL + restClient.BaseURL = restURL + restClient.UploadURL = uploadURL // Construct GraphQL client // We use NewEnterpriseClient unconditionally since we already parsed the API host gqlHTTPClient := &http.Client{ - Transport: &bearerAuthTransport{ - transport: &github.GraphQLFeaturesTransport{ + Transport: &transport.BearerAuthTransport{ + Transport: &transport.GraphQLFeaturesTransport{ Transport: http.DefaultTransport, }, - token: cfg.Token, + Token: cfg.Token, }, } - gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient) + + gqlClient := githubv4.NewEnterpriseClient(graphQLURL.String(), gqlHTTPClient) // Create raw content client (shares REST client's HTTP transport) - rawClient := raw.NewClient(restClient, apiHost.rawURL) + rawClient := raw.NewClient(restClient, rawURL) // Set up repo access cache for lockdown mode var repoAccessCache *lockdown.RepoAccessCache @@ -131,35 +101,8 @@ func createGitHubClients(cfg MCPServerConfig, apiHost apiHost) (*githubClients, }, nil } -// resolveEnabledToolsets determines which toolsets should be enabled based on config. -// Returns nil for "use defaults", empty slice for "none", or explicit list. -func resolveEnabledToolsets(cfg MCPServerConfig) []string { - enabledToolsets := cfg.EnabledToolsets - - // In dynamic mode, remove "all" and "default" since users enable toolsets on demand - if cfg.DynamicToolsets && enabledToolsets != nil { - enabledToolsets = github.RemoveToolset(enabledToolsets, string(github.ToolsetMetadataAll.ID)) - enabledToolsets = github.RemoveToolset(enabledToolsets, string(github.ToolsetMetadataDefault.ID)) - } - - if enabledToolsets != nil { - return enabledToolsets - } - if cfg.DynamicToolsets { - // Dynamic mode with no toolsets specified: start empty so users enable on demand - return []string{} - } - if len(cfg.EnabledTools) > 0 { - // When specific tools are requested but no toolsets, don't use default toolsets - // This matches the original behavior: --tools=X alone registers only X - return []string{} - } - // nil means "use defaults" in WithToolsets - return nil -} - -func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { - apiHost, err := parseAPIHost(cfg.Host) +func NewStdioMCPServer(ctx context.Context, cfg github.MCPServerConfig) (*mcp.Server, error) { + apiHost, err := utils.NewAPIHost(cfg.Host) if err != nil { return nil, fmt.Errorf("failed to parse API host: %w", err) } @@ -169,55 +112,9 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { return nil, fmt.Errorf("failed to create GitHub clients: %w", err) } - enabledToolsets := resolveEnabledToolsets(cfg) - // Create feature checker featureChecker := createFeatureChecker(cfg.EnabledFeatures) - // Build and register the tool/resource/prompt inventory - inventoryBuilder := github.NewInventory(cfg.Translator). - WithDeprecatedAliases(github.DeprecatedToolAliases). - WithReadOnly(cfg.ReadOnly). - WithToolsets(enabledToolsets). - WithTools(cfg.EnabledTools). - WithFeatureChecker(featureChecker). - WithServerInstructions() - - // Apply token scope filtering if scopes are known (for PAT filtering) - if cfg.TokenScopes != nil { - inventoryBuilder = inventoryBuilder.WithFilter(github.CreateToolScopeFilter(cfg.TokenScopes)) - } - - inventory, err := inventoryBuilder.Build() - if err != nil { - return nil, fmt.Errorf("failed to build inventory: %w", err) - } - - // Create the MCP server - serverOpts := &mcp.ServerOptions{ - Instructions: inventory.Instructions(), - Logger: cfg.Logger, - CompletionHandler: github.CompletionsHandler(func(_ context.Context) (*gogithub.Client, error) { - return clients.rest, nil - }), - } - - // In dynamic mode, explicitly advertise capabilities since tools/resources/prompts - // may be enabled at runtime even if none are registered initially. - if cfg.DynamicToolsets { - serverOpts.Capabilities = &mcp.ServerCapabilities{ - Tools: &mcp.ToolCapabilities{}, - Resources: &mcp.ResourceCapabilities{}, - Prompts: &mcp.PromptCapabilities{}, - } - } - - ghServer := github.NewServer(cfg.Version, serverOpts) - - // Add middlewares - ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext) - ghServer.AddReceivingMiddleware(addUserAgentsMiddleware(cfg, clients.rest, clients.gqlHTTP)) - // Create dependencies for tool handlers deps := github.NewBaseDeps( clients.rest, @@ -232,58 +129,33 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { cfg.ContentWindowSize, featureChecker, ) + // Build and register the tool/resource/prompt inventory + inventoryBuilder := github.NewInventory(cfg.Translator). + WithDeprecatedAliases(github.DeprecatedToolAliases). + WithReadOnly(cfg.ReadOnly). + WithToolsets(github.ResolvedEnabledToolsets(cfg.DynamicToolsets, cfg.EnabledToolsets, cfg.EnabledTools)). + WithTools(github.CleanTools(cfg.EnabledTools)). + WithServerInstructions(). + WithFeatureChecker(featureChecker) - // Inject dependencies into context for all tool handlers - ghServer.AddReceivingMiddleware(func(next mcp.MethodHandler) mcp.MethodHandler { - return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { - return next(github.ContextWithDeps(ctx, deps), method, req) - } - }) - - if unrecognized := inventory.UnrecognizedToolsets(); len(unrecognized) > 0 { - fmt.Fprintf(os.Stderr, "Warning: unrecognized toolsets ignored: %s\n", strings.Join(unrecognized, ", ")) + // Apply token scope filtering if scopes are known (for PAT filtering) + if cfg.TokenScopes != nil { + inventoryBuilder = inventoryBuilder.WithFilter(github.CreateToolScopeFilter(cfg.TokenScopes)) } - // Register GitHub tools/resources/prompts from the inventory. - // In dynamic mode with no explicit toolsets, this is a no-op since enabledToolsets - // is empty - users enable toolsets at runtime via the dynamic tools below (but can - // enable toolsets or tools explicitly that do need registration). - inventory.RegisterAll(context.Background(), ghServer, deps) - - // Register dynamic toolset management tools (enable/disable) - these are separate - // meta-tools that control the inventory, not part of the inventory itself - if cfg.DynamicToolsets { - registerDynamicTools(ghServer, inventory, deps, cfg.Translator) + inventory, err := inventoryBuilder.Build() + if err != nil { + return nil, fmt.Errorf("failed to build inventory: %w", err) } - return ghServer, nil -} - -// registerDynamicTools adds the dynamic toolset enable/disable tools to the server. -func registerDynamicTools(server *mcp.Server, inventory *inventory.Inventory, deps *github.BaseDeps, t translations.TranslationHelperFunc) { - dynamicDeps := github.DynamicToolDependencies{ - Server: server, - Inventory: inventory, - ToolDeps: deps, - T: t, - } - for _, tool := range github.DynamicTools(inventory) { - tool.RegisterFunc(server, dynamicDeps) + ghServer, err := github.NewMCPServer(ctx, &cfg, deps, inventory) + if err != nil { + return nil, fmt.Errorf("failed to create GitHub MCP server: %w", err) } -} -// createFeatureChecker returns a FeatureFlagChecker that checks if a flag name -// is present in the provided list of enabled features. For the local server, -// this is populated from the --features CLI flag. -func createFeatureChecker(enabledFeatures []string) inventory.FeatureFlagChecker { - // Build a set for O(1) lookup - featureSet := make(map[string]bool, len(enabledFeatures)) - for _, f := range enabledFeatures { - featureSet[f] = true - } - return func(_ context.Context, flagName string) (bool, error) { - return featureSet[flagName], nil - } + ghServer.AddReceivingMiddleware(addUserAgentsMiddleware(cfg, clients.rest, clients.gqlHTTP)) + + return ghServer, nil } type StdioServerConfig struct { @@ -378,7 +250,7 @@ func RunStdioServer(cfg StdioServerConfig) error { logger.Debug("skipping scope filtering for non-PAT token") } - ghServer, err := NewMCPServer(MCPServerConfig{ + ghServer, err := NewStdioMCPServer(ctx, github.MCPServerConfig{ Version: cfg.Version, Host: cfg.Host, Token: cfg.Token, @@ -440,214 +312,21 @@ func RunStdioServer(cfg StdioServerConfig) error { return nil } -type apiHost struct { - baseRESTURL *url.URL - graphqlURL *url.URL - uploadURL *url.URL - rawURL *url.URL -} - -func newDotcomHost() (apiHost, error) { - baseRestURL, err := url.Parse("https://api.github.com/") - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse dotcom REST URL: %w", err) - } - - gqlURL, err := url.Parse("https://api.github.com/graphql") - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse dotcom GraphQL URL: %w", err) - } - - uploadURL, err := url.Parse("https://uploads.github.com") - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse dotcom Upload URL: %w", err) - } - - rawURL, err := url.Parse("https://raw.githubusercontent.com/") - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse dotcom Raw URL: %w", err) - } - - return apiHost{ - baseRESTURL: baseRestURL, - graphqlURL: gqlURL, - uploadURL: uploadURL, - rawURL: rawURL, - }, nil -} - -func newGHECHost(hostname string) (apiHost, error) { - u, err := url.Parse(hostname) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHEC URL: %w", err) - } - - // Unsecured GHEC would be an error - if u.Scheme == "http" { - return apiHost{}, fmt.Errorf("GHEC URL must be HTTPS") - } - - restURL, err := url.Parse(fmt.Sprintf("https://api.%s/", u.Hostname())) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHEC REST URL: %w", err) - } - - gqlURL, err := url.Parse(fmt.Sprintf("https://api.%s/graphql", u.Hostname())) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHEC GraphQL URL: %w", err) - } - - uploadURL, err := url.Parse(fmt.Sprintf("https://uploads.%s/", u.Hostname())) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHEC Upload URL: %w", err) - } - - rawURL, err := url.Parse(fmt.Sprintf("https://raw.%s/", u.Hostname())) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHEC Raw URL: %w", err) - } - - return apiHost{ - baseRESTURL: restURL, - graphqlURL: gqlURL, - uploadURL: uploadURL, - rawURL: rawURL, - }, nil -} - -func newGHESHost(hostname string) (apiHost, error) { - u, err := url.Parse(hostname) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHES URL: %w", err) - } - - restURL, err := url.Parse(fmt.Sprintf("%s://%s/api/v3/", u.Scheme, u.Hostname())) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHES REST URL: %w", err) - } - - gqlURL, err := url.Parse(fmt.Sprintf("%s://%s/api/graphql", u.Scheme, u.Hostname())) - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHES GraphQL URL: %w", err) - } - - // Check if subdomain isolation is enabled - // See https://docs.github.com/en/enterprise-server@3.17/admin/configuring-settings/hardening-security-for-your-enterprise/enabling-subdomain-isolation#about-subdomain-isolation - hasSubdomainIsolation := checkSubdomainIsolation(u.Scheme, u.Hostname()) - - var uploadURL *url.URL - if hasSubdomainIsolation { - // With subdomain isolation: https://uploads.hostname/ - uploadURL, err = url.Parse(fmt.Sprintf("%s://uploads.%s/", u.Scheme, u.Hostname())) - } else { - // Without subdomain isolation: https://hostname/api/uploads/ - uploadURL, err = url.Parse(fmt.Sprintf("%s://%s/api/uploads/", u.Scheme, u.Hostname())) - } - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHES Upload URL: %w", err) - } - - var rawURL *url.URL - if hasSubdomainIsolation { - // With subdomain isolation: https://raw.hostname/ - rawURL, err = url.Parse(fmt.Sprintf("%s://raw.%s/", u.Scheme, u.Hostname())) - } else { - // Without subdomain isolation: https://hostname/raw/ - rawURL, err = url.Parse(fmt.Sprintf("%s://%s/raw/", u.Scheme, u.Hostname())) - } - if err != nil { - return apiHost{}, fmt.Errorf("failed to parse GHES Raw URL: %w", err) - } - - return apiHost{ - baseRESTURL: restURL, - graphqlURL: gqlURL, - uploadURL: uploadURL, - rawURL: rawURL, - }, nil -} - -// checkSubdomainIsolation detects if GitHub Enterprise Server has subdomain isolation enabled -// by attempting to ping the raw./_ping endpoint on the subdomain. The raw subdomain must always exist for subdomain isolation. -func checkSubdomainIsolation(scheme, hostname string) bool { - subdomainURL := fmt.Sprintf("%s://raw.%s/_ping", scheme, hostname) - - client := &http.Client{ - Timeout: 5 * time.Second, - // Don't follow redirects - we just want to check if the endpoint exists - //nolint:revive // parameters are required by http.Client.CheckRedirect signature - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - - resp, err := client.Get(subdomainURL) - if err != nil { - return false - } - defer resp.Body.Close() - - return resp.StatusCode == http.StatusOK -} - -// Note that this does not handle ports yet, so development environments are out. -func parseAPIHost(s string) (apiHost, error) { - if s == "" { - return newDotcomHost() - } - - u, err := url.Parse(s) - if err != nil { - return apiHost{}, fmt.Errorf("could not parse host as URL: %s", s) - } - - if u.Scheme == "" { - return apiHost{}, fmt.Errorf("host must have a scheme (http or https): %s", s) - } - - if strings.HasSuffix(u.Hostname(), "github.com") { - return newDotcomHost() - } - - if strings.HasSuffix(u.Hostname(), "ghe.com") { - return newGHECHost(s) +// createFeatureChecker returns a FeatureFlagChecker that checks if a flag name +// is present in the provided list of enabled features. For the local server, +// this is populated from the --features CLI flag. +func createFeatureChecker(enabledFeatures []string) inventory.FeatureFlagChecker { + // Build a set for O(1) lookup + featureSet := make(map[string]bool, len(enabledFeatures)) + for _, f := range enabledFeatures { + featureSet[f] = true } - - return newGHESHost(s) -} - -type userAgentTransport struct { - transport http.RoundTripper - agent string -} - -func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req = req.Clone(req.Context()) - req.Header.Set("User-Agent", t.agent) - return t.transport.RoundTrip(req) -} - -type bearerAuthTransport struct { - transport http.RoundTripper - token string -} - -func (t *bearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req = req.Clone(req.Context()) - req.Header.Set("Authorization", "Bearer "+t.token) - return t.transport.RoundTrip(req) -} - -func addGitHubAPIErrorToContext(next mcp.MethodHandler) mcp.MethodHandler { - return func(ctx context.Context, method string, req mcp.Request) (result mcp.Result, err error) { - // Ensure the context is cleared of any previous errors - // as context isn't propagated through middleware - ctx = errors.ContextWithGitHubErrors(ctx) - return next(ctx, method, req) + return func(_ context.Context, flagName string) (bool, error) { + return featureSet[flagName], nil } } -func addUserAgentsMiddleware(cfg MCPServerConfig, restClient *gogithub.Client, gqlHTTPClient *http.Client) func(next mcp.MethodHandler) mcp.MethodHandler { +func addUserAgentsMiddleware(cfg github.MCPServerConfig, restClient *gogithub.Client, gqlHTTPClient *http.Client) func(next mcp.MethodHandler) mcp.MethodHandler { return func(next mcp.MethodHandler) mcp.MethodHandler { return func(ctx context.Context, method string, request mcp.Request) (result mcp.Result, err error) { if method != "initialize" { @@ -669,9 +348,9 @@ func addUserAgentsMiddleware(cfg MCPServerConfig, restClient *gogithub.Client, g restClient.UserAgent = userAgent - gqlHTTPClient.Transport = &userAgentTransport{ - transport: gqlHTTPClient.Transport, - agent: userAgent, + gqlHTTPClient.Transport = &transport.UserAgentTransport{ + Transport: gqlHTTPClient.Transport, + Agent: userAgent, } return next(ctx, method, request) @@ -682,14 +361,12 @@ func addUserAgentsMiddleware(cfg MCPServerConfig, restClient *gogithub.Client, g // fetchTokenScopesForHost fetches the OAuth scopes for a token from the GitHub API. // It constructs the appropriate API host URL based on the configured host. func fetchTokenScopesForHost(ctx context.Context, token, host string) ([]string, error) { - apiHost, err := parseAPIHost(host) + apiHost, err := utils.NewAPIHost(host) if err != nil { return nil, fmt.Errorf("failed to parse API host: %w", err) } - fetcher := scopes.NewFetcher(scopes.FetcherOptions{ - APIHost: apiHost.baseRESTURL.String(), - }) + fetcher := scopes.NewFetcher(apiHost, scopes.FetcherOptions{}) return fetcher.FetchTokenScopes(ctx, token) } diff --git a/internal/ghmcp/server_test.go b/internal/ghmcp/server_test.go index 2139aa280..6f0e3ac3f 100644 --- a/internal/ghmcp/server_test.go +++ b/internal/ghmcp/server_test.go @@ -1,113 +1 @@ package ghmcp - -import ( - "testing" - - "github.com/github/github-mcp-server/pkg/translations" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// TestNewMCPServer_CreatesSuccessfully verifies that the server can be created -// with the deps injection middleware properly configured. -func TestNewMCPServer_CreatesSuccessfully(t *testing.T) { - t.Parallel() - - // Create a minimal server configuration - cfg := MCPServerConfig{ - Version: "test", - Host: "", // defaults to github.com - Token: "test-token", - EnabledToolsets: []string{"context"}, - ReadOnly: false, - Translator: translations.NullTranslationHelper, - ContentWindowSize: 5000, - LockdownMode: false, - InsidersMode: false, - } - - // Create the server - server, err := NewMCPServer(cfg) - require.NoError(t, err, "expected server creation to succeed") - require.NotNil(t, server, "expected server to be non-nil") - - // The fact that the server was created successfully indicates that: - // 1. The deps injection middleware is properly added - // 2. Tools can be registered without panicking - // - // If the middleware wasn't properly added, tool calls would panic with - // "ToolDependencies not found in context" when executed. - // - // The actual middleware functionality and tool execution with ContextWithDeps - // is already tested in pkg/github/*_test.go. -} - -// TestResolveEnabledToolsets verifies the toolset resolution logic. -func TestResolveEnabledToolsets(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - cfg MCPServerConfig - expectedResult []string - }{ - { - name: "nil toolsets without dynamic mode and no tools - use defaults", - cfg: MCPServerConfig{ - EnabledToolsets: nil, - DynamicToolsets: false, - EnabledTools: nil, - }, - expectedResult: nil, // nil means "use defaults" - }, - { - name: "nil toolsets with dynamic mode - start empty", - cfg: MCPServerConfig{ - EnabledToolsets: nil, - DynamicToolsets: true, - EnabledTools: nil, - }, - expectedResult: []string{}, // empty slice means no toolsets - }, - { - name: "explicit toolsets", - cfg: MCPServerConfig{ - EnabledToolsets: []string{"repos", "issues"}, - DynamicToolsets: false, - }, - expectedResult: []string{"repos", "issues"}, - }, - { - name: "empty toolsets - disable all", - cfg: MCPServerConfig{ - EnabledToolsets: []string{}, - DynamicToolsets: false, - }, - expectedResult: []string{}, // empty slice means no toolsets - }, - { - name: "specific tools without toolsets - no default toolsets", - cfg: MCPServerConfig{ - EnabledToolsets: nil, - DynamicToolsets: false, - EnabledTools: []string{"get_me"}, - }, - expectedResult: []string{}, // empty slice when tools specified but no toolsets - }, - { - name: "dynamic mode with explicit toolsets removes all and default", - cfg: MCPServerConfig{ - EnabledToolsets: []string{"all", "repos"}, - DynamicToolsets: true, - }, - expectedResult: []string{"repos"}, // "all" is removed in dynamic mode - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result := resolveEnabledToolsets(tc.cfg) - assert.Equal(t, tc.expectedResult, result) - }) - } -} diff --git a/pkg/context/graphql_features.go b/pkg/context/graphql_features.go new file mode 100644 index 000000000..ebba3f757 --- /dev/null +++ b/pkg/context/graphql_features.go @@ -0,0 +1,19 @@ +package context + +import "context" + +// graphQLFeaturesKey is a context key for GraphQL feature flags +type graphQLFeaturesKey struct{} + +// withGraphQLFeatures adds GraphQL feature flags to the context +func WithGraphQLFeatures(ctx context.Context, features ...string) context.Context { + return context.WithValue(ctx, graphQLFeaturesKey{}, features) +} + +// GetGraphQLFeatures retrieves GraphQL feature flags from the context +func GetGraphQLFeatures(ctx context.Context) []string { + if features, ok := ctx.Value(graphQLFeaturesKey{}).([]string); ok { + return features + } + return nil +} diff --git a/pkg/context/mcp_info.go b/pkg/context/mcp_info.go new file mode 100644 index 000000000..ce5505682 --- /dev/null +++ b/pkg/context/mcp_info.go @@ -0,0 +1,39 @@ +package context + +import "context" + +type mcpMethodInfoCtx string + +var mcpMethodInfoCtxKey mcpMethodInfoCtx = "mcpmethodinfo" + +// MCPMethodInfo contains pre-parsed MCP method information extracted from the JSON-RPC request. +// This is populated early in the request lifecycle to enable: +// - Inventory filtering via ForMCPRequest (only register needed tools/resources/prompts) +// - Avoiding duplicate JSON parsing in middlewares (secret-scanning, scope-challenge) +// - Performance optimization for per-request server creation +type MCPMethodInfo struct { + // Method is the MCP method being called (e.g., "tools/call", "tools/list", "initialize") + Method string + // ItemName is the name of the specific item being accessed (tool name, resource URI, prompt name) + // Only populated for call/get methods (tools/call, prompts/get, resources/read) + ItemName string + // Owner is the repository owner from tool call arguments, if present + Owner string + // Repo is the repository name from tool call arguments, if present + Repo string + // Arguments contains the raw tool arguments for tools/call requests + Arguments map[string]any +} + +// WithMCPMethodInfo stores the MCPMethodInfo in the context. +func WithMCPMethodInfo(ctx context.Context, info *MCPMethodInfo) context.Context { + return context.WithValue(ctx, mcpMethodInfoCtxKey, info) +} + +// MCPMethod retrieves the MCPMethodInfo from the context. +func MCPMethod(ctx context.Context) (*MCPMethodInfo, bool) { + if info, ok := ctx.Value(mcpMethodInfoCtxKey).(*MCPMethodInfo); ok { + return info, true + } + return nil, false +} diff --git a/pkg/context/request.go b/pkg/context/request.go new file mode 100644 index 000000000..70867f32e --- /dev/null +++ b/pkg/context/request.go @@ -0,0 +1,99 @@ +package context + +import "context" + +// readonlyCtxKey is a context key for read-only mode +type readonlyCtxKey struct{} + +// WithReadonly adds read-only mode state to the context +func WithReadonly(ctx context.Context, enabled bool) context.Context { + return context.WithValue(ctx, readonlyCtxKey{}, enabled) +} + +// IsReadonly retrieves the read-only mode state from the context +func IsReadonly(ctx context.Context) bool { + if enabled, ok := ctx.Value(readonlyCtxKey{}).(bool); ok { + return enabled + } + return false +} + +// toolsetsCtxKey is a context key for the active toolsets +type toolsetsCtxKey struct{} + +// WithToolsets adds the active toolsets to the context +func WithToolsets(ctx context.Context, toolsets []string) context.Context { + return context.WithValue(ctx, toolsetsCtxKey{}, toolsets) +} + +// GetToolsets retrieves the active toolsets from the context +func GetToolsets(ctx context.Context) []string { + if toolsets, ok := ctx.Value(toolsetsCtxKey{}).([]string); ok { + return toolsets + } + return nil +} + +// toolsCtxKey is a context key for tools +type toolsCtxKey struct{} + +// WithTools adds the tools to the context +func WithTools(ctx context.Context, tools []string) context.Context { + return context.WithValue(ctx, toolsCtxKey{}, tools) +} + +// GetTools retrieves the tools from the context +func GetTools(ctx context.Context) []string { + if tools, ok := ctx.Value(toolsCtxKey{}).([]string); ok { + return tools + } + return nil +} + +// lockdownCtxKey is a context key for lockdown mode +type lockdownCtxKey struct{} + +// WithLockdownMode adds lockdown mode state to the context +func WithLockdownMode(ctx context.Context, enabled bool) context.Context { + return context.WithValue(ctx, lockdownCtxKey{}, enabled) +} + +// IsLockdownMode retrieves the lockdown mode state from the context +func IsLockdownMode(ctx context.Context) bool { + if enabled, ok := ctx.Value(lockdownCtxKey{}).(bool); ok { + return enabled + } + return false +} + +// insidersCtxKey is a context key for insiders mode +type insidersCtxKey struct{} + +// WithInsidersMode adds insiders mode state to the context +func WithInsidersMode(ctx context.Context, enabled bool) context.Context { + return context.WithValue(ctx, insidersCtxKey{}, enabled) +} + +// IsInsidersMode retrieves the insiders mode state from the context +func IsInsidersMode(ctx context.Context) bool { + if enabled, ok := ctx.Value(insidersCtxKey{}).(bool); ok { + return enabled + } + return false +} + +// headerFeaturesCtxKey is a context key for raw header feature flags +type headerFeaturesCtxKey struct{} + +// WithHeaderFeatures stores the raw feature flags from the X-MCP-Features header into context +func WithHeaderFeatures(ctx context.Context, features []string) context.Context { + return context.WithValue(ctx, headerFeaturesCtxKey{}, features) +} + +// GetHeaderFeatures retrieves the raw feature flags from context +func GetHeaderFeatures(ctx context.Context) []string { + if features, ok := ctx.Value(headerFeaturesCtxKey{}).([]string); ok { + return features + } + return nil +} diff --git a/pkg/context/token.go b/pkg/context/token.go new file mode 100644 index 000000000..beddb02b2 --- /dev/null +++ b/pkg/context/token.go @@ -0,0 +1,32 @@ +package context + +import ( + "context" + + "github.com/github/github-mcp-server/pkg/utils" +) + +// tokenCtxKey is a context key for authentication token information +type tokenCtx string + +var tokenCtxKey tokenCtx = "tokenctx" + +type TokenInfo struct { + Token string + TokenType utils.TokenType + ScopesFetched bool + Scopes []string +} + +// WithTokenInfo adds TokenInfo to the context +func WithTokenInfo(ctx context.Context, tokenInfo *TokenInfo) context.Context { + return context.WithValue(ctx, tokenCtxKey, tokenInfo) +} + +// GetTokenInfo retrieves the authentication token from the context +func GetTokenInfo(ctx context.Context) (*TokenInfo, bool) { + if tokenInfo, ok := ctx.Value(tokenCtxKey).(*TokenInfo); ok { + return tokenInfo, true + } + return nil, false +} diff --git a/pkg/github/dependencies.go b/pkg/github/dependencies.go index 15d807a24..b16bbee00 100644 --- a/pkg/github/dependencies.go +++ b/pkg/github/dependencies.go @@ -4,13 +4,17 @@ import ( "context" "errors" "fmt" + "net/http" "os" + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/http/transport" "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/raw" "github.com/github/github-mcp-server/pkg/scopes" "github.com/github/github-mcp-server/pkg/translations" + "github.com/github/github-mcp-server/pkg/utils" gogithub "github.com/google/go-github/v79/github" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/shurcooL/githubv4" @@ -23,6 +27,14 @@ type depsContextKey struct{} // ErrDepsNotInContext is returned when ToolDependencies is not found in context. var ErrDepsNotInContext = errors.New("ToolDependencies not found in context; use ContextWithDeps to inject") +func InjectDepsMiddleware(deps ToolDependencies) mcp.Middleware { + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (result mcp.Result, err error) { + return next(ContextWithDeps(ctx, deps), method, req) + } + } +} + // ContextWithDeps returns a new context with the ToolDependencies stored in it. // This is used to inject dependencies at request time rather than at registration time, // avoiding expensive closure creation during server initialization. @@ -69,13 +81,13 @@ type ToolDependencies interface { GetRawClient(ctx context.Context) (*raw.Client, error) // GetRepoAccessCache returns the lockdown mode repo access cache - GetRepoAccessCache() *lockdown.RepoAccessCache + GetRepoAccessCache(ctx context.Context) (*lockdown.RepoAccessCache, error) // GetT returns the translation helper function GetT() translations.TranslationHelperFunc // GetFlags returns feature flags - GetFlags() FeatureFlags + GetFlags(ctx context.Context) FeatureFlags // GetContentWindowSize returns the content window size for log truncation GetContentWindowSize() int @@ -145,13 +157,15 @@ func (d BaseDeps) GetRawClient(_ context.Context) (*raw.Client, error) { } // GetRepoAccessCache implements ToolDependencies. -func (d BaseDeps) GetRepoAccessCache() *lockdown.RepoAccessCache { return d.RepoAccessCache } +func (d BaseDeps) GetRepoAccessCache(_ context.Context) (*lockdown.RepoAccessCache, error) { + return d.RepoAccessCache, nil +} // GetT implements ToolDependencies. func (d BaseDeps) GetT() translations.TranslationHelperFunc { return d.T } // GetFlags implements ToolDependencies. -func (d BaseDeps) GetFlags() FeatureFlags { return d.Flags } +func (d BaseDeps) GetFlags(_ context.Context) FeatureFlags { return d.Flags } // GetContentWindowSize implements ToolDependencies. func (d BaseDeps) GetContentWindowSize() int { return d.ContentWindowSize } @@ -221,3 +235,157 @@ func NewToolFromHandler( st.AcceptedScopes = scopes.ExpandScopes(requiredScopes...) return st } + +type RequestDeps struct { + // Static dependencies + apiHosts utils.APIHostResolver + version string + lockdownMode bool + RepoAccessOpts []lockdown.RepoAccessOption + T translations.TranslationHelperFunc + ContentWindowSize int + + // Feature flag checker for runtime checks + featureChecker inventory.FeatureFlagChecker +} + +// NewRequestDeps creates a RequestDeps with the provided clients and configuration. +func NewRequestDeps( + apiHosts utils.APIHostResolver, + version string, + lockdownMode bool, + repoAccessOpts []lockdown.RepoAccessOption, + t translations.TranslationHelperFunc, + contentWindowSize int, + featureChecker inventory.FeatureFlagChecker, +) *RequestDeps { + return &RequestDeps{ + apiHosts: apiHosts, + version: version, + lockdownMode: lockdownMode, + RepoAccessOpts: repoAccessOpts, + T: t, + ContentWindowSize: contentWindowSize, + featureChecker: featureChecker, + } +} + +// GetClient implements ToolDependencies. +func (d *RequestDeps) GetClient(ctx context.Context) (*gogithub.Client, error) { + // extract the token from the context + tokenInfo, ok := ghcontext.GetTokenInfo(ctx) + if !ok { + return nil, fmt.Errorf("no token info in context") + } + token := tokenInfo.Token + + baseRestURL, err := d.apiHosts.BaseRESTURL(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get base REST URL: %w", err) + } + uploadURL, err := d.apiHosts.UploadURL(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get upload URL: %w", err) + } + + // Construct REST client + restClient := gogithub.NewClient(nil).WithAuthToken(token) + restClient.UserAgent = fmt.Sprintf("github-mcp-server/%s", d.version) + restClient.BaseURL = baseRestURL + restClient.UploadURL = uploadURL + return restClient, nil +} + +// GetGQLClient implements ToolDependencies. +func (d *RequestDeps) GetGQLClient(ctx context.Context) (*githubv4.Client, error) { + // extract the token from the context + tokenInfo, ok := ghcontext.GetTokenInfo(ctx) + if !ok { + return nil, fmt.Errorf("no token info in context") + } + token := tokenInfo.Token + + // Construct GraphQL client + // We use NewEnterpriseClient unconditionally since we already parsed the API host + // Wrap transport with GraphQLFeaturesTransport to inject feature flags from context, + // matching the transport chain used by the remote server. + gqlHTTPClient := &http.Client{ + Transport: &transport.BearerAuthTransport{ + Transport: &transport.GraphQLFeaturesTransport{ + Transport: http.DefaultTransport, + }, + Token: token, + }, + } + + graphqlURL, err := d.apiHosts.GraphqlURL(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GraphQL URL: %w", err) + } + + gqlClient := githubv4.NewEnterpriseClient(graphqlURL.String(), gqlHTTPClient) + return gqlClient, nil +} + +// GetRawClient implements ToolDependencies. +func (d *RequestDeps) GetRawClient(ctx context.Context) (*raw.Client, error) { + client, err := d.GetClient(ctx) + if err != nil { + return nil, err + } + + rawURL, err := d.apiHosts.RawURL(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get Raw URL: %w", err) + } + + rawClient := raw.NewClient(client, rawURL) + + return rawClient, nil +} + +// GetRepoAccessCache implements ToolDependencies. +func (d *RequestDeps) GetRepoAccessCache(ctx context.Context) (*lockdown.RepoAccessCache, error) { + if !d.lockdownMode { + return nil, nil + } + + gqlClient, err := d.GetGQLClient(ctx) + if err != nil { + return nil, err + } + + // Create repo access cache + instance := lockdown.GetInstance(gqlClient, d.RepoAccessOpts...) + return instance, nil +} + +// GetT implements ToolDependencies. +func (d *RequestDeps) GetT() translations.TranslationHelperFunc { return d.T } + +// GetFlags implements ToolDependencies. +func (d *RequestDeps) GetFlags(ctx context.Context) FeatureFlags { + return FeatureFlags{ + LockdownMode: d.lockdownMode && ghcontext.IsLockdownMode(ctx), + InsidersMode: ghcontext.IsInsidersMode(ctx), + } +} + +// GetContentWindowSize implements ToolDependencies. +func (d *RequestDeps) GetContentWindowSize() int { return d.ContentWindowSize } + +// IsFeatureEnabled checks if a feature flag is enabled. +func (d *RequestDeps) IsFeatureEnabled(ctx context.Context, flagName string) bool { + if d.featureChecker == nil || flagName == "" { + return false + } + + enabled, err := d.featureChecker(ctx, flagName) + if err != nil { + // Log error but don't fail the tool - treat as disabled + fmt.Fprintf(os.Stderr, "Feature flag check error for %q: %v\n", flagName, err) + return false + } + + return enabled +} diff --git a/pkg/github/feature_flags_test.go b/pkg/github/feature_flags_test.go index 498c6e487..2f0a435c9 100644 --- a/pkg/github/feature_flags_test.go +++ b/pkg/github/feature_flags_test.go @@ -45,7 +45,7 @@ func HelloWorldTool(t translations.TranslationHelperFunc) inventory.ServerTool { if deps.IsFeatureEnabled(ctx, RemoteMCPEnthusiasticGreeting) { greeting += " Welcome to the future of MCP! 🎉" } - if deps.GetFlags().InsidersMode { + if deps.GetFlags(ctx).InsidersMode { greeting += " Experimental features are enabled! 🚀" } diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 62e1a0bac..c4cc54175 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -9,9 +9,9 @@ import ( "strings" "time" + ghcontext "github.com/github/github-mcp-server/pkg/context" ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/inventory" - "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/octicons" "github.com/github/github-mcp-server/pkg/sanitize" "github.com/github/github-mcp-server/pkg/scopes" @@ -312,13 +312,13 @@ Options are: switch method { case "get": - result, err := GetIssue(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, deps.GetFlags()) + result, err := GetIssue(ctx, client, deps, owner, repo, issueNumber) return result, nil, err case "get_comments": - result, err := GetIssueComments(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, pagination, deps.GetFlags()) + result, err := GetIssueComments(ctx, client, deps, owner, repo, issueNumber, pagination) return result, nil, err case "get_sub_issues": - result, err := GetSubIssues(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, pagination, deps.GetFlags()) + result, err := GetSubIssues(ctx, client, deps, owner, repo, issueNumber, pagination) return result, nil, err case "get_labels": result, err := GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber) @@ -329,7 +329,13 @@ Options are: }) } -func GetIssue(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner string, repo string, issueNumber int, flags FeatureFlags) (*mcp.CallToolResult, error) { +func GetIssue(ctx context.Context, client *github.Client, deps ToolDependencies, owner string, repo string, issueNumber int) (*mcp.CallToolResult, error) { + cache, err := deps.GetRepoAccessCache(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get repo access cache: %w", err) + } + flags := deps.GetFlags(ctx) + issue, resp, err := client.Issues.Get(ctx, owner, repo, issueNumber) if err != nil { return nil, fmt.Errorf("failed to get issue: %w", err) @@ -378,7 +384,13 @@ func GetIssue(ctx context.Context, client *github.Client, cache *lockdown.RepoAc return utils.NewToolResultText(string(r)), nil } -func GetIssueComments(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner string, repo string, issueNumber int, pagination PaginationParams, flags FeatureFlags) (*mcp.CallToolResult, error) { +func GetIssueComments(ctx context.Context, client *github.Client, deps ToolDependencies, owner string, repo string, issueNumber int, pagination PaginationParams) (*mcp.CallToolResult, error) { + cache, err := deps.GetRepoAccessCache(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get repo access cache: %w", err) + } + flags := deps.GetFlags(ctx) + opts := &github.IssueListCommentsOptions{ ListOptions: github.ListOptions{ Page: pagination.Page, @@ -432,7 +444,13 @@ func GetIssueComments(ctx context.Context, client *github.Client, cache *lockdow return utils.NewToolResultText(string(r)), nil } -func GetSubIssues(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner string, repo string, issueNumber int, pagination PaginationParams, featureFlags FeatureFlags) (*mcp.CallToolResult, error) { +func GetSubIssues(ctx context.Context, client *github.Client, deps ToolDependencies, owner string, repo string, issueNumber int, pagination PaginationParams) (*mcp.CallToolResult, error) { + cache, err := deps.GetRepoAccessCache(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get repo access cache: %w", err) + } + featureFlags := deps.GetFlags(ctx) + opts := &github.IssueListOptions{ ListOptions: github.ListOptions{ Page: pagination.Page, @@ -1898,7 +1916,7 @@ func AssignCopilotToIssue(t translations.TranslationHelperFunc) inventory.Server // Add the GraphQL-Features header for the agent assignment API // The header will be read by the HTTP transport if it's configured to do so - ctxWithFeatures := withGraphQLFeatures(ctx, "issues_copilot_assignment_api_support") + ctxWithFeatures := ghcontext.WithGraphQLFeatures(ctx, "issues_copilot_assignment_api_support") // Capture the time before assignment to filter out older PRs during polling assignmentTime := time.Now().UTC() @@ -2096,19 +2114,3 @@ func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) inventory.Ser }, ) } - -// graphQLFeaturesKey is a context key for GraphQL feature flags -type graphQLFeaturesKey struct{} - -// withGraphQLFeatures adds GraphQL feature flags to the context -func withGraphQLFeatures(ctx context.Context, features ...string) context.Context { - return context.WithValue(ctx, graphQLFeaturesKey{}, features) -} - -// GetGraphQLFeatures retrieves GraphQL feature flags from the context -func GetGraphQLFeatures(ctx context.Context) []string { - if features, ok := ctx.Value(graphQLFeaturesKey{}).([]string); ok { - return features - } - return nil -} diff --git a/pkg/github/params.go b/pkg/github/params.go new file mode 100644 index 000000000..42803a392 --- /dev/null +++ b/pkg/github/params.go @@ -0,0 +1,393 @@ +package github + +import ( + "errors" + "fmt" + "strconv" + + "github.com/google/go-github/v79/github" + "github.com/google/jsonschema-go/jsonschema" +) + +// OptionalParamOK is a helper function that can be used to fetch a requested parameter from the request. +// It returns the value, a boolean indicating if the parameter was present, and an error if the type is wrong. +func OptionalParamOK[T any, A map[string]any](args A, p string) (value T, ok bool, err error) { + // Check if the parameter is present in the request + val, exists := args[p] + if !exists { + // Not present, return zero value, false, no error + return + } + + // Check if the parameter is of the expected type + value, ok = val.(T) + if !ok { + // Present but wrong type + err = fmt.Errorf("parameter %s is not of type %T, is %T", p, value, val) + ok = true // Set ok to true because the parameter *was* present, even if wrong type + return + } + + // Present and correct type + ok = true + return +} + +// isAcceptedError checks if the error is an accepted error. +func isAcceptedError(err error) bool { + var acceptedError *github.AcceptedError + return errors.As(err, &acceptedError) +} + +// RequiredParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request. +// 2. Checks if the parameter is of the expected type. +// 3. Checks if the parameter is not empty, i.e: non-zero value +func RequiredParam[T comparable](args map[string]any, p string) (T, error) { + var zero T + + // Check if the parameter is present in the request + if _, ok := args[p]; !ok { + return zero, fmt.Errorf("missing required parameter: %s", p) + } + + // Check if the parameter is of the expected type + val, ok := args[p].(T) + if !ok { + return zero, fmt.Errorf("parameter %s is not of type %T", p, zero) + } + + if val == zero { + return zero, fmt.Errorf("missing required parameter: %s", p) + } + + return val, nil +} + +// RequiredInt is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request. +// 2. Checks if the parameter is of the expected type. +// 3. Checks if the parameter is not empty, i.e: non-zero value +func RequiredInt(args map[string]any, p string) (int, error) { + v, err := RequiredParam[float64](args, p) + if err != nil { + return 0, err + } + return int(v), nil +} + +// RequiredBigInt is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request. +// 2. Checks if the parameter is of the expected type (float64). +// 3. Checks if the parameter is not empty, i.e: non-zero value. +// 4. Validates that the float64 value can be safely converted to int64 without truncation. +func RequiredBigInt(args map[string]any, p string) (int64, error) { + v, err := RequiredParam[float64](args, p) + if err != nil { + return 0, err + } + + result := int64(v) + // Check if converting back produces the same value to avoid silent truncation + if float64(result) != v { + return 0, fmt.Errorf("parameter %s value %f is too large to fit in int64", p, v) + } + return result, nil +} + +// OptionalParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, it checks if the parameter is of the expected type and returns it +func OptionalParam[T any](args map[string]any, p string) (T, error) { + var zero T + + // Check if the parameter is present in the request + if _, ok := args[p]; !ok { + return zero, nil + } + + // Check if the parameter is of the expected type + if _, ok := args[p].(T); !ok { + return zero, fmt.Errorf("parameter %s is not of type %T, is %T", p, zero, args[p]) + } + + return args[p].(T), nil +} + +// OptionalIntParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, it checks if the parameter is of the expected type and returns it +func OptionalIntParam(args map[string]any, p string) (int, error) { + v, err := OptionalParam[float64](args, p) + if err != nil { + return 0, err + } + return int(v), nil +} + +// OptionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request +// similar to optionalIntParam, but it also takes a default value. +func OptionalIntParamWithDefault(args map[string]any, p string, d int) (int, error) { + v, err := OptionalIntParam(args, p) + if err != nil { + return 0, err + } + if v == 0 { + return d, nil + } + return v, nil +} + +// OptionalBoolParamWithDefault is a helper function that can be used to fetch a requested parameter from the request +// similar to optionalBoolParam, but it also takes a default value. +func OptionalBoolParamWithDefault(args map[string]any, p string, d bool) (bool, error) { + _, ok := args[p] + v, err := OptionalParam[bool](args, p) + if err != nil { + return false, err + } + if !ok { + return d, nil + } + return v, nil +} + +// OptionalStringArrayParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, iterates the elements and checks each is a string +func OptionalStringArrayParam(args map[string]any, p string) ([]string, error) { + // Check if the parameter is present in the request + if _, ok := args[p]; !ok { + return []string{}, nil + } + + switch v := args[p].(type) { + case nil: + return []string{}, nil + case []string: + return v, nil + case []any: + strSlice := make([]string, len(v)) + for i, v := range v { + s, ok := v.(string) + if !ok { + return []string{}, fmt.Errorf("parameter %s is not of type string, is %T", p, v) + } + strSlice[i] = s + } + return strSlice, nil + default: + return []string{}, fmt.Errorf("parameter %s could not be coerced to []string, is %T", p, args[p]) + } +} + +func convertStringSliceToBigIntSlice(s []string) ([]int64, error) { + int64Slice := make([]int64, len(s)) + for i, str := range s { + val, err := convertStringToBigInt(str, 0) + if err != nil { + return nil, fmt.Errorf("failed to convert element %d (%s) to int64: %w", i, str, err) + } + int64Slice[i] = val + } + return int64Slice, nil +} + +func convertStringToBigInt(s string, def int64) (int64, error) { + v, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return def, fmt.Errorf("failed to convert string %s to int64: %w", s, err) + } + return v, nil +} + +// OptionalBigIntArrayParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns an empty slice +// 2. If it is present, iterates the elements, checks each is a string, and converts them to int64 values +func OptionalBigIntArrayParam(args map[string]any, p string) ([]int64, error) { + // Check if the parameter is present in the request + if _, ok := args[p]; !ok { + return []int64{}, nil + } + + switch v := args[p].(type) { + case nil: + return []int64{}, nil + case []string: + return convertStringSliceToBigIntSlice(v) + case []any: + int64Slice := make([]int64, len(v)) + for i, v := range v { + s, ok := v.(string) + if !ok { + return []int64{}, fmt.Errorf("parameter %s is not of type string, is %T", p, v) + } + val, err := convertStringToBigInt(s, 0) + if err != nil { + return []int64{}, fmt.Errorf("parameter %s: failed to convert element %d (%s) to int64: %w", p, i, s, err) + } + int64Slice[i] = val + } + return int64Slice, nil + default: + return []int64{}, fmt.Errorf("parameter %s could not be coerced to []int64, is %T", p, args[p]) + } +} + +// WithPagination adds REST API pagination parameters to a tool. +// https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api +func WithPagination(schema *jsonschema.Schema) *jsonschema.Schema { + schema.Properties["page"] = &jsonschema.Schema{ + Type: "number", + Description: "Page number for pagination (min 1)", + Minimum: jsonschema.Ptr(1.0), + } + + schema.Properties["perPage"] = &jsonschema.Schema{ + Type: "number", + Description: "Results per page for pagination (min 1, max 100)", + Minimum: jsonschema.Ptr(1.0), + Maximum: jsonschema.Ptr(100.0), + } + + return schema +} + +// WithUnifiedPagination adds REST API pagination parameters to a tool. +// GraphQL tools will use this and convert page/perPage to GraphQL cursor parameters internally. +func WithUnifiedPagination(schema *jsonschema.Schema) *jsonschema.Schema { + schema.Properties["page"] = &jsonschema.Schema{ + Type: "number", + Description: "Page number for pagination (min 1)", + Minimum: jsonschema.Ptr(1.0), + } + + schema.Properties["perPage"] = &jsonschema.Schema{ + Type: "number", + Description: "Results per page for pagination (min 1, max 100)", + Minimum: jsonschema.Ptr(1.0), + Maximum: jsonschema.Ptr(100.0), + } + + schema.Properties["after"] = &jsonschema.Schema{ + Type: "string", + Description: "Cursor for pagination. Use the endCursor from the previous page's PageInfo for GraphQL APIs.", + } + + return schema +} + +// WithCursorPagination adds only cursor-based pagination parameters to a tool (no page parameter). +func WithCursorPagination(schema *jsonschema.Schema) *jsonschema.Schema { + schema.Properties["perPage"] = &jsonschema.Schema{ + Type: "number", + Description: "Results per page for pagination (min 1, max 100)", + Minimum: jsonschema.Ptr(1.0), + Maximum: jsonschema.Ptr(100.0), + } + + schema.Properties["after"] = &jsonschema.Schema{ + Type: "string", + Description: "Cursor for pagination. Use the endCursor from the previous page's PageInfo for GraphQL APIs.", + } + + return schema +} + +type PaginationParams struct { + Page int + PerPage int + After string +} + +// OptionalPaginationParams returns the "page", "perPage", and "after" parameters from the request, +// or their default values if not present, "page" default is 1, "perPage" default is 30. +// In future, we may want to make the default values configurable, or even have this +// function returned from `withPagination`, where the defaults are provided alongside +// the min/max values. +func OptionalPaginationParams(args map[string]any) (PaginationParams, error) { + page, err := OptionalIntParamWithDefault(args, "page", 1) + if err != nil { + return PaginationParams{}, err + } + perPage, err := OptionalIntParamWithDefault(args, "perPage", 30) + if err != nil { + return PaginationParams{}, err + } + after, err := OptionalParam[string](args, "after") + if err != nil { + return PaginationParams{}, err + } + return PaginationParams{ + Page: page, + PerPage: perPage, + After: after, + }, nil +} + +// OptionalCursorPaginationParams returns the "perPage" and "after" parameters from the request, +// without the "page" parameter, suitable for cursor-based pagination only. +func OptionalCursorPaginationParams(args map[string]any) (CursorPaginationParams, error) { + perPage, err := OptionalIntParamWithDefault(args, "perPage", 30) + if err != nil { + return CursorPaginationParams{}, err + } + after, err := OptionalParam[string](args, "after") + if err != nil { + return CursorPaginationParams{}, err + } + return CursorPaginationParams{ + PerPage: perPage, + After: after, + }, nil +} + +type CursorPaginationParams struct { + PerPage int + After string +} + +// ToGraphQLParams converts cursor pagination parameters to GraphQL-specific parameters. +func (p CursorPaginationParams) ToGraphQLParams() (*GraphQLPaginationParams, error) { + if p.PerPage > 100 { + return nil, fmt.Errorf("perPage value %d exceeds maximum of 100", p.PerPage) + } + if p.PerPage < 0 { + return nil, fmt.Errorf("perPage value %d cannot be negative", p.PerPage) + } + first := int32(p.PerPage) + + var after *string + if p.After != "" { + after = &p.After + } + + return &GraphQLPaginationParams{ + First: &first, + After: after, + }, nil +} + +type GraphQLPaginationParams struct { + First *int32 + After *string +} + +// ToGraphQLParams converts REST API pagination parameters to GraphQL-specific parameters. +// This converts page/perPage to first parameter for GraphQL queries. +// If After is provided, it takes precedence over page-based pagination. +func (p PaginationParams) ToGraphQLParams() (*GraphQLPaginationParams, error) { + // Convert to CursorPaginationParams and delegate to avoid duplication + cursor := CursorPaginationParams{ + PerPage: p.PerPage, + After: p.After, + } + return cursor.ToGraphQLParams() +} diff --git a/pkg/github/params_test.go b/pkg/github/params_test.go new file mode 100644 index 000000000..9d7cfe432 --- /dev/null +++ b/pkg/github/params_test.go @@ -0,0 +1,503 @@ +package github + +import ( + "fmt" + "testing" + + "github.com/google/go-github/v79/github" + "github.com/stretchr/testify/assert" +) + +func Test_IsAcceptedError(t *testing.T) { + tests := []struct { + name string + err error + expectAccepted bool + }{ + { + name: "github AcceptedError", + err: &github.AcceptedError{}, + expectAccepted: true, + }, + { + name: "regular error", + err: fmt.Errorf("some other error"), + expectAccepted: false, + }, + { + name: "nil error", + err: nil, + expectAccepted: false, + }, + { + name: "wrapped AcceptedError", + err: fmt.Errorf("wrapped: %w", &github.AcceptedError{}), + expectAccepted: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := isAcceptedError(tc.err) + assert.Equal(t, tc.expectAccepted, result) + }) + } +} + +func Test_RequiredStringParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected string + expectError bool + }{ + { + name: "valid string parameter", + params: map[string]interface{}{"name": "test-value"}, + paramName: "name", + expected: "test-value", + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "name", + expected: "", + expectError: true, + }, + { + name: "empty string parameter", + params: map[string]interface{}{"name": ""}, + paramName: "name", + expected: "", + expectError: true, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"name": 123}, + paramName: "name", + expected: "", + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := RequiredParam[string](tc.params, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalStringParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected string + expectError bool + }{ + { + name: "valid string parameter", + params: map[string]interface{}{"name": "test-value"}, + paramName: "name", + expected: "test-value", + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "name", + expected: "", + expectError: false, + }, + { + name: "empty string parameter", + params: map[string]interface{}{"name": ""}, + paramName: "name", + expected: "", + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"name": 123}, + paramName: "name", + expected: "", + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := OptionalParam[string](tc.params, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_RequiredInt(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected int + expectError bool + }{ + { + name: "valid number parameter", + params: map[string]interface{}{"count": float64(42)}, + paramName: "count", + expected: 42, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "count", + expected: 0, + expectError: true, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"count": "not-a-number"}, + paramName: "count", + expected: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := RequiredInt(tc.params, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} +func Test_OptionalIntParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected int + expectError bool + }{ + { + name: "valid number parameter", + params: map[string]interface{}{"count": float64(42)}, + paramName: "count", + expected: 42, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "count", + expected: 0, + expectError: false, + }, + { + name: "zero value", + params: map[string]interface{}{"count": float64(0)}, + paramName: "count", + expected: 0, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"count": "not-a-number"}, + paramName: "count", + expected: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := OptionalIntParam(tc.params, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalNumberParamWithDefault(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + defaultVal int + expected int + expectError bool + }{ + { + name: "valid number parameter", + params: map[string]interface{}{"count": float64(42)}, + paramName: "count", + defaultVal: 10, + expected: 42, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "count", + defaultVal: 10, + expected: 10, + expectError: false, + }, + { + name: "zero value", + params: map[string]interface{}{"count": float64(0)}, + paramName: "count", + defaultVal: 10, + expected: 10, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"count": "not-a-number"}, + paramName: "count", + defaultVal: 10, + expected: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := OptionalIntParamWithDefault(tc.params, tc.paramName, tc.defaultVal) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalBooleanParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected bool + expectError bool + }{ + { + name: "true value", + params: map[string]interface{}{"flag": true}, + paramName: "flag", + expected: true, + expectError: false, + }, + { + name: "false value", + params: map[string]interface{}{"flag": false}, + paramName: "flag", + expected: false, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "flag", + expected: false, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"flag": "not-a-boolean"}, + paramName: "flag", + expected: false, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := OptionalParam[bool](tc.params, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func TestOptionalStringArrayParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected []string + expectError bool + }{ + { + name: "parameter not in request", + params: map[string]any{}, + paramName: "flag", + expected: []string{}, + expectError: false, + }, + { + name: "valid any array parameter", + params: map[string]any{ + "flag": []any{"v1", "v2"}, + }, + paramName: "flag", + expected: []string{"v1", "v2"}, + expectError: false, + }, + { + name: "valid string array parameter", + params: map[string]any{ + "flag": []string{"v1", "v2"}, + }, + paramName: "flag", + expected: []string{"v1", "v2"}, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]any{ + "flag": 1, + }, + paramName: "flag", + expected: []string{}, + expectError: true, + }, + { + name: "wrong slice type parameter", + params: map[string]any{ + "flag": []any{"foo", 2}, + }, + paramName: "flag", + expected: []string{}, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := OptionalStringArrayParam(tc.params, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func TestOptionalPaginationParams(t *testing.T) { + tests := []struct { + name string + params map[string]any + expected PaginationParams + expectError bool + }{ + { + name: "no pagination parameters, default values", + params: map[string]any{}, + expected: PaginationParams{ + Page: 1, + PerPage: 30, + }, + expectError: false, + }, + { + name: "page parameter, default perPage", + params: map[string]any{ + "page": float64(2), + }, + expected: PaginationParams{ + Page: 2, + PerPage: 30, + }, + expectError: false, + }, + { + name: "perPage parameter, default page", + params: map[string]any{ + "perPage": float64(50), + }, + expected: PaginationParams{ + Page: 1, + PerPage: 50, + }, + expectError: false, + }, + { + name: "page and perPage parameters", + params: map[string]any{ + "page": float64(2), + "perPage": float64(50), + }, + expected: PaginationParams{ + Page: 2, + PerPage: 50, + }, + expectError: false, + }, + { + name: "invalid page parameter", + params: map[string]any{ + "page": "not-a-number", + }, + expected: PaginationParams{}, + expectError: true, + }, + { + name: "invalid perPage parameter", + params: map[string]any{ + "perPage": "not-a-number", + }, + expected: PaginationParams{}, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := OptionalPaginationParams(tc.params) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 308d2eb8b..a11fe29a5 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -15,7 +15,6 @@ import ( ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/inventory" - "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/octicons" "github.com/github/github-mcp-server/pkg/sanitize" "github.com/github/github-mcp-server/pkg/scopes" @@ -101,7 +100,7 @@ Possible options: switch method { case "get": - result, err := GetPullRequest(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, deps.GetFlags()) + result, err := GetPullRequest(ctx, client, deps, owner, repo, pullNumber) return result, nil, err case "get_diff": result, err := GetPullRequestDiff(ctx, client, owner, repo, pullNumber) @@ -121,13 +120,13 @@ Possible options: if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } - result, err := GetPullRequestReviewComments(ctx, gqlClient, deps.GetRepoAccessCache(), owner, repo, pullNumber, cursorPagination, deps.GetFlags()) + result, err := GetPullRequestReviewComments(ctx, gqlClient, deps, owner, repo, pullNumber, cursorPagination) return result, nil, err case "get_reviews": - result, err := GetPullRequestReviews(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, deps.GetFlags()) + result, err := GetPullRequestReviews(ctx, client, deps, owner, repo, pullNumber) return result, nil, err case "get_comments": - result, err := GetIssueComments(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, pagination, deps.GetFlags()) + result, err := GetIssueComments(ctx, client, deps, owner, repo, pullNumber, pagination) return result, nil, err default: return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil @@ -135,7 +134,13 @@ Possible options: }) } -func GetPullRequest(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner, repo string, pullNumber int, ff FeatureFlags) (*mcp.CallToolResult, error) { +func GetPullRequest(ctx context.Context, client *github.Client, deps ToolDependencies, owner, repo string, pullNumber int) (*mcp.CallToolResult, error) { + cache, err := deps.GetRepoAccessCache(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get repo access cache: %w", err) + } + ff := deps.GetFlags(ctx) + pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, @@ -340,7 +345,13 @@ type pageInfoFragment struct { EndCursor githubv4.String } -func GetPullRequestReviewComments(ctx context.Context, gqlClient *githubv4.Client, cache *lockdown.RepoAccessCache, owner, repo string, pullNumber int, pagination CursorPaginationParams, ff FeatureFlags) (*mcp.CallToolResult, error) { +func GetPullRequestReviewComments(ctx context.Context, gqlClient *githubv4.Client, deps ToolDependencies, owner, repo string, pullNumber int, pagination CursorPaginationParams) (*mcp.CallToolResult, error) { + cache, err := deps.GetRepoAccessCache(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get repo access cache: %w", err) + } + ff := deps.GetFlags(ctx) + // Convert pagination parameters to GraphQL format gqlParams, err := pagination.ToGraphQLParams() if err != nil { @@ -421,7 +432,13 @@ func GetPullRequestReviewComments(ctx context.Context, gqlClient *githubv4.Clien return utils.NewToolResultText(string(r)), nil } -func GetPullRequestReviews(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner, repo string, pullNumber int, ff FeatureFlags) (*mcp.CallToolResult, error) { +func GetPullRequestReviews(ctx context.Context, client *github.Client, deps ToolDependencies, owner, repo string, pullNumber int) (*mcp.CallToolResult, error) { + cache, err := deps.GetRepoAccessCache(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get repo access cache: %w", err) + } + ff := deps.GetFlags(ctx) + reviews, resp, err := client.PullRequests.ListReviews(ctx, owner, repo, pullNumber, nil) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, diff --git a/pkg/github/server.go b/pkg/github/server.go index 8248da58f..9a602e153 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -3,433 +3,203 @@ package github import ( "context" "encoding/json" - "errors" "fmt" - "strconv" + "log/slog" "strings" + "time" + gherrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/octicons" + "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" - "github.com/google/go-github/v79/github" - "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/mcp" ) -// NewServer creates a new GitHub MCP server with the specified GH client and logger. +type MCPServerConfig struct { + // Version of the server + Version string -func NewServer(version string, opts *mcp.ServerOptions) *mcp.Server { - if opts == nil { - opts = &mcp.ServerOptions{} - } + // GitHub Host to target for API requests (e.g. github.com or github.enterprise.com) + Host string - // Create a new MCP server - s := mcp.NewServer(&mcp.Implementation{ - Name: "github-mcp-server", - Title: "GitHub MCP Server", - Version: version, - Icons: octicons.Icons("mark-github"), - }, opts) + // GitHub Token to authenticate with the GitHub API + Token string - return s -} + // EnabledToolsets is a list of toolsets to enable + // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#tool-configuration + EnabledToolsets []string -func CompletionsHandler(getClient GetClientFn) func(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { - return func(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { - switch req.Params.Ref.Type { - case "ref/resource": - if strings.HasPrefix(req.Params.Ref.URI, "repo://") { - return RepositoryResourceCompletionHandler(getClient)(ctx, req) - } - return nil, fmt.Errorf("unsupported resource URI: %s", req.Params.Ref.URI) - case "ref/prompt": - return nil, nil - default: - return nil, fmt.Errorf("unsupported ref type: %s", req.Params.Ref.Type) - } - } -} + // EnabledTools is a list of specific tools to enable (additive to toolsets) + // When specified, these tools are registered in addition to any specified toolset tools + EnabledTools []string -// OptionalParamOK is a helper function that can be used to fetch a requested parameter from the request. -// It returns the value, a boolean indicating if the parameter was present, and an error if the type is wrong. -func OptionalParamOK[T any, A map[string]any](args A, p string) (value T, ok bool, err error) { - // Check if the parameter is present in the request - val, exists := args[p] - if !exists { - // Not present, return zero value, false, no error - return - } + // EnabledFeatures is a list of feature flags that are enabled + // Items with FeatureFlagEnable matching an entry in this list will be available + EnabledFeatures []string - // Check if the parameter is of the expected type - value, ok = val.(T) - if !ok { - // Present but wrong type - err = fmt.Errorf("parameter %s is not of type %T, is %T", p, value, val) - ok = true // Set ok to true because the parameter *was* present, even if wrong type - return - } + // Whether to enable dynamic toolsets + // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#dynamic-tool-discovery + DynamicToolsets bool - // Present and correct type - ok = true - return -} + // ReadOnly indicates if we should only offer read-only tools + ReadOnly bool -// isAcceptedError checks if the error is an accepted error. -func isAcceptedError(err error) bool { - var acceptedError *github.AcceptedError - return errors.As(err, &acceptedError) -} - -// RequiredParam is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request. -// 2. Checks if the parameter is of the expected type. -// 3. Checks if the parameter is not empty, i.e: non-zero value -func RequiredParam[T comparable](args map[string]any, p string) (T, error) { - var zero T - - // Check if the parameter is present in the request - if _, ok := args[p]; !ok { - return zero, fmt.Errorf("missing required parameter: %s", p) - } + // Translator provides translated text for the server tooling + Translator translations.TranslationHelperFunc - // Check if the parameter is of the expected type - val, ok := args[p].(T) - if !ok { - return zero, fmt.Errorf("parameter %s is not of type %T", p, zero) - } + // Content window size + ContentWindowSize int - if val == zero { - return zero, fmt.Errorf("missing required parameter: %s", p) - } + // LockdownMode indicates if we should enable lockdown mode + LockdownMode bool - return val, nil -} + // InsidersMode indicates if we should enable experimental features + InsidersMode bool -// RequiredInt is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request. -// 2. Checks if the parameter is of the expected type. -// 3. Checks if the parameter is not empty, i.e: non-zero value -func RequiredInt(args map[string]any, p string) (int, error) { - v, err := RequiredParam[float64](args, p) - if err != nil { - return 0, err - } - return int(v), nil -} + // Logger is used for logging within the server + Logger *slog.Logger + // RepoAccessTTL overrides the default TTL for repository access cache entries. + RepoAccessTTL *time.Duration -// RequiredBigInt is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request. -// 2. Checks if the parameter is of the expected type (float64). -// 3. Checks if the parameter is not empty, i.e: non-zero value. -// 4. Validates that the float64 value can be safely converted to int64 without truncation. -func RequiredBigInt(args map[string]any, p string) (int64, error) { - v, err := RequiredParam[float64](args, p) - if err != nil { - return 0, err - } + // TokenScopes contains the OAuth scopes available to the token. + // When non-nil, tools requiring scopes not in this list will be hidden. + // This is used for PAT scope filtering where we can't issue scope challenges. + TokenScopes []string - result := int64(v) - // Check if converting back produces the same value to avoid silent truncation - if float64(result) != v { - return 0, fmt.Errorf("parameter %s value %f is too large to fit in int64", p, v) - } - return result, nil + // Additional server options to apply + ServerOptions []MCPServerOption } -// OptionalParam is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request, if not, it returns its zero-value -// 2. If it is present, it checks if the parameter is of the expected type and returns it -func OptionalParam[T any](args map[string]any, p string) (T, error) { - var zero T +type MCPServerOption func(*mcp.ServerOptions) - // Check if the parameter is present in the request - if _, ok := args[p]; !ok { - return zero, nil +func NewMCPServer(ctx context.Context, cfg *MCPServerConfig, deps ToolDependencies, inv *inventory.Inventory) (*mcp.Server, error) { + // Create the MCP server + serverOpts := &mcp.ServerOptions{ + Instructions: inv.Instructions(), + Logger: cfg.Logger, + CompletionHandler: CompletionsHandler(deps.GetClient), } - // Check if the parameter is of the expected type - if _, ok := args[p].(T); !ok { - return zero, fmt.Errorf("parameter %s is not of type %T, is %T", p, zero, args[p]) + // Apply any additional server options + for _, o := range cfg.ServerOptions { + o(serverOpts) } - return args[p].(T), nil -} - -// OptionalIntParam is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request, if not, it returns its zero-value -// 2. If it is present, it checks if the parameter is of the expected type and returns it -func OptionalIntParam(args map[string]any, p string) (int, error) { - v, err := OptionalParam[float64](args, p) - if err != nil { - return 0, err - } - return int(v), nil -} - -// OptionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request -// similar to optionalIntParam, but it also takes a default value. -func OptionalIntParamWithDefault(args map[string]any, p string, d int) (int, error) { - v, err := OptionalIntParam(args, p) - if err != nil { - return 0, err - } - if v == 0 { - return d, nil - } - return v, nil -} - -// OptionalBoolParamWithDefault is a helper function that can be used to fetch a requested parameter from the request -// similar to optionalBoolParam, but it also takes a default value. -func OptionalBoolParamWithDefault(args map[string]any, p string, d bool) (bool, error) { - _, ok := args[p] - v, err := OptionalParam[bool](args, p) - if err != nil { - return false, err - } - if !ok { - return d, nil - } - return v, nil -} - -// OptionalStringArrayParam is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request, if not, it returns its zero-value -// 2. If it is present, iterates the elements and checks each is a string -func OptionalStringArrayParam(args map[string]any, p string) ([]string, error) { - // Check if the parameter is present in the request - if _, ok := args[p]; !ok { - return []string{}, nil - } - - switch v := args[p].(type) { - case nil: - return []string{}, nil - case []string: - return v, nil - case []any: - strSlice := make([]string, len(v)) - for i, v := range v { - s, ok := v.(string) - if !ok { - return []string{}, fmt.Errorf("parameter %s is not of type string, is %T", p, v) - } - strSlice[i] = s + // In dynamic mode, explicitly advertise capabilities since tools/resources/prompts + // may be enabled at runtime even if none are registered initially. + if cfg.DynamicToolsets { + serverOpts.Capabilities = &mcp.ServerCapabilities{ + Tools: &mcp.ToolCapabilities{}, + Resources: &mcp.ResourceCapabilities{}, + Prompts: &mcp.PromptCapabilities{}, } - return strSlice, nil - default: - return []string{}, fmt.Errorf("parameter %s could not be coerced to []string, is %T", p, args[p]) } -} - -func convertStringSliceToBigIntSlice(s []string) ([]int64, error) { - int64Slice := make([]int64, len(s)) - for i, str := range s { - val, err := convertStringToBigInt(str, 0) - if err != nil { - return nil, fmt.Errorf("failed to convert element %d (%s) to int64: %w", i, str, err) - } - int64Slice[i] = val - } - return int64Slice, nil -} -func convertStringToBigInt(s string, def int64) (int64, error) { - v, err := strconv.ParseInt(s, 10, 64) - if err != nil { - return def, fmt.Errorf("failed to convert string %s to int64: %w", s, err) - } - return v, nil -} + ghServer := NewServer(cfg.Version, serverOpts) -// OptionalBigIntArrayParam is a helper function that can be used to fetch a requested parameter from the request. -// It does the following checks: -// 1. Checks if the parameter is present in the request, if not, it returns an empty slice -// 2. If it is present, iterates the elements, checks each is a string, and converts them to int64 values -func OptionalBigIntArrayParam(args map[string]any, p string) ([]int64, error) { - // Check if the parameter is present in the request - if _, ok := args[p]; !ok { - return []int64{}, nil - } + // Add middlewares + ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext) + ghServer.AddReceivingMiddleware(InjectDepsMiddleware(deps)) - switch v := args[p].(type) { - case nil: - return []int64{}, nil - case []string: - return convertStringSliceToBigIntSlice(v) - case []any: - int64Slice := make([]int64, len(v)) - for i, v := range v { - s, ok := v.(string) - if !ok { - return []int64{}, fmt.Errorf("parameter %s is not of type string, is %T", p, v) - } - val, err := convertStringToBigInt(s, 0) - if err != nil { - return []int64{}, fmt.Errorf("parameter %s: failed to convert element %d (%s) to int64: %w", p, i, s, err) - } - int64Slice[i] = val - } - return int64Slice, nil - default: - return []int64{}, fmt.Errorf("parameter %s could not be coerced to []int64, is %T", p, args[p]) + if unrecognized := inv.UnrecognizedToolsets(); len(unrecognized) > 0 { + cfg.Logger.Warn("Warning: unrecognized toolsets ignored", "toolsets", strings.Join(unrecognized, ", ")) } -} -// WithPagination adds REST API pagination parameters to a tool. -// https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api -func WithPagination(schema *jsonschema.Schema) *jsonschema.Schema { - schema.Properties["page"] = &jsonschema.Schema{ - Type: "number", - Description: "Page number for pagination (min 1)", - Minimum: jsonschema.Ptr(1.0), - } + // Register GitHub tools/resources/prompts from the inventory. + // In dynamic mode with no explicit toolsets, this is a no-op since enabledToolsets + // is empty - users enable toolsets at runtime via the dynamic tools below (but can + // enable toolsets or tools explicitly that do need registration). + inv.RegisterAll(ctx, ghServer, deps) - schema.Properties["perPage"] = &jsonschema.Schema{ - Type: "number", - Description: "Results per page for pagination (min 1, max 100)", - Minimum: jsonschema.Ptr(1.0), - Maximum: jsonschema.Ptr(100.0), + // Register dynamic toolset management tools (enable/disable) - these are separate + // meta-tools that control the inventory, not part of the inventory itself + if cfg.DynamicToolsets { + registerDynamicTools(ghServer, inv, deps, cfg.Translator) } - return schema + return ghServer, nil } -// WithUnifiedPagination adds REST API pagination parameters to a tool. -// GraphQL tools will use this and convert page/perPage to GraphQL cursor parameters internally. -func WithUnifiedPagination(schema *jsonschema.Schema) *jsonschema.Schema { - schema.Properties["page"] = &jsonschema.Schema{ - Type: "number", - Description: "Page number for pagination (min 1)", - Minimum: jsonschema.Ptr(1.0), +// registerDynamicTools adds the dynamic toolset enable/disable tools to the server. +func registerDynamicTools(server *mcp.Server, inventory *inventory.Inventory, deps ToolDependencies, t translations.TranslationHelperFunc) { + dynamicDeps := DynamicToolDependencies{ + Server: server, + Inventory: inventory, + ToolDeps: deps, + T: t, } - - schema.Properties["perPage"] = &jsonschema.Schema{ - Type: "number", - Description: "Results per page for pagination (min 1, max 100)", - Minimum: jsonschema.Ptr(1.0), - Maximum: jsonschema.Ptr(100.0), + for _, tool := range DynamicTools(inventory) { + tool.RegisterFunc(server, dynamicDeps) } - - schema.Properties["after"] = &jsonschema.Schema{ - Type: "string", - Description: "Cursor for pagination. Use the endCursor from the previous page's PageInfo for GraphQL APIs.", - } - - return schema } -// WithCursorPagination adds only cursor-based pagination parameters to a tool (no page parameter). -func WithCursorPagination(schema *jsonschema.Schema) *jsonschema.Schema { - schema.Properties["perPage"] = &jsonschema.Schema{ - Type: "number", - Description: "Results per page for pagination (min 1, max 100)", - Minimum: jsonschema.Ptr(1.0), - Maximum: jsonschema.Ptr(100.0), +// ResolvedEnabledToolsets determines which toolsets should be enabled based on config. +// Returns nil for "use defaults", empty slice for "none", or explicit list. +func ResolvedEnabledToolsets(dynamicToolsets bool, enabledToolsets []string, enabledTools []string) []string { + // In dynamic mode, remove "all" and "default" since users enable toolsets on demand + if dynamicToolsets && enabledToolsets != nil { + enabledToolsets = RemoveToolset(enabledToolsets, string(ToolsetMetadataAll.ID)) + enabledToolsets = RemoveToolset(enabledToolsets, string(ToolsetMetadataDefault.ID)) } - schema.Properties["after"] = &jsonschema.Schema{ - Type: "string", - Description: "Cursor for pagination. Use the endCursor from the previous page's PageInfo for GraphQL APIs.", + if enabledToolsets != nil { + return enabledToolsets } - - return schema -} - -type PaginationParams struct { - Page int - PerPage int - After string -} - -// OptionalPaginationParams returns the "page", "perPage", and "after" parameters from the request, -// or their default values if not present, "page" default is 1, "perPage" default is 30. -// In future, we may want to make the default values configurable, or even have this -// function returned from `withPagination`, where the defaults are provided alongside -// the min/max values. -func OptionalPaginationParams(args map[string]any) (PaginationParams, error) { - page, err := OptionalIntParamWithDefault(args, "page", 1) - if err != nil { - return PaginationParams{}, err - } - perPage, err := OptionalIntParamWithDefault(args, "perPage", 30) - if err != nil { - return PaginationParams{}, err - } - after, err := OptionalParam[string](args, "after") - if err != nil { - return PaginationParams{}, err - } - return PaginationParams{ - Page: page, - PerPage: perPage, - After: after, - }, nil -} - -// OptionalCursorPaginationParams returns the "perPage" and "after" parameters from the request, -// without the "page" parameter, suitable for cursor-based pagination only. -func OptionalCursorPaginationParams(args map[string]any) (CursorPaginationParams, error) { - perPage, err := OptionalIntParamWithDefault(args, "perPage", 30) - if err != nil { - return CursorPaginationParams{}, err + if dynamicToolsets { + // Dynamic mode with no toolsets specified: start empty so users enable on demand + return []string{} } - after, err := OptionalParam[string](args, "after") - if err != nil { - return CursorPaginationParams{}, err + if len(enabledTools) > 0 { + // When specific tools are requested but no toolsets, don't use default toolsets + // This matches the original behavior: --tools=X alone registers only X + return []string{} } - return CursorPaginationParams{ - PerPage: perPage, - After: after, - }, nil -} -type CursorPaginationParams struct { - PerPage int - After string + // nil means "use defaults" in WithToolsets + return nil } -// ToGraphQLParams converts cursor pagination parameters to GraphQL-specific parameters. -func (p CursorPaginationParams) ToGraphQLParams() (*GraphQLPaginationParams, error) { - if p.PerPage > 100 { - return nil, fmt.Errorf("perPage value %d exceeds maximum of 100", p.PerPage) - } - if p.PerPage < 0 { - return nil, fmt.Errorf("perPage value %d cannot be negative", p.PerPage) +func addGitHubAPIErrorToContext(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (result mcp.Result, err error) { + // Ensure the context is cleared of any previous errors + // as context isn't propagated through middleware + ctx = gherrors.ContextWithGitHubErrors(ctx) + return next(ctx, method, req) } - first := int32(p.PerPage) +} - var after *string - if p.After != "" { - after = &p.After +// NewServer creates a new GitHub MCP server with the specified GH client and logger. +func NewServer(version string, opts *mcp.ServerOptions) *mcp.Server { + if opts == nil { + opts = &mcp.ServerOptions{} } - return &GraphQLPaginationParams{ - First: &first, - After: after, - }, nil -} + // Create a new MCP server + s := mcp.NewServer(&mcp.Implementation{ + Name: "github-mcp-server", + Title: "GitHub MCP Server", + Version: version, + Icons: octicons.Icons("mark-github"), + }, opts) -type GraphQLPaginationParams struct { - First *int32 - After *string + return s } -// ToGraphQLParams converts REST API pagination parameters to GraphQL-specific parameters. -// This converts page/perPage to first parameter for GraphQL queries. -// If After is provided, it takes precedence over page-based pagination. -func (p PaginationParams) ToGraphQLParams() (*GraphQLPaginationParams, error) { - // Convert to CursorPaginationParams and delegate to avoid duplication - cursor := CursorPaginationParams{ - PerPage: p.PerPage, - After: p.After, +func CompletionsHandler(getClient GetClientFn) func(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { + return func(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { + switch req.Params.Ref.Type { + case "ref/resource": + if strings.HasPrefix(req.Params.Ref.URI, "repo://") { + return RepositoryResourceCompletionHandler(getClient)(ctx, req) + } + return nil, fmt.Errorf("unsupported resource URI: %s", req.Params.Ref.URI) + case "ref/prompt": + return nil, nil + default: + return nil, fmt.Errorf("unsupported ref type: %s", req.Params.Ref.Type) + } } - return cursor.ToGraphQLParams() } func MarshalledTextResult(v any) *mcp.CallToolResult { diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index f4ae5f831..f21752b27 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -12,15 +12,16 @@ import ( "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/raw" "github.com/github/github-mcp-server/pkg/translations" - "github.com/google/go-github/v79/github" + gogithub "github.com/google/go-github/v79/github" "github.com/shurcooL/githubv4" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // stubDeps is a test helper that implements ToolDependencies with configurable behavior. // Use this when you need to test error paths or when you need closure-based client creation. type stubDeps struct { - clientFn func(context.Context) (*github.Client, error) + clientFn func(context.Context) (*gogithub.Client, error) gqlClientFn func(context.Context) (*githubv4.Client, error) rawClientFn func(context.Context) (*raw.Client, error) @@ -30,7 +31,7 @@ type stubDeps struct { contentWindowSize int } -func (s stubDeps) GetClient(ctx context.Context) (*github.Client, error) { +func (s stubDeps) GetClient(ctx context.Context) (*gogithub.Client, error) { if s.clientFn != nil { return s.clientFn(ctx) } @@ -51,21 +52,23 @@ func (s stubDeps) GetRawClient(ctx context.Context) (*raw.Client, error) { return nil, nil } -func (s stubDeps) GetRepoAccessCache() *lockdown.RepoAccessCache { return s.repoAccessCache } +func (s stubDeps) GetRepoAccessCache(_ context.Context) (*lockdown.RepoAccessCache, error) { + return s.repoAccessCache, nil +} func (s stubDeps) GetT() translations.TranslationHelperFunc { return s.t } -func (s stubDeps) GetFlags() FeatureFlags { return s.flags } +func (s stubDeps) GetFlags(_ context.Context) FeatureFlags { return s.flags } func (s stubDeps) GetContentWindowSize() int { return s.contentWindowSize } func (s stubDeps) IsFeatureEnabled(_ context.Context, _ string) bool { return false } // Helper functions to create stub client functions for error testing -func stubClientFnFromHTTP(httpClient *http.Client) func(context.Context) (*github.Client, error) { - return func(_ context.Context) (*github.Client, error) { - return github.NewClient(httpClient), nil +func stubClientFnFromHTTP(httpClient *http.Client) func(context.Context) (*gogithub.Client, error) { + return func(_ context.Context) (*gogithub.Client, error) { + return gogithub.NewClient(httpClient), nil } } -func stubClientFnErr(errMsg string) func(context.Context) (*github.Client, error) { - return func(_ context.Context) (*github.Client, error) { +func stubClientFnErr(errMsg string) func(context.Context) (*gogithub.Client, error) { + return func(_ context.Context) (*gogithub.Client, error) { return nil, errors.New(errMsg) } } @@ -90,7 +93,7 @@ func stubFeatureFlags(enabledFlags map[string]bool) FeatureFlags { func badRequestHandler(msg string) http.HandlerFunc { return func(w http.ResponseWriter, _ *http.Request) { - structuredErrorResponse := github.ErrorResponse{ + structuredErrorResponse := gogithub.ErrorResponse{ Message: msg, } @@ -103,496 +106,116 @@ func badRequestHandler(msg string) http.HandlerFunc { } } -func Test_IsAcceptedError(t *testing.T) { - tests := []struct { - name string - err error - expectAccepted bool - }{ - { - name: "github AcceptedError", - err: &github.AcceptedError{}, - expectAccepted: true, - }, - { - name: "regular error", - err: fmt.Errorf("some other error"), - expectAccepted: false, - }, - { - name: "nil error", - err: nil, - expectAccepted: false, - }, - { - name: "wrapped AcceptedError", - err: fmt.Errorf("wrapped: %w", &github.AcceptedError{}), - expectAccepted: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result := isAcceptedError(tc.err) - assert.Equal(t, tc.expectAccepted, result) - }) - } -} +// TestNewMCPServer_CreatesSuccessfully verifies that the server can be created +// with the deps injection middleware properly configured. +func TestNewMCPServer_CreatesSuccessfully(t *testing.T) { + t.Parallel() -func Test_RequiredStringParam(t *testing.T) { - tests := []struct { - name string - params map[string]interface{} - paramName string - expected string - expectError bool - }{ - { - name: "valid string parameter", - params: map[string]interface{}{"name": "test-value"}, - paramName: "name", - expected: "test-value", - expectError: false, - }, - { - name: "missing parameter", - params: map[string]interface{}{}, - paramName: "name", - expected: "", - expectError: true, - }, - { - name: "empty string parameter", - params: map[string]interface{}{"name": ""}, - paramName: "name", - expected: "", - expectError: true, - }, - { - name: "wrong type parameter", - params: map[string]interface{}{"name": 123}, - paramName: "name", - expected: "", - expectError: true, - }, + // Create a minimal server configuration + cfg := MCPServerConfig{ + Version: "test", + Host: "", // defaults to github.com + Token: "test-token", + EnabledToolsets: []string{"context"}, + ReadOnly: false, + Translator: translations.NullTranslationHelper, + ContentWindowSize: 5000, + LockdownMode: false, + InsidersMode: false, } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := RequiredParam[string](tc.params, tc.paramName) + deps := stubDeps{} - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} + // Build inventory + inv, err := NewInventory(cfg.Translator). + WithDeprecatedAliases(DeprecatedToolAliases). + WithToolsets(cfg.EnabledToolsets). + Build() -func Test_OptionalStringParam(t *testing.T) { - tests := []struct { - name string - params map[string]interface{} - paramName string - expected string - expectError bool - }{ - { - name: "valid string parameter", - params: map[string]interface{}{"name": "test-value"}, - paramName: "name", - expected: "test-value", - expectError: false, - }, - { - name: "missing parameter", - params: map[string]interface{}{}, - paramName: "name", - expected: "", - expectError: false, - }, - { - name: "empty string parameter", - params: map[string]interface{}{"name": ""}, - paramName: "name", - expected: "", - expectError: false, - }, - { - name: "wrong type parameter", - params: map[string]interface{}{"name": 123}, - paramName: "name", - expected: "", - expectError: true, - }, - } + require.NoError(t, err, "expected inventory build to succeed") - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := OptionalParam[string](tc.params, tc.paramName) + // Create the server + server, err := NewMCPServer(context.Background(), &cfg, deps, inv) + require.NoError(t, err, "expected server creation to succeed") + require.NotNil(t, server, "expected server to be non-nil") - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } + // The fact that the server was created successfully indicates that: + // 1. The deps injection middleware is properly added + // 2. Tools can be registered without panicking + // + // If the middleware wasn't properly added, tool calls would panic with + // "ToolDependencies not found in context" when executed. + // + // The actual middleware functionality and tool execution with ContextWithDeps + // is already tested in pkg/github/*_test.go. } -func Test_RequiredInt(t *testing.T) { - tests := []struct { - name string - params map[string]interface{} - paramName string - expected int - expectError bool - }{ - { - name: "valid number parameter", - params: map[string]interface{}{"count": float64(42)}, - paramName: "count", - expected: 42, - expectError: false, - }, - { - name: "missing parameter", - params: map[string]interface{}{}, - paramName: "count", - expected: 0, - expectError: true, - }, - { - name: "wrong type parameter", - params: map[string]interface{}{"count": "not-a-number"}, - paramName: "count", - expected: 0, - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := RequiredInt(tc.params, tc.paramName) +// TestResolveEnabledToolsets verifies the toolset resolution logic. +func TestResolveEnabledToolsets(t *testing.T) { + t.Parallel() - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} -func Test_OptionalIntParam(t *testing.T) { tests := []struct { - name string - params map[string]interface{} - paramName string - expected int - expectError bool - }{ - { - name: "valid number parameter", - params: map[string]interface{}{"count": float64(42)}, - paramName: "count", - expected: 42, - expectError: false, - }, - { - name: "missing parameter", - params: map[string]interface{}{}, - paramName: "count", - expected: 0, - expectError: false, - }, - { - name: "zero value", - params: map[string]interface{}{"count": float64(0)}, - paramName: "count", - expected: 0, - expectError: false, - }, - { - name: "wrong type parameter", - params: map[string]interface{}{"count": "not-a-number"}, - paramName: "count", - expected: 0, - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := OptionalIntParam(tc.params, tc.paramName) - - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} - -func Test_OptionalNumberParamWithDefault(t *testing.T) { - tests := []struct { - name string - params map[string]interface{} - paramName string - defaultVal int - expected int - expectError bool - }{ - { - name: "valid number parameter", - params: map[string]interface{}{"count": float64(42)}, - paramName: "count", - defaultVal: 10, - expected: 42, - expectError: false, - }, - { - name: "missing parameter", - params: map[string]interface{}{}, - paramName: "count", - defaultVal: 10, - expected: 10, - expectError: false, - }, - { - name: "zero value", - params: map[string]interface{}{"count": float64(0)}, - paramName: "count", - defaultVal: 10, - expected: 10, - expectError: false, - }, - { - name: "wrong type parameter", - params: map[string]interface{}{"count": "not-a-number"}, - paramName: "count", - defaultVal: 10, - expected: 0, - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := OptionalIntParamWithDefault(tc.params, tc.paramName, tc.defaultVal) - - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} - -func Test_OptionalBooleanParam(t *testing.T) { - tests := []struct { - name string - params map[string]interface{} - paramName string - expected bool - expectError bool - }{ - { - name: "true value", - params: map[string]interface{}{"flag": true}, - paramName: "flag", - expected: true, - expectError: false, - }, - { - name: "false value", - params: map[string]interface{}{"flag": false}, - paramName: "flag", - expected: false, - expectError: false, - }, - { - name: "missing parameter", - params: map[string]interface{}{}, - paramName: "flag", - expected: false, - expectError: false, - }, - { - name: "wrong type parameter", - params: map[string]interface{}{"flag": "not-a-boolean"}, - paramName: "flag", - expected: false, - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := OptionalParam[bool](tc.params, tc.paramName) - - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} - -func TestOptionalStringArrayParam(t *testing.T) { - tests := []struct { - name string - params map[string]interface{} - paramName string - expected []string - expectError bool - }{ - { - name: "parameter not in request", - params: map[string]any{}, - paramName: "flag", - expected: []string{}, - expectError: false, - }, - { - name: "valid any array parameter", - params: map[string]any{ - "flag": []any{"v1", "v2"}, - }, - paramName: "flag", - expected: []string{"v1", "v2"}, - expectError: false, - }, - { - name: "valid string array parameter", - params: map[string]any{ - "flag": []string{"v1", "v2"}, - }, - paramName: "flag", - expected: []string{"v1", "v2"}, - expectError: false, - }, - { - name: "wrong type parameter", - params: map[string]any{ - "flag": 1, - }, - paramName: "flag", - expected: []string{}, - expectError: true, - }, - { - name: "wrong slice type parameter", - params: map[string]any{ - "flag": []any{"foo", 2}, - }, - paramName: "flag", - expected: []string{}, - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := OptionalStringArrayParam(tc.params, tc.paramName) - - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } - }) - } -} - -func TestOptionalPaginationParams(t *testing.T) { - tests := []struct { - name string - params map[string]any - expected PaginationParams - expectError bool + name string + cfg MCPServerConfig + expectedResult []string }{ { - name: "no pagination parameters, default values", - params: map[string]any{}, - expected: PaginationParams{ - Page: 1, - PerPage: 30, + name: "nil toolsets without dynamic mode and no tools - use defaults", + cfg: MCPServerConfig{ + EnabledToolsets: nil, + DynamicToolsets: false, + EnabledTools: nil, }, - expectError: false, + expectedResult: nil, // nil means "use defaults" }, { - name: "page parameter, default perPage", - params: map[string]any{ - "page": float64(2), + name: "nil toolsets with dynamic mode - start empty", + cfg: MCPServerConfig{ + EnabledToolsets: nil, + DynamicToolsets: true, + EnabledTools: nil, }, - expected: PaginationParams{ - Page: 2, - PerPage: 30, - }, - expectError: false, + expectedResult: []string{}, // empty slice means no toolsets }, { - name: "perPage parameter, default page", - params: map[string]any{ - "perPage": float64(50), - }, - expected: PaginationParams{ - Page: 1, - PerPage: 50, + name: "explicit toolsets", + cfg: MCPServerConfig{ + EnabledToolsets: []string{"repos", "issues"}, + DynamicToolsets: false, }, - expectError: false, + expectedResult: []string{"repos", "issues"}, }, { - name: "page and perPage parameters", - params: map[string]any{ - "page": float64(2), - "perPage": float64(50), + name: "empty toolsets - disable all", + cfg: MCPServerConfig{ + EnabledToolsets: []string{}, + DynamicToolsets: false, }, - expected: PaginationParams{ - Page: 2, - PerPage: 50, - }, - expectError: false, + expectedResult: []string{}, // empty slice means no toolsets }, { - name: "invalid page parameter", - params: map[string]any{ - "page": "not-a-number", + name: "specific tools without toolsets - no default toolsets", + cfg: MCPServerConfig{ + EnabledToolsets: nil, + DynamicToolsets: false, + EnabledTools: []string{"get_me"}, }, - expected: PaginationParams{}, - expectError: true, + expectedResult: []string{}, // empty slice when tools specified but no toolsets }, { - name: "invalid perPage parameter", - params: map[string]any{ - "perPage": "not-a-number", + name: "dynamic mode with explicit toolsets removes all and default", + cfg: MCPServerConfig{ + EnabledToolsets: []string{"all", "repos"}, + DynamicToolsets: true, }, - expected: PaginationParams{}, - expectError: true, + expectedResult: []string{"repos"}, // "all" is removed in dynamic mode }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - result, err := OptionalPaginationParams(tc.params) - - if tc.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.expected, result) - } + result := ResolvedEnabledToolsets(tc.cfg.DynamicToolsets, tc.cfg.EnabledToolsets, tc.cfg.EnabledTools) + assert.Equal(t, tc.expectedResult, result) }) } } diff --git a/pkg/http/handler.go b/pkg/http/handler.go new file mode 100644 index 000000000..df0b819fc --- /dev/null +++ b/pkg/http/handler.go @@ -0,0 +1,287 @@ +package http + +import ( + "context" + "log/slog" + "net/http" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/http/middleware" + "github.com/github/github-mcp-server/pkg/http/oauth" + "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/scopes" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/github/github-mcp-server/pkg/utils" + "github.com/go-chi/chi/v5" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type InventoryFactoryFunc func(r *http.Request) (*inventory.Inventory, error) +type GitHubMCPServerFactoryFunc func(r *http.Request, deps github.ToolDependencies, inventory *inventory.Inventory, cfg *github.MCPServerConfig) (*mcp.Server, error) + +type Handler struct { + ctx context.Context + config *ServerConfig + deps github.ToolDependencies + logger *slog.Logger + apiHosts utils.APIHostResolver + t translations.TranslationHelperFunc + githubMcpServerFactory GitHubMCPServerFactoryFunc + inventoryFactoryFunc InventoryFactoryFunc + oauthCfg *oauth.Config + scopeFetcher scopes.FetcherInterface +} + +type HandlerOptions struct { + GitHubMcpServerFactory GitHubMCPServerFactoryFunc + InventoryFactory InventoryFactoryFunc + OAuthConfig *oauth.Config + ScopeFetcher scopes.FetcherInterface + FeatureChecker inventory.FeatureFlagChecker +} + +type HandlerOption func(*HandlerOptions) + +func WithScopeFetcher(f scopes.FetcherInterface) HandlerOption { + return func(o *HandlerOptions) { + o.ScopeFetcher = f + } +} + +func WithGitHubMCPServerFactory(f GitHubMCPServerFactoryFunc) HandlerOption { + return func(o *HandlerOptions) { + o.GitHubMcpServerFactory = f + } +} + +func WithInventoryFactory(f InventoryFactoryFunc) HandlerOption { + return func(o *HandlerOptions) { + o.InventoryFactory = f + } +} + +func WithOAuthConfig(cfg *oauth.Config) HandlerOption { + return func(o *HandlerOptions) { + o.OAuthConfig = cfg + } +} + +func WithFeatureChecker(checker inventory.FeatureFlagChecker) HandlerOption { + return func(o *HandlerOptions) { + o.FeatureChecker = checker + } +} + +func NewHTTPMcpHandler( + ctx context.Context, + cfg *ServerConfig, + deps github.ToolDependencies, + t translations.TranslationHelperFunc, + logger *slog.Logger, + apiHost utils.APIHostResolver, + options ...HandlerOption) *Handler { + opts := &HandlerOptions{} + for _, o := range options { + o(opts) + } + + githubMcpServerFactory := opts.GitHubMcpServerFactory + if githubMcpServerFactory == nil { + githubMcpServerFactory = DefaultGitHubMCPServerFactory + } + + scopeFetcher := opts.ScopeFetcher + if scopeFetcher == nil { + scopeFetcher = scopes.NewFetcher(apiHost, scopes.FetcherOptions{}) + } + + inventoryFactory := opts.InventoryFactory + if inventoryFactory == nil { + inventoryFactory = DefaultInventoryFactory(cfg, t, opts.FeatureChecker, scopeFetcher) + } + + return &Handler{ + ctx: ctx, + config: cfg, + deps: deps, + logger: logger, + apiHosts: apiHost, + t: t, + githubMcpServerFactory: githubMcpServerFactory, + inventoryFactoryFunc: inventoryFactory, + oauthCfg: opts.OAuthConfig, + scopeFetcher: scopeFetcher, + } +} + +func (h *Handler) RegisterMiddleware(r chi.Router) { + r.Use( + middleware.ExtractUserToken(h.oauthCfg), + middleware.WithRequestConfig, + middleware.WithMCPParse(), + middleware.WithPATScopes(h.logger, h.scopeFetcher), + ) + + if h.config.ScopeChallenge { + r.Use(middleware.WithScopeChallenge(h.oauthCfg, h.scopeFetcher)) + } +} + +// RegisterRoutes registers the routes for the MCP server +// URL-based values take precedence over header-based values +func (h *Handler) RegisterRoutes(r chi.Router) { + // Base routes + r.Mount("/", h) + r.With(withReadonly).Mount("/readonly", h) + r.With(withInsiders).Mount("/insiders", h) + r.With(withReadonly, withInsiders).Mount("/readonly/insiders", h) + + // Toolset routes + r.With(withToolset).Mount("/x/{toolset}", h) + r.With(withToolset, withReadonly).Mount("/x/{toolset}/readonly", h) + r.With(withToolset, withInsiders).Mount("/x/{toolset}/insiders", h) + r.With(withToolset, withReadonly, withInsiders).Mount("/x/{toolset}/readonly/insiders", h) +} + +// withReadonly is middleware that sets readonly mode in the request context +func withReadonly(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := ghcontext.WithReadonly(r.Context(), true) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// withToolset is middleware that extracts the toolset from the URL and sets it in the request context +func withToolset(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + toolset := chi.URLParam(r, "toolset") + ctx := ghcontext.WithToolsets(r.Context(), []string{toolset}) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// withInsiders is middleware that sets insiders mode in the request context +func withInsiders(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := ghcontext.WithInsidersMode(r.Context(), true) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + inv, err := h.inventoryFactoryFunc(r) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + invToUse := inv + if methodInfo, ok := ghcontext.MCPMethod(r.Context()); ok && methodInfo != nil { + invToUse = inv.ForMCPRequest(methodInfo.Method, methodInfo.ItemName) + } + + ghServer, err := h.githubMcpServerFactory(r, h.deps, invToUse, &github.MCPServerConfig{ + Version: h.config.Version, + Translator: h.t, + ContentWindowSize: h.config.ContentWindowSize, + Logger: h.logger, + RepoAccessTTL: h.config.RepoAccessCacheTTL, + // Explicitly set empty capabilities. inv.ForMCPRequest currently returns nothing for Initialize. + ServerOptions: []github.MCPServerOption{ + func(so *mcp.ServerOptions) { + so.Capabilities = &mcp.ServerCapabilities{ + Tools: &mcp.ToolCapabilities{}, + Resources: &mcp.ResourceCapabilities{}, + Prompts: &mcp.PromptCapabilities{}, + } + }, + }, + }) + + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + mcpHandler := mcp.NewStreamableHTTPHandler(func(_ *http.Request) *mcp.Server { + return ghServer + }, &mcp.StreamableHTTPOptions{ + Stateless: true, + }) + + mcpHandler.ServeHTTP(w, r) +} + +func DefaultGitHubMCPServerFactory(r *http.Request, deps github.ToolDependencies, inventory *inventory.Inventory, cfg *github.MCPServerConfig) (*mcp.Server, error) { + return github.NewMCPServer(r.Context(), cfg, deps, inventory) +} + +// DefaultInventoryFactory creates the default inventory factory for HTTP mode +func DefaultInventoryFactory(_ *ServerConfig, t translations.TranslationHelperFunc, featureChecker inventory.FeatureFlagChecker, scopeFetcher scopes.FetcherInterface) InventoryFactoryFunc { + return func(r *http.Request) (*inventory.Inventory, error) { + b := github.NewInventory(t). + WithDeprecatedAliases(github.DeprecatedToolAliases). + WithFeatureChecker(featureChecker) + + b = InventoryFiltersForRequest(r, b) + b = PATScopeFilter(b, r, scopeFetcher) + + b.WithServerInstructions() + + return b.Build() + } +} + +// InventoryFiltersForRequest applies filters to the inventory builder +// based on the request context and headers +func InventoryFiltersForRequest(r *http.Request, builder *inventory.Builder) *inventory.Builder { + ctx := r.Context() + + if ghcontext.IsReadonly(ctx) { + builder = builder.WithReadOnly(true) + } + + toolsets := ghcontext.GetToolsets(ctx) + tools := ghcontext.GetTools(ctx) + + if len(toolsets) > 0 { + builder = builder.WithToolsets(github.ResolvedEnabledToolsets(false, toolsets, tools)) // No dynamic toolsets in HTTP mode + } + + if len(tools) > 0 { + if len(toolsets) == 0 { + builder = builder.WithToolsets([]string{}) + } + builder = builder.WithTools(github.CleanTools(tools)) + } + + return builder +} + +func PATScopeFilter(b *inventory.Builder, r *http.Request, fetcher scopes.FetcherInterface) *inventory.Builder { + ctx := r.Context() + + tokenInfo, ok := ghcontext.GetTokenInfo(ctx) + if !ok || tokenInfo == nil { + return b + } + + // Scopes should have already been fetched by the WithPATScopes middleware. + // Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header. + // Fine-grained PATs and other token types don't support this, so we skip filtering. + if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken { + if tokenInfo.ScopesFetched { + return b.WithFilter(github.CreateToolScopeFilter(tokenInfo.Scopes)) + } + + scopesList, err := fetcher.FetchTokenScopes(ctx, tokenInfo.Token) + if err != nil { + return b + } + + return b.WithFilter(github.CreateToolScopeFilter(scopesList)) + } + + return b +} diff --git a/pkg/http/handler_test.go b/pkg/http/handler_test.go new file mode 100644 index 000000000..c92075569 --- /dev/null +++ b/pkg/http/handler_test.go @@ -0,0 +1,348 @@ +package http + +import ( + "context" + "log/slog" + "net/http" + "net/http/httptest" + "sort" + "testing" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/scopes" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/github/github-mcp-server/pkg/utils" + "github.com/go-chi/chi/v5" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func mockTool(name, toolsetID string, readOnly bool) inventory.ServerTool { + return inventory.ServerTool{ + Tool: mcp.Tool{ + Name: name, + Annotations: &mcp.ToolAnnotations{ReadOnlyHint: readOnly}, + }, + Toolset: inventory.ToolsetMetadata{ + ID: inventory.ToolsetID(toolsetID), + Description: "Test: " + toolsetID, + }, + } +} + +type allScopesFetcher struct{} + +func (f allScopesFetcher) FetchTokenScopes(_ context.Context, _ string) ([]string, error) { + return []string{ + string(scopes.Repo), + string(scopes.WriteOrg), + string(scopes.User), + string(scopes.Gist), + string(scopes.Notifications), + }, nil +} + +var _ scopes.FetcherInterface = allScopesFetcher{} + +func mockToolWithFeatureFlag(name, toolsetID string, readOnly bool, enableFlag, disableFlag string) inventory.ServerTool { + tool := mockTool(name, toolsetID, readOnly) + tool.FeatureFlagEnable = enableFlag + tool.FeatureFlagDisable = disableFlag + return tool +} + +func TestInventoryFiltersForRequest(t *testing.T) { + tools := []inventory.ServerTool{ + mockTool("get_file_contents", "repos", true), + mockTool("create_repository", "repos", false), + mockTool("list_issues", "issues", true), + mockTool("issue_write", "issues", false), + } + + tests := []struct { + name string + contextSetup func(context.Context) context.Context + expectedTools []string + }{ + { + name: "no filters applies defaults", + contextSetup: func(ctx context.Context) context.Context { return ctx }, + expectedTools: []string{"get_file_contents", "create_repository", "list_issues", "issue_write"}, + }, + { + name: "readonly from context filters write tools", + contextSetup: func(ctx context.Context) context.Context { + return ghcontext.WithReadonly(ctx, true) + }, + expectedTools: []string{"get_file_contents", "list_issues"}, + }, + { + name: "toolset from context filters to toolset", + contextSetup: func(ctx context.Context) context.Context { + return ghcontext.WithToolsets(ctx, []string{"repos"}) + }, + expectedTools: []string{"get_file_contents", "create_repository"}, + }, + { + name: "tools alone clears default toolsets", + contextSetup: func(ctx context.Context) context.Context { + return ghcontext.WithTools(ctx, []string{"list_issues"}) + }, + expectedTools: []string{"list_issues"}, + }, + { + name: "tools are additive with toolsets", + contextSetup: func(ctx context.Context) context.Context { + ctx = ghcontext.WithToolsets(ctx, []string{"repos"}) + ctx = ghcontext.WithTools(ctx, []string{"list_issues"}) + return ctx + }, + expectedTools: []string{"get_file_contents", "create_repository", "list_issues"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req = req.WithContext(tt.contextSetup(req.Context())) + + builder := inventory.NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}) + + builder = InventoryFiltersForRequest(req, builder) + inv, err := builder.Build() + require.NoError(t, err) + + available := inv.AvailableTools(context.Background()) + toolNames := make([]string, len(available)) + for i, tool := range available { + toolNames[i] = tool.Tool.Name + } + + assert.ElementsMatch(t, tt.expectedTools, toolNames) + }) + } +} + +// testTools returns a set of mock tools across different toolsets with mixed read-only/write capabilities +func testTools() []inventory.ServerTool { + return []inventory.ServerTool{ + mockTool("get_file_contents", "repos", true), + mockTool("create_repository", "repos", false), + mockTool("list_issues", "issues", true), + mockTool("create_issue", "issues", false), + mockTool("list_pull_requests", "pull_requests", true), + mockTool("create_pull_request", "pull_requests", false), + // Feature-flagged tools for testing X-MCP-Features header + mockToolWithFeatureFlag("needs_holdback", "repos", true, "mcp_holdback_consolidated_projects", ""), + mockToolWithFeatureFlag("hidden_by_holdback", "repos", true, "", "mcp_holdback_consolidated_projects"), + } +} + +// extractToolNames extracts tool names from an inventory +func extractToolNames(ctx context.Context, inv *inventory.Inventory) []string { + available := inv.AvailableTools(ctx) + names := make([]string, len(available)) + for i, tool := range available { + names[i] = tool.Tool.Name + } + sort.Strings(names) + return names +} + +func TestHTTPHandlerRoutes(t *testing.T) { + tools := testTools() + + tests := []struct { + name string + path string + headers map[string]string + expectedTools []string + }{ + { + name: "root path returns all tools", + path: "/", + expectedTools: []string{"get_file_contents", "create_repository", "list_issues", "create_issue", "list_pull_requests", "create_pull_request", "hidden_by_holdback"}, + }, + { + name: "readonly path filters write tools", + path: "/readonly", + expectedTools: []string{"get_file_contents", "list_issues", "list_pull_requests", "hidden_by_holdback"}, + }, + { + name: "toolset path filters to toolset", + path: "/x/repos", + expectedTools: []string{"get_file_contents", "create_repository", "hidden_by_holdback"}, + }, + { + name: "toolset path with issues", + path: "/x/issues", + expectedTools: []string{"list_issues", "create_issue"}, + }, + { + name: "toolset readonly path filters to readonly tools in toolset", + path: "/x/repos/readonly", + expectedTools: []string{"get_file_contents", "hidden_by_holdback"}, + }, + { + name: "toolset readonly path with issues", + path: "/x/issues/readonly", + expectedTools: []string{"list_issues"}, + }, + { + name: "X-MCP-Tools header filters to specific tools", + path: "/", + headers: map[string]string{ + headers.MCPToolsHeader: "list_issues", + }, + expectedTools: []string{"list_issues"}, + }, + { + name: "X-MCP-Tools header with multiple tools", + path: "/", + headers: map[string]string{ + headers.MCPToolsHeader: "list_issues,get_file_contents", + }, + expectedTools: []string{"list_issues", "get_file_contents"}, + }, + { + name: "X-MCP-Tools header does not expose extra tools", + path: "/", + headers: map[string]string{ + headers.MCPToolsHeader: "list_issues", + }, + expectedTools: []string{"list_issues"}, + }, + { + name: "X-MCP-Readonly header filters write tools", + path: "/", + headers: map[string]string{ + headers.MCPReadOnlyHeader: "true", + }, + expectedTools: []string{"get_file_contents", "list_issues", "list_pull_requests", "hidden_by_holdback"}, + }, + { + name: "X-MCP-Toolsets header filters to toolset", + path: "/", + headers: map[string]string{ + headers.MCPToolsetsHeader: "repos", + }, + expectedTools: []string{"get_file_contents", "create_repository", "hidden_by_holdback"}, + }, + { + name: "URL toolset takes precedence over header toolset", + path: "/x/issues", + headers: map[string]string{ + headers.MCPToolsetsHeader: "repos", + }, + expectedTools: []string{"list_issues", "create_issue"}, + }, + { + name: "URL readonly takes precedence over header", + path: "/readonly", + headers: map[string]string{ + headers.MCPReadOnlyHeader: "false", + }, + expectedTools: []string{"get_file_contents", "list_issues", "list_pull_requests", "hidden_by_holdback"}, + }, + { + name: "X-MCP-Features header enables flagged tool", + path: "/", + headers: map[string]string{ + headers.MCPFeaturesHeader: "mcp_holdback_consolidated_projects", + }, + expectedTools: []string{"get_file_contents", "create_repository", "list_issues", "create_issue", "list_pull_requests", "create_pull_request", "needs_holdback"}, + }, + { + name: "X-MCP-Features header with unknown flag is ignored", + path: "/", + headers: map[string]string{ + headers.MCPFeaturesHeader: "unknown_flag", + }, + expectedTools: []string{"get_file_contents", "create_repository", "list_issues", "create_issue", "list_pull_requests", "create_pull_request", "hidden_by_holdback"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var capturedInventory *inventory.Inventory + var capturedCtx context.Context + + // Create feature checker that reads from context (same as production) + featureChecker := createHTTPFeatureChecker() + + apiHost, err := utils.NewAPIHost("https://api.github.com") + require.NoError(t, err) + + // Create inventory factory that captures the built inventory + inventoryFactory := func(r *http.Request) (*inventory.Inventory, error) { + capturedCtx = r.Context() + builder := inventory.NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFeatureChecker(featureChecker) + builder = InventoryFiltersForRequest(r, builder) + inv, err := builder.Build() + if err != nil { + return nil, err + } + capturedInventory = inv + return inv, nil + } + + // Create mock MCP server factory that just returns a minimal server + mcpServerFactory := func(_ *http.Request, _ github.ToolDependencies, _ *inventory.Inventory, _ *github.MCPServerConfig) (*mcp.Server, error) { + return mcp.NewServer(&mcp.Implementation{Name: "test", Version: "0.0.1"}, nil), nil + } + + allScopesFetcher := allScopesFetcher{} + + // Create handler with our factories + handler := NewHTTPMcpHandler( + context.Background(), + &ServerConfig{Version: "test"}, + nil, // deps not needed for this test + translations.NullTranslationHelper, + slog.Default(), + apiHost, + WithInventoryFactory(inventoryFactory), + WithGitHubMCPServerFactory(mcpServerFactory), + WithScopeFetcher(allScopesFetcher), + ) + + // Create router and register routes + r := chi.NewRouter() + handler.RegisterMiddleware(r) + handler.RegisterRoutes(r) + + // Create request + req := httptest.NewRequest(http.MethodPost, tt.path, nil) + + // Ensure we're setting Authorization header for token context + req.Header.Set(headers.AuthorizationHeader, "Bearer ghp_testtoken") + + for k, v := range tt.headers { + req.Header.Set(k, v) + } + + // Execute request + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // Verify the inventory was captured and has the expected tools + require.NotNil(t, capturedInventory, "inventory should have been created") + + toolNames := extractToolNames(capturedCtx, capturedInventory) + expectedSorted := make([]string, len(tt.expectedTools)) + copy(expectedSorted, tt.expectedTools) + sort.Strings(expectedSorted) + + assert.Equal(t, expectedSorted, toolNames, "tools should match expected") + }) + } +} diff --git a/pkg/http/headers/headers.go b/pkg/http/headers/headers.go new file mode 100644 index 000000000..bbc46b43f --- /dev/null +++ b/pkg/http/headers/headers.go @@ -0,0 +1,53 @@ +package headers + +const ( + // AuthorizationHeader is a standard HTTP Header. + AuthorizationHeader = "Authorization" + // ContentTypeHeader is a standard HTTP Header. + ContentTypeHeader = "Content-Type" + // AcceptHeader is a standard HTTP Header. + AcceptHeader = "Accept" + // UserAgentHeader is a standard HTTP Header. + UserAgentHeader = "User-Agent" + + // ContentTypeJSON is the standard MIME type for JSON. + ContentTypeJSON = "application/json" + // ContentTypeEventStream is the standard MIME type for Event Streams. + ContentTypeEventStream = "text/event-stream" + + // ForwardedForHeader is a standard HTTP Header used to forward the originating IP address of a client. + ForwardedForHeader = "X-Forwarded-For" + + // RealIPHeader is a standard HTTP Header used to indicate the real IP address of the client. + RealIPHeader = "X-Real-IP" + + // ForwardedHostHeader is a standard HTTP Header for preserving the original Host header when proxying. + ForwardedHostHeader = "X-Forwarded-Host" + // ForwardedProtoHeader is a standard HTTP Header for preserving the original protocol when proxying. + ForwardedProtoHeader = "X-Forwarded-Proto" + + // RequestHmacHeader is used to authenticate requests to the Raw API. + RequestHmacHeader = "Request-Hmac" + + // MCP-specific headers. + + // MCPReadOnlyHeader indicates whether the MCP is in read-only mode. + MCPReadOnlyHeader = "X-MCP-Readonly" + // MCPToolsetsHeader is a comma-separated list of MCP toolsets that the request is for. + MCPToolsetsHeader = "X-MCP-Toolsets" + // MCPToolsHeader is a comma-separated list of MCP tools that the request is for. + MCPToolsHeader = "X-MCP-Tools" + // MCPLockdownHeader indicates whether lockdown mode is enabled. + MCPLockdownHeader = "X-MCP-Lockdown" + // MCPInsidersHeader indicates whether insiders mode is enabled for early access features. + MCPInsidersHeader = "X-MCP-Insiders" + // MCPFeaturesHeader is a comma-separated list of feature flags to enable. + MCPFeaturesHeader = "X-MCP-Features" + + // GitHub-specific headers. + + // GraphQLFeaturesHeader is a comma-separated list of GraphQL feature flags to enable for GraphQL requests. + GraphQLFeaturesHeader = "GraphQL-Features" + // GitHubAPIVersionHeader is the header used to specify the GitHub API version. + GitHubAPIVersionHeader = "X-GitHub-Api-Version" +) diff --git a/pkg/http/headers/parse.go b/pkg/http/headers/parse.go new file mode 100644 index 000000000..2b5eddacd --- /dev/null +++ b/pkg/http/headers/parse.go @@ -0,0 +1,21 @@ +package headers + +import "strings" + +// ParseCommaSeparated splits a header value by comma, trims whitespace, +// and filters out empty values +func ParseCommaSeparated(value string) []string { + if value == "" { + return []string{} + } + + parts := strings.Split(value, ",") + result := make([]string, 0, len(parts)) + for _, p := range parts { + trimmed := strings.TrimSpace(p) + if trimmed != "" { + result = append(result, trimmed) + } + } + return result +} diff --git a/pkg/http/headers/parse_test.go b/pkg/http/headers/parse_test.go new file mode 100644 index 000000000..d8b55a696 --- /dev/null +++ b/pkg/http/headers/parse_test.go @@ -0,0 +1,58 @@ +package headers + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseCommaSeparated(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + { + name: "empty string", + input: "", + expected: []string{}, + }, + { + name: "single value", + input: "foo", + expected: []string{"foo"}, + }, + { + name: "multiple values", + input: "foo,bar,baz", + expected: []string{"foo", "bar", "baz"}, + }, + { + name: "whitespace trimmed", + input: " foo , bar , baz ", + expected: []string{"foo", "bar", "baz"}, + }, + { + name: "empty values filtered", + input: "foo,,bar,", + expected: []string{"foo", "bar"}, + }, + { + name: "only commas", + input: ",,,", + expected: []string{}, + }, + { + name: "whitespace only values filtered", + input: "foo, ,bar", + expected: []string{"foo", "bar"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ParseCommaSeparated(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/http/mark/mark.go b/pkg/http/mark/mark.go new file mode 100644 index 000000000..859a30923 --- /dev/null +++ b/pkg/http/mark/mark.go @@ -0,0 +1,65 @@ +// Package mark provides a mechanism for tagging errors with a well-known error value. +package mark + +import "errors" + +// This list of errors is not exhaustive, but is a good starting point for most +// applications. Feel free to add more as needed, but don't go overboard. +// Remember, the specific types of errors are only important so far as someone +// calling your code might want to write logic to handle each type of error +// differently. +// +// Do not add application-specific errors to this list. Instead, just define +// your own package with your own application-specific errors, and use this +// package to mark errors with them. The errors in this package are not special, +// they're just plain old errors. +// +// Not all errors need to be marked. An error that is not marked should be +// treated as an unexpected error that cannot be handled by calling code. This +// is often the case for network errors or logic errors. +var ( + ErrNotFound = errors.New("not found") + ErrAlreadyExists = errors.New("already exists") + ErrBadRequest = errors.New("bad request") + ErrUnauthorized = errors.New("unauthorized") + ErrCancelled = errors.New("request cancelled") + ErrUnavailable = errors.New("unavailable") + ErrTimedout = errors.New("request timed out") + ErrTooLarge = errors.New("request is too large") + ErrTooManyRequests = errors.New("too many requests") + ErrForbidden = errors.New("forbidden") +) + +// With wraps err with another error that will return true from errors.Is and +// errors.As for both err and markErr, and anything either may wrap. +func With(err, markErr error) error { + if err == nil { + return nil + } + return marked{wrapped: err, mark: markErr} +} + +type marked struct { + wrapped error + mark error +} + +func (f marked) Is(target error) bool { + // if this is false, errors.Is will call unwrap and retry on the wrapped + // error. + return errors.Is(f.mark, target) +} + +func (f marked) As(target any) bool { + // if this is false, errors.As will call unwrap and retry on the wrapped + // error. + return errors.As(f.mark, target) +} + +func (f marked) Unwrap() error { + return f.wrapped +} + +func (f marked) Error() string { + return f.mark.Error() + ": " + f.wrapped.Error() +} diff --git a/pkg/http/middleware/mcp_parse.go b/pkg/http/middleware/mcp_parse.go new file mode 100644 index 000000000..c82616b27 --- /dev/null +++ b/pkg/http/middleware/mcp_parse.go @@ -0,0 +1,126 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + + ghcontext "github.com/github/github-mcp-server/pkg/context" +) + +// mcpJSONRPCRequest represents the structure of an MCP JSON-RPC request. +// We only parse the fields needed for routing and optimization. +type mcpJSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params struct { + // For tools/call + Name string `json:"name,omitempty"` + Arguments json.RawMessage `json:"arguments,omitempty"` + // For prompts/get + // Name is shared with tools/call + // For resources/read + URI string `json:"uri,omitempty"` + } `json:"params"` +} + +// WithMCPParse creates a middleware that parses MCP JSON-RPC requests early in the +// request lifecycle and stores the parsed information in the request context. +// This enables: +// - Registry filtering via ForMCPRequest (only register needed tools/resources/prompts) +// - Avoiding duplicate JSON parsing in downstream middlewares +// - Access to owner/repo for secret-scanning middleware +// +// The middleware reads the request body, parses it, restores the body for downstream +// handlers, and stores the parsed MCPMethodInfo in the request context. +func WithMCPParse() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Skip health check endpoints + if r.URL.Path == "/_ping" { + next.ServeHTTP(w, r) + return + } + + // Only parse POST requests (MCP uses JSON-RPC over POST) + if r.Method != http.MethodPost { + next.ServeHTTP(w, r) + return + } + + // Read the request body + body, err := io.ReadAll(r.Body) + if err != nil { + // Log but continue - don't block requests on parse errors + next.ServeHTTP(w, r) + return + } + + // Restore the body for downstream handlers + r.Body = io.NopCloser(bytes.NewReader(body)) + + // Skip empty bodies + if len(body) == 0 { + next.ServeHTTP(w, r) + return + } + + // Parse the JSON-RPC request + var mcpReq mcpJSONRPCRequest + err = json.Unmarshal(body, &mcpReq) + if err != nil { + // Log but continue - could be a non-MCP request or malformed JSON + next.ServeHTTP(w, r) + return + } + + // Skip if not a valid JSON-RPC 2.0 request + if mcpReq.JSONRPC != "2.0" || mcpReq.Method == "" { + next.ServeHTTP(w, r) + return + } + + // Build the MCPMethodInfo + methodInfo := &ghcontext.MCPMethodInfo{ + Method: mcpReq.Method, + } + + // Extract item name based on method type + + switch mcpReq.Method { + case "tools/call": + methodInfo.ItemName = mcpReq.Params.Name + // Parse arguments if present + if len(mcpReq.Params.Arguments) > 0 { + var args map[string]any + err := json.Unmarshal(mcpReq.Params.Arguments, &args) + if err == nil { + methodInfo.Arguments = args + // Extract owner and repo if present + if owner, ok := args["owner"].(string); ok { + methodInfo.Owner = owner + } + if repo, ok := args["repo"].(string); ok { + methodInfo.Repo = repo + } + } + } + case "prompts/get": + methodInfo.ItemName = mcpReq.Params.Name + case "resources/read": + methodInfo.ItemName = mcpReq.Params.URI + default: + // Whatever + } + + // Store the parsed info in context + ctx = ghcontext.WithMCPMethodInfo(ctx, methodInfo) + + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(fn) + } +} diff --git a/pkg/http/middleware/mcp_parse_test.go b/pkg/http/middleware/mcp_parse_test.go new file mode 100644 index 000000000..5a28a30c3 --- /dev/null +++ b/pkg/http/middleware/mcp_parse_test.go @@ -0,0 +1,191 @@ +package middleware + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWithMCPParse(t *testing.T) { + tests := []struct { + name string + method string + path string + body string + expectInfo bool + expectedMethod string + expectedItem string + expectedOwner string + expectedRepo string + expectedArgs map[string]any + }{ + { + name: "health check path is skipped", + method: http.MethodPost, + path: "/_ping", + body: `{"jsonrpc":"2.0","method":"tools/list"}`, + expectInfo: false, + }, + { + name: "GET request is skipped", + method: http.MethodGet, + path: "/mcp", + body: `{"jsonrpc":"2.0","method":"tools/list"}`, + expectInfo: false, + }, + { + name: "empty body is skipped", + method: http.MethodPost, + path: "/mcp", + body: "", + expectInfo: false, + }, + { + name: "invalid JSON is skipped", + method: http.MethodPost, + path: "/mcp", + body: "not valid json", + expectInfo: false, + }, + { + name: "non-JSON-RPC 2.0 is skipped", + method: http.MethodPost, + path: "/mcp", + body: `{"jsonrpc":"1.0","method":"tools/list"}`, + expectInfo: false, + }, + { + name: "empty method is skipped", + method: http.MethodPost, + path: "/mcp", + body: `{"jsonrpc":"2.0","method":""}`, + expectInfo: false, + }, + { + name: "tools/list parses method only", + method: http.MethodPost, + path: "/mcp", + body: `{"jsonrpc":"2.0","method":"tools/list"}`, + expectInfo: true, + expectedMethod: "tools/list", + }, + { + name: "tools/call parses name", + method: http.MethodPost, + path: "/mcp", + body: `{"jsonrpc":"2.0","method":"tools/call","params":{"name":"get_file_contents"}}`, + expectInfo: true, + expectedMethod: "tools/call", + expectedItem: "get_file_contents", + }, + { + name: "tools/call parses owner and repo from arguments", + method: http.MethodPost, + path: "/mcp", + body: `{"jsonrpc":"2.0","method":"tools/call","params":{"name":"get_file_contents","arguments":{"owner":"github","repo":"github-mcp-server","path":"README.md"}}}`, + expectInfo: true, + expectedMethod: "tools/call", + expectedItem: "get_file_contents", + expectedOwner: "github", + expectedRepo: "github-mcp-server", + expectedArgs: map[string]any{"owner": "github", "repo": "github-mcp-server", "path": "README.md"}, + }, + { + name: "tools/call with invalid arguments JSON continues without args", + method: http.MethodPost, + path: "/mcp", + body: `{"jsonrpc":"2.0","method":"tools/call","params":{"name":"get_file_contents","arguments":"not an object"}}`, + expectInfo: true, + expectedMethod: "tools/call", + expectedItem: "get_file_contents", + }, + { + name: "prompts/get parses name", + method: http.MethodPost, + path: "/mcp", + body: `{"jsonrpc":"2.0","method":"prompts/get","params":{"name":"my_prompt"}}`, + expectInfo: true, + expectedMethod: "prompts/get", + expectedItem: "my_prompt", + }, + { + name: "resources/read parses URI as item name", + method: http.MethodPost, + path: "/mcp", + body: `{"jsonrpc":"2.0","method":"resources/read","params":{"uri":"repo://github/github-mcp-server"}}`, + expectInfo: true, + expectedMethod: "resources/read", + expectedItem: "repo://github/github-mcp-server", + }, + { + name: "initialize method parses correctly", + method: http.MethodPost, + path: "/mcp", + body: `{"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{}}}`, + expectInfo: true, + expectedMethod: "initialize", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var capturedInfo *ghcontext.MCPMethodInfo + var infoCaptured bool + + // Create a handler that captures the MCPMethodInfo from context + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + capturedInfo, infoCaptured = ghcontext.MCPMethod(r.Context()) + }) + + middleware := WithMCPParse() + handler := middleware(nextHandler) + + req := httptest.NewRequest(tt.method, tt.path, strings.NewReader(tt.body)) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if tt.expectInfo { + require.True(t, infoCaptured, "MCPMethodInfo should be present in context") + require.NotNil(t, capturedInfo) + assert.Equal(t, tt.expectedMethod, capturedInfo.Method) + assert.Equal(t, tt.expectedItem, capturedInfo.ItemName) + assert.Equal(t, tt.expectedOwner, capturedInfo.Owner) + assert.Equal(t, tt.expectedRepo, capturedInfo.Repo) + if tt.expectedArgs != nil { + assert.Equal(t, tt.expectedArgs, capturedInfo.Arguments) + } + } else { + assert.False(t, infoCaptured, "MCPMethodInfo should not be present in context") + } + }) + } +} + +func TestWithMCPParse_BodyRestoration(t *testing.T) { + originalBody := `{"jsonrpc":"2.0","method":"tools/call","params":{"name":"test_tool"}}` + + var capturedBody string + + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + require.NoError(t, err) + capturedBody = string(body) + }) + + middleware := WithMCPParse() + handler := middleware(nextHandler) + + req := httptest.NewRequest(http.MethodPost, "/mcp", strings.NewReader(originalBody)) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + assert.Equal(t, originalBody, capturedBody, "body should be restored for downstream handlers") +} diff --git a/pkg/http/middleware/pat_scope.go b/pkg/http/middleware/pat_scope.go new file mode 100644 index 000000000..8b77b3d32 --- /dev/null +++ b/pkg/http/middleware/pat_scope.go @@ -0,0 +1,50 @@ +package middleware + +import ( + "log/slog" + "net/http" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/scopes" + "github.com/github/github-mcp-server/pkg/utils" +) + +// WithPATScopes is a middleware that fetches and stores scopes for classic Personal Access Tokens (PATs) in the request context. +func WithPATScopes(logger *slog.Logger, scopeFetcher scopes.FetcherInterface) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + tokenInfo, ok := ghcontext.GetTokenInfo(ctx) + if !ok || tokenInfo == nil { + logger.Warn("no token info found in context") + next.ServeHTTP(w, r) + return + } + + // Fetch token scopes for scope-based tool filtering (PAT tokens only) + // Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header. + // Fine-grained PATs and other token types don't support this, so we skip filtering. + if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken { + scopesList, err := scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token) + if err != nil { + logger.Warn("failed to fetch PAT scopes", "error", err) + next.ServeHTTP(w, r) + return + } + + tokenInfo.Scopes = scopesList + tokenInfo.ScopesFetched = true + + // Store fetched scopes in context for downstream use + ctx := ghcontext.WithTokenInfo(ctx, tokenInfo) + + next.ServeHTTP(w, r.WithContext(ctx)) + return + } + + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } +} diff --git a/pkg/http/middleware/pat_scope_test.go b/pkg/http/middleware/pat_scope_test.go new file mode 100644 index 000000000..eb472bcf1 --- /dev/null +++ b/pkg/http/middleware/pat_scope_test.go @@ -0,0 +1,187 @@ +package middleware + +import ( + "context" + "errors" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockScopeFetcher is a mock implementation of scopes.FetcherInterface +type mockScopeFetcher struct { + scopes []string + err error +} + +func (m *mockScopeFetcher) FetchTokenScopes(_ context.Context, _ string) ([]string, error) { + return m.scopes, m.err +} + +func TestWithPATScopes(t *testing.T) { + logger := slog.Default() + + tests := []struct { + name string + tokenInfo *ghcontext.TokenInfo + fetcherScopes []string + fetcherErr error + expectScopesFetched bool + expectedScopes []string + expectNextHandlerCalled bool + }{ + { + name: "no token info in context calls next handler", + tokenInfo: nil, + expectScopesFetched: false, + expectedScopes: nil, + expectNextHandlerCalled: true, + }, + { + name: "non-PAT token type skips scope fetching", + tokenInfo: &ghcontext.TokenInfo{ + Token: "gho_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + TokenType: utils.TokenTypeOAuthAccessToken, + }, + expectScopesFetched: false, + expectedScopes: nil, + expectNextHandlerCalled: true, + }, + { + name: "fine-grained PAT skips scope fetching", + tokenInfo: &ghcontext.TokenInfo{ + Token: "github_pat_xxxxxxxxxxxxxxxxxxxxxxx", + TokenType: utils.TokenTypeFineGrainedPersonalAccessToken, + }, + expectScopesFetched: false, + expectedScopes: nil, + expectNextHandlerCalled: true, + }, + { + name: "classic PAT fetches and stores scopes", + tokenInfo: &ghcontext.TokenInfo{ + Token: "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + TokenType: utils.TokenTypePersonalAccessToken, + }, + fetcherScopes: []string{"repo", "user", "read:org"}, + expectScopesFetched: true, + expectedScopes: []string{"repo", "user", "read:org"}, + expectNextHandlerCalled: true, + }, + { + name: "classic PAT with empty scopes", + tokenInfo: &ghcontext.TokenInfo{ + Token: "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + TokenType: utils.TokenTypePersonalAccessToken, + }, + fetcherScopes: []string{}, + expectScopesFetched: true, + expectedScopes: []string{}, + expectNextHandlerCalled: true, + }, + { + name: "fetcher error calls next handler without scopes", + tokenInfo: &ghcontext.TokenInfo{ + Token: "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + TokenType: utils.TokenTypePersonalAccessToken, + }, + fetcherErr: errors.New("network error"), + expectScopesFetched: false, + expectedScopes: nil, + expectNextHandlerCalled: true, + }, + { + name: "old-style PAT (40 hex chars) fetches scopes", + tokenInfo: &ghcontext.TokenInfo{ + Token: "0123456789abcdef0123456789abcdef01234567", + TokenType: utils.TokenTypePersonalAccessToken, + }, + fetcherScopes: []string{"repo"}, + expectScopesFetched: true, + expectedScopes: []string{"repo"}, + expectNextHandlerCalled: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var capturedTokenInfo *ghcontext.TokenInfo + var nextHandlerCalled bool + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextHandlerCalled = true + capturedTokenInfo, _ = ghcontext.GetTokenInfo(r.Context()) + w.WriteHeader(http.StatusOK) + }) + + fetcher := &mockScopeFetcher{ + scopes: tt.fetcherScopes, + err: tt.fetcherErr, + } + + middleware := WithPATScopes(logger, fetcher) + handler := middleware(nextHandler) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + + // Set up context with token info if provided + if tt.tokenInfo != nil { + ctx := ghcontext.WithTokenInfo(req.Context(), tt.tokenInfo) + req = req.WithContext(ctx) + } + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, tt.expectNextHandlerCalled, nextHandlerCalled, "next handler called mismatch") + + if tt.expectNextHandlerCalled && tt.tokenInfo != nil { + require.NotNil(t, capturedTokenInfo, "expected token info in context") + assert.Equal(t, tt.expectScopesFetched, capturedTokenInfo.ScopesFetched) + assert.Equal(t, tt.expectedScopes, capturedTokenInfo.Scopes) + } + }) + } +} + +func TestWithPATScopes_PreservesExistingTokenInfo(t *testing.T) { + logger := slog.Default() + + var capturedTokenInfo *ghcontext.TokenInfo + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedTokenInfo, _ = ghcontext.GetTokenInfo(r.Context()) + w.WriteHeader(http.StatusOK) + }) + + fetcher := &mockScopeFetcher{ + scopes: []string{"repo", "user"}, + } + + originalTokenInfo := &ghcontext.TokenInfo{ + Token: "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + TokenType: utils.TokenTypePersonalAccessToken, + } + + middleware := WithPATScopes(logger, fetcher) + handler := middleware(nextHandler) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := ghcontext.WithTokenInfo(req.Context(), originalTokenInfo) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + require.NotNil(t, capturedTokenInfo) + assert.Equal(t, originalTokenInfo.Token, capturedTokenInfo.Token) + assert.Equal(t, originalTokenInfo.TokenType, capturedTokenInfo.TokenType) + assert.True(t, capturedTokenInfo.ScopesFetched) + assert.Equal(t, []string{"repo", "user"}, capturedTokenInfo.Scopes) +} diff --git a/pkg/http/middleware/request_config.go b/pkg/http/middleware/request_config.go new file mode 100644 index 000000000..5cabe16eb --- /dev/null +++ b/pkg/http/middleware/request_config.go @@ -0,0 +1,59 @@ +package middleware + +import ( + "net/http" + "slices" + "strings" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/http/headers" +) + +// WithRequestConfig is a middleware that extracts MCP-related headers and sets them in the request context. +// This includes readonly mode, toolsets, tools, lockdown mode, insiders mode, and feature flags. +func WithRequestConfig(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Readonly mode + if relaxedParseBool(r.Header.Get(headers.MCPReadOnlyHeader)) { + ctx = ghcontext.WithReadonly(ctx, true) + } + + // Toolsets + if toolsets := headers.ParseCommaSeparated(r.Header.Get(headers.MCPToolsetsHeader)); len(toolsets) > 0 { + ctx = ghcontext.WithToolsets(ctx, toolsets) + } + + // Tools + if tools := headers.ParseCommaSeparated(r.Header.Get(headers.MCPToolsHeader)); len(tools) > 0 { + ctx = ghcontext.WithTools(ctx, tools) + } + + // Lockdown mode + if relaxedParseBool(r.Header.Get(headers.MCPLockdownHeader)) { + ctx = ghcontext.WithLockdownMode(ctx, true) + } + + // Insiders mode + if relaxedParseBool(r.Header.Get(headers.MCPInsidersHeader)) { + ctx = ghcontext.WithInsidersMode(ctx, true) + } + + // Feature flags + if features := headers.ParseCommaSeparated(r.Header.Get(headers.MCPFeaturesHeader)); len(features) > 0 { + ctx = ghcontext.WithHeaderFeatures(ctx, features) + } + + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// relaxedParseBool parses a string into a boolean value, treating various +// common false values or empty strings as false, and everything else as true. +// It is case-insensitive and trims whitespace. +func relaxedParseBool(s string) bool { + s = strings.TrimSpace(strings.ToLower(s)) + falseValues := []string{"", "false", "0", "no", "off", "n", "f"} + return !slices.Contains(falseValues, s) +} diff --git a/pkg/http/middleware/scope_challenge.go b/pkg/http/middleware/scope_challenge.go new file mode 100644 index 000000000..526797241 --- /dev/null +++ b/pkg/http/middleware/scope_challenge.go @@ -0,0 +1,143 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/http/oauth" + "github.com/github/github-mcp-server/pkg/scopes" + "github.com/github/github-mcp-server/pkg/utils" +) + +// WithScopeChallenge creates a new middleware that determines if an OAuth request contains sufficient scopes to +// complete the request and returns a scope challenge if not. +func WithScopeChallenge(oauthCfg *oauth.Config, scopeFetcher scopes.FetcherInterface) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Skip health check endpoints + if r.URL.Path == "/_ping" { + next.ServeHTTP(w, r) + return + } + + // Get user from context + tokenInfo, ok := ghcontext.GetTokenInfo(ctx) + if !ok { + next.ServeHTTP(w, r) + return + } + + // Only check OAuth tokens - scope challenge allows OAuth apps to request additional scopes + if tokenInfo.TokenType != utils.TokenTypeOAuthAccessToken { + next.ServeHTTP(w, r) + return + } + + // Try to use pre-parsed MCP method info first (performance optimization) + // This avoids re-parsing the JSON body if WithMCPParse middleware ran earlier + var toolName string + if methodInfo, ok := ghcontext.MCPMethod(ctx); ok && methodInfo != nil { + // Only check tools/call requests + if methodInfo.Method != "tools/call" { + next.ServeHTTP(w, r) + return + } + toolName = methodInfo.ItemName + } else { + // Fallback: parse the request body directly + body, err := io.ReadAll(r.Body) + if err != nil { + next.ServeHTTP(w, r) + return + } + r.Body = io.NopCloser(bytes.NewReader(body)) + + var mcpRequest struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params struct { + Name string `json:"name,omitempty"` + Arguments map[string]any `json:"arguments,omitempty"` + } `json:"params"` + } + + err = json.Unmarshal(body, &mcpRequest) + if err != nil { + next.ServeHTTP(w, r) + return + } + + // Only check tools/call requests + if mcpRequest.Method != "tools/call" { + next.ServeHTTP(w, r) + return + } + + toolName = mcpRequest.Params.Name + } + toolScopeInfo, err := scopes.GetToolScopeInfo(toolName) + if err != nil { + next.ServeHTTP(w, r) + return + } + + // If tool not found in scope map, allow the request + if toolScopeInfo == nil { + next.ServeHTTP(w, r) + return + } + + // Get OAuth scopes from GitHub API + activeScopes, err := scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token) + if err != nil { + next.ServeHTTP(w, r) + return + } + + // Store active scopes in context for downstream use + tokenInfo.Scopes = activeScopes + tokenInfo.ScopesFetched = true + ctx = ghcontext.WithTokenInfo(ctx, tokenInfo) + r = r.WithContext(ctx) + + // Check if user has the required scopes + if toolScopeInfo.HasAcceptedScope(activeScopes...) { + next.ServeHTTP(w, r) + return + } + + // User lacks required scopes - get the scopes they need + requiredScopes := toolScopeInfo.GetRequiredScopesSlice() + + // Build the resource metadata URL using the shared utility + // GetEffectiveResourcePath returns the original path (e.g., /mcp or /mcp/x/all) + // which is used to construct the well-known OAuth protected resource URL + resourcePath := oauth.ResolveResourcePath(r, oauthCfg) + resourceMetadataURL := oauth.BuildResourceMetadataURL(r, oauthCfg, resourcePath) + + // Build recommended scopes: existing scopes + required scopes + recommendedScopes := make([]string, 0, len(activeScopes)+len(requiredScopes)) + recommendedScopes = append(recommendedScopes, activeScopes...) + recommendedScopes = append(recommendedScopes, requiredScopes...) + + // Build the WWW-Authenticate header value + wwwAuthenticateHeader := fmt.Sprintf(`Bearer error="insufficient_scope", scope=%q, resource_metadata=%q, error_description=%q`, + strings.Join(recommendedScopes, " "), + resourceMetadataURL, + "Additional scopes required: "+strings.Join(requiredScopes, ", "), + ) + + // Send scope challenge response with the superset of existing and required scopes + w.Header().Set("WWW-Authenticate", wwwAuthenticateHeader) + http.Error(w, "Forbidden: insufficient scopes", http.StatusForbidden) + } + return http.HandlerFunc(fn) + } +} diff --git a/pkg/http/middleware/token.go b/pkg/http/middleware/token.go new file mode 100644 index 000000000..c362ea201 --- /dev/null +++ b/pkg/http/middleware/token.go @@ -0,0 +1,47 @@ +package middleware + +import ( + "errors" + "fmt" + "net/http" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/http/oauth" + "github.com/github/github-mcp-server/pkg/utils" +) + +func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tokenType, token, err := utils.ParseAuthorizationHeader(r) + if err != nil { + // For missing Authorization header, return 401 with WWW-Authenticate header per MCP spec + if errors.Is(err, utils.ErrMissingAuthorizationHeader) { + sendAuthChallenge(w, r, oauthCfg) + return + } + // For other auth errors (bad format, unsupported), return 400 + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + ctx := r.Context() + ctx = ghcontext.WithTokenInfo(ctx, &ghcontext.TokenInfo{ + Token: token, + TokenType: tokenType, + }) + r = r.WithContext(ctx) + + next.ServeHTTP(w, r) + }) + } +} + +// sendAuthChallenge sends a 401 Unauthorized response with WWW-Authenticate header +// containing the OAuth protected resource metadata URL as per RFC 6750 and MCP spec. +func sendAuthChallenge(w http.ResponseWriter, r *http.Request, oauthCfg *oauth.Config) { + resourcePath := oauth.ResolveResourcePath(r, oauthCfg) + resourceMetadataURL := oauth.BuildResourceMetadataURL(r, oauthCfg, resourcePath) + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata=%q`, resourceMetadataURL)) + http.Error(w, "Unauthorized", http.StatusUnauthorized) +} diff --git a/pkg/http/middleware/token_test.go b/pkg/http/middleware/token_test.go new file mode 100644 index 000000000..fa8f0ee98 --- /dev/null +++ b/pkg/http/middleware/token_test.go @@ -0,0 +1,321 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/github/github-mcp-server/pkg/http/oauth" + "github.com/github/github-mcp-server/pkg/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExtractUserToken(t *testing.T) { + oauthCfg := &oauth.Config{ + BaseURL: "https://example.com", + AuthorizationServer: "https://github.com/login/oauth", + } + + tests := []struct { + name string + authHeader string + expectedStatusCode int + expectedTokenType utils.TokenType + expectedToken string + expectTokenInfo bool + expectWWWAuth bool + }{ + // Missing authorization header + { + name: "missing Authorization header returns 401 with WWW-Authenticate", + authHeader: "", + expectedStatusCode: http.StatusUnauthorized, + expectTokenInfo: false, + expectWWWAuth: true, + }, + // Personal Access Token (classic) - ghp_ prefix + { + name: "personal access token (classic) with Bearer prefix", + authHeader: "Bearer ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + expectedStatusCode: http.StatusOK, + expectedTokenType: utils.TokenTypePersonalAccessToken, + expectedToken: "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + expectTokenInfo: true, + }, + { + name: "personal access token (classic) with bearer lowercase", + authHeader: "bearer ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + expectedStatusCode: http.StatusOK, + expectedTokenType: utils.TokenTypePersonalAccessToken, + expectedToken: "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + expectTokenInfo: true, + }, + { + name: "personal access token (classic) without Bearer prefix", + authHeader: "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + expectedStatusCode: http.StatusOK, + expectedTokenType: utils.TokenTypePersonalAccessToken, + expectedToken: "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + expectTokenInfo: true, + }, + // Fine-grained Personal Access Token - github_pat_ prefix + { + name: "fine-grained personal access token with Bearer prefix", + authHeader: "Bearer github_pat_xxxxxxxxxxxxxxxxxxxxxxx", + expectedStatusCode: http.StatusOK, + expectedTokenType: utils.TokenTypeFineGrainedPersonalAccessToken, + expectedToken: "github_pat_xxxxxxxxxxxxxxxxxxxxxxx", + expectTokenInfo: true, + }, + { + name: "fine-grained personal access token without Bearer prefix", + authHeader: "github_pat_xxxxxxxxxxxxxxxxxxxxxxx", + expectedStatusCode: http.StatusOK, + expectedTokenType: utils.TokenTypeFineGrainedPersonalAccessToken, + expectedToken: "github_pat_xxxxxxxxxxxxxxxxxxxxxxx", + expectTokenInfo: true, + }, + // OAuth Access Token - gho_ prefix + { + name: "OAuth access token with Bearer prefix", + authHeader: "Bearer gho_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + expectedStatusCode: http.StatusOK, + expectedTokenType: utils.TokenTypeOAuthAccessToken, + expectedToken: "gho_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + expectTokenInfo: true, + }, + { + name: "OAuth access token without Bearer prefix", + authHeader: "gho_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + expectedStatusCode: http.StatusOK, + expectedTokenType: utils.TokenTypeOAuthAccessToken, + expectedToken: "gho_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + expectTokenInfo: true, + }, + // User-to-Server GitHub App Token - ghu_ prefix + { + name: "user-to-server GitHub App token with Bearer prefix", + authHeader: "Bearer ghu_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + expectedStatusCode: http.StatusOK, + expectedTokenType: utils.TokenTypeUserToServerGitHubAppToken, + expectedToken: "ghu_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + expectTokenInfo: true, + }, + { + name: "user-to-server GitHub App token without Bearer prefix", + authHeader: "ghu_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + expectedStatusCode: http.StatusOK, + expectedTokenType: utils.TokenTypeUserToServerGitHubAppToken, + expectedToken: "ghu_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + expectTokenInfo: true, + }, + // Server-to-Server GitHub App Token (installation token) - ghs_ prefix + { + name: "server-to-server GitHub App token with Bearer prefix", + authHeader: "Bearer ghs_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + expectedStatusCode: http.StatusOK, + expectedTokenType: utils.TokenTypeServerToServerGitHubAppToken, + expectedToken: "ghs_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + expectTokenInfo: true, + }, + { + name: "server-to-server GitHub App token without Bearer prefix", + authHeader: "ghs_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + expectedStatusCode: http.StatusOK, + expectedTokenType: utils.TokenTypeServerToServerGitHubAppToken, + expectedToken: "ghs_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + expectTokenInfo: true, + }, + // Old-style Personal Access Token (40 hex characters, pre-2021) + { + name: "old-style personal access token (40 hex chars) with Bearer prefix", + authHeader: "Bearer 0123456789abcdef0123456789abcdef01234567", + expectedStatusCode: http.StatusOK, + expectedTokenType: utils.TokenTypePersonalAccessToken, + expectedToken: "0123456789abcdef0123456789abcdef01234567", + expectTokenInfo: true, + }, + { + name: "old-style personal access token (40 hex chars) without Bearer prefix", + authHeader: "0123456789abcdef0123456789abcdef01234567", + expectedStatusCode: http.StatusOK, + expectedTokenType: utils.TokenTypePersonalAccessToken, + expectedToken: "0123456789abcdef0123456789abcdef01234567", + expectTokenInfo: true, + }, + // Error cases + { + name: "unsupported GitHub-Bearer header returns 400", + authHeader: "GitHub-Bearer some_encrypted_token", + expectedStatusCode: http.StatusBadRequest, + expectTokenInfo: false, + }, + { + name: "invalid token format returns 400", + authHeader: "Bearer invalid_token_format", + expectedStatusCode: http.StatusBadRequest, + expectTokenInfo: false, + }, + { + name: "unrecognized prefix returns 400", + authHeader: "Bearer xyz_notavalidprefix", + expectedStatusCode: http.StatusBadRequest, + expectTokenInfo: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var capturedTokenInfo *ghcontext.TokenInfo + var tokenInfoCaptured bool + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedTokenInfo, tokenInfoCaptured = ghcontext.GetTokenInfo(r.Context()) + w.WriteHeader(http.StatusOK) + }) + + middleware := ExtractUserToken(oauthCfg) + handler := middleware(nextHandler) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + if tt.authHeader != "" { + req.Header.Set(headers.AuthorizationHeader, tt.authHeader) + } + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + assert.Equal(t, tt.expectedStatusCode, rr.Code) + + if tt.expectWWWAuth { + wwwAuth := rr.Header().Get("WWW-Authenticate") + assert.NotEmpty(t, wwwAuth, "expected WWW-Authenticate header") + assert.Contains(t, wwwAuth, "Bearer resource_metadata=") + } + + if tt.expectTokenInfo { + require.True(t, tokenInfoCaptured, "expected TokenInfo to be present in context") + require.NotNil(t, capturedTokenInfo) + assert.Equal(t, tt.expectedTokenType, capturedTokenInfo.TokenType) + assert.Equal(t, tt.expectedToken, capturedTokenInfo.Token) + } else { + assert.False(t, tokenInfoCaptured, "expected no TokenInfo in context") + } + }) + } +} + +func TestExtractUserToken_NilOAuthConfig(t *testing.T) { + var capturedTokenInfo *ghcontext.TokenInfo + var tokenInfoCaptured bool + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedTokenInfo, tokenInfoCaptured = ghcontext.GetTokenInfo(r.Context()) + w.WriteHeader(http.StatusOK) + }) + + middleware := ExtractUserToken(nil) + handler := middleware(nextHandler) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set(headers.AuthorizationHeader, "Bearer ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + require.True(t, tokenInfoCaptured) + require.NotNil(t, capturedTokenInfo) + assert.Equal(t, utils.TokenTypePersonalAccessToken, capturedTokenInfo.TokenType) +} + +func TestExtractUserToken_MissingAuthHeader_WWWAuthenticateFormat(t *testing.T) { + oauthCfg := &oauth.Config{ + BaseURL: "https://api.example.com", + AuthorizationServer: "https://github.com/login/oauth", + ResourcePath: "/mcp", + } + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + middleware := ExtractUserToken(oauthCfg) + handler := middleware(nextHandler) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + // No Authorization header + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusUnauthorized, rr.Code) + wwwAuth := rr.Header().Get("WWW-Authenticate") + assert.NotEmpty(t, wwwAuth) + assert.Contains(t, wwwAuth, "Bearer") + assert.Contains(t, wwwAuth, "resource_metadata=") + assert.Contains(t, wwwAuth, "/.well-known/oauth-protected-resource") +} + +func TestSendAuthChallenge(t *testing.T) { + tests := []struct { + name string + oauthCfg *oauth.Config + requestPath string + expectedContains []string + }{ + { + name: "with base URL configured", + oauthCfg: &oauth.Config{ + BaseURL: "https://mcp.example.com", + }, + requestPath: "/api/test", + expectedContains: []string{ + "Bearer", + "resource_metadata=", + "https://mcp.example.com/.well-known/oauth-protected-resource", + }, + }, + { + name: "with nil config uses request host", + oauthCfg: nil, + requestPath: "/api/test", + expectedContains: []string{ + "Bearer", + "resource_metadata=", + "/.well-known/oauth-protected-resource", + }, + }, + { + name: "with resource path configured", + oauthCfg: &oauth.Config{ + BaseURL: "https://mcp.example.com", + ResourcePath: "/mcp", + }, + requestPath: "/api/test", + expectedContains: []string{ + "Bearer", + "resource_metadata=", + "/mcp", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil) + + sendAuthChallenge(rr, req, tt.oauthCfg) + + assert.Equal(t, http.StatusUnauthorized, rr.Code) + wwwAuth := rr.Header().Get("WWW-Authenticate") + for _, expected := range tt.expectedContains { + assert.Contains(t, wwwAuth, expected) + } + }) + } +} diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go new file mode 100644 index 000000000..ecdcf95ab --- /dev/null +++ b/pkg/http/oauth/oauth.go @@ -0,0 +1,243 @@ +// Package oauth provides OAuth 2.0 Protected Resource Metadata (RFC 9728) support +// for the GitHub MCP Server HTTP mode. +package oauth + +import ( + "fmt" + "net/http" + "strings" + + "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/go-chi/chi/v5" + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/oauthex" +) + +const ( + // OAuthProtectedResourcePrefix is the well-known path prefix for OAuth protected resource metadata. + OAuthProtectedResourcePrefix = "/.well-known/oauth-protected-resource" + + // DefaultAuthorizationServer is GitHub's OAuth authorization server. + DefaultAuthorizationServer = "https://github.com/login/oauth" +) + +// SupportedScopes lists all OAuth scopes that may be required by MCP tools. +var SupportedScopes = []string{ + "repo", + "read:org", + "read:user", + "user:email", + "read:packages", + "write:packages", + "read:project", + "project", + "gist", + "notifications", + "workflow", + "codespace", +} + +// Config holds the OAuth configuration for the MCP server. +type Config struct { + // BaseURL is the publicly accessible URL where this server is hosted. + // This is used to construct the OAuth resource URL. + BaseURL string + + // AuthorizationServer is the OAuth authorization server URL. + // Defaults to GitHub's OAuth server if not specified. + AuthorizationServer string + + // ResourcePath is the externally visible base path for the MCP server (e.g., "/mcp"). + // This is used to restore the original path when a proxy strips a base path before forwarding. + // If empty, requests are treated as already using the external path. + ResourcePath string +} + +// AuthHandler handles OAuth-related HTTP endpoints. +type AuthHandler struct { + cfg *Config +} + +// NewAuthHandler creates a new OAuth auth handler. +func NewAuthHandler(cfg *Config) (*AuthHandler, error) { + if cfg == nil { + cfg = &Config{} + } + + // Default authorization server to GitHub + if cfg.AuthorizationServer == "" { + cfg.AuthorizationServer = DefaultAuthorizationServer + } + + return &AuthHandler{ + cfg: cfg, + }, nil +} + +// routePatterns defines the route patterns for OAuth protected resource metadata. +var routePatterns = []string{ + "", // Root: /.well-known/oauth-protected-resource + "/readonly", // Read-only mode + "/insiders", // Insiders mode + "/x/{toolset}", + "/x/{toolset}/readonly", +} + +// RegisterRoutes registers the OAuth protected resource metadata routes. +func (h *AuthHandler) RegisterRoutes(r chi.Router) { + for _, pattern := range routePatterns { + for _, route := range h.routesForPattern(pattern) { + path := OAuthProtectedResourcePrefix + route + r.Handle(path, h.metadataHandler()) + } + } +} + +func (h *AuthHandler) metadataHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resourcePath := resolveResourcePath( + strings.TrimPrefix(r.URL.Path, OAuthProtectedResourcePrefix), + h.cfg.ResourcePath, + ) + resourceURL := h.buildResourceURL(r, resourcePath) + + metadata := &oauthex.ProtectedResourceMetadata{ + Resource: resourceURL, + AuthorizationServers: []string{h.cfg.AuthorizationServer}, + ResourceName: "GitHub MCP Server", + ScopesSupported: SupportedScopes, + BearerMethodsSupported: []string{"header"}, + } + + auth.ProtectedResourceMetadataHandler(metadata).ServeHTTP(w, r) + }) +} + +// routesForPattern generates route variants for a given pattern. +// GitHub strips the /mcp prefix before forwarding, so we register both variants: +// - With /mcp prefix: for direct access or when GitHub doesn't strip +// - Without /mcp prefix: for when GitHub has stripped the prefix +func (h *AuthHandler) routesForPattern(pattern string) []string { + basePaths := []string{""} + if basePath := normalizeBasePath(h.cfg.ResourcePath); basePath != "" { + basePaths = append(basePaths, basePath) + } else { + basePaths = append(basePaths, "/mcp") + } + + routes := make([]string, 0, len(basePaths)*2) + for _, basePath := range basePaths { + routes = append(routes, joinRoute(basePath, pattern)) + routes = append(routes, joinRoute(basePath, pattern)+"/") + } + + return routes +} + +// resolveResourcePath returns the externally visible resource path, +// restoring the configured base path when proxies strip it before forwarding. +func resolveResourcePath(path, basePath string) string { + if path == "" { + path = "/" + } + base := normalizeBasePath(basePath) + if base == "" { + return path + } + if path == "/" { + return base + } + if path == base || strings.HasPrefix(path, base+"/") { + return path + } + return base + path +} + +// ResolveResourcePath returns the externally visible resource path for a request. +// Exported for use by middleware. +func ResolveResourcePath(r *http.Request, cfg *Config) string { + basePath := "" + if cfg != nil { + basePath = cfg.ResourcePath + } + return resolveResourcePath(r.URL.Path, basePath) +} + +// buildResourceURL constructs the full resource URL for OAuth metadata. +func (h *AuthHandler) buildResourceURL(r *http.Request, resourcePath string) string { + host, scheme := GetEffectiveHostAndScheme(r, h.cfg) + baseURL := fmt.Sprintf("%s://%s", scheme, host) + if h.cfg.BaseURL != "" { + baseURL = strings.TrimSuffix(h.cfg.BaseURL, "/") + } + if resourcePath == "" { + resourcePath = "/" + } + if !strings.HasPrefix(resourcePath, "/") { + resourcePath = "/" + resourcePath + } + return baseURL + resourcePath +} + +// GetEffectiveHostAndScheme returns the effective host and scheme for a request. +func GetEffectiveHostAndScheme(r *http.Request, cfg *Config) (host, scheme string) { //nolint:revive + if fh := r.Header.Get(headers.ForwardedHostHeader); fh != "" { + host = fh + } else { + host = r.Host + } + if host == "" { + host = "localhost" + } + if fp := r.Header.Get(headers.ForwardedProtoHeader); fp != "" { + scheme = strings.ToLower(fp) + } else { + if r.TLS != nil { + scheme = "https" + } else { + scheme = "http" + } + } + return +} + +// BuildResourceMetadataURL constructs the full URL to the OAuth protected resource metadata endpoint. +func BuildResourceMetadataURL(r *http.Request, cfg *Config, resourcePath string) string { + host, scheme := GetEffectiveHostAndScheme(r, cfg) + suffix := "" + if resourcePath != "" && resourcePath != "/" { + if !strings.HasPrefix(resourcePath, "/") { + suffix = "/" + resourcePath + } else { + suffix = resourcePath + } + } + if cfg != nil && cfg.BaseURL != "" { + return strings.TrimSuffix(cfg.BaseURL, "/") + OAuthProtectedResourcePrefix + suffix + } + return fmt.Sprintf("%s://%s%s%s", scheme, host, OAuthProtectedResourcePrefix, suffix) +} + +func normalizeBasePath(path string) string { + trimmed := strings.TrimSpace(path) + if trimmed == "" || trimmed == "/" { + return "" + } + if !strings.HasPrefix(trimmed, "/") { + trimmed = "/" + trimmed + } + return strings.TrimSuffix(trimmed, "/") +} + +func joinRoute(basePath, pattern string) string { + if basePath == "" { + return pattern + } + if pattern == "" { + return basePath + } + if strings.HasSuffix(basePath, "/") { + return strings.TrimSuffix(basePath, "/") + pattern + } + return basePath + pattern +} diff --git a/pkg/http/oauth/oauth_test.go b/pkg/http/oauth/oauth_test.go new file mode 100644 index 000000000..9133e8331 --- /dev/null +++ b/pkg/http/oauth/oauth_test.go @@ -0,0 +1,615 @@ +package oauth + +import ( + "crypto/tls" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewAuthHandler(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + expectedAuthServer string + expectedResourcePath string + }{ + { + name: "nil config uses defaults", + cfg: nil, + expectedAuthServer: DefaultAuthorizationServer, + expectedResourcePath: "", + }, + { + name: "empty config uses defaults", + cfg: &Config{}, + expectedAuthServer: DefaultAuthorizationServer, + expectedResourcePath: "", + }, + { + name: "custom authorization server", + cfg: &Config{ + AuthorizationServer: "https://custom.example.com/oauth", + }, + expectedAuthServer: "https://custom.example.com/oauth", + expectedResourcePath: "", + }, + { + name: "custom base URL and resource path", + cfg: &Config{ + BaseURL: "https://example.com", + ResourcePath: "/mcp", + }, + expectedAuthServer: DefaultAuthorizationServer, + expectedResourcePath: "/mcp", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(tc.cfg) + require.NoError(t, err) + require.NotNil(t, handler) + + assert.Equal(t, tc.expectedAuthServer, handler.cfg.AuthorizationServer) + }) + } +} + +func TestGetEffectiveHostAndScheme(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setupRequest func() *http.Request + cfg *Config + expectedHost string + expectedScheme string + }{ + { + name: "basic request without forwarding headers", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "http", // defaults to http + }, + { + name: "request with X-Forwarded-Host header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "internal.example.com" + req.Header.Set(headers.ForwardedHostHeader, "public.example.com") + return req + }, + cfg: &Config{}, + expectedHost: "public.example.com", + expectedScheme: "http", + }, + { + name: "request with X-Forwarded-Proto header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + req.Header.Set(headers.ForwardedProtoHeader, "http") + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "http", + }, + { + name: "request with both forwarding headers", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "internal.example.com" + req.Header.Set(headers.ForwardedHostHeader, "public.example.com") + req.Header.Set(headers.ForwardedProtoHeader, "https") + return req + }, + cfg: &Config{}, + expectedHost: "public.example.com", + expectedScheme: "https", + }, + { + name: "request with TLS", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + req.TLS = &tls.ConnectionState{} + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "https", + }, + { + name: "X-Forwarded-Proto takes precedence over TLS", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + req.TLS = &tls.ConnectionState{} + req.Header.Set(headers.ForwardedProtoHeader, "http") + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "http", + }, + { + name: "scheme is lowercased", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + req.Header.Set(headers.ForwardedProtoHeader, "HTTPS") + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "https", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req := tc.setupRequest() + host, scheme := GetEffectiveHostAndScheme(req, tc.cfg) + + assert.Equal(t, tc.expectedHost, host) + assert.Equal(t, tc.expectedScheme, scheme) + }) + } +} + +func TestResolveResourcePath(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + setupRequest func() *http.Request + expectedPath string + }{ + { + name: "no base path uses request path", + cfg: &Config{}, + setupRequest: func() *http.Request { + return httptest.NewRequest(http.MethodGet, "/x/repos", nil) + }, + expectedPath: "/x/repos", + }, + { + name: "base path restored for root", + cfg: &Config{ + ResourcePath: "/mcp", + }, + setupRequest: func() *http.Request { + return httptest.NewRequest(http.MethodGet, "/", nil) + }, + expectedPath: "/mcp", + }, + { + name: "base path restored for nested", + cfg: &Config{ + ResourcePath: "/mcp", + }, + setupRequest: func() *http.Request { + return httptest.NewRequest(http.MethodGet, "/readonly", nil) + }, + expectedPath: "/mcp/readonly", + }, + { + name: "base path preserved when already present", + cfg: &Config{ + ResourcePath: "/mcp", + }, + setupRequest: func() *http.Request { + return httptest.NewRequest(http.MethodGet, "/mcp/readonly/", nil) + }, + expectedPath: "/mcp/readonly/", + }, + { + name: "custom base path restored", + cfg: &Config{ + ResourcePath: "/api", + }, + setupRequest: func() *http.Request { + return httptest.NewRequest(http.MethodGet, "/x/repos", nil) + }, + expectedPath: "/api/x/repos", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req := tc.setupRequest() + path := ResolveResourcePath(req, tc.cfg) + + assert.Equal(t, tc.expectedPath, path) + }) + } +} + +func TestBuildResourceMetadataURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + setupRequest func() *http.Request + resourcePath string + expectedURL string + }{ + { + name: "root path", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/", + expectedURL: "http://api.example.com/.well-known/oauth-protected-resource", + }, + { + name: "resource path preserves trailing slash", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp/", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp/", + expectedURL: "http://api.example.com/.well-known/oauth-protected-resource/mcp/", + }, + { + name: "with custom resource path", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedURL: "http://api.example.com/.well-known/oauth-protected-resource/mcp", + }, + { + name: "with base URL config", + cfg: &Config{ + BaseURL: "https://custom.example.com", + }, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedURL: "https://custom.example.com/.well-known/oauth-protected-resource/mcp", + }, + { + name: "with forwarded headers", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "internal.example.com" + req.Header.Set(headers.ForwardedHostHeader, "public.example.com") + req.Header.Set(headers.ForwardedProtoHeader, "https") + return req + }, + resourcePath: "/mcp", + expectedURL: "https://public.example.com/.well-known/oauth-protected-resource/mcp", + }, + { + name: "nil config uses request host", + cfg: nil, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "", + expectedURL: "http://api.example.com/.well-known/oauth-protected-resource", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req := tc.setupRequest() + url := BuildResourceMetadataURL(req, tc.cfg, tc.resourcePath) + + assert.Equal(t, tc.expectedURL, url) + }) + } +} + +func TestHandleProtectedResource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + path string + host string + method string + expectedStatusCode int + expectedScopes []string + validateResponse func(t *testing.T, body map[string]any) + }{ + { + name: "GET request returns protected resource metadata", + cfg: &Config{ + BaseURL: "https://api.example.com", + }, + path: OAuthProtectedResourcePrefix, + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + expectedScopes: SupportedScopes, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + assert.Equal(t, "GitHub MCP Server", body["resource_name"]) + assert.Equal(t, "https://api.example.com/", body["resource"]) + + authServers, ok := body["authorization_servers"].([]any) + require.True(t, ok) + require.Len(t, authServers, 1) + assert.Equal(t, DefaultAuthorizationServer, authServers[0]) + }, + }, + { + name: "OPTIONS request for CORS preflight", + cfg: &Config{ + BaseURL: "https://api.example.com", + }, + path: OAuthProtectedResourcePrefix, + host: "api.example.com", + method: http.MethodOptions, + expectedStatusCode: http.StatusNoContent, + }, + { + name: "path with /mcp suffix", + cfg: &Config{ + BaseURL: "https://api.example.com", + }, + path: OAuthProtectedResourcePrefix + "/mcp", + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + assert.Equal(t, "https://api.example.com/mcp", body["resource"]) + }, + }, + { + name: "path with /readonly suffix", + cfg: &Config{ + BaseURL: "https://api.example.com", + }, + path: OAuthProtectedResourcePrefix + "/readonly", + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + assert.Equal(t, "https://api.example.com/readonly", body["resource"]) + }, + }, + { + name: "path with trailing slash", + cfg: &Config{ + BaseURL: "https://api.example.com", + }, + path: OAuthProtectedResourcePrefix + "/mcp/", + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + assert.Equal(t, "https://api.example.com/mcp/", body["resource"]) + }, + }, + { + name: "custom authorization server in response", + cfg: &Config{ + BaseURL: "https://api.example.com", + AuthorizationServer: "https://custom.auth.example.com/oauth", + }, + path: OAuthProtectedResourcePrefix, + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + authServers, ok := body["authorization_servers"].([]any) + require.True(t, ok) + require.Len(t, authServers, 1) + assert.Equal(t, "https://custom.auth.example.com/oauth", authServers[0]) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(tc.cfg) + require.NoError(t, err) + + router := chi.NewRouter() + handler.RegisterRoutes(router) + + req := httptest.NewRequest(tc.method, tc.path, nil) + req.Host = tc.host + + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectedStatusCode, rec.Code) + + // Check CORS headers + assert.Equal(t, "*", rec.Header().Get("Access-Control-Allow-Origin")) + assert.Contains(t, rec.Header().Get("Access-Control-Allow-Methods"), "GET") + assert.Contains(t, rec.Header().Get("Access-Control-Allow-Methods"), "OPTIONS") + + if tc.method == http.MethodGet && tc.validateResponse != nil { + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + + var body map[string]any + err := json.Unmarshal(rec.Body.Bytes(), &body) + require.NoError(t, err) + + tc.validateResponse(t, body) + + // Verify scopes if expected + if tc.expectedScopes != nil { + scopes, ok := body["scopes_supported"].([]any) + require.True(t, ok) + assert.Len(t, scopes, len(tc.expectedScopes)) + } + } + }) + } +} + +func TestRegisterRoutes(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(&Config{ + BaseURL: "https://api.example.com", + }) + require.NoError(t, err) + + router := chi.NewRouter() + handler.RegisterRoutes(router) + + // List of expected routes that should be registered + expectedRoutes := []string{ + OAuthProtectedResourcePrefix, + OAuthProtectedResourcePrefix + "/", + OAuthProtectedResourcePrefix + "/mcp", + OAuthProtectedResourcePrefix + "/mcp/", + OAuthProtectedResourcePrefix + "/readonly", + OAuthProtectedResourcePrefix + "/readonly/", + OAuthProtectedResourcePrefix + "/mcp/readonly", + OAuthProtectedResourcePrefix + "/mcp/readonly/", + OAuthProtectedResourcePrefix + "/x/repos", + OAuthProtectedResourcePrefix + "/mcp/x/repos", + } + + for _, route := range expectedRoutes { + t.Run("route:"+route, func(t *testing.T) { + // Test GET + req := httptest.NewRequest(http.MethodGet, route, nil) + req.Host = "api.example.com" + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "GET %s should return 200", route) + + // Test OPTIONS (CORS preflight) + req = httptest.NewRequest(http.MethodOptions, route, nil) + req.Host = "api.example.com" + rec = httptest.NewRecorder() + router.ServeHTTP(rec, req) + assert.Equal(t, http.StatusNoContent, rec.Code, "OPTIONS %s should return 204", route) + }) + } +} + +func TestSupportedScopes(t *testing.T) { + t.Parallel() + + // Verify all expected scopes are present + expectedScopes := []string{ + "repo", + "read:org", + "read:user", + "user:email", + "read:packages", + "write:packages", + "read:project", + "project", + "gist", + "notifications", + "workflow", + "codespace", + } + + assert.Equal(t, expectedScopes, SupportedScopes) +} + +func TestProtectedResourceResponseFormat(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(&Config{ + BaseURL: "https://api.example.com", + }) + require.NoError(t, err) + + router := chi.NewRouter() + handler.RegisterRoutes(router) + + req := httptest.NewRequest(http.MethodGet, OAuthProtectedResourcePrefix, nil) + req.Host = "api.example.com" + + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var response map[string]any + err = json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + // Verify all required RFC 9728 fields are present + assert.Contains(t, response, "resource") + assert.Contains(t, response, "authorization_servers") + assert.Contains(t, response, "bearer_methods_supported") + assert.Contains(t, response, "scopes_supported") + + // Verify resource name (optional but we include it) + assert.Contains(t, response, "resource_name") + assert.Equal(t, "GitHub MCP Server", response["resource_name"]) + + // Verify bearer_methods_supported contains "header" + bearerMethods, ok := response["bearer_methods_supported"].([]any) + require.True(t, ok) + assert.Contains(t, bearerMethods, "header") + + // Verify authorization_servers is an array with GitHub OAuth + authServers, ok := response["authorization_servers"].([]any) + require.True(t, ok) + assert.Len(t, authServers, 1) + assert.Equal(t, DefaultAuthorizationServer, authServers[0]) +} + +func TestOAuthProtectedResourcePrefix(t *testing.T) { + t.Parallel() + + // RFC 9728 specifies this well-known path + assert.Equal(t, "/.well-known/oauth-protected-resource", OAuthProtectedResourcePrefix) +} + +func TestDefaultAuthorizationServer(t *testing.T) { + t.Parallel() + + assert.Equal(t, "https://github.com/login/oauth", DefaultAuthorizationServer) +} diff --git a/pkg/http/server.go b/pkg/http/server.go new file mode 100644 index 000000000..7a7ab46de --- /dev/null +++ b/pkg/http/server.go @@ -0,0 +1,224 @@ +package http + +import ( + "context" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "os/signal" + "slices" + "syscall" + "time" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/http/oauth" + "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/lockdown" + "github.com/github/github-mcp-server/pkg/scopes" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/github/github-mcp-server/pkg/utils" + "github.com/go-chi/chi/v5" +) + +// knownFeatureFlags are the feature flags that can be enabled via X-MCP-Features header. +// Only these flags are accepted from headers. +var knownFeatureFlags = []string{ + github.FeatureFlagHoldbackConsolidatedProjects, + github.FeatureFlagHoldbackConsolidatedActions, +} + +type ServerConfig struct { + // Version of the server + Version string + + // GitHub Host to target for API requests (e.g. github.com or github.enterprise.com) + Host string + + // Port to listen on (default: 8082) + Port int + + // BaseURL is the publicly accessible URL of this server for OAuth resource metadata. + // If not set, the server will derive the URL from incoming request headers. + BaseURL string + + // ResourcePath is the externally visible base path for this server (e.g., "/mcp"). + // This is used to restore the original path when a proxy strips a base path before forwarding. + ResourcePath string + + // ExportTranslations indicates if we should export translations + // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#i18n--overriding-descriptions + ExportTranslations bool + + // EnableCommandLogging indicates if we should log commands + EnableCommandLogging bool + + // Path to the log file if not stderr + LogFilePath string + + // Content window size + ContentWindowSize int + + // LockdownMode indicates if we should enable lockdown mode + LockdownMode bool + + // RepoAccessCacheTTL overrides the default TTL for repository access cache entries. + RepoAccessCacheTTL *time.Duration + + // ScopeChallenge indicates if we should return OAuth scope challenges, and if we should perform + // tool filtering based on token scopes. + ScopeChallenge bool +} + +func RunHTTPServer(cfg ServerConfig) error { + // Create app context + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + t, dumpTranslations := translations.TranslationHelper() + + var slogHandler slog.Handler + var logOutput io.Writer + if cfg.LogFilePath != "" { + file, err := os.OpenFile(cfg.LogFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) + if err != nil { + return fmt.Errorf("failed to open log file: %w", err) + } + logOutput = file + slogHandler = slog.NewTextHandler(logOutput, &slog.HandlerOptions{Level: slog.LevelDebug}) + } else { + logOutput = os.Stderr + slogHandler = slog.NewTextHandler(logOutput, &slog.HandlerOptions{Level: slog.LevelInfo}) + } + logger := slog.New(slogHandler) + logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "lockdownEnabled", cfg.LockdownMode) + + apiHost, err := utils.NewAPIHost(cfg.Host) + if err != nil { + return fmt.Errorf("failed to parse API host: %w", err) + } + + repoAccessOpts := []lockdown.RepoAccessOption{ + lockdown.WithLogger(logger.With("component", "lockdown")), + } + if cfg.RepoAccessCacheTTL != nil { + repoAccessOpts = append(repoAccessOpts, lockdown.WithTTL(*cfg.RepoAccessCacheTTL)) + } + + featureChecker := createHTTPFeatureChecker() + + deps := github.NewRequestDeps( + apiHost, + cfg.Version, + cfg.LockdownMode, + repoAccessOpts, + t, + cfg.ContentWindowSize, + featureChecker, + ) + + // Initialize the global tool scope map + err = initGlobalToolScopeMap(t) + if err != nil { + return fmt.Errorf("failed to initialize tool scope map: %w", err) + } + + // Register OAuth protected resource metadata endpoints + oauthCfg := &oauth.Config{ + BaseURL: cfg.BaseURL, + ResourcePath: cfg.ResourcePath, + } + + serverOptions := []HandlerOption{} + if cfg.ScopeChallenge { + scopeFetcher := scopes.NewFetcher(apiHost, scopes.FetcherOptions{}) + serverOptions = append(serverOptions, WithScopeFetcher(scopeFetcher)) + } + + r := chi.NewRouter() + handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, apiHost, append(serverOptions, WithFeatureChecker(featureChecker), WithOAuthConfig(oauthCfg))...) + oauthHandler, err := oauth.NewAuthHandler(oauthCfg) + if err != nil { + return fmt.Errorf("failed to create OAuth handler: %w", err) + } + + r.Group(func(r chi.Router) { + // Register Middleware First, needs to be before route registration + handler.RegisterMiddleware(r) + + // Register MCP server routes + handler.RegisterRoutes(r) + }) + logger.Info("MCP endpoints registered", "baseURL", cfg.BaseURL) + + r.Group(func(r chi.Router) { + // Register OAuth protected resource metadata endpoints + oauthHandler.RegisterRoutes(r) + }) + logger.Info("OAuth protected resource endpoints registered", "baseURL", cfg.BaseURL) + + addr := fmt.Sprintf(":%d", cfg.Port) + httpSvr := http.Server{ + Addr: addr, + Handler: r, + ReadHeaderTimeout: 60 * time.Second, + } + + go func() { + <-ctx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + logger.Info("shutting down server") + if err := httpSvr.Shutdown(shutdownCtx); err != nil { + logger.Error("error during server shutdown", "error", err) + } + }() + + if cfg.ExportTranslations { + // Once server is initialized, all translations are loaded + dumpTranslations() + } + + logger.Info("HTTP server listening", "addr", addr) + if err := httpSvr.ListenAndServe(); err != nil && err != http.ErrServerClosed { + return fmt.Errorf("HTTP server error: %w", err) + } + + logger.Info("server stopped gracefully") + return nil +} + +func initGlobalToolScopeMap(t translations.TranslationHelperFunc) error { + // Build inventory with all tools to extract scope information + inv, err := inventory.NewBuilder(). + SetTools(github.AllTools(t)). + Build() + + if err != nil { + return fmt.Errorf("failed to build inventory for tool scope map: %w", err) + } + + // Initialize the global scope map + scopes.SetToolScopeMapFromInventory(inv) + + return nil +} + +// createHTTPFeatureChecker creates a feature checker that reads header features from context +// and validates them against the knownFeatureFlags whitelist +func createHTTPFeatureChecker() inventory.FeatureFlagChecker { + // Pre-compute whitelist as set for O(1) lookup + knownSet := make(map[string]bool, len(knownFeatureFlags)) + for _, f := range knownFeatureFlags { + knownSet[f] = true + } + + return func(ctx context.Context, flag string) (bool, error) { + if knownSet[flag] && slices.Contains(ghcontext.GetHeaderFeatures(ctx), flag) { + return true, nil + } + return false, nil + } +} diff --git a/pkg/http/transport/bearer.go b/pkg/http/transport/bearer.go new file mode 100644 index 000000000..66922bbda --- /dev/null +++ b/pkg/http/transport/bearer.go @@ -0,0 +1,26 @@ +package transport + +import ( + "net/http" + "strings" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + headers "github.com/github/github-mcp-server/pkg/http/headers" +) + +type BearerAuthTransport struct { + Transport http.RoundTripper + Token string +} + +func (t *BearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.Header.Set(headers.AuthorizationHeader, "Bearer "+t.Token) + + // Check for GraphQL-Features in context and add header if present + if features := ghcontext.GetGraphQLFeatures(req.Context()); len(features) > 0 { + req.Header.Set(headers.GraphQLFeaturesHeader, strings.Join(features, ", ")) + } + + return t.Transport.RoundTrip(req) +} diff --git a/pkg/github/transport.go b/pkg/http/transport/graphql_features.go similarity index 69% rename from pkg/github/transport.go rename to pkg/http/transport/graphql_features.go index 0a4372b23..7fe9182fc 100644 --- a/pkg/github/transport.go +++ b/pkg/http/transport/graphql_features.go @@ -1,8 +1,11 @@ -package github +package transport import ( "net/http" "strings" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/http/headers" ) // GraphQLFeaturesTransport is an http.RoundTripper that adds GraphQL-Features @@ -15,14 +18,16 @@ import ( // // Usage: // +// import "github.com/github/github-mcp-server/pkg/http/transport" +// // httpClient := &http.Client{ -// Transport: &github.GraphQLFeaturesTransport{ +// Transport: &transport.GraphQLFeaturesTransport{ // Transport: http.DefaultTransport, // }, // } // gqlClient := githubv4.NewClient(httpClient) // -// Then use withGraphQLFeatures(ctx, "feature_name") when calling GraphQL operations. +// Then use ghcontext.WithGraphQLFeatures(ctx, "feature_name") when calling GraphQL operations. type GraphQLFeaturesTransport struct { // Transport is the underlying HTTP transport. If nil, http.DefaultTransport is used. Transport http.RoundTripper @@ -39,8 +44,8 @@ func (t *GraphQLFeaturesTransport) RoundTrip(req *http.Request) (*http.Response, req = req.Clone(req.Context()) // Check for GraphQL-Features in context and add header if present - if features := GetGraphQLFeatures(req.Context()); len(features) > 0 { - req.Header.Set("GraphQL-Features", strings.Join(features, ", ")) + if features := ghcontext.GetGraphQLFeatures(req.Context()); len(features) > 0 { + req.Header.Set(headers.GraphQLFeaturesHeader, strings.Join(features, ", ")) } return transport.RoundTrip(req) diff --git a/pkg/github/transport_test.go b/pkg/http/transport/graphql_features_test.go similarity index 83% rename from pkg/github/transport_test.go rename to pkg/http/transport/graphql_features_test.go index c98108255..1a0dc4214 100644 --- a/pkg/github/transport_test.go +++ b/pkg/http/transport/graphql_features_test.go @@ -1,4 +1,4 @@ -package github +package transport import ( "context" @@ -6,6 +6,9 @@ import ( "net/http/httptest" "testing" + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -54,8 +57,8 @@ func TestGraphQLFeaturesTransport(t *testing.T) { // Create a test server that captures the request header server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - capturedHeader = r.Header.Get("GraphQL-Features") - headerExists = r.Header.Get("GraphQL-Features") != "" + capturedHeader = r.Header.Get(headers.GraphQLFeaturesHeader) + headerExists = r.Header.Get(headers.GraphQLFeaturesHeader) != "" w.WriteHeader(http.StatusOK) })) defer server.Close() @@ -68,7 +71,7 @@ func TestGraphQLFeaturesTransport(t *testing.T) { // Create a request ctx := context.Background() if tc.features != nil { - ctx = withGraphQLFeatures(ctx, tc.features...) + ctx = ghcontext.WithGraphQLFeatures(ctx, tc.features...) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil) @@ -95,7 +98,7 @@ func TestGraphQLFeaturesTransport_NilTransport(t *testing.T) { // Create a test server server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - capturedHeader = r.Header.Get("GraphQL-Features") + capturedHeader = r.Header.Get(headers.GraphQLFeaturesHeader) w.WriteHeader(http.StatusOK) })) defer server.Close() @@ -106,7 +109,7 @@ func TestGraphQLFeaturesTransport_NilTransport(t *testing.T) { } // Create a request with features - ctx := withGraphQLFeatures(context.Background(), "test_feature") + ctx := ghcontext.WithGraphQLFeatures(context.Background(), "test_feature") req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil) require.NoError(t, err) @@ -134,12 +137,12 @@ func TestGraphQLFeaturesTransport_DoesNotMutateOriginalRequest(t *testing.T) { } // Create a request with features - ctx := withGraphQLFeatures(context.Background(), "test_feature") + ctx := ghcontext.WithGraphQLFeatures(context.Background(), "test_feature") req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil) require.NoError(t, err) // Store the original header value - originalHeader := req.Header.Get("GraphQL-Features") + originalHeader := req.Header.Get(headers.GraphQLFeaturesHeader) // Execute the request resp, err := transport.RoundTrip(req) @@ -147,5 +150,5 @@ func TestGraphQLFeaturesTransport_DoesNotMutateOriginalRequest(t *testing.T) { defer resp.Body.Close() // Verify the original request was not mutated - assert.Equal(t, originalHeader, req.Header.Get("GraphQL-Features")) + assert.Equal(t, originalHeader, req.Header.Get(headers.GraphQLFeaturesHeader)) } diff --git a/pkg/http/transport/user_agent.go b/pkg/http/transport/user_agent.go new file mode 100644 index 000000000..a489941cc --- /dev/null +++ b/pkg/http/transport/user_agent.go @@ -0,0 +1,18 @@ +package transport + +import ( + "net/http" + + "github.com/github/github-mcp-server/pkg/http/headers" +) + +type UserAgentTransport struct { + Transport http.RoundTripper + Agent string +} + +func (t *UserAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.Header.Set(headers.UserAgentHeader, t.Agent) + return t.Transport.RoundTrip(req) +} diff --git a/pkg/scopes/fetcher.go b/pkg/scopes/fetcher.go index 48e000179..b37245503 100644 --- a/pkg/scopes/fetcher.go +++ b/pkg/scopes/fetcher.go @@ -7,6 +7,9 @@ import ( "net/url" "strings" "time" + + "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/github/github-mcp-server/pkg/utils" ) // OAuthScopesHeader is the HTTP response header containing the token's OAuth scopes. @@ -23,28 +26,27 @@ type FetcherOptions struct { // APIHost is the GitHub API host (e.g., "https://api.github.com"). // Defaults to "https://api.github.com" if empty. - APIHost string + APIHost utils.APIHostResolver +} + +type FetcherInterface interface { + FetchTokenScopes(ctx context.Context, token string) ([]string, error) } // Fetcher retrieves token scopes from GitHub's API. // It uses an HTTP HEAD request to minimize bandwidth since we only need headers. type Fetcher struct { client *http.Client - apiHost string + apiHost utils.APIHostResolver } // NewFetcher creates a new scope fetcher with the given options. -func NewFetcher(opts FetcherOptions) *Fetcher { +func NewFetcher(apiHost utils.APIHostResolver, opts FetcherOptions) *Fetcher { client := opts.HTTPClient if client == nil { client = &http.Client{Timeout: DefaultFetchTimeout} } - apiHost := opts.APIHost - if apiHost == "" { - apiHost = "https://api.github.com" - } - return &Fetcher{ client: client, apiHost: apiHost, @@ -61,8 +63,13 @@ func NewFetcher(opts FetcherOptions) *Fetcher { // Note: Fine-grained PATs don't return the X-OAuth-Scopes header, so an empty // slice is returned for those tokens. func (f *Fetcher) FetchTokenScopes(ctx context.Context, token string) ([]string, error) { + apiHostURL, err := f.apiHost.BaseRESTURL(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get API host URL: %w", err) + } + // Use a lightweight endpoint that requires authentication - endpoint, err := url.JoinPath(f.apiHost, "/") + endpoint, err := url.JoinPath(apiHostURL.String(), "/") if err != nil { return nil, fmt.Errorf("failed to construct API URL: %w", err) } @@ -72,9 +79,9 @@ func (f *Fetcher) FetchTokenScopes(ctx context.Context, token string) ([]string, return nil, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("Authorization", "Bearer "+token) - req.Header.Set("Accept", "application/vnd.github+json") - req.Header.Set("X-GitHub-Api-Version", "2022-11-28") + req.Header.Set(headers.AuthorizationHeader, "Bearer "+token) + req.Header.Set(headers.AcceptHeader, "application/vnd.github+json") + req.Header.Set(headers.GitHubAPIVersionHeader, "2022-11-28") resp, err := f.client.Do(req) if err != nil { @@ -115,11 +122,16 @@ func ParseScopeHeader(header string) []string { // FetchTokenScopes is a convenience function that creates a default fetcher // and fetches the token scopes. func FetchTokenScopes(ctx context.Context, token string) ([]string, error) { - return NewFetcher(FetcherOptions{}).FetchTokenScopes(ctx, token) + apiHost, err := utils.NewAPIHost("https://api.github.com/") + if err != nil { + return nil, fmt.Errorf("failed to create default API host: %w", err) + } + + return NewFetcher(apiHost, FetcherOptions{}).FetchTokenScopes(ctx, token) } // FetchTokenScopesWithHost is a convenience function that creates a fetcher // for a specific API host and fetches the token scopes. -func FetchTokenScopesWithHost(ctx context.Context, token, apiHost string) ([]string, error) { - return NewFetcher(FetcherOptions{APIHost: apiHost}).FetchTokenScopes(ctx, token) +func FetchTokenScopesWithHost(ctx context.Context, token string, apiHost utils.APIHostResolver) ([]string, error) { + return NewFetcher(apiHost, FetcherOptions{}).FetchTokenScopes(ctx, token) } diff --git a/pkg/scopes/fetcher_test.go b/pkg/scopes/fetcher_test.go index 13feab5b0..2d887d7a8 100644 --- a/pkg/scopes/fetcher_test.go +++ b/pkg/scopes/fetcher_test.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "net/http/httptest" + "net/url" "testing" "time" @@ -11,6 +12,23 @@ import ( "github.com/stretchr/testify/require" ) +type testAPIHostResolver struct { + baseURL string +} + +func (t testAPIHostResolver) BaseRESTURL(_ context.Context) (*url.URL, error) { + return url.Parse(t.baseURL) +} +func (t testAPIHostResolver) GraphqlURL(_ context.Context) (*url.URL, error) { + return nil, nil +} +func (t testAPIHostResolver) UploadURL(_ context.Context) (*url.URL, error) { + return nil, nil +} +func (t testAPIHostResolver) RawURL(_ context.Context) (*url.URL, error) { + return nil, nil +} + func TestParseScopeHeader(t *testing.T) { tests := []struct { name string @@ -146,10 +164,8 @@ func TestFetcher_FetchTokenScopes(t *testing.T) { t.Run(tt.name, func(t *testing.T) { server := httptest.NewServer(tt.handler) defer server.Close() - - fetcher := NewFetcher(FetcherOptions{ - APIHost: server.URL, - }) + apiHost := testAPIHostResolver{baseURL: server.URL} + fetcher := NewFetcher(apiHost, FetcherOptions{}) scopes, err := fetcher.FetchTokenScopes(context.Background(), "test-token") @@ -167,10 +183,13 @@ func TestFetcher_FetchTokenScopes(t *testing.T) { } func TestFetcher_DefaultOptions(t *testing.T) { - fetcher := NewFetcher(FetcherOptions{}) + apiHost := testAPIHostResolver{baseURL: "https://api.github.com"} + fetcher := NewFetcher(apiHost, FetcherOptions{}) // Verify default API host is set - assert.Equal(t, "https://api.github.com", fetcher.apiHost) + apiURL, err := fetcher.apiHost.BaseRESTURL(context.Background()) + require.NoError(t, err) + assert.Equal(t, "https://api.github.com", apiURL.String()) // Verify default HTTP client is set with timeout assert.NotNil(t, fetcher.client) @@ -180,7 +199,8 @@ func TestFetcher_DefaultOptions(t *testing.T) { func TestFetcher_CustomHTTPClient(t *testing.T) { customClient := &http.Client{Timeout: 5 * time.Second} - fetcher := NewFetcher(FetcherOptions{ + apiHost := testAPIHostResolver{baseURL: "https://api.github.com"} + fetcher := NewFetcher(apiHost, FetcherOptions{ HTTPClient: customClient, }) @@ -188,11 +208,12 @@ func TestFetcher_CustomHTTPClient(t *testing.T) { } func TestFetcher_CustomAPIHost(t *testing.T) { - fetcher := NewFetcher(FetcherOptions{ - APIHost: "https://api.github.enterprise.com", - }) + apiHost := testAPIHostResolver{baseURL: "https://api.github.enterprise.com"} + fetcher := NewFetcher(apiHost, FetcherOptions{}) - assert.Equal(t, "https://api.github.enterprise.com", fetcher.apiHost) + apiURL, err := fetcher.apiHost.BaseRESTURL(context.Background()) + require.NoError(t, err) + assert.Equal(t, "https://api.github.enterprise.com", apiURL.String()) } func TestFetcher_ContextCancellation(t *testing.T) { @@ -202,9 +223,8 @@ func TestFetcher_ContextCancellation(t *testing.T) { })) defer server.Close() - fetcher := NewFetcher(FetcherOptions{ - APIHost: server.URL, - }) + apiHost := testAPIHostResolver{baseURL: server.URL} + fetcher := NewFetcher(apiHost, FetcherOptions{}) ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately diff --git a/pkg/scopes/map.go b/pkg/scopes/map.go new file mode 100644 index 000000000..3c9833834 --- /dev/null +++ b/pkg/scopes/map.go @@ -0,0 +1,129 @@ +package scopes + +import "github.com/github/github-mcp-server/pkg/inventory" + +// ToolScopeMap maps tool names to their scope requirements. +type ToolScopeMap map[string]*ToolScopeInfo + +// ToolScopeInfo contains scope information for a single tool. +type ToolScopeInfo struct { + // RequiredScopes contains the scopes that are directly required by this tool. + RequiredScopes []string + + // AcceptedScopes contains all scopes that satisfy the requirements (including parent scopes). + AcceptedScopes []string +} + +// globalToolScopeMap is populated from inventory when SetToolScopeMapFromInventory is called +var globalToolScopeMap ToolScopeMap + +// SetToolScopeMapFromInventory builds and stores a tool scope map from an inventory. +// This should be called after building the inventory to make scopes available for middleware. +func SetToolScopeMapFromInventory(inv *inventory.Inventory) { + globalToolScopeMap = GetToolScopeMapFromInventory(inv) +} + +// SetGlobalToolScopeMap sets the global tool scope map directly. +// This is useful for testing when you don't have a full inventory. +func SetGlobalToolScopeMap(m ToolScopeMap) { + globalToolScopeMap = m +} + +// GetToolScopeMap returns the global tool scope map. +// Returns an empty map if SetToolScopeMapFromInventory hasn't been called yet. +func GetToolScopeMap() (ToolScopeMap, error) { + if globalToolScopeMap == nil { + return make(ToolScopeMap), nil + } + return globalToolScopeMap, nil +} + +// GetToolScopeInfo returns scope information for a specific tool from the global scope map. +func GetToolScopeInfo(toolName string) (*ToolScopeInfo, error) { + m, err := GetToolScopeMap() + if err != nil { + return nil, err + } + return m[toolName], nil +} + +// GetToolScopeMapFromInventory builds a tool scope map from an inventory. +// This extracts scope information from ServerTool.RequiredScopes and ServerTool.AcceptedScopes. +func GetToolScopeMapFromInventory(inv *inventory.Inventory) ToolScopeMap { + result := make(ToolScopeMap) + + // Get all tools from the inventory (both enabled and disabled) + // We need all tools for scope checking purposes + allTools := inv.AllTools() + for i := range allTools { + tool := &allTools[i] + if len(tool.RequiredScopes) > 0 || len(tool.AcceptedScopes) > 0 { + result[tool.Tool.Name] = &ToolScopeInfo{ + RequiredScopes: tool.RequiredScopes, + AcceptedScopes: tool.AcceptedScopes, + } + } + } + + return result +} + +// HasAcceptedScope checks if any of the provided user scopes satisfy the tool's requirements. +func (t *ToolScopeInfo) HasAcceptedScope(userScopes ...string) bool { + if t == nil || len(t.AcceptedScopes) == 0 { + return true // No scopes required + } + + userScopeSet := make(map[string]bool) + for _, scope := range userScopes { + userScopeSet[scope] = true + } + + for _, scope := range t.AcceptedScopes { + if userScopeSet[scope] { + return true + } + } + return false +} + +// MissingScopes returns the required scopes that are not present in the user's scopes. +func (t *ToolScopeInfo) MissingScopes(userScopes ...string) []string { + if t == nil || len(t.RequiredScopes) == 0 { + return nil + } + + // Create a set of user scopes for O(1) lookup + userScopeSet := make(map[string]bool, len(userScopes)) + for _, s := range userScopes { + userScopeSet[s] = true + } + + // Check if any accepted scope is present + hasAccepted := false + for _, scope := range t.AcceptedScopes { + if userScopeSet[scope] { + hasAccepted = true + break + } + } + + if hasAccepted { + return nil // User has sufficient scopes + } + + // Return required scopes as the minimum needed + missing := make([]string, len(t.RequiredScopes)) + copy(missing, t.RequiredScopes) + return missing +} + +// GetRequiredScopesSlice returns the required scopes as a slice of strings. +func (t *ToolScopeInfo) GetRequiredScopesSlice() []string { + if t == nil { + return nil + } + scopes := make([]string, len(t.RequiredScopes)) + copy(scopes, t.RequiredScopes) + return scopes +} diff --git a/pkg/scopes/map_test.go b/pkg/scopes/map_test.go new file mode 100644 index 000000000..5f33cdda2 --- /dev/null +++ b/pkg/scopes/map_test.go @@ -0,0 +1,194 @@ +package scopes + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetToolScopeMap(t *testing.T) { + // Reset and set up a test map + SetGlobalToolScopeMap(ToolScopeMap{ + "test_tool": &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + }) + + m, err := GetToolScopeMap() + require.NoError(t, err) + require.NotNil(t, m) + require.Greater(t, len(m), 0, "expected at least one tool in the scope map") + + testTool, ok := m["test_tool"] + require.True(t, ok, "expected test_tool to be in the scope map") + assert.Contains(t, testTool.RequiredScopes, "read:org") + assert.Contains(t, testTool.AcceptedScopes, "read:org") + assert.Contains(t, testTool.AcceptedScopes, "admin:org") +} + +func TestGetToolScopeInfo(t *testing.T) { + // Set up test scope map + SetGlobalToolScopeMap(ToolScopeMap{ + "search_orgs": &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + }) + + info, err := GetToolScopeInfo("search_orgs") + require.NoError(t, err) + require.NotNil(t, info) + + // Non-existent tool should return nil + info, err = GetToolScopeInfo("nonexistent_tool") + require.NoError(t, err) + assert.Nil(t, info) +} + +func TestToolScopeInfo_HasAcceptedScope(t *testing.T) { + testCases := []struct { + name string + scopeInfo *ToolScopeInfo + userScopes []string + expected bool + }{ + { + name: "has exact required scope", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + userScopes: []string{"read:org"}, + expected: true, + }, + { + name: "has parent scope (admin:org grants read:org)", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + userScopes: []string{"admin:org"}, + expected: true, + }, + { + name: "has parent scope (write:org grants read:org)", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + userScopes: []string{"write:org"}, + expected: true, + }, + { + name: "missing required scope", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + userScopes: []string{"repo"}, + expected: false, + }, + { + name: "no scope required", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{}, + AcceptedScopes: []string{}, + }, + userScopes: []string{}, + expected: true, + }, + { + name: "nil scope info", + scopeInfo: nil, + userScopes: []string{}, + expected: true, + }, + { + name: "repo scope for tool requiring repo", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"repo"}, + AcceptedScopes: []string{"repo"}, + }, + userScopes: []string{"repo"}, + expected: true, + }, + { + name: "missing repo scope", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"repo"}, + AcceptedScopes: []string{"repo"}, + }, + userScopes: []string{"public_repo"}, + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := tc.scopeInfo.HasAcceptedScope(tc.userScopes...) + assert.Equal(t, tc.expected, result) + }) + } +} + +func TestToolScopeInfo_MissingScopes(t *testing.T) { + testCases := []struct { + name string + scopeInfo *ToolScopeInfo + userScopes []string + expectedLen int + expectedScopes []string + }{ + { + name: "has required scope - no missing", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + userScopes: []string{"read:org"}, + expectedLen: 0, + expectedScopes: nil, + }, + { + name: "missing scope", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + userScopes: []string{"repo"}, + expectedLen: 1, + expectedScopes: []string{"read:org"}, + }, + { + name: "no scope required - no missing", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{}, + AcceptedScopes: []string{}, + }, + userScopes: []string{}, + expectedLen: 0, + expectedScopes: nil, + }, + { + name: "nil scope info - no missing", + scopeInfo: nil, + userScopes: []string{}, + expectedLen: 0, + expectedScopes: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + missing := tc.scopeInfo.MissingScopes(tc.userScopes...) + assert.Len(t, missing, tc.expectedLen) + if tc.expectedScopes != nil { + for _, expected := range tc.expectedScopes { + assert.Contains(t, missing, expected) + } + } + }) + } +} diff --git a/pkg/utils/api.go b/pkg/utils/api.go new file mode 100644 index 000000000..a523917de --- /dev/null +++ b/pkg/utils/api.go @@ -0,0 +1,222 @@ +package utils //nolint:revive //TODO: figure out a better name for this package + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strings" + "time" +) + +type APIHostResolver interface { + BaseRESTURL(ctx context.Context) (*url.URL, error) + GraphqlURL(ctx context.Context) (*url.URL, error) + UploadURL(ctx context.Context) (*url.URL, error) + RawURL(ctx context.Context) (*url.URL, error) +} + +type APIHost struct { + restURL *url.URL + gqlURL *url.URL + uploadURL *url.URL + rawURL *url.URL +} + +var _ APIHostResolver = APIHost{} + +func NewAPIHost(s string) (APIHostResolver, error) { + a, err := parseAPIHost(s) + + if err != nil { + return nil, err + } + + return a, nil +} + +// APIHostResolver implementation +func (a APIHost) BaseRESTURL(_ context.Context) (*url.URL, error) { + return a.restURL, nil +} + +func (a APIHost) GraphqlURL(_ context.Context) (*url.URL, error) { + return a.gqlURL, nil +} + +func (a APIHost) UploadURL(_ context.Context) (*url.URL, error) { + return a.uploadURL, nil +} + +func (a APIHost) RawURL(_ context.Context) (*url.URL, error) { + return a.rawURL, nil +} + +func newDotcomHost() (APIHost, error) { + baseRestURL, err := url.Parse("https://api.github.com/") + if err != nil { + return APIHost{}, fmt.Errorf("failed to parse dotcom REST URL: %w", err) + } + + gqlURL, err := url.Parse("https://api.github.com/graphql") + if err != nil { + return APIHost{}, fmt.Errorf("failed to parse dotcom GraphQL URL: %w", err) + } + + uploadURL, err := url.Parse("https://uploads.github.com") + if err != nil { + return APIHost{}, fmt.Errorf("failed to parse dotcom Upload URL: %w", err) + } + + rawURL, err := url.Parse("https://raw.githubusercontent.com/") + if err != nil { + return APIHost{}, fmt.Errorf("failed to parse dotcom Raw URL: %w", err) + } + + return APIHost{ + restURL: baseRestURL, + gqlURL: gqlURL, + uploadURL: uploadURL, + rawURL: rawURL, + }, nil +} + +func newGHECHost(hostname string) (APIHost, error) { + u, err := url.Parse(hostname) + if err != nil { + return APIHost{}, fmt.Errorf("failed to parse GHEC URL: %w", err) + } + + // Unsecured GHEC would be an error + if u.Scheme == "http" { + return APIHost{}, fmt.Errorf("GHEC URL must be HTTPS") + } + + restURL, err := url.Parse(fmt.Sprintf("https://api.%s/", u.Hostname())) + if err != nil { + return APIHost{}, fmt.Errorf("failed to parse GHEC REST URL: %w", err) + } + + gqlURL, err := url.Parse(fmt.Sprintf("https://api.%s/graphql", u.Hostname())) + if err != nil { + return APIHost{}, fmt.Errorf("failed to parse GHEC GraphQL URL: %w", err) + } + + uploadURL, err := url.Parse(fmt.Sprintf("https://uploads.%s/", u.Hostname())) + if err != nil { + return APIHost{}, fmt.Errorf("failed to parse GHEC Upload URL: %w", err) + } + + rawURL, err := url.Parse(fmt.Sprintf("https://raw.%s/", u.Hostname())) + if err != nil { + return APIHost{}, fmt.Errorf("failed to parse GHEC Raw URL: %w", err) + } + + return APIHost{ + restURL: restURL, + gqlURL: gqlURL, + uploadURL: uploadURL, + rawURL: rawURL, + }, nil +} + +func newGHESHost(hostname string) (APIHost, error) { + u, err := url.Parse(hostname) + if err != nil { + return APIHost{}, fmt.Errorf("failed to parse GHES URL: %w", err) + } + + restURL, err := url.Parse(fmt.Sprintf("%s://%s/api/v3/", u.Scheme, u.Hostname())) + if err != nil { + return APIHost{}, fmt.Errorf("failed to parse GHES REST URL: %w", err) + } + + gqlURL, err := url.Parse(fmt.Sprintf("%s://%s/api/graphql", u.Scheme, u.Hostname())) + if err != nil { + return APIHost{}, fmt.Errorf("failed to parse GHES GraphQL URL: %w", err) + } + + // Check if subdomain isolation is enabled + // See https://docs.github.com/en/enterprise-server@3.17/admin/configuring-settings/hardening-security-for-your-enterprise/enabling-subdomain-isolation#about-subdomain-isolation + hasSubdomainIsolation := checkSubdomainIsolation(u.Scheme, u.Hostname()) + + var uploadURL *url.URL + if hasSubdomainIsolation { + // With subdomain isolation: https://uploads.hostname/ + uploadURL, err = url.Parse(fmt.Sprintf("%s://uploads.%s/", u.Scheme, u.Hostname())) + } else { + // Without subdomain isolation: https://hostname/api/uploads/ + uploadURL, err = url.Parse(fmt.Sprintf("%s://%s/api/uploads/", u.Scheme, u.Hostname())) + } + if err != nil { + return APIHost{}, fmt.Errorf("failed to parse GHES Upload URL: %w", err) + } + + var rawURL *url.URL + if hasSubdomainIsolation { + // With subdomain isolation: https://raw.hostname/ + rawURL, err = url.Parse(fmt.Sprintf("%s://raw.%s/", u.Scheme, u.Hostname())) + } else { + // Without subdomain isolation: https://hostname/raw/ + rawURL, err = url.Parse(fmt.Sprintf("%s://%s/raw/", u.Scheme, u.Hostname())) + } + if err != nil { + return APIHost{}, fmt.Errorf("failed to parse GHES Raw URL: %w", err) + } + + return APIHost{ + restURL: restURL, + gqlURL: gqlURL, + uploadURL: uploadURL, + rawURL: rawURL, + }, nil +} + +// checkSubdomainIsolation detects if GitHub Enterprise Server has subdomain isolation enabled +// by attempting to ping the raw./_ping endpoint on the subdomain. The raw subdomain must always exist for subdomain isolation. +func checkSubdomainIsolation(scheme, hostname string) bool { + subdomainURL := fmt.Sprintf("%s://raw.%s/_ping", scheme, hostname) + + client := &http.Client{ + Timeout: 5 * time.Second, + // Don't follow redirects - we just want to check if the endpoint exists + //nolint:revive // parameters are required by http.Client.CheckRedirect signature + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + resp, err := client.Get(subdomainURL) + if err != nil { + return false + } + defer resp.Body.Close() + + return resp.StatusCode == http.StatusOK +} + +// Note that this does not handle ports yet, so development environments are out. +func parseAPIHost(s string) (APIHost, error) { + if s == "" { + return newDotcomHost() + } + + u, err := url.Parse(s) + if err != nil { + return APIHost{}, fmt.Errorf("could not parse host as URL: %s", s) + } + + if u.Scheme == "" { + return APIHost{}, fmt.Errorf("host must have a scheme (http or https): %s", s) + } + + if strings.HasSuffix(u.Hostname(), "github.com") { + return newDotcomHost() + } + + if strings.HasSuffix(u.Hostname(), "ghe.com") { + return newGHECHost(s) + } + + return newGHESHost(s) +} diff --git a/pkg/utils/token.go b/pkg/utils/token.go new file mode 100644 index 000000000..8933fb0bd --- /dev/null +++ b/pkg/utils/token.go @@ -0,0 +1,75 @@ +package utils //nolint:revive //TODO: figure out a better name for this package + +import ( + "fmt" + "net/http" + "regexp" + "strings" + + httpheaders "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/github/github-mcp-server/pkg/http/mark" +) + +type TokenType int + +const ( + TokenTypeUnknown TokenType = iota + TokenTypePersonalAccessToken + TokenTypeFineGrainedPersonalAccessToken + TokenTypeOAuthAccessToken + TokenTypeUserToServerGitHubAppToken + TokenTypeServerToServerGitHubAppToken +) + +var supportedGitHubPrefixes = map[string]TokenType{ + "ghp_": TokenTypePersonalAccessToken, // Personal access token (classic) + "github_pat_": TokenTypeFineGrainedPersonalAccessToken, // Fine-grained personal access token + "gho_": TokenTypeOAuthAccessToken, // OAuth access token + "ghu_": TokenTypeUserToServerGitHubAppToken, // User access token for a GitHub App + "ghs_": TokenTypeServerToServerGitHubAppToken, // Installation access token for a GitHub App (a.k.a. server-to-server token) +} + +var ( + ErrMissingAuthorizationHeader = fmt.Errorf("%w: missing required Authorization header", mark.ErrBadRequest) + ErrBadAuthorizationHeader = fmt.Errorf("%w: Authorization header is badly formatted", mark.ErrBadRequest) + ErrUnsupportedAuthorizationHeader = fmt.Errorf("%w: unsupported Authorization header", mark.ErrBadRequest) +) + +// oldPatternRegexp is the regular expression for the old pattern of the token. +// Until 2021, GitHub API tokens did not have an identifiable prefix. They +// were 40 characters long and only contained the characters a-f and 0-9. +var oldPatternRegexp = regexp.MustCompile(`\A[a-f0-9]{40}\z`) + +// ParseAuthorizationHeader parses the Authorization header from the HTTP request +func ParseAuthorizationHeader(req *http.Request) (tokenType TokenType, token string, _ error) { + authHeader := req.Header.Get(httpheaders.AuthorizationHeader) + if authHeader == "" { + return 0, "", ErrMissingAuthorizationHeader + } + + switch { + // decrypt dotcom token and set it as token + case strings.HasPrefix(authHeader, "GitHub-Bearer "): + return 0, "", ErrUnsupportedAuthorizationHeader + default: + // support both "Bearer" and "bearer" to conform to api.github.com + if len(authHeader) > 7 && strings.EqualFold(authHeader[:7], "Bearer ") { + token = authHeader[7:] + } else { + token = authHeader + } + } + + for prefix, tokenType := range supportedGitHubPrefixes { + if strings.HasPrefix(token, prefix) { + return tokenType, token, nil + } + } + + matchesOldTokenPattern := oldPatternRegexp.MatchString(token) + if matchesOldTokenPattern { + return TokenTypePersonalAccessToken, token, nil + } + + return 0, "", ErrBadAuthorizationHeader +} diff --git a/third-party-licenses.darwin.md b/third-party-licenses.darwin.md index 8217c7707..6028ecfda 100644 --- a/third-party-licenses.darwin.md +++ b/third-party-licenses.darwin.md @@ -15,6 +15,7 @@ The following packages are included for the amd64, arm64 architectures. - [github.com/aymerick/douceur](https://pkg.go.dev/github.com/aymerick/douceur) ([MIT](https://github.com/aymerick/douceur/blob/v0.2.0/LICENSE)) - [github.com/fsnotify/fsnotify](https://pkg.go.dev/github.com/fsnotify/fsnotify) ([BSD-3-Clause](https://github.com/fsnotify/fsnotify/blob/v1.9.0/LICENSE)) - [github.com/github/github-mcp-server](https://pkg.go.dev/github.com/github/github-mcp-server) ([MIT](https://github.com/github/github-mcp-server/blob/HEAD/LICENSE)) + - [github.com/go-chi/chi/v5](https://pkg.go.dev/github.com/go-chi/chi/v5) ([MIT](https://github.com/go-chi/chi/blob/v5.2.3/LICENSE)) - [github.com/go-openapi/jsonpointer](https://pkg.go.dev/github.com/go-openapi/jsonpointer) ([Apache-2.0](https://github.com/go-openapi/jsonpointer/blob/v0.19.5/LICENSE)) - [github.com/go-openapi/swag](https://pkg.go.dev/github.com/go-openapi/swag) ([Apache-2.0](https://github.com/go-openapi/swag/blob/v0.21.1/LICENSE)) - [github.com/go-viper/mapstructure/v2](https://pkg.go.dev/github.com/go-viper/mapstructure/v2) ([MIT](https://github.com/go-viper/mapstructure/blob/v2.5.0/LICENSE)) diff --git a/third-party-licenses.linux.md b/third-party-licenses.linux.md index 981e388e5..3d7b8b3fe 100644 --- a/third-party-licenses.linux.md +++ b/third-party-licenses.linux.md @@ -15,6 +15,7 @@ The following packages are included for the 386, amd64, arm64 architectures. - [github.com/aymerick/douceur](https://pkg.go.dev/github.com/aymerick/douceur) ([MIT](https://github.com/aymerick/douceur/blob/v0.2.0/LICENSE)) - [github.com/fsnotify/fsnotify](https://pkg.go.dev/github.com/fsnotify/fsnotify) ([BSD-3-Clause](https://github.com/fsnotify/fsnotify/blob/v1.9.0/LICENSE)) - [github.com/github/github-mcp-server](https://pkg.go.dev/github.com/github/github-mcp-server) ([MIT](https://github.com/github/github-mcp-server/blob/HEAD/LICENSE)) + - [github.com/go-chi/chi/v5](https://pkg.go.dev/github.com/go-chi/chi/v5) ([MIT](https://github.com/go-chi/chi/blob/v5.2.3/LICENSE)) - [github.com/go-openapi/jsonpointer](https://pkg.go.dev/github.com/go-openapi/jsonpointer) ([Apache-2.0](https://github.com/go-openapi/jsonpointer/blob/v0.19.5/LICENSE)) - [github.com/go-openapi/swag](https://pkg.go.dev/github.com/go-openapi/swag) ([Apache-2.0](https://github.com/go-openapi/swag/blob/v0.21.1/LICENSE)) - [github.com/go-viper/mapstructure/v2](https://pkg.go.dev/github.com/go-viper/mapstructure/v2) ([MIT](https://github.com/go-viper/mapstructure/blob/v2.5.0/LICENSE)) diff --git a/third-party-licenses.windows.md b/third-party-licenses.windows.md index ae0e2389e..48bad011e 100644 --- a/third-party-licenses.windows.md +++ b/third-party-licenses.windows.md @@ -15,6 +15,7 @@ The following packages are included for the 386, amd64, arm64 architectures. - [github.com/aymerick/douceur](https://pkg.go.dev/github.com/aymerick/douceur) ([MIT](https://github.com/aymerick/douceur/blob/v0.2.0/LICENSE)) - [github.com/fsnotify/fsnotify](https://pkg.go.dev/github.com/fsnotify/fsnotify) ([BSD-3-Clause](https://github.com/fsnotify/fsnotify/blob/v1.9.0/LICENSE)) - [github.com/github/github-mcp-server](https://pkg.go.dev/github.com/github/github-mcp-server) ([MIT](https://github.com/github/github-mcp-server/blob/HEAD/LICENSE)) + - [github.com/go-chi/chi/v5](https://pkg.go.dev/github.com/go-chi/chi/v5) ([MIT](https://github.com/go-chi/chi/blob/v5.2.3/LICENSE)) - [github.com/go-openapi/jsonpointer](https://pkg.go.dev/github.com/go-openapi/jsonpointer) ([Apache-2.0](https://github.com/go-openapi/jsonpointer/blob/v0.19.5/LICENSE)) - [github.com/go-openapi/swag](https://pkg.go.dev/github.com/go-openapi/swag) ([Apache-2.0](https://github.com/go-openapi/swag/blob/v0.21.1/LICENSE)) - [github.com/go-viper/mapstructure/v2](https://pkg.go.dev/github.com/go-viper/mapstructure/v2) ([MIT](https://github.com/go-viper/mapstructure/blob/v2.5.0/LICENSE)) diff --git a/third-party/github.com/go-chi/chi/v5/LICENSE b/third-party/github.com/go-chi/chi/v5/LICENSE new file mode 100644 index 000000000..d99f02ffa --- /dev/null +++ b/third-party/github.com/go-chi/chi/v5/LICENSE @@ -0,0 +1,20 @@ +Copyright (c) 2015-present Peter Kieltyka (https://github.com/pkieltyka), Google Inc. + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.