diff --git a/internal/mcp/connection.go b/internal/mcp/connection.go index 89aa1c38..477143f0 100644 --- a/internal/mcp/connection.go +++ b/internal/mcp/connection.go @@ -71,6 +71,58 @@ type Connection struct { httpTransportType HTTPTransportType // Type of HTTP transport in use } +// newMCPClient creates a new MCP SDK client with standard implementation details +func newMCPClient() *sdk.Client { + return sdk.NewClient(&sdk.Implementation{ + Name: "awmg", + Version: "1.0.0", + }, nil) +} + +// newHTTPConnection creates a new HTTP Connection struct with common fields +func newHTTPConnection(ctx context.Context, cancel context.CancelFunc, client *sdk.Client, session *sdk.ClientSession, url string, headers map[string]string, httpClient *http.Client, transportType HTTPTransportType) *Connection { + return &Connection{ + client: client, + session: session, + ctx: ctx, + cancel: cancel, + isHTTP: true, + httpURL: url, + headers: headers, + httpClient: httpClient, + httpTransportType: transportType, + } +} + +// createJSONRPCRequest creates a JSON-RPC 2.0 request map +func createJSONRPCRequest(requestID uint64, method string, params interface{}) map[string]interface{} { + return map[string]interface{}{ + "jsonrpc": "2.0", + "id": requestID, + "method": method, + "params": params, + } +} + +// setupHTTPRequest creates and configures an HTTP request with standard headers +func setupHTTPRequest(ctx context.Context, url string, requestBody []byte, headers map[string]string) (*http.Request, error) { + httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(requestBody)) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + + // Set standard headers + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "application/json, text/event-stream") + + // Add configured headers (e.g., Authorization) + for key, value := range headers { + httpReq.Header.Set(key, value) + } + + return httpReq, nil +} + // NewConnection creates a new MCP connection using the official SDK func NewConnection(ctx context.Context, command string, args []string, env map[string]string) (*Connection, error) { logger.LogInfo("backend", "Creating new MCP backend connection, command=%s, args=%v", command, args) @@ -78,10 +130,7 @@ func NewConnection(ctx context.Context, command string, args []string, env map[s ctx, cancel := context.WithCancel(ctx) // Create MCP client - client := sdk.NewClient(&sdk.Implementation{ - Name: "awmg", - Version: "1.0.0", - }, nil) + client := newMCPClient() // Expand Docker -e flags that reference environment variables // Docker's `-e VAR_NAME` expects VAR_NAME to be in the environment @@ -239,10 +288,7 @@ func NewHTTPConnection(ctx context.Context, url string, headers map[string]strin // tryStreamableHTTPTransport attempts to connect using the streamable HTTP transport (2025-03-26 spec) func tryStreamableHTTPTransport(ctx context.Context, cancel context.CancelFunc, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) { // Create MCP client - client := sdk.NewClient(&sdk.Implementation{ - Name: "awmg", - Version: "1.0.0", - }, nil) + client := newMCPClient() // Create streamable HTTP transport transport := &sdk.StreamableClientTransport{ @@ -261,17 +307,7 @@ func tryStreamableHTTPTransport(ctx context.Context, cancel context.CancelFunc, return nil, fmt.Errorf("streamable HTTP transport connect failed: %w", err) } - conn := &Connection{ - client: client, - session: session, - ctx: ctx, - cancel: cancel, - isHTTP: true, - httpURL: url, - headers: headers, - httpClient: httpClient, - httpTransportType: HTTPTransportStreamable, - } + conn := newHTTPConnection(ctx, cancel, client, session, url, headers, httpClient, HTTPTransportStreamable) logger.LogInfo("backend", "Streamable HTTP transport connected successfully") logConn.Printf("Connected with streamable HTTP transport") @@ -281,10 +317,7 @@ func tryStreamableHTTPTransport(ctx context.Context, cancel context.CancelFunc, // trySSETransport attempts to connect using the SSE transport (2024-11-05 spec) func trySSETransport(ctx context.Context, cancel context.CancelFunc, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) { // Create MCP client - client := sdk.NewClient(&sdk.Implementation{ - Name: "awmg", - Version: "1.0.0", - }, nil) + client := newMCPClient() // Create SSE transport transport := &sdk.SSEClientTransport{ @@ -302,17 +335,7 @@ func trySSETransport(ctx context.Context, cancel context.CancelFunc, url string, return nil, fmt.Errorf("SSE transport connect failed: %w", err) } - conn := &Connection{ - client: client, - session: session, - ctx: ctx, - cancel: cancel, - isHTTP: true, - httpURL: url, - headers: headers, - httpClient: httpClient, - httpTransportType: HTTPTransportSSE, - } + conn := newHTTPConnection(ctx, cancel, client, session, url, headers, httpClient, HTTPTransportSSE) logger.LogInfo("backend", "SSE transport connected successfully") logConn.Printf("Connected with SSE transport") @@ -455,12 +478,7 @@ func (c *Connection) initializeHTTPSession() (string, error) { }, } - request := map[string]interface{}{ - "jsonrpc": "2.0", - "id": requestID, - "method": "initialize", - "params": initParams, - } + request := createJSONRPCRequest(requestID, "initialize", initParams) requestBody, err := json.Marshal(request) if err != nil { @@ -469,27 +487,18 @@ func (c *Connection) initializeHTTPSession() (string, error) { logConn.Printf("Sending initialize request: %s", string(requestBody)) - // Create HTTP request - httpReq, err := http.NewRequest("POST", c.httpURL, bytes.NewReader(requestBody)) + // Create HTTP request with standard headers + httpReq, err := setupHTTPRequest(context.Background(), c.httpURL, requestBody, c.headers) if err != nil { - return "", fmt.Errorf("failed to create HTTP request: %w", err) + return "", err } - // Set headers - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Accept", "application/json, text/event-stream") - // Generate a temporary session ID for the initialize request // Some backends may require this header even during initialization tempSessionID := fmt.Sprintf("awmg-init-%d", requestID) httpReq.Header.Set("Mcp-Session-Id", tempSessionID) logConn.Printf("Sending initialize with temporary session ID: %s", tempSessionID) - // Add configured headers (e.g., Authorization) - for key, value := range c.headers { - httpReq.Header.Set(key, value) - } - logConn.Printf("Sending initialize to %s", c.httpURL) // Send request @@ -568,28 +577,19 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params requestID := atomic.AddUint64(&requestIDCounter, 1) // Create JSON-RPC request - request := map[string]interface{}{ - "jsonrpc": "2.0", - "id": requestID, - "method": method, - "params": params, - } + request := createJSONRPCRequest(requestID, method, params) requestBody, err := json.Marshal(request) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } - // Create HTTP request - httpReq, err := http.NewRequestWithContext(ctx, "POST", c.httpURL, bytes.NewReader(requestBody)) + // Create HTTP request with standard headers + httpReq, err := setupHTTPRequest(ctx, c.httpURL, requestBody, c.headers) if err != nil { - return nil, fmt.Errorf("failed to create HTTP request: %w", err) + return nil, err } - // Set headers - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Accept", "application/json, text/event-stream") - // Add Mcp-Session-Id header with priority: // 1) Context session ID (if explicitly provided for this request) // 2) Stored httpSessionID from initialization @@ -608,11 +608,6 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params logConn.Printf("No session ID available (backend may not require session management)") } - // Add configured headers - for key, value := range c.headers { - httpReq.Header.Set(key, value) - } - logConn.Printf("Sending HTTP request to %s: method=%s, id=%d", c.httpURL, method, requestID) // Send request using the reusable HTTP client diff --git a/internal/mcp/connection_test.go b/internal/mcp/connection_test.go index c5b6d14f..bb4d4cca 100644 --- a/internal/mcp/connection_test.go +++ b/internal/mcp/connection_test.go @@ -564,3 +564,142 @@ data: {"jsonrpc":"2.0","id":` + idStr + `,"result":{"tools":[]}} t.Logf("Successfully parsed SSE-formatted response from server") } + +// TestNewMCPClient tests the newMCPClient helper function +func TestNewMCPClient(t *testing.T) { + client := newMCPClient() + require.NotNil(t, client, "newMCPClient should return a non-nil client") +} + +// TestCreateJSONRPCRequest tests the createJSONRPCRequest helper function +func TestCreateJSONRPCRequest(t *testing.T) { + tests := []struct { + name string + requestID uint64 + method string + params interface{} + }{ + { + name: "simple request with nil params", + requestID: 1, + method: "initialize", + params: nil, + }, + { + name: "request with map params", + requestID: 42, + method: "tools/list", + params: map[string]interface{}{"filter": "test"}, + }, + { + name: "request with struct params", + requestID: 100, + method: "tools/call", + params: struct{ Name string }{Name: "test-tool"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + request := createJSONRPCRequest(tt.requestID, tt.method, tt.params) + + assert.Equal(t, "2.0", request["jsonrpc"], "jsonrpc version should be 2.0") + assert.Equal(t, tt.requestID, request["id"], "id should match requestID") + assert.Equal(t, tt.method, request["method"], "method should match") + assert.Equal(t, tt.params, request["params"], "params should match") + }) + } +} + +// TestSetupHTTPRequest tests the setupHTTPRequest helper function +func TestSetupHTTPRequest(t *testing.T) { + tests := []struct { + name string + url string + requestBody []byte + headers map[string]string + expectError bool + expectedMethod string + }{ + { + name: "basic request with no custom headers", + url: "http://example.com/mcp", + requestBody: []byte(`{"test": "data"}`), + headers: map[string]string{}, + expectError: false, + expectedMethod: "POST", + }, + { + name: "request with custom headers", + url: "http://example.com/mcp", + requestBody: []byte(`{"test": "data"}`), + headers: map[string]string{ + "Authorization": "Bearer token123", + "X-Custom": "value", + }, + expectError: false, + expectedMethod: "POST", + }, + { + name: "request with empty body", + url: "http://example.com/mcp", + requestBody: []byte{}, + headers: map[string]string{}, + expectError: false, + expectedMethod: "POST", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + req, err := setupHTTPRequest(ctx, tt.url, tt.requestBody, tt.headers) + + if tt.expectError { + assert.Error(t, err, "Expected error") + return + } + + require.NoError(t, err, "setupHTTPRequest should not return error") + require.NotNil(t, req, "Request should not be nil") + + // Verify method + assert.Equal(t, tt.expectedMethod, req.Method, "Method should be POST") + + // Verify URL + assert.Equal(t, tt.url, req.URL.String(), "URL should match") + + // Verify standard headers + assert.Equal(t, "application/json", req.Header.Get("Content-Type"), "Content-Type should be application/json") + assert.Equal(t, "application/json, text/event-stream", req.Header.Get("Accept"), "Accept header should be set") + + // Verify custom headers + for key, value := range tt.headers { + assert.Equal(t, value, req.Header.Get(key), "Custom header %s should match", key) + } + }) + } +} + +// TestNewHTTPConnection tests the newHTTPConnection helper function +func TestNewHTTPConnection(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := newMCPClient() + url := "http://example.com/mcp" + headers := map[string]string{"Authorization": "test"} + httpClient := &http.Client{} + + conn := newHTTPConnection(ctx, cancel, client, nil, url, headers, httpClient, HTTPTransportStreamable) + + require.NotNil(t, conn, "Connection should not be nil") + assert.Equal(t, client, conn.client, "Client should match") + assert.Equal(t, ctx, conn.ctx, "Context should match") + assert.NotNil(t, conn.cancel, "Cancel function should not be nil") + assert.True(t, conn.isHTTP, "isHTTP should be true") + assert.Equal(t, url, conn.httpURL, "URL should match") + assert.Equal(t, headers, conn.headers, "Headers should match") + assert.Equal(t, httpClient, conn.httpClient, "HTTP client should match") + assert.Equal(t, HTTPTransportStreamable, conn.httpTransportType, "Transport type should match") +}