diff --git a/internal/mcp/connection.go b/internal/mcp/connection.go index 7f28ca3b..89aa1c38 100644 --- a/internal/mcp/connection.go +++ b/internal/mcp/connection.go @@ -636,11 +636,6 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params logConn.Printf("Received HTTP response: status=%d, body_len=%d", httpResp.StatusCode, len(responseBody)) - // Check for HTTP errors - if httpResp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("HTTP error: status=%d, body=%s", httpResp.StatusCode, string(responseBody)) - } - // Parse JSON-RPC response // The response might be in SSE format (event: message\ndata: {...}) // Try to parse as JSON first, if that fails, try SSE format @@ -651,6 +646,19 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params logConn.Printf("Initial JSON parse failed, attempting SSE format parsing") sseData, sseErr := parseSSEResponse(responseBody) if sseErr != nil { + // If we have a non-OK HTTP status and can't parse the response, + // construct a JSON-RPC error response with HTTP error details + if httpResp.StatusCode != http.StatusOK { + logConn.Printf("HTTP error status=%d, body cannot be parsed as JSON-RPC", httpResp.StatusCode) + return &Response{ + JSONRPC: "2.0", + Error: &ResponseError{ + Code: -32603, // Internal error + Message: fmt.Sprintf("HTTP %d: %s", httpResp.StatusCode, http.StatusText(httpResp.StatusCode)), + Data: json.RawMessage(responseBody), + }, + }, nil + } // Include the response body to help debug what the server actually returned bodyPreview := string(responseBody) if len(bodyPreview) > 500 { @@ -661,11 +669,39 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params // Successfully extracted JSON from SSE, now parse it if err := json.Unmarshal(sseData, &rpcResponse); err != nil { + // If we have a non-OK HTTP status and can't parse the SSE data, + // construct a JSON-RPC error response with HTTP error details + if httpResp.StatusCode != http.StatusOK { + logConn.Printf("HTTP error status=%d, SSE data cannot be parsed as JSON-RPC", httpResp.StatusCode) + return &Response{ + JSONRPC: "2.0", + Error: &ResponseError{ + Code: -32603, // Internal error + Message: fmt.Sprintf("HTTP %d: %s", httpResp.StatusCode, http.StatusText(httpResp.StatusCode)), + Data: json.RawMessage(responseBody), + }, + }, nil + } return nil, fmt.Errorf("failed to parse JSON data extracted from SSE response: %w\nJSON data: %s", err, string(sseData)) } logConn.Printf("Successfully parsed SSE-formatted response") } + // Check for HTTP errors after parsing + // If we have a non-OK status but successfully parsed a JSON-RPC response, + // pass it through (it may already contain an error field) + if httpResp.StatusCode != http.StatusOK { + logConn.Printf("HTTP error status=%d with valid JSON-RPC response, passing through", httpResp.StatusCode) + // If the response doesn't already have an error, construct one + if rpcResponse.Error == nil { + rpcResponse.Error = &ResponseError{ + Code: -32603, // Internal error + Message: fmt.Sprintf("HTTP %d: %s", httpResp.StatusCode, http.StatusText(httpResp.StatusCode)), + Data: responseBody, + } + } + } + return &rpcResponse, nil } diff --git a/internal/mcp/http_error_propagation_test.go b/internal/mcp/http_error_propagation_test.go new file mode 100644 index 00000000..4d9c6c9a --- /dev/null +++ b/internal/mcp/http_error_propagation_test.go @@ -0,0 +1,375 @@ +package mcp + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestHTTPErrorPropagation_Non200Status tests that non-200 HTTP status codes +// are properly converted to JSON-RPC error responses +func TestHTTPErrorPropagation_Non200Status(t *testing.T) { + tests := []struct { + name string + statusCode int + responseBody map[string]interface{} + expectErrorMsg string // Expected substring in error message + }{ + { + name: "HTTP 400 Bad Request", + statusCode: http.StatusBadRequest, + responseBody: map[string]interface{}{ + "error": "Invalid parameters", + }, + expectErrorMsg: "400", + }, + { + name: "HTTP 401 Unauthorized", + statusCode: http.StatusUnauthorized, + responseBody: map[string]interface{}{ + "error": "Authentication required", + }, + expectErrorMsg: "401", + }, + { + name: "HTTP 403 Forbidden", + statusCode: http.StatusForbidden, + responseBody: map[string]interface{}{ + "error": "Access denied", + }, + expectErrorMsg: "403", + }, + { + name: "HTTP 404 Not Found", + statusCode: http.StatusNotFound, + responseBody: map[string]interface{}{ + "error": "Resource not found", + }, + expectErrorMsg: "404", + }, + { + name: "HTTP 500 Internal Server Error", + statusCode: http.StatusInternalServerError, + responseBody: map[string]interface{}{ + "error": "Database connection failed", + }, + expectErrorMsg: "500", + }, + { + name: "HTTP 502 Bad Gateway", + statusCode: http.StatusBadGateway, + responseBody: map[string]interface{}{ + "error": "Upstream service unavailable", + }, + expectErrorMsg: "502", + }, + { + name: "HTTP 503 Service Unavailable", + statusCode: http.StatusServiceUnavailable, + responseBody: map[string]interface{}{ + "error": "Service temporarily down", + }, + expectErrorMsg: "503", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + requestCount := 0 + + // Create test server that succeeds on initialize but fails on subsequent requests + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + method, _ := reqBody["method"].(string) + + if method == "initialize" { + // Initialize succeeds + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": reqBody["id"], + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "serverInfo": map[string]interface{}{ + "name": "test-server", + "version": "1.0.0", + }, + }, + }) + return + } + + // Subsequent requests return the configured error status + w.WriteHeader(tt.statusCode) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(tt.responseBody) + })) + defer testServer.Close() + + // Create connection with custom headers to use plain JSON transport + conn, err := NewHTTPConnection(context.Background(), testServer.URL, map[string]string{ + "Authorization": "test-token", + }) + require.NoError(t, err, "Failed to create connection") + defer conn.Close() + + // Send request that will trigger error response + resp, err := conn.SendRequestWithServerID(context.Background(), "tools/list", nil, "test-server") + require.NoError(t, err, "SendRequestWithServerID should not return Go error") + require.NotNil(t, resp, "Response should not be nil") + + // Verify the response contains an error field + require.NotNil(t, resp.Error, "Response should contain error field") + assert.Equal(t, -32603, resp.Error.Code, "Error code should be -32603 (Internal error)") + assert.Contains(t, resp.Error.Message, tt.expectErrorMsg, + "Error message should contain HTTP status code") + + // Verify error data contains original response body + if resp.Error.Data != nil { + var errorData interface{} + err := json.Unmarshal(resp.Error.Data, &errorData) + require.NoError(t, err, "Error data should be valid JSON") + } + }) + } +} + +// TestHTTPErrorPropagation_JSONRPCError tests that HTTP 200 responses with +// JSON-RPC error field are properly returned +func TestHTTPErrorPropagation_JSONRPCError(t *testing.T) { + requestCount := 0 + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + method, _ := reqBody["method"].(string) + + if method == "initialize" { + // Initialize succeeds + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": reqBody["id"], + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "serverInfo": map[string]interface{}{ + "name": "test-server", + "version": "1.0.0", + }, + }, + }) + return + } + + // Return JSON-RPC error with HTTP 200 + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": reqBody["id"], + "error": map[string]interface{}{ + "code": -32600, + "message": "Invalid Request", + "data": "Tool not found", + }, + }) + })) + defer testServer.Close() + + conn, err := NewHTTPConnection(context.Background(), testServer.URL, map[string]string{ + "Authorization": "test-token", + }) + require.NoError(t, err, "Failed to create connection") + defer conn.Close() + + // Send request + resp, err := conn.SendRequestWithServerID(context.Background(), "tools/call", + map[string]interface{}{"name": "unknown"}, "test-server") + require.NoError(t, err, "SendRequestWithServerID should not return Go error") + require.NotNil(t, resp, "Response should not be nil") + + // Verify the response contains the JSON-RPC error + require.NotNil(t, resp.Error, "Response should contain error field") + assert.Equal(t, -32600, resp.Error.Code, "Error code should match backend error") + assert.Equal(t, "Invalid Request", resp.Error.Message, "Error message should match backend error") +} + +// TestHTTPErrorPropagation_MixedContent tests error responses with mixed content types +func TestHTTPErrorPropagation_MixedContent(t *testing.T) { + tests := []struct { + name string + statusCode int + responseBody string // Raw response body + contentType string + }{ + { + name: "Plain text error", + statusCode: http.StatusInternalServerError, + responseBody: "Internal Server Error", + contentType: "text/plain", + }, + { + name: "HTML error page", + statusCode: http.StatusNotFound, + responseBody: "404 Not Found", + contentType: "text/html", + }, + { + name: "JSON error without JSON-RPC structure", + statusCode: http.StatusBadRequest, + responseBody: `{"message": "Bad request", "code": "INVALID_PARAMS"}`, + contentType: "application/json", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + requestCount := 0 + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + method, _ := reqBody["method"].(string) + + if method == "initialize" { + // Initialize succeeds + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": reqBody["id"], + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "serverInfo": map[string]interface{}{ + "name": "test-server", + "version": "1.0.0", + }, + }, + }) + return + } + + // Return error with specified content type + w.WriteHeader(tt.statusCode) + w.Header().Set("Content-Type", tt.contentType) + w.Write([]byte(tt.responseBody)) + })) + defer testServer.Close() + + conn, err := NewHTTPConnection(context.Background(), testServer.URL, map[string]string{ + "Authorization": "test-token", + }) + require.NoError(t, err, "Failed to create connection") + defer conn.Close() + + // Send request + resp, err := conn.SendRequestWithServerID(context.Background(), "tools/list", nil, "test-server") + require.NoError(t, err, "SendRequestWithServerID should not return Go error") + require.NotNil(t, resp, "Response should not be nil") + + // Verify the response contains an error field + require.NotNil(t, resp.Error, "Response should contain error field") + assert.Equal(t, -32603, resp.Error.Code, "Error code should be -32603") + + // Verify error data contains original response body + if resp.Error.Data != nil { + data := string(resp.Error.Data) + assert.Contains(t, data, tt.responseBody, "Error data should contain original response") + } + }) + } +} + +// TestHTTPErrorPropagation_PreservesDetails tests that error details are preserved +func TestHTTPErrorPropagation_PreservesDetails(t *testing.T) { + requestCount := 0 + originalError := map[string]interface{}{ + "type": "authentication_error", + "message": "API key is invalid or expired", + "details": map[string]interface{}{ + "apiKeyPrefix": "sk-test-****", + "expiresAt": "2026-01-01T00:00:00Z", + }, + } + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + method, _ := reqBody["method"].(string) + + if method == "initialize" { + // Initialize succeeds + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": reqBody["id"], + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "serverInfo": map[string]interface{}{ + "name": "test-server", + "version": "1.0.0", + }, + }, + }) + return + } + + // Return detailed error + w.WriteHeader(http.StatusUnauthorized) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(originalError) + })) + defer testServer.Close() + + conn, err := NewHTTPConnection(context.Background(), testServer.URL, map[string]string{ + "Authorization": "test-token", + }) + require.NoError(t, err, "Failed to create connection") + defer conn.Close() + + // Send request + resp, err := conn.SendRequestWithServerID(context.Background(), "tools/list", nil, "test-server") + require.NoError(t, err, "SendRequestWithServerID should not return Go error") + require.NotNil(t, resp, "Response should not be nil") + + // Verify error is present + require.NotNil(t, resp.Error, "Response should contain error field") + assert.Contains(t, resp.Error.Message, "401", "Error message should contain status code") + + // Verify error data contains original error details + require.NotNil(t, resp.Error.Data, "Error data should not be nil") + + var errorData map[string]interface{} + err = json.Unmarshal(resp.Error.Data, &errorData) + require.NoError(t, err, "Error data should be valid JSON") + + // Verify original error details are preserved + assert.Equal(t, originalError["type"], errorData["type"], "Error type should be preserved") + assert.Equal(t, originalError["message"], errorData["message"], "Error message should be preserved") + + details, ok := errorData["details"].(map[string]interface{}) + require.True(t, ok, "Error details should be preserved") + assert.NotNil(t, details["apiKeyPrefix"], "API key prefix should be preserved") + assert.NotNil(t, details["expiresAt"], "Expiration time should be preserved") +} diff --git a/internal/server/http_error_propagation_test.go b/internal/server/http_error_propagation_test.go new file mode 100644 index 00000000..f849806c --- /dev/null +++ b/internal/server/http_error_propagation_test.go @@ -0,0 +1,447 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/githubnext/gh-aw-mcpg/internal/config" + "github.com/githubnext/gh-aw-mcpg/internal/launcher" +) + +// TestUnifiedServer_HTTPErrorPropagation tests that HTTP backend errors are +// properly propagated through the unified server to clients +func TestUnifiedServer_HTTPErrorPropagation(t *testing.T) { + tests := []struct { + name string + backendStatus int + backendBody map[string]interface{} + expectHasError bool + expectErrorMsg string // Expected substring in error message + }{ + { + name: "HTTP 500 from backend", + backendStatus: http.StatusInternalServerError, + backendBody: map[string]interface{}{ + "error": "Database connection failed", + }, + expectHasError: true, + expectErrorMsg: "500", + }, + { + name: "HTTP 503 from backend", + backendStatus: http.StatusServiceUnavailable, + backendBody: map[string]interface{}{ + "error": "Service temporarily unavailable", + }, + expectHasError: true, + expectErrorMsg: "503", + }, + { + name: "HTTP 401 from backend", + backendStatus: http.StatusUnauthorized, + backendBody: map[string]interface{}{ + "error": "Invalid credentials", + }, + expectHasError: true, + expectErrorMsg: "401", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + requestCount := 0 + + // Create mock HTTP backend + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + method, _ := reqBody["method"].(string) + + if method == "initialize" { + // Initialize succeeds + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": reqBody["id"], + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "serverInfo": map[string]interface{}{ + "name": "test-backend", + "version": "1.0.0", + }, + }, + }) + return + } else if method == "tools/list" && requestCount == 2 { + // First tools/list (during initialization) succeeds + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": reqBody["id"], + "result": map[string]interface{}{ + "tools": []map[string]interface{}{ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + }, + }, + }, + }, + }) + return + } + + // Subsequent requests return error + w.WriteHeader(tt.backendStatus) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(tt.backendBody) + })) + defer backendServer.Close() + + // Create gateway config + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{ + "test-backend": { + Type: "http", + URL: backendServer.URL, + }, + }, + } + + // Create launcher and unified server + ctx := context.Background() + l := launcher.New(ctx, cfg) + us, err := NewUnified(ctx, cfg) + require.NoError(t, err, "Failed to create unified server") + defer us.Close() + + // Get connection + conn, err := launcher.GetOrLaunch(l, "test-backend") + require.NoError(t, err, "Failed to get connection") + + // Make request that triggers error + resp, err := conn.SendRequestWithServerID(ctx, "tools/call", + map[string]interface{}{ + "name": "test_tool", + "arguments": map[string]interface{}{}, + }, "test-backend") + + // With the fix, we should get a response with an error field, not a Go error + require.NoError(t, err, "Should not return Go error") + require.NotNil(t, resp, "Response should not be nil") + + if tt.expectHasError { + require.NotNil(t, resp.Error, "Response should contain error field") + assert.Contains(t, resp.Error.Message, tt.expectErrorMsg, + "Error message should contain expected substring") + } + }) + } +} + +// TestUnifiedServer_ChecksErrorBeforeUnmarshal tests that the unified server +// checks for error field before attempting to unmarshal result +func TestUnifiedServer_ChecksErrorBeforeUnmarshal(t *testing.T) { + requestCount := 0 + + // Create mock HTTP backend that returns error-only responses + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + method, _ := reqBody["method"].(string) + + if method == "initialize" { + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": reqBody["id"], + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "serverInfo": map[string]interface{}{ + "name": "test-backend", + "version": "1.0.0", + }, + }, + }) + return + } else if method == "tools/list" && requestCount == 2 { + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": reqBody["id"], + "result": map[string]interface{}{ + "tools": []map[string]interface{}{ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + }, + }, + }, + }, + }) + return + } + + // Return error response without result field + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": reqBody["id"], + "error": map[string]interface{}{ + "code": -32600, + "message": "Invalid request", + }, + }) + })) + defer backendServer.Close() + + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{ + "test-backend": { + Type: "http", + URL: backendServer.URL, + }, + }, + } + + ctx := context.Background() + us, err := NewUnified(ctx, cfg) + require.NoError(t, err, "Failed to create unified server") + defer us.Close() + + conn, err := launcher.GetOrLaunch(us.launcher, "test-backend") + require.NoError(t, err, "Failed to get connection") + + // Make request that returns error-only response + resp, err := conn.SendRequestWithServerID(ctx, "tools/call", + map[string]interface{}{ + "name": "test_tool", + "arguments": map[string]interface{}{}, + }, "test-backend") + + // Should get response without crashing + require.NoError(t, err, "Should not return Go error") + require.NotNil(t, resp, "Response should not be nil") + require.NotNil(t, resp.Error, "Response should contain error field") +} + +// TestProxyToServer_HTTPErrorPropagation tests that the legacy proxy function +// properly handles HTTP backend errors +func TestProxyToServer_HTTPErrorPropagation(t *testing.T) { + requestCount := 0 + + // Create mock HTTP backend + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + method, _ := reqBody["method"].(string) + + if method == "initialize" { + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": reqBody["id"], + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "serverInfo": map[string]interface{}{ + "name": "test-backend", + "version": "1.0.0", + }, + }, + }) + return + } + + // Return HTTP 503 error + w.WriteHeader(http.StatusServiceUnavailable) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "error": "Service unavailable", + }) + })) + defer backendServer.Close() + + // Create gateway config + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{ + "test-backend": { + Type: "http", + URL: backendServer.URL, + }, + }, + } + + // Create launcher + ctx := context.Background() + l := launcher.New(ctx, cfg) + + // Create legacy server + s := New(ctx, l, "routed") + + // Create test request + jsonReq := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list", + } + reqBody, _ := json.Marshal(jsonReq) + + // Create HTTP request + req := httptest.NewRequest("POST", "/mcp/test-backend", bytes.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + // Handle request + s.mux.ServeHTTP(w, req) + + // Check response + resp := w.Result() + defer resp.Body.Close() + + var jsonResp map[string]interface{} + err := json.NewDecoder(resp.Body).Decode(&jsonResp) + require.NoError(t, err, "Failed to decode response") + + // With the fix, the response should contain error details, not generic "Internal error" + if errorField, ok := jsonResp["error"].(map[string]interface{}); ok { + message, _ := errorField["message"].(string) + // The error message should contain HTTP status information + assert.Contains(t, message, "503", "Error message should mention HTTP 503 status") + } else { + t.Error("Response should contain error field") + } +} + +// TestHTTPBackendError_DataPreservation tests that error data is preserved +// when propagating errors from HTTP backends +func TestHTTPBackendError_DataPreservation(t *testing.T) { + originalErrorData := map[string]interface{}{ + "type": "rate_limit_error", + "message": "Rate limit exceeded", + "retry_after": 60, + "limit": 100, + "remaining": 0, + } + + requestCount := 0 + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + method, _ := reqBody["method"].(string) + + if method == "initialize" { + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": reqBody["id"], + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "serverInfo": map[string]interface{}{ + "name": "test-backend", + "version": "1.0.0", + }, + }, + }) + return + } else if method == "tools/list" && requestCount == 2 { + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": reqBody["id"], + "result": map[string]interface{}{ + "tools": []map[string]interface{}{ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + }, + }, + }, + }, + }) + return + } + + // Return detailed error + w.WriteHeader(http.StatusTooManyRequests) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(originalErrorData) + })) + defer backendServer.Close() + + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{ + "test-backend": { + Type: "http", + URL: backendServer.URL, + }, + }, + } + + ctx := context.Background() + us, err := NewUnified(ctx, cfg) + require.NoError(t, err, "Failed to create unified server") + defer us.Close() + + conn, err := launcher.GetOrLaunch(us.launcher, "test-backend") + require.NoError(t, err, "Failed to get connection") + + resp, err := conn.SendRequestWithServerID(ctx, "tools/call", + map[string]interface{}{ + "name": "test_tool", + "arguments": map[string]interface{}{}, + }, "test-backend") + + require.NoError(t, err, "Should not return Go error") + require.NotNil(t, resp, "Response should not be nil") + require.NotNil(t, resp.Error, "Response should contain error field") + + // Verify error data is preserved + require.NotNil(t, resp.Error.Data, "Error data should not be nil") + + var errorData map[string]interface{} + err = json.Unmarshal(resp.Error.Data, &errorData) + require.NoError(t, err, "Error data should be valid JSON") + + // Verify all original error fields are preserved + assert.Equal(t, originalErrorData["type"], errorData["type"], "Error type should be preserved") + assert.Equal(t, originalErrorData["message"], errorData["message"], "Error message should be preserved") + assert.Equal(t, float64(60), errorData["retry_after"], "Retry-after should be preserved") + assert.Equal(t, float64(100), errorData["limit"], "Limit should be preserved") + assert.Equal(t, float64(0), errorData["remaining"], "Remaining should be preserved") +} diff --git a/internal/server/unified.go b/internal/server/unified.go index a018a8f0..1cc8eea8 100644 --- a/internal/server/unified.go +++ b/internal/server/unified.go @@ -189,6 +189,11 @@ func (us *UnifiedServer) registerToolsFromBackend(serverID string) error { return fmt.Errorf("failed to list tools: %w", err) } + // Check if the backend returned an error + if result.Error != nil { + return fmt.Errorf("backend error listing tools: code=%d, message=%s", result.Error.Code, result.Error.Message) + } + // Parse the result var listResult struct { Tools []struct { @@ -441,6 +446,11 @@ func (g *guardBackendCaller) CallTool(ctx context.Context, toolName string, args return nil, err } + // Check if the backend returned an error + if response.Error != nil { + return nil, fmt.Errorf("backend error: code=%d, message=%s", response.Error.Code, response.Error.Message) + } + // Parse the result var result interface{} if err := json.Unmarshal(response.Result, &result); err != nil { @@ -516,6 +526,11 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName return &sdk.CallToolResult{IsError: true}, nil, err } + // Check if the backend returned an error + if response.Error != nil { + return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("backend error: code=%d, message=%s", response.Error.Code, response.Error.Message) + } + // Parse the backend result var backendResult interface{} if err := json.Unmarshal(response.Result, &backendResult); err != nil { diff --git a/test/integration/http_error_test.go b/test/integration/http_error_test.go index dfee8229..0b751b3b 100644 --- a/test/integration/http_error_test.go +++ b/test/integration/http_error_test.go @@ -377,12 +377,17 @@ func TestHTTPError_RequestFailure(t *testing.T) { t.Log("✓ Connection established successfully") // Try to make a request that should fail - _, err = conn.SendRequest("tools/list", nil) - require.NotNil(t, err, "Expected request to fail, but it succeeded") + resp, err := conn.SendRequest("tools/list", nil) + require.NoError(t, err, "Unexpected error making request") + require.NotNil(t, resp, "Expected response") - // Verify error is properly propagated - if err != nil { - t.Logf("✓ Request failure error properly propagated: %v", err) + // Verify the response contains an error + require.NotNil(t, resp.Error, "Expected response to contain an error field") + require.Contains(t, resp.Error.Message, "503", "Expected error to mention HTTP 503 status") + + // Verify error is properly propagated in the response + if resp.Error != nil { + t.Logf("✓ Request failure error properly propagated in response: code=%d, message=%s", resp.Error.Code, resp.Error.Message) } }