diff --git a/internal/launcher/launcher_test.go b/internal/launcher/launcher_test.go index fd92aeea..c40ea8b2 100644 --- a/internal/launcher/launcher_test.go +++ b/internal/launcher/launcher_test.go @@ -2,6 +2,9 @@ package launcher import ( "context" + "encoding/json" + "net/http" + "net/http/httptest" "os" "testing" @@ -9,12 +12,31 @@ import ( ) func TestHTTPConnection(t *testing.T) { + // Create a mock HTTP server that handles initialize + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "serverInfo": map[string]interface{}{ + "name": "test-server", + "version": "1.0.0", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer mockServer.Close() + // Create test config with HTTP server jsonConfig := `{ "mcpServers": { "safeinputs": { "type": "http", - "url": "http://host.docker.internal:3000/", + "url": "` + mockServer.URL + `", "headers": { "Authorization": "test-auth-secret" } @@ -53,8 +75,8 @@ func TestHTTPConnection(t *testing.T) { t.Errorf("Expected type 'http', got '%s'", httpServer.Type) } - if httpServer.URL != "http://host.docker.internal:3000/" { - t.Errorf("Expected URL 'http://host.docker.internal:3000/', got '%s'", httpServer.URL) + if httpServer.URL != mockServer.URL { + t.Errorf("Expected URL '%s', got '%s'", mockServer.URL, httpServer.URL) } if httpServer.Headers["Authorization"] != "test-auth-secret" { @@ -75,8 +97,8 @@ func TestHTTPConnection(t *testing.T) { t.Error("Connection should be HTTP") } - if conn.GetHTTPURL() != "http://host.docker.internal:3000/" { - t.Errorf("Expected URL 'http://host.docker.internal:3000/', got '%s'", conn.GetHTTPURL()) + if conn.GetHTTPURL() != mockServer.URL { + t.Errorf("Expected URL '%s', got '%s'", mockServer.URL, conn.GetHTTPURL()) } if conn.GetHTTPHeaders()["Authorization"] != "test-auth-secret" { @@ -89,12 +111,31 @@ func TestHTTPConnectionWithVariableExpansion(t *testing.T) { os.Setenv("TEST_AUTH_TOKEN", "secret-token-value") defer os.Unsetenv("TEST_AUTH_TOKEN") + // Create a mock HTTP server that handles initialize + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "serverInfo": map[string]interface{}{ + "name": "test-server", + "version": "1.0.0", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer mockServer.Close() + // Create test config with variable expansion jsonConfig := `{ "mcpServers": { "safeinputs": { "type": "http", - "url": "http://host.docker.internal:3000/", + "url": "` + mockServer.URL + `", "headers": { "Authorization": "${TEST_AUTH_TOKEN}" } diff --git a/internal/mcp/connection.go b/internal/mcp/connection.go index 0758857a..e0b4c979 100644 --- a/internal/mcp/connection.go +++ b/internal/mcp/connection.go @@ -37,10 +37,11 @@ type Connection struct { ctx context.Context cancel context.CancelFunc // HTTP-specific fields - isHTTP bool - httpURL string - headers map[string]string - httpClient *http.Client + isHTTP bool + httpURL string + headers map[string]string + httpClient *http.Client + httpSessionID string // Session ID returned by the HTTP backend } // NewConnection creates a new MCP connection using the official SDK @@ -126,7 +127,17 @@ func NewConnection(ctx context.Context, command string, args []string, env map[s } // NewHTTPConnection creates a new HTTP-based MCP connection -// For HTTP servers that are already running, we just store the connection info +// For HTTP servers that are already running, we connect and initialize a session +// +// NOTE: This currently implements the MCP HTTP protocol manually instead of using +// the SDK's SSEClientTransport. This is because: +// 1. SSEClientTransport requires Server-Sent Events (SSE) format +// 2. Some HTTP MCP servers (like safeinputs) use plain JSON-RPC over HTTP POST +// 3. The MCP spec allows both SSE and plain HTTP transports +// +// TODO: Migrate to sdk.SSEClientTransport once we confirm all target HTTP backends +// support the SSE format, or extend the SDK to support both transport formats. +// This would eliminate manual JSON-RPC handling and improve maintainability. func NewHTTPConnection(ctx context.Context, url string, headers map[string]string) (*Connection, error) { logger.LogInfo("backend", "Creating HTTP MCP connection, url=%s", url) logConn.Printf("Creating HTTP MCP connection: url=%s", url) @@ -151,9 +162,20 @@ func NewHTTPConnection(ctx context.Context, url string, headers map[string]strin httpClient: httpClient, } - logger.LogInfoMd("backend", "Successfully created HTTP MCP connection, url=%s", url) - logConn.Printf("HTTP connection created: url=%s", url) - log.Printf("Configured HTTP MCP server: %s", url) + // Send initialize request to establish a session with the HTTP backend + // This is critical for backends that require session management + logConn.Printf("Sending initialize request to HTTP backend: url=%s", url) + sessionID, err := conn.initializeHTTPSession() + if err != nil { + cancel() + logger.LogError("backend", "Failed to initialize HTTP session, url=%s, error=%v", url, err) + return nil, fmt.Errorf("failed to initialize HTTP session: %w", err) + } + + conn.httpSessionID = sessionID + logger.LogInfo("backend", "Successfully created HTTP MCP connection with session, url=%s, session=%s", url, sessionID) + logConn.Printf("HTTP connection created with session: url=%s, session=%s", url, sessionID) + log.Printf("Configured HTTP MCP server with session: %s (session: %s)", url, sessionID) return conn, nil } @@ -232,6 +254,101 @@ func (c *Connection) SendRequestWithServerID(ctx context.Context, method string, return result, err } +// initializeHTTPSession sends an initialize request to the HTTP backend and captures the session ID +func (c *Connection) initializeHTTPSession() (string, error) { + // Generate unique request ID + requestID := atomic.AddUint64(&requestIDCounter, 1) + + // Create initialize request with MCP protocol parameters + initParams := map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "clientInfo": map[string]interface{}{ + "name": "awmg", + "version": "1.0.0", + }, + } + + request := map[string]interface{}{ + "jsonrpc": "2.0", + "id": requestID, + "method": "initialize", + "params": initParams, + } + + requestBody, err := json.Marshal(request) + if err != nil { + return "", fmt.Errorf("failed to marshal initialize request: %w", err) + } + + logConn.Printf("Sending initialize request: %s", string(requestBody)) + + // Create HTTP request + httpReq, err := http.NewRequest("POST", c.httpURL, bytes.NewReader(requestBody)) + if err != nil { + return "", fmt.Errorf("failed to create HTTP request: %w", err) + } + + // Set headers + httpReq.Header.Set("Content-Type", "application/json") + + // Generate a temporary session ID for the initialize request + // Some backends may require this header even during initialization + tempSessionID := fmt.Sprintf("awmg-init-%d", requestID) + httpReq.Header.Set("Mcp-Session-Id", tempSessionID) + logConn.Printf("Sending initialize with temporary session ID: %s", tempSessionID) + + // Add configured headers (e.g., Authorization) + for key, value := range c.headers { + httpReq.Header.Set(key, value) + } + + logConn.Printf("Sending initialize to %s", c.httpURL) + + // Send request + httpResp, err := c.httpClient.Do(httpReq) + if err != nil { + return "", fmt.Errorf("failed to send initialize request: %w", err) + } + defer httpResp.Body.Close() + + // Capture the Mcp-Session-Id from response headers + sessionID := httpResp.Header.Get("Mcp-Session-Id") + if sessionID != "" { + logConn.Printf("Captured Mcp-Session-Id from response: %s", sessionID) + } else { + // If no session ID in response, use the temporary one + // This handles backends that don't return a session ID + sessionID = tempSessionID + logConn.Printf("No Mcp-Session-Id in response, using temporary session ID: %s", sessionID) + } + + // Read response body + responseBody, err := io.ReadAll(httpResp.Body) + if err != nil { + return "", fmt.Errorf("failed to read initialize response: %w", err) + } + + logConn.Printf("Initialize response: status=%d, body_len=%d, session=%s", httpResp.StatusCode, len(responseBody), sessionID) + + // Check for HTTP errors + if httpResp.StatusCode != http.StatusOK { + return "", fmt.Errorf("initialize failed: status=%d, body=%s", httpResp.StatusCode, string(responseBody)) + } + + // Parse JSON-RPC response to check for errors + var rpcResponse Response + if err := json.Unmarshal(responseBody, &rpcResponse); err != nil { + return "", fmt.Errorf("failed to parse initialize response: %w", err) + } + + if rpcResponse.Error != nil { + return "", fmt.Errorf("initialize error: code=%d, message=%s", rpcResponse.Error.Code, rpcResponse.Error.Message) + } + + return sessionID, nil +} + // sendHTTPRequest sends a JSON-RPC request to an HTTP MCP server // The ctx parameter is used to extract session ID for the Mcp-Session-Id header func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params interface{}) (*Response, error) { @@ -260,10 +377,22 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params // Set headers httpReq.Header.Set("Content-Type", "application/json") - // Extract session ID from context and add Mcp-Session-Id header - if sessionID, ok := ctx.Value(SessionIDContextKey).(string); ok && sessionID != "" { + // Add Mcp-Session-Id header with priority: + // 1) Context session ID (if explicitly provided for this request) + // 2) Stored httpSessionID from initialization + var sessionID string + if ctxSessionID, ok := ctx.Value(SessionIDContextKey).(string); ok && ctxSessionID != "" { + sessionID = ctxSessionID + logConn.Printf("Using session ID from context: %s", sessionID) + } else if c.httpSessionID != "" { + sessionID = c.httpSessionID + logConn.Printf("Using stored session ID from initialization: %s", sessionID) + } + + if sessionID != "" { httpReq.Header.Set("Mcp-Session-Id", sessionID) - logConn.Printf("Added Mcp-Session-Id header: %s", sessionID) + } else { + logConn.Printf("No session ID available (backend may not require session management)") } // Add configured headers diff --git a/internal/server/unified_http_backend_test.go b/internal/server/unified_http_backend_test.go index 419e8c29..ea57310b 100644 --- a/internal/server/unified_http_backend_test.go +++ b/internal/server/unified_http_backend_test.go @@ -172,6 +172,7 @@ func TestHTTPBackendInitializationWithSessionIDRequirement(t *testing.T) { // TestHTTPBackend_SessionIDPropagation tests that session ID is propagated through tool calls func TestHTTPBackend_SessionIDPropagation(t *testing.T) { // Track session IDs received at different stages + initializeSessionID := "" initSessionID := "" toolCallSessionID := "" @@ -186,6 +187,23 @@ func TestHTTPBackend_SessionIDPropagation(t *testing.T) { json.NewDecoder(r.Body).Decode(&req) switch req.Method { + case "initialize": + initializeSessionID = sessionID + // Return initialize response + response := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "serverInfo": map[string]interface{}{ + "name": "test-http-server", + "version": "1.0.0", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) case "tools/list": initSessionID = sessionID // Return tools list @@ -260,8 +278,14 @@ func TestHTTPBackend_SessionIDPropagation(t *testing.T) { } // Verify session IDs were received + if initializeSessionID == "" { + t.Errorf("No session ID received during initialize") + } else { + t.Logf("Initialize session ID: %s", initializeSessionID) + } + if initSessionID == "" { - t.Errorf("No session ID received during initialization") + t.Errorf("No session ID received during tools/list (initialization)") } else { t.Logf("Init session ID: %s", initSessionID) } diff --git a/test/integration/safeinputs_http_test.go b/test/integration/safeinputs_http_test.go index aa92efa9..0369568e 100644 --- a/test/integration/safeinputs_http_test.go +++ b/test/integration/safeinputs_http_test.go @@ -67,6 +67,19 @@ func TestSafeinputsHTTPBackend(t *testing.T) { // Return appropriate response based on method var response map[string]interface{} switch rpcReq.Method { + case "initialize": + response = map[string]interface{}{ + "jsonrpc": "2.0", + "id": rpcReq.ID, + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "serverInfo": map[string]interface{}{ + "name": "safeinputs-server", + "version": "1.0.0", + }, + }, + } case "tools/list": response = map[string]interface{}{ "jsonrpc": "2.0", @@ -231,7 +244,8 @@ func TestSafeinputsHTTPBackend(t *testing.T) { t.Logf("Request #%d session ID: %s", i+1, headers["Mcp-Session-Id"]) // Verify the session ID follows the expected pattern for initialization - if strings.HasPrefix(headers["Mcp-Session-Id"], "gateway-init-") { + if strings.HasPrefix(headers["Mcp-Session-Id"], "awmg-init-") || + strings.HasPrefix(headers["Mcp-Session-Id"], "gateway-init-") { t.Logf("✓ Request #%d has correct gateway initialization session ID pattern", i+1) } }