diff --git a/pkg/mcp/middleware.go b/pkg/mcp/middleware.go new file mode 100644 index 0000000..ee3594a --- /dev/null +++ b/pkg/mcp/middleware.go @@ -0,0 +1,24 @@ +package mcp + +import ( + "context" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// WithTimeoutContext adds a timeout context to all tool handlers +// This helps prevent context cancellation errors from the k8s client +func WithTimeoutContext(timeout time.Duration) server.ServerOption { + return server.WithToolHandlerMiddleware(func(next server.ToolHandlerFunc) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (result *mcp.CallToolResult, err error) { + // Create a new context with the specified timeout + timeoutCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + // Call the next handler with the timeout context + return next(timeoutCtx, request) + } + }) +} diff --git a/pkg/mcp/server.go b/pkg/mcp/server.go index 5e80739..e47d989 100644 --- a/pkg/mcp/server.go +++ b/pkg/mcp/server.go @@ -3,12 +3,16 @@ package mcp import ( "context" "log" + "time" "github.com/mark3labs/mcp-go/server" "github.com/StacklokLabs/mkp/pkg/k8s" ) +// defaultCtxTimeout is the default timeout for tool calls +const defaultCtxTimeout = 30 * time.Second + // Config holds configuration options for the MCP server type Config struct { // ServeResources determines whether to serve cluster resources @@ -43,6 +47,9 @@ func CreateServer(k8sClient *k8s.Client, config *Config) *server.MCPServer { "0.1.0", server.WithResourceCapabilities(true, true), server.WithToolCapabilities(true), + // Add timeout middleware to prevent context cancellation errors + WithTimeoutContext(defaultCtxTimeout), + server.WithRecovery(), ) // Add tools @@ -67,8 +74,12 @@ func CreateServer(k8sClient *k8s.Client, config *Config) *server.MCPServer { // Add resources if enabled if config.ServeResources { go func() { + // Create a timeout context for listing resources + timeoutCtx, cancel := context.WithTimeout(context.Background(), defaultCtxTimeout) + defer cancel() + // List resources in a goroutine to avoid blocking server startup - resources, err := impl.HandleListAllResources(context.Background()) + resources, err := impl.HandleListAllResources(timeoutCtx) if err != nil { log.Printf("Failed to list resources: %v", err) return