diff --git a/pkg/tools/mcp/oauth.go b/pkg/tools/mcp/oauth.go index ce3b5976b..0c0041c72 100644 --- a/pkg/tools/mcp/oauth.go +++ b/pkg/tools/mcp/oauth.go @@ -336,6 +336,16 @@ type oauthTransport struct { } func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) { + var bodyBytes []byte + if req.Body != nil && req.Body != http.NoBody { + var err error + bodyBytes, err = io.ReadAll(req.Body) + if err != nil { + return nil, err + } + req.Body = io.NopCloser(strings.NewReader(string(bodyBytes))) + } + reqClone := req.Clone(req.Context()) if token, err := t.tokenStore.GetToken(t.baseURL); err == nil && !token.IsExpired() { @@ -357,6 +367,10 @@ func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) { return nil, fmt.Errorf("OAuth flow failed: %w", err) } + if len(bodyBytes) > 0 { + req.Body = io.NopCloser(strings.NewReader(string(bodyBytes))) + } + return t.RoundTrip(req) } } diff --git a/pkg/tools/mcp/remote.go b/pkg/tools/mcp/remote.go index f7780639c..ce03f00c5 100644 --- a/pkg/tools/mcp/remote.go +++ b/pkg/tools/mcp/remote.go @@ -6,9 +6,7 @@ import ( "iter" "log/slog" "net/http" - "strings" "sync" - "time" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -86,76 +84,47 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *mcp.InitializeReque // Create HTTP client with OAuth support httpClient := c.createHTTPClient() - // Attempt MCP initialization with retry logic for OAuth-related failures. - // When a server requires OAuth, the first connection attempt may fail with a "broken session" - // error because OAuth flow (even successful flow) interrupts the MCP handshake. Let's retry once after OAuth completes. - // Example of such MCP Server that broke the session with OAuth flow: https://mcp.prisma.io/mcp - const maxAttempts = 2 - var lastErr error + var transport mcp.Transport - for attempt := 1; attempt <= maxAttempts; attempt++ { - if attempt > 1 { - slog.Debug("Retrying MCP initialization after OAuth flow", "attempt", attempt) + switch c.transportType { + case "sse": + transport = &mcp.SSEClientTransport{ + Endpoint: c.url, + HTTPClient: httpClient, } - - var transport mcp.Transport - - switch c.transportType { - case "sse": - transport = &mcp.SSEClientTransport{ - Endpoint: c.url, - HTTPClient: httpClient, - } - case "streamable", "streamable-http": - transport = &mcp.StreamableClientTransport{ - Endpoint: c.url, - HTTPClient: httpClient, - } - default: - return nil, fmt.Errorf("unsupported transport type: %s", c.transportType) - } - - // Create an MCP client with elicitation support - impl := &mcp.Implementation{ - Name: "cagent", - Version: "1.0.0", + case "streamable", "streamable-http": + transport = &mcp.StreamableClientTransport{ + Endpoint: c.url, + HTTPClient: httpClient, } + default: + return nil, fmt.Errorf("unsupported transport type: %s", c.transportType) + } - opts := &mcp.ClientOptions{ - ElicitationHandler: c.handleElicitationRequest, - } + // Create an MCP client with elicitation support + impl := &mcp.Implementation{ + Name: "cagent", + Version: "1.0.0", + } - client := mcp.NewClient(impl, opts) - - // Connect to the MCP server - session, err := client.Connect(ctx, transport, nil) - if err != nil { - lastErr = err - - // Check if this is a "broken session" error that might be OAuth-related - if attempt < maxAttempts && isBrokenSessionError(err) { - slog.Debug("MCP connection failed with broken session error, retrying after OAuth", "error", err) - // Brief pause before retry to allow OAuth state to settle - select { - case <-ctx.Done(): - return nil, fmt.Errorf("failed to connect to MCP server: %w", ctx.Err()) - case <-time.After(100 * time.Millisecond): - } - continue - } - - return nil, fmt.Errorf("failed to connect to MCP server: %w", err) - } + opts := &mcp.ClientOptions{ + ElicitationHandler: c.handleElicitationRequest, + } - c.mu.Lock() - c.session = session - c.mu.Unlock() + client := mcp.NewClient(impl, opts) - slog.Debug("Remote MCP client connected successfully", "attempt", attempt) - return session.InitializeResult(), nil + // Connect to the MCP server + session, err := client.Connect(ctx, transport, nil) + if err != nil { + return nil, fmt.Errorf("failed to connect to MCP server: %w", err) } - return nil, fmt.Errorf("failed to connect to MCP server after %d attempts: %w", maxAttempts, lastErr) + c.mu.Lock() + c.session = session + c.mu.Unlock() + + slog.Debug("Remote MCP client connected successfully") + return session.InitializeResult(), nil } // createHTTPClient creates an HTTP client with OAuth support @@ -225,15 +194,3 @@ func (c *remoteMCPClient) requestUserConsent(ctx context.Context) (bool, error) return result.Action == "accept", nil } - -// isBrokenSessionError checks if an error is a "broken session" error from the MCP SDK -// This error typically occurs when OAuth interrupts the MCP session handshake -func isBrokenSessionError(err error) bool { - if err == nil { - return false - } - errMsg := strings.ToLower(err.Error()) - // The error message comes from mcp-go/mcp/streamable.go:1211 - // "broken session: 400 Bad Request" - return strings.Contains(errMsg, "broken session") -}