Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions internal/mcp/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ func parseSSEResponse(body []byte) ([]byte, error) {
return nil, fmt.Errorf("no data field found in SSE response")
}

// isConnectionError checks if an error indicates a connection failure
func isConnectionError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
return strings.Contains(errStr, "connection refused") ||
strings.Contains(errStr, "no such host") ||
strings.Contains(errStr, "network is unreachable")
}

// ContextKey for session ID
type ContextKey string

Expand Down Expand Up @@ -584,9 +595,7 @@ func (c *Connection) initializeHTTPSession() (string, error) {
httpResp, err := c.httpClient.Do(httpReq)
if err != nil {
// Check if it's a connection error (cannot connect at all)
if strings.Contains(err.Error(), "connection refused") ||
strings.Contains(err.Error(), "no such host") ||
strings.Contains(err.Error(), "network is unreachable") {
if isConnectionError(err) {
return "", fmt.Errorf("cannot connect to HTTP backend at %s: %w", c.httpURL, err)
}
return "", fmt.Errorf("failed to send initialize request to %s: %w", c.httpURL, err)
Expand Down Expand Up @@ -698,9 +707,7 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params
httpResp, err := c.httpClient.Do(httpReq)
if err != nil {
// Check if it's a connection error (cannot connect at all)
if strings.Contains(err.Error(), "connection refused") ||
strings.Contains(err.Error(), "no such host") ||
strings.Contains(err.Error(), "network is unreachable") {
if isConnectionError(err) {
return nil, fmt.Errorf("cannot connect to HTTP backend at %s: %w", c.httpURL, err)
}
return nil, fmt.Errorf("failed to send HTTP request to %s: %w", c.httpURL, err)
Expand Down
57 changes: 57 additions & 0 deletions internal/mcp/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -703,3 +703,60 @@ func TestNewHTTPConnection(t *testing.T) {
assert.Equal(t, httpClient, conn.httpClient, "HTTP client should match")
assert.Equal(t, HTTPTransportStreamable, conn.httpTransportType, "Transport type should match")
}

// TestIsConnectionError tests the isConnectionError helper function
func TestIsConnectionError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "connection refused error",
err: fmt.Errorf("dial tcp: connection refused"),
expected: true,
},
{
name: "no such host error",
err: fmt.Errorf("dial tcp: lookup example.com: no such host"),
expected: true,
},
{
name: "network is unreachable error",
err: fmt.Errorf("dial tcp: connect: network is unreachable"),
expected: true,
},
{
name: "connection refused in wrapped error",
err: fmt.Errorf("failed to connect: %w", fmt.Errorf("connection refused")),
expected: true,
},
{
name: "other error",
err: fmt.Errorf("some other error"),
expected: false,
},
{
name: "timeout error",
err: fmt.Errorf("context deadline exceeded"),
expected: false,
},
{
name: "EOF error",
err: fmt.Errorf("unexpected EOF"),
expected: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isConnectionError(tt.err)
assert.Equal(t, tt.expected, result, "isConnectionError(%v) should return %v", tt.err, tt.expected)
})
}
}
Loading