diff --git a/README.md b/README.md index 755eb16..6cb2b70 100644 --- a/README.md +++ b/README.md @@ -367,6 +367,28 @@ The resource URIs follow these formats: ### Configuration +#### Transport Protocol + +MKP supports two transport protocols for the MCP server: + +- **SSE (Server-Sent Events)**: The default transport protocol, suitable for most use cases +- **Streamable HTTP**: A streaming HTTP transport that supports both direct HTTP responses and SSE streams, useful for environments like ToolHive that require HTTP-based communication + +You can configure the transport protocol using either a CLI flag or an environment variable: + +```bash +# Using CLI flag +./build/mkp-server --transport=streamable-http + +# Using environment variable +MCP_TRANSPORT=streamable-http ./build/mkp-server + +# Default (SSE) +./build/mkp-server +``` + +The `MCP_TRANSPORT` environment variable is automatically set by ToolHive when running MKP in that environment. + #### Controlling Resource Discovery By default, MKP serves all Kubernetes resources as MCP resources, which provides diff --git a/cmd/server/main.go b/cmd/server/main.go index db73dbd..3b044f0 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -9,6 +9,7 @@ import ( "os" "os/signal" "strconv" + "strings" "syscall" "time" @@ -16,6 +17,12 @@ import ( "github.com/StacklokLabs/mkp/pkg/mcp" ) +const ( + // Transport types + transportSSE = "sse" + transportStreamableHTTP = "streamable-http" +) + func main() { // Parse command line flags kubeconfig := flag.String("kubeconfig", "", "Path to kubeconfig file. If not provided, in-cluster config will be used") @@ -28,6 +35,8 @@ func main() { "Interval to periodically re-read the kubeconfig (e.g., 5m for 5 minutes). If 0, no refresh will be performed") enableRateLimiting := flag.Bool("enable-rate-limiting", true, "Whether to enable rate limiting for tool calls. When false, no rate limiting will be applied") + transport := flag.String("transport", getDefaultTransport(), + "Transport protocol to use: 'sse' or 'streamable-http'. Can also be set via MCP_TRANSPORT environment variable") flag.Parse() @@ -74,16 +83,30 @@ func main() { // Create MCP server using the helper function mcpServer := mcp.CreateServer(k8sClient, config) - // Create SSE server - sseServer := mcp.CreateSSEServer(mcpServer) + // Create and start the appropriate transport server + var transportServer interface { + Start(string) error + Shutdown(context.Context) error + } + + switch strings.ToLower(*transport) { + case transportStreamableHTTP: + log.Println("Using streamable-http transport") + transportServer = mcp.CreateStreamableHTTPServer(mcpServer) + case transportSSE: + log.Println("Using SSE transport") + transportServer = mcp.CreateSSEServer(mcpServer) + default: + log.Fatalf("Invalid transport: %s. Must be 'sse' or 'streamable-http'", *transport) + } // Channel to receive server errors serverErrCh := make(chan error, 1) // Start the server in a goroutine go func() { - log.Printf("Starting MCP server on %s", *addr) - if err := sseServer.Start(*addr); err != nil { + log.Printf("Starting MCP server on %s with %s transport", *addr, *transport) + if err := transportServer.Start(*addr); err != nil { log.Printf("Server error: %v", err) serverErrCh <- err } @@ -106,8 +129,8 @@ func main() { go func() { log.Println("Initiating server shutdown...") - // Stop the SSE server - err := sseServer.Shutdown(shutdownCtx) + // Stop the transport server + err := transportServer.Shutdown(shutdownCtx) if err != nil { log.Printf("Error during shutdown: %v", err) } @@ -166,3 +189,27 @@ func getDefaultAddress() string { return fmt.Sprintf(":%d", port) } + +// getDefaultTransport returns the transport to use based on MCP_TRANSPORT environment variable. +// If the environment variable is not set, returns "sse". +// Valid values are "sse" and "streamable-http". +func getDefaultTransport() string { + defaultTransport := transportSSE + + transportEnv := os.Getenv("MCP_TRANSPORT") + if transportEnv == "" { + return defaultTransport + } + + // Normalize the transport value + transport := strings.ToLower(strings.TrimSpace(transportEnv)) + + // Validate the transport value + if transport != transportSSE && transport != transportStreamableHTTP { + log.Printf("Invalid MCP_TRANSPORT: %s, using default: %s", + transportEnv, defaultTransport) + return defaultTransport + } + + return transport +} diff --git a/pkg/mcp/server.go b/pkg/mcp/server.go index cb5be94..e0adf13 100644 --- a/pkg/mcp/server.go +++ b/pkg/mcp/server.go @@ -145,3 +145,8 @@ func StopServer() { func CreateSSEServer(mcpServer *server.MCPServer) *server.SSEServer { return server.NewSSEServer(mcpServer) } + +// CreateStreamableHTTPServer creates a new StreamableHTTP server for the MCP server +func CreateStreamableHTTPServer(mcpServer *server.MCPServer) *server.StreamableHTTPServer { + return server.NewStreamableHTTPServer(mcpServer) +}