diff --git a/internal/mcp/connection.go b/internal/mcp/connection.go index 6ab009b..194f9f0 100644 --- a/internal/mcp/connection.go +++ b/internal/mcp/connection.go @@ -37,6 +37,17 @@ func parseSSEResponse(body []byte) ([]byte, error) { return nil, fmt.Errorf("no data field found in SSE response") } +// isConnectionError checks if an error indicates a connection failure +func isConnectionError(err error) bool { + if err == nil { + return false + } + errStr := err.Error() + return strings.Contains(errStr, "connection refused") || + strings.Contains(errStr, "no such host") || + strings.Contains(errStr, "network is unreachable") +} + // ContextKey for session ID type ContextKey string @@ -584,9 +595,7 @@ func (c *Connection) initializeHTTPSession() (string, error) { httpResp, err := c.httpClient.Do(httpReq) if err != nil { // Check if it's a connection error (cannot connect at all) - if strings.Contains(err.Error(), "connection refused") || - strings.Contains(err.Error(), "no such host") || - strings.Contains(err.Error(), "network is unreachable") { + if isConnectionError(err) { return "", fmt.Errorf("cannot connect to HTTP backend at %s: %w", c.httpURL, err) } return "", fmt.Errorf("failed to send initialize request to %s: %w", c.httpURL, err) @@ -698,9 +707,7 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params httpResp, err := c.httpClient.Do(httpReq) if err != nil { // Check if it's a connection error (cannot connect at all) - if strings.Contains(err.Error(), "connection refused") || - strings.Contains(err.Error(), "no such host") || - strings.Contains(err.Error(), "network is unreachable") { + if isConnectionError(err) { return nil, fmt.Errorf("cannot connect to HTTP backend at %s: %w", c.httpURL, err) } return nil, fmt.Errorf("failed to send HTTP request to %s: %w", c.httpURL, err) diff --git a/internal/mcp/connection_test.go b/internal/mcp/connection_test.go index 9ee392c..5e61618 100644 --- a/internal/mcp/connection_test.go +++ b/internal/mcp/connection_test.go @@ -703,3 +703,60 @@ func TestNewHTTPConnection(t *testing.T) { assert.Equal(t, httpClient, conn.httpClient, "HTTP client should match") assert.Equal(t, HTTPTransportStreamable, conn.httpTransportType, "Transport type should match") } + +// TestIsConnectionError tests the isConnectionError helper function +func TestIsConnectionError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "connection refused error", + err: fmt.Errorf("dial tcp: connection refused"), + expected: true, + }, + { + name: "no such host error", + err: fmt.Errorf("dial tcp: lookup example.com: no such host"), + expected: true, + }, + { + name: "network is unreachable error", + err: fmt.Errorf("dial tcp: connect: network is unreachable"), + expected: true, + }, + { + name: "connection refused in wrapped error", + err: fmt.Errorf("failed to connect: %w", fmt.Errorf("connection refused")), + expected: true, + }, + { + name: "other error", + err: fmt.Errorf("some other error"), + expected: false, + }, + { + name: "timeout error", + err: fmt.Errorf("context deadline exceeded"), + expected: false, + }, + { + name: "EOF error", + err: fmt.Errorf("unexpected EOF"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isConnectionError(tt.err) + assert.Equal(t, tt.expected, result, "isConnectionError(%v) should return %v", tt.err, tt.expected) + }) + } +}