Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 225 additions & 46 deletions internal/mcp/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,31 @@ const SessionIDContextKey ContextKey = "awmg-session-id"
// requestIDCounter is used to generate unique request IDs for HTTP requests
var requestIDCounter uint64

// HTTPTransportType represents the type of HTTP transport being used
type HTTPTransportType string

const (
// HTTPTransportStreamable uses the streamable HTTP transport (2025-03-26 spec)
HTTPTransportStreamable HTTPTransportType = "streamable"
// HTTPTransportSSE uses the SSE transport (2024-11-05 spec)
HTTPTransportSSE HTTPTransportType = "sse"
// HTTPTransportPlainJSON uses plain JSON-RPC 2.0 over HTTP POST (non-standard)
HTTPTransportPlainJSON HTTPTransportType = "plain-json"
)

// Connection represents a connection to an MCP server using the official SDK
type Connection struct {
client *sdk.Client
session *sdk.ClientSession
ctx context.Context
cancel context.CancelFunc
// HTTP-specific fields
isHTTP bool
httpURL string
headers map[string]string
httpClient *http.Client
httpSessionID string // Session ID returned by the HTTP backend
isHTTP bool
httpURL string
headers map[string]string
httpClient *http.Client
httpSessionID string // Session ID returned by the HTTP backend
httpTransportType HTTPTransportType // Type of HTTP transport in use
}

// NewConnection creates a new MCP connection using the official SDK
Expand Down Expand Up @@ -126,20 +139,20 @@ func NewConnection(ctx context.Context, command string, args []string, env map[s
return conn, nil
}

// NewHTTPConnection creates a new HTTP-based MCP connection
// NewHTTPConnection creates a new HTTP-based MCP connection with transport fallback
// For HTTP servers that are already running, we connect and initialize a session
//
// NOTE: This currently implements the MCP HTTP protocol manually instead of using
// the SDK's SSEClientTransport. This is because:
// 1. SSEClientTransport requires Server-Sent Events (SSE) format
// 2. Some HTTP MCP servers (like safeinputs) use plain JSON-RPC over HTTP POST
// 3. The MCP spec allows both SSE and plain HTTP transports
// This function implements a fallback strategy for HTTP transports:
// 1. If custom headers are provided, skip SDK transports (they don't support custom headers)
// and use plain JSON-RPC 2.0 over HTTP POST (for safeinputs compatibility)
// 2. Otherwise, try standard transports:
// a. Streamable HTTP (2025-03-26 spec) using SDK's StreamableClientTransport
// b. SSE (2024-11-05 spec) using SDK's SSEClientTransport
// c. Plain JSON-RPC 2.0 over HTTP POST as final fallback
//
// TODO: Migrate to sdk.SSEClientTransport once we confirm all target HTTP backends
// support the SSE format, or extend the SDK to support both transport formats.
// This would eliminate manual JSON-RPC handling and improve maintainability.
// This ensures compatibility with all types of HTTP MCP servers.
func NewHTTPConnection(ctx context.Context, url string, headers map[string]string) (*Connection, error) {
logger.LogInfo("backend", "Creating HTTP MCP connection, url=%s", url)
logger.LogInfo("backend", "Creating HTTP MCP connection with transport fallback, url=%s", url)
logConn.Printf("Creating HTTP MCP connection: url=%s", url)
ctx, cancel := context.WithCancel(ctx)

Expand All @@ -153,29 +166,158 @@ func NewHTTPConnection(ctx context.Context, url string, headers map[string]strin
},
}

// If custom headers are provided, skip SDK transports as they don't support headers
// This is typical for backends like safeinputs that require authentication
if len(headers) > 0 {
logConn.Printf("Custom headers detected, using plain JSON-RPC transport for %s", url)
conn, err := tryPlainJSONTransport(ctx, cancel, url, headers, httpClient)
if err == nil {
logger.LogInfo("backend", "Successfully connected using plain JSON-RPC transport, url=%s", url)
log.Printf("Configured HTTP MCP server with plain JSON-RPC transport: %s", url)
return conn, nil
}
cancel()
logger.LogError("backend", "Plain JSON-RPC transport failed for url=%s, error=%v", url, err)
return nil, fmt.Errorf("failed to connect with plain JSON-RPC transport: %w", err)
}

// Try standard transports in order: streamable HTTP → SSE → plain JSON-RPC

// Try 1: Streamable HTTP (2025-03-26 spec)
logConn.Printf("Attempting streamable HTTP transport for %s", url)
conn, err := tryStreamableHTTPTransport(ctx, cancel, url, headers, httpClient)
if err == nil {
logger.LogInfo("backend", "Successfully connected using streamable HTTP transport, url=%s", url)
log.Printf("Configured HTTP MCP server with streamable transport: %s", url)
return conn, nil
}
logConn.Printf("Streamable HTTP failed: %v", err)

// Try 2: SSE (2024-11-05 spec)
logConn.Printf("Attempting SSE transport for %s", url)
conn, err = trySSETransport(ctx, cancel, url, headers, httpClient)
if err == nil {
logger.LogInfo("backend", "Successfully connected using SSE transport, url=%s", url)
log.Printf("Configured HTTP MCP server with SSE transport: %s", url)
return conn, nil
}
logConn.Printf("SSE transport failed: %v", err)

// Try 3: Plain JSON-RPC over HTTP (non-standard, for fallback)
logConn.Printf("Attempting plain JSON-RPC transport for %s", url)
conn, err = tryPlainJSONTransport(ctx, cancel, url, headers, httpClient)
if err == nil {
logger.LogInfo("backend", "Successfully connected using plain JSON-RPC transport, url=%s", url)
log.Printf("Configured HTTP MCP server with plain JSON-RPC transport: %s", url)
return conn, nil
}
logConn.Printf("Plain JSON-RPC transport failed: %v", err)

// All transports failed
cancel()
logger.LogError("backend", "All HTTP transports failed for url=%s", url)
return nil, fmt.Errorf("failed to connect using any HTTP transport (tried streamable, SSE, and plain JSON-RPC): last error: %w", err)
}

// tryStreamableHTTPTransport attempts to connect using the streamable HTTP transport (2025-03-26 spec)
func tryStreamableHTTPTransport(ctx context.Context, cancel context.CancelFunc, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) {
// Create MCP client
client := sdk.NewClient(&sdk.Implementation{
Name: "awmg",
Version: "1.0.0",
}, nil)

// Create streamable HTTP transport
transport := &sdk.StreamableClientTransport{
Endpoint: url,
HTTPClient: httpClient,
MaxRetries: 0, // Don't retry on failure - we'll try other transports
}

// Try to connect - this will fail if the server doesn't support streamable HTTP
session, err := client.Connect(ctx, transport, nil)
if err != nil {
return nil, fmt.Errorf("streamable HTTP transport connect failed: %w", err)
}

conn := &Connection{
client: client,
session: session,
ctx: ctx,
cancel: cancel,
isHTTP: true,
httpURL: url,
headers: headers,
httpClient: httpClient,
httpTransportType: HTTPTransportStreamable,
}

logger.LogInfo("backend", "Streamable HTTP transport connected successfully")
logConn.Printf("Connected with streamable HTTP transport")
return conn, nil
}

// trySSETransport attempts to connect using the SSE transport (2024-11-05 spec)
func trySSETransport(ctx context.Context, cancel context.CancelFunc, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) {
// Create MCP client
client := sdk.NewClient(&sdk.Implementation{
Name: "awmg",
Version: "1.0.0",
}, nil)

// Create SSE transport
transport := &sdk.SSEClientTransport{
Endpoint: url,
HTTPClient: httpClient,
}

// Try to connect - this will fail if the server doesn't support SSE
session, err := client.Connect(ctx, transport, nil)
if err != nil {
return nil, fmt.Errorf("SSE transport connect failed: %w", err)
}

conn := &Connection{
ctx: ctx,
cancel: cancel,
isHTTP: true,
httpURL: url,
headers: headers,
httpClient: httpClient,
client: client,
session: session,
ctx: ctx,
cancel: cancel,
isHTTP: true,
httpURL: url,
headers: headers,
httpClient: httpClient,
httpTransportType: HTTPTransportSSE,
}

logger.LogInfo("backend", "SSE transport connected successfully")
logConn.Printf("Connected with SSE transport")
return conn, nil
}

// tryPlainJSONTransport attempts to connect using plain JSON-RPC 2.0 over HTTP POST (non-standard)
// This is used for compatibility with servers like safeinputs that don't implement standard MCP HTTP transports
func tryPlainJSONTransport(ctx context.Context, cancel context.CancelFunc, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) {
conn := &Connection{
ctx: ctx,
cancel: cancel,
isHTTP: true,
httpURL: url,
headers: headers,
httpClient: httpClient,
httpTransportType: HTTPTransportPlainJSON,
}

// Send initialize request to establish a session with the HTTP backend
// This is critical for backends that require session management
logConn.Printf("Sending initialize request to HTTP backend: url=%s", url)
logConn.Printf("Sending initialize request via plain JSON-RPC to: %s", url)
sessionID, err := conn.initializeHTTPSession()
if err != nil {
cancel()
logger.LogError("backend", "Failed to initialize HTTP session, url=%s, error=%v", url, err)
return nil, fmt.Errorf("failed to initialize HTTP session: %w", err)
return nil, fmt.Errorf("plain JSON-RPC initialize failed: %w", err)
}

conn.httpSessionID = sessionID
logger.LogInfo("backend", "Successfully created HTTP MCP connection with session, url=%s, session=%s", url, sessionID)
logConn.Printf("HTTP connection created with session: url=%s, session=%s", url, sessionID)
log.Printf("Configured HTTP MCP server with session: %s (session: %s)", url, sessionID)
logger.LogInfo("backend", "Plain JSON-RPC transport connected successfully with session=%s", sessionID)
logConn.Printf("Connected with plain JSON-RPC transport, session=%s", sessionID)
return conn, nil
}

Expand Down Expand Up @@ -214,9 +356,22 @@ func (c *Connection) SendRequestWithServerID(ctx context.Context, method string,
var result *Response
var err error

// Handle HTTP connections by proxying the request
// Handle HTTP connections
if c.isHTTP {
result, err = c.sendHTTPRequest(ctx, method, params)
// For plain JSON-RPC transport, use manual HTTP requests
if c.httpTransportType == HTTPTransportPlainJSON {
result, err = c.sendHTTPRequest(ctx, method, params)
// Log the response from backend server
var responsePayload []byte
if result != nil {
responsePayload, _ = json.Marshal(result)
}
logger.LogRPCResponse(logger.RPCDirectionInbound, serverID, responsePayload, err)
return result, err
}

// For streamable and SSE transports, use SDK session methods
result, err = c.callSDKMethod(method, params)
// Log the response from backend server
var responsePayload []byte
if result != nil {
Expand All @@ -227,22 +382,7 @@ func (c *Connection) SendRequestWithServerID(ctx context.Context, method string,
}

// Handle stdio connections using SDK client
switch method {
case "tools/list":
result, err = c.listTools()
case "tools/call":
result, err = c.callTool(params)
case "resources/list":
result, err = c.listResources()
case "resources/read":
result, err = c.readResource(params)
case "prompts/list":
result, err = c.listPrompts()
case "prompts/get":
result, err = c.getPrompt(params)
default:
err = fmt.Errorf("unsupported method: %s", method)
}
result, err = c.callSDKMethod(method, params)

// Log the response from backend server
var responsePayload []byte
Expand All @@ -254,6 +394,27 @@ func (c *Connection) SendRequestWithServerID(ctx context.Context, method string,
return result, err
}

// callSDKMethod calls the appropriate SDK method based on the method name
// This centralizes the method dispatch logic used by both HTTP SDK transports and stdio
func (c *Connection) callSDKMethod(method string, params interface{}) (*Response, error) {
switch method {
case "tools/list":
return c.listTools()
case "tools/call":
return c.callTool(params)
case "resources/list":
return c.listResources()
case "resources/read":
return c.readResource(params)
case "prompts/list":
return c.listPrompts()
case "prompts/get":
return c.getPrompt(params)
default:
return nil, fmt.Errorf("unsupported method: %s", method)
}
}

// initializeHTTPSession sends an initialize request to the HTTP backend and captures the session ID
func (c *Connection) initializeHTTPSession() (string, error) {
// Generate unique request ID
Expand Down Expand Up @@ -432,6 +593,9 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params
}

func (c *Connection) listTools() (*Response, error) {
if c.session == nil {
return nil, fmt.Errorf("SDK session not available for plain JSON-RPC transport")
}
result, err := c.session.ListTools(c.ctx, &sdk.ListToolsParams{})
if err != nil {
return nil, err
Expand All @@ -450,6 +614,9 @@ func (c *Connection) listTools() (*Response, error) {
}

func (c *Connection) callTool(params interface{}) (*Response, error) {
if c.session == nil {
return nil, fmt.Errorf("SDK session not available for plain JSON-RPC transport")
}
var callParams CallToolParams
paramsJSON, _ := json.Marshal(params)
if err := json.Unmarshal(paramsJSON, &callParams); err != nil {
Expand Down Expand Up @@ -477,6 +644,9 @@ func (c *Connection) callTool(params interface{}) (*Response, error) {
}

func (c *Connection) listResources() (*Response, error) {
if c.session == nil {
return nil, fmt.Errorf("SDK session not available for plain JSON-RPC transport")
}
result, err := c.session.ListResources(c.ctx, &sdk.ListResourcesParams{})
if err != nil {
return nil, err
Expand All @@ -495,6 +665,9 @@ func (c *Connection) listResources() (*Response, error) {
}

func (c *Connection) readResource(params interface{}) (*Response, error) {
if c.session == nil {
return nil, fmt.Errorf("SDK session not available for plain JSON-RPC transport")
}
var readParams struct {
URI string `json:"uri"`
}
Expand Down Expand Up @@ -523,6 +696,9 @@ func (c *Connection) readResource(params interface{}) (*Response, error) {
}

func (c *Connection) listPrompts() (*Response, error) {
if c.session == nil {
return nil, fmt.Errorf("SDK session not available for plain JSON-RPC transport")
}
result, err := c.session.ListPrompts(c.ctx, &sdk.ListPromptsParams{})
if err != nil {
return nil, err
Expand All @@ -541,6 +717,9 @@ func (c *Connection) listPrompts() (*Response, error) {
}

func (c *Connection) getPrompt(params interface{}) (*Response, error) {
if c.session == nil {
return nil, fmt.Errorf("SDK session not available for plain JSON-RPC transport")
}
var getParams struct {
Name string `json:"name"`
Arguments map[string]string `json:"arguments"`
Expand Down