diff --git a/internal/mcp/connection.go b/internal/mcp/connection.go index e0b4c979..f2254dd1 100644 --- a/internal/mcp/connection.go +++ b/internal/mcp/connection.go @@ -30,6 +30,18 @@ const SessionIDContextKey ContextKey = "awmg-session-id" // requestIDCounter is used to generate unique request IDs for HTTP requests var requestIDCounter uint64 +// HTTPTransportType represents the type of HTTP transport being used +type HTTPTransportType string + +const ( + // HTTPTransportStreamable uses the streamable HTTP transport (2025-03-26 spec) + HTTPTransportStreamable HTTPTransportType = "streamable" + // HTTPTransportSSE uses the SSE transport (2024-11-05 spec) + HTTPTransportSSE HTTPTransportType = "sse" + // HTTPTransportPlainJSON uses plain JSON-RPC 2.0 over HTTP POST (non-standard) + HTTPTransportPlainJSON HTTPTransportType = "plain-json" +) + // Connection represents a connection to an MCP server using the official SDK type Connection struct { client *sdk.Client @@ -37,11 +49,12 @@ type Connection struct { ctx context.Context cancel context.CancelFunc // HTTP-specific fields - isHTTP bool - httpURL string - headers map[string]string - httpClient *http.Client - httpSessionID string // Session ID returned by the HTTP backend + isHTTP bool + httpURL string + headers map[string]string + httpClient *http.Client + httpSessionID string // Session ID returned by the HTTP backend + httpTransportType HTTPTransportType // Type of HTTP transport in use } // NewConnection creates a new MCP connection using the official SDK @@ -126,20 +139,20 @@ func NewConnection(ctx context.Context, command string, args []string, env map[s return conn, nil } -// NewHTTPConnection creates a new HTTP-based MCP connection +// NewHTTPConnection creates a new HTTP-based MCP connection with transport fallback // For HTTP servers that are already running, we connect and initialize a session // -// NOTE: This currently implements the MCP HTTP protocol manually instead of using -// the SDK's SSEClientTransport. This is because: -// 1. SSEClientTransport requires Server-Sent Events (SSE) format -// 2. Some HTTP MCP servers (like safeinputs) use plain JSON-RPC over HTTP POST -// 3. The MCP spec allows both SSE and plain HTTP transports +// This function implements a fallback strategy for HTTP transports: +// 1. If custom headers are provided, skip SDK transports (they don't support custom headers) +// and use plain JSON-RPC 2.0 over HTTP POST (for safeinputs compatibility) +// 2. Otherwise, try standard transports: +// a. Streamable HTTP (2025-03-26 spec) using SDK's StreamableClientTransport +// b. SSE (2024-11-05 spec) using SDK's SSEClientTransport +// c. Plain JSON-RPC 2.0 over HTTP POST as final fallback // -// TODO: Migrate to sdk.SSEClientTransport once we confirm all target HTTP backends -// support the SSE format, or extend the SDK to support both transport formats. -// This would eliminate manual JSON-RPC handling and improve maintainability. +// This ensures compatibility with all types of HTTP MCP servers. func NewHTTPConnection(ctx context.Context, url string, headers map[string]string) (*Connection, error) { - logger.LogInfo("backend", "Creating HTTP MCP connection, url=%s", url) + logger.LogInfo("backend", "Creating HTTP MCP connection with transport fallback, url=%s", url) logConn.Printf("Creating HTTP MCP connection: url=%s", url) ctx, cancel := context.WithCancel(ctx) @@ -153,29 +166,158 @@ func NewHTTPConnection(ctx context.Context, url string, headers map[string]strin }, } + // If custom headers are provided, skip SDK transports as they don't support headers + // This is typical for backends like safeinputs that require authentication + if len(headers) > 0 { + logConn.Printf("Custom headers detected, using plain JSON-RPC transport for %s", url) + conn, err := tryPlainJSONTransport(ctx, cancel, url, headers, httpClient) + if err == nil { + logger.LogInfo("backend", "Successfully connected using plain JSON-RPC transport, url=%s", url) + log.Printf("Configured HTTP MCP server with plain JSON-RPC transport: %s", url) + return conn, nil + } + cancel() + logger.LogError("backend", "Plain JSON-RPC transport failed for url=%s, error=%v", url, err) + return nil, fmt.Errorf("failed to connect with plain JSON-RPC transport: %w", err) + } + + // Try standard transports in order: streamable HTTP → SSE → plain JSON-RPC + + // Try 1: Streamable HTTP (2025-03-26 spec) + logConn.Printf("Attempting streamable HTTP transport for %s", url) + conn, err := tryStreamableHTTPTransport(ctx, cancel, url, headers, httpClient) + if err == nil { + logger.LogInfo("backend", "Successfully connected using streamable HTTP transport, url=%s", url) + log.Printf("Configured HTTP MCP server with streamable transport: %s", url) + return conn, nil + } + logConn.Printf("Streamable HTTP failed: %v", err) + + // Try 2: SSE (2024-11-05 spec) + logConn.Printf("Attempting SSE transport for %s", url) + conn, err = trySSETransport(ctx, cancel, url, headers, httpClient) + if err == nil { + logger.LogInfo("backend", "Successfully connected using SSE transport, url=%s", url) + log.Printf("Configured HTTP MCP server with SSE transport: %s", url) + return conn, nil + } + logConn.Printf("SSE transport failed: %v", err) + + // Try 3: Plain JSON-RPC over HTTP (non-standard, for fallback) + logConn.Printf("Attempting plain JSON-RPC transport for %s", url) + conn, err = tryPlainJSONTransport(ctx, cancel, url, headers, httpClient) + if err == nil { + logger.LogInfo("backend", "Successfully connected using plain JSON-RPC transport, url=%s", url) + log.Printf("Configured HTTP MCP server with plain JSON-RPC transport: %s", url) + return conn, nil + } + logConn.Printf("Plain JSON-RPC transport failed: %v", err) + + // All transports failed + cancel() + logger.LogError("backend", "All HTTP transports failed for url=%s", url) + return nil, fmt.Errorf("failed to connect using any HTTP transport (tried streamable, SSE, and plain JSON-RPC): last error: %w", err) +} + +// tryStreamableHTTPTransport attempts to connect using the streamable HTTP transport (2025-03-26 spec) +func tryStreamableHTTPTransport(ctx context.Context, cancel context.CancelFunc, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) { + // Create MCP client + client := sdk.NewClient(&sdk.Implementation{ + Name: "awmg", + Version: "1.0.0", + }, nil) + + // Create streamable HTTP transport + transport := &sdk.StreamableClientTransport{ + Endpoint: url, + HTTPClient: httpClient, + MaxRetries: 0, // Don't retry on failure - we'll try other transports + } + + // Try to connect - this will fail if the server doesn't support streamable HTTP + session, err := client.Connect(ctx, transport, nil) + if err != nil { + return nil, fmt.Errorf("streamable HTTP transport connect failed: %w", err) + } + + conn := &Connection{ + client: client, + session: session, + ctx: ctx, + cancel: cancel, + isHTTP: true, + httpURL: url, + headers: headers, + httpClient: httpClient, + httpTransportType: HTTPTransportStreamable, + } + + logger.LogInfo("backend", "Streamable HTTP transport connected successfully") + logConn.Printf("Connected with streamable HTTP transport") + return conn, nil +} + +// trySSETransport attempts to connect using the SSE transport (2024-11-05 spec) +func trySSETransport(ctx context.Context, cancel context.CancelFunc, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) { + // Create MCP client + client := sdk.NewClient(&sdk.Implementation{ + Name: "awmg", + Version: "1.0.0", + }, nil) + + // Create SSE transport + transport := &sdk.SSEClientTransport{ + Endpoint: url, + HTTPClient: httpClient, + } + + // Try to connect - this will fail if the server doesn't support SSE + session, err := client.Connect(ctx, transport, nil) + if err != nil { + return nil, fmt.Errorf("SSE transport connect failed: %w", err) + } + conn := &Connection{ - ctx: ctx, - cancel: cancel, - isHTTP: true, - httpURL: url, - headers: headers, - httpClient: httpClient, + client: client, + session: session, + ctx: ctx, + cancel: cancel, + isHTTP: true, + httpURL: url, + headers: headers, + httpClient: httpClient, + httpTransportType: HTTPTransportSSE, + } + + logger.LogInfo("backend", "SSE transport connected successfully") + logConn.Printf("Connected with SSE transport") + return conn, nil +} + +// tryPlainJSONTransport attempts to connect using plain JSON-RPC 2.0 over HTTP POST (non-standard) +// This is used for compatibility with servers like safeinputs that don't implement standard MCP HTTP transports +func tryPlainJSONTransport(ctx context.Context, cancel context.CancelFunc, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) { + conn := &Connection{ + ctx: ctx, + cancel: cancel, + isHTTP: true, + httpURL: url, + headers: headers, + httpClient: httpClient, + httpTransportType: HTTPTransportPlainJSON, } // Send initialize request to establish a session with the HTTP backend // This is critical for backends that require session management - logConn.Printf("Sending initialize request to HTTP backend: url=%s", url) + logConn.Printf("Sending initialize request via plain JSON-RPC to: %s", url) sessionID, err := conn.initializeHTTPSession() if err != nil { - cancel() - logger.LogError("backend", "Failed to initialize HTTP session, url=%s, error=%v", url, err) - return nil, fmt.Errorf("failed to initialize HTTP session: %w", err) + return nil, fmt.Errorf("plain JSON-RPC initialize failed: %w", err) } conn.httpSessionID = sessionID - logger.LogInfo("backend", "Successfully created HTTP MCP connection with session, url=%s, session=%s", url, sessionID) - logConn.Printf("HTTP connection created with session: url=%s, session=%s", url, sessionID) - log.Printf("Configured HTTP MCP server with session: %s (session: %s)", url, sessionID) + logger.LogInfo("backend", "Plain JSON-RPC transport connected successfully with session=%s", sessionID) + logConn.Printf("Connected with plain JSON-RPC transport, session=%s", sessionID) return conn, nil } @@ -214,9 +356,22 @@ func (c *Connection) SendRequestWithServerID(ctx context.Context, method string, var result *Response var err error - // Handle HTTP connections by proxying the request + // Handle HTTP connections if c.isHTTP { - result, err = c.sendHTTPRequest(ctx, method, params) + // For plain JSON-RPC transport, use manual HTTP requests + if c.httpTransportType == HTTPTransportPlainJSON { + result, err = c.sendHTTPRequest(ctx, method, params) + // Log the response from backend server + var responsePayload []byte + if result != nil { + responsePayload, _ = json.Marshal(result) + } + logger.LogRPCResponse(logger.RPCDirectionInbound, serverID, responsePayload, err) + return result, err + } + + // For streamable and SSE transports, use SDK session methods + result, err = c.callSDKMethod(method, params) // Log the response from backend server var responsePayload []byte if result != nil { @@ -227,22 +382,7 @@ func (c *Connection) SendRequestWithServerID(ctx context.Context, method string, } // Handle stdio connections using SDK client - switch method { - case "tools/list": - result, err = c.listTools() - case "tools/call": - result, err = c.callTool(params) - case "resources/list": - result, err = c.listResources() - case "resources/read": - result, err = c.readResource(params) - case "prompts/list": - result, err = c.listPrompts() - case "prompts/get": - result, err = c.getPrompt(params) - default: - err = fmt.Errorf("unsupported method: %s", method) - } + result, err = c.callSDKMethod(method, params) // Log the response from backend server var responsePayload []byte @@ -254,6 +394,27 @@ func (c *Connection) SendRequestWithServerID(ctx context.Context, method string, return result, err } +// callSDKMethod calls the appropriate SDK method based on the method name +// This centralizes the method dispatch logic used by both HTTP SDK transports and stdio +func (c *Connection) callSDKMethod(method string, params interface{}) (*Response, error) { + switch method { + case "tools/list": + return c.listTools() + case "tools/call": + return c.callTool(params) + case "resources/list": + return c.listResources() + case "resources/read": + return c.readResource(params) + case "prompts/list": + return c.listPrompts() + case "prompts/get": + return c.getPrompt(params) + default: + return nil, fmt.Errorf("unsupported method: %s", method) + } +} + // initializeHTTPSession sends an initialize request to the HTTP backend and captures the session ID func (c *Connection) initializeHTTPSession() (string, error) { // Generate unique request ID @@ -432,6 +593,9 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params } func (c *Connection) listTools() (*Response, error) { + if c.session == nil { + return nil, fmt.Errorf("SDK session not available for plain JSON-RPC transport") + } result, err := c.session.ListTools(c.ctx, &sdk.ListToolsParams{}) if err != nil { return nil, err @@ -450,6 +614,9 @@ func (c *Connection) listTools() (*Response, error) { } func (c *Connection) callTool(params interface{}) (*Response, error) { + if c.session == nil { + return nil, fmt.Errorf("SDK session not available for plain JSON-RPC transport") + } var callParams CallToolParams paramsJSON, _ := json.Marshal(params) if err := json.Unmarshal(paramsJSON, &callParams); err != nil { @@ -477,6 +644,9 @@ func (c *Connection) callTool(params interface{}) (*Response, error) { } func (c *Connection) listResources() (*Response, error) { + if c.session == nil { + return nil, fmt.Errorf("SDK session not available for plain JSON-RPC transport") + } result, err := c.session.ListResources(c.ctx, &sdk.ListResourcesParams{}) if err != nil { return nil, err @@ -495,6 +665,9 @@ func (c *Connection) listResources() (*Response, error) { } func (c *Connection) readResource(params interface{}) (*Response, error) { + if c.session == nil { + return nil, fmt.Errorf("SDK session not available for plain JSON-RPC transport") + } var readParams struct { URI string `json:"uri"` } @@ -523,6 +696,9 @@ func (c *Connection) readResource(params interface{}) (*Response, error) { } func (c *Connection) listPrompts() (*Response, error) { + if c.session == nil { + return nil, fmt.Errorf("SDK session not available for plain JSON-RPC transport") + } result, err := c.session.ListPrompts(c.ctx, &sdk.ListPromptsParams{}) if err != nil { return nil, err @@ -541,6 +717,9 @@ func (c *Connection) listPrompts() (*Response, error) { } func (c *Connection) getPrompt(params interface{}) (*Response, error) { + if c.session == nil { + return nil, fmt.Errorf("SDK session not available for plain JSON-RPC transport") + } var getParams struct { Name string `json:"name"` Arguments map[string]string `json:"arguments"`