Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions pkg/tools/mcp/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,16 @@ type oauthTransport struct {
}

func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
var bodyBytes []byte
if req.Body != nil && req.Body != http.NoBody {
var err error
bodyBytes, err = io.ReadAll(req.Body)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @jba referred me here as I'm fixing this in the MCP client implementation.

I think you need to close the body after reading it, as RoundTripper says that RoundTrip must always close the body.

if err != nil {
return nil, err
}
req.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))
}

reqClone := req.Clone(req.Context())

if token, err := t.tokenStore.GetToken(t.baseURL); err == nil && !token.IsExpired() {
Expand All @@ -357,6 +367,10 @@ func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("OAuth flow failed: %w", err)
}

if len(bodyBytes) > 0 {
req.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))
}

return t.RoundTrip(req)
}
}
Expand Down
107 changes: 32 additions & 75 deletions pkg/tools/mcp/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ import (
"iter"
"log/slog"
"net/http"
"strings"
"sync"
"time"

"github.com/modelcontextprotocol/go-sdk/mcp"

Expand Down Expand Up @@ -86,76 +84,47 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *mcp.InitializeReque
// Create HTTP client with OAuth support
httpClient := c.createHTTPClient()

// Attempt MCP initialization with retry logic for OAuth-related failures.
// When a server requires OAuth, the first connection attempt may fail with a "broken session"
// error because OAuth flow (even successful flow) interrupts the MCP handshake. Let's retry once after OAuth completes.
// Example of such MCP Server that broke the session with OAuth flow: https://mcp.prisma.io/mcp
const maxAttempts = 2
var lastErr error
var transport mcp.Transport

for attempt := 1; attempt <= maxAttempts; attempt++ {
if attempt > 1 {
slog.Debug("Retrying MCP initialization after OAuth flow", "attempt", attempt)
switch c.transportType {
case "sse":
transport = &mcp.SSEClientTransport{
Endpoint: c.url,
HTTPClient: httpClient,
}

var transport mcp.Transport

switch c.transportType {
case "sse":
transport = &mcp.SSEClientTransport{
Endpoint: c.url,
HTTPClient: httpClient,
}
case "streamable", "streamable-http":
transport = &mcp.StreamableClientTransport{
Endpoint: c.url,
HTTPClient: httpClient,
}
default:
return nil, fmt.Errorf("unsupported transport type: %s", c.transportType)
}

// Create an MCP client with elicitation support
impl := &mcp.Implementation{
Name: "cagent",
Version: "1.0.0",
case "streamable", "streamable-http":
transport = &mcp.StreamableClientTransport{
Endpoint: c.url,
HTTPClient: httpClient,
}
default:
return nil, fmt.Errorf("unsupported transport type: %s", c.transportType)
}

opts := &mcp.ClientOptions{
ElicitationHandler: c.handleElicitationRequest,
}
// Create an MCP client with elicitation support
impl := &mcp.Implementation{
Name: "cagent",
Version: "1.0.0",
}

client := mcp.NewClient(impl, opts)

// Connect to the MCP server
session, err := client.Connect(ctx, transport, nil)
if err != nil {
lastErr = err

// Check if this is a "broken session" error that might be OAuth-related
if attempt < maxAttempts && isBrokenSessionError(err) {
slog.Debug("MCP connection failed with broken session error, retrying after OAuth", "error", err)
// Brief pause before retry to allow OAuth state to settle
select {
case <-ctx.Done():
return nil, fmt.Errorf("failed to connect to MCP server: %w", ctx.Err())
case <-time.After(100 * time.Millisecond):
}
continue
}

return nil, fmt.Errorf("failed to connect to MCP server: %w", err)
}
opts := &mcp.ClientOptions{
ElicitationHandler: c.handleElicitationRequest,
}

c.mu.Lock()
c.session = session
c.mu.Unlock()
client := mcp.NewClient(impl, opts)

slog.Debug("Remote MCP client connected successfully", "attempt", attempt)
return session.InitializeResult(), nil
// Connect to the MCP server
session, err := client.Connect(ctx, transport, nil)
if err != nil {
return nil, fmt.Errorf("failed to connect to MCP server: %w", err)
}

return nil, fmt.Errorf("failed to connect to MCP server after %d attempts: %w", maxAttempts, lastErr)
c.mu.Lock()
c.session = session
c.mu.Unlock()

slog.Debug("Remote MCP client connected successfully")
return session.InitializeResult(), nil
}

// createHTTPClient creates an HTTP client with OAuth support
Expand Down Expand Up @@ -225,15 +194,3 @@ func (c *remoteMCPClient) requestUserConsent(ctx context.Context) (bool, error)

return result.Action == "accept", nil
}

// isBrokenSessionError checks if an error is a "broken session" error from the MCP SDK
// This error typically occurs when OAuth interrupts the MCP session handshake
func isBrokenSessionError(err error) bool {
if err == nil {
return false
}
errMsg := strings.ToLower(err.Error())
// The error message comes from mcp-go/mcp/streamable.go:1211
// "broken session: 400 Bad Request"
return strings.Contains(errMsg, "broken session")
}
Loading