Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 65 additions & 70 deletions internal/mcp/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,66 @@ type Connection struct {
httpTransportType HTTPTransportType // Type of HTTP transport in use
}

// newMCPClient creates a new MCP SDK client with standard implementation details
func newMCPClient() *sdk.Client {
return sdk.NewClient(&sdk.Implementation{
Name: "awmg",
Version: "1.0.0",
}, nil)
}

// newHTTPConnection creates a new HTTP Connection struct with common fields
func newHTTPConnection(ctx context.Context, cancel context.CancelFunc, client *sdk.Client, session *sdk.ClientSession, url string, headers map[string]string, httpClient *http.Client, transportType HTTPTransportType) *Connection {
return &Connection{
client: client,
session: session,
ctx: ctx,
cancel: cancel,
isHTTP: true,
httpURL: url,
headers: headers,
httpClient: httpClient,
httpTransportType: transportType,
}
}

// createJSONRPCRequest creates a JSON-RPC 2.0 request map
func createJSONRPCRequest(requestID uint64, method string, params interface{}) map[string]interface{} {
return map[string]interface{}{
"jsonrpc": "2.0",
"id": requestID,
"method": method,
"params": params,
}
}

// setupHTTPRequest creates and configures an HTTP request with standard headers
func setupHTTPRequest(ctx context.Context, url string, requestBody []byte, headers map[string]string) (*http.Request, error) {
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(requestBody))
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}

// Set standard headers
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "application/json, text/event-stream")

// Add configured headers (e.g., Authorization)
for key, value := range headers {
httpReq.Header.Set(key, value)
}

return httpReq, nil
}

// NewConnection creates a new MCP connection using the official SDK
func NewConnection(ctx context.Context, command string, args []string, env map[string]string) (*Connection, error) {
logger.LogInfo("backend", "Creating new MCP backend connection, command=%s, args=%v", command, args)
logConn.Printf("Creating new MCP connection: command=%s, args=%v", command, args)
ctx, cancel := context.WithCancel(ctx)

// Create MCP client
client := sdk.NewClient(&sdk.Implementation{
Name: "awmg",
Version: "1.0.0",
}, nil)
client := newMCPClient()

// Expand Docker -e flags that reference environment variables
// Docker's `-e VAR_NAME` expects VAR_NAME to be in the environment
Expand Down Expand Up @@ -239,10 +288,7 @@ func NewHTTPConnection(ctx context.Context, url string, headers map[string]strin
// tryStreamableHTTPTransport attempts to connect using the streamable HTTP transport (2025-03-26 spec)
func tryStreamableHTTPTransport(ctx context.Context, cancel context.CancelFunc, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) {
// Create MCP client
client := sdk.NewClient(&sdk.Implementation{
Name: "awmg",
Version: "1.0.0",
}, nil)
client := newMCPClient()

// Create streamable HTTP transport
transport := &sdk.StreamableClientTransport{
Expand All @@ -261,17 +307,7 @@ func tryStreamableHTTPTransport(ctx context.Context, cancel context.CancelFunc,
return nil, fmt.Errorf("streamable HTTP transport connect failed: %w", err)
}

conn := &Connection{
client: client,
session: session,
ctx: ctx,
cancel: cancel,
isHTTP: true,
httpURL: url,
headers: headers,
httpClient: httpClient,
httpTransportType: HTTPTransportStreamable,
}
conn := newHTTPConnection(ctx, cancel, client, session, url, headers, httpClient, HTTPTransportStreamable)

logger.LogInfo("backend", "Streamable HTTP transport connected successfully")
logConn.Printf("Connected with streamable HTTP transport")
Expand All @@ -281,10 +317,7 @@ func tryStreamableHTTPTransport(ctx context.Context, cancel context.CancelFunc,
// trySSETransport attempts to connect using the SSE transport (2024-11-05 spec)
func trySSETransport(ctx context.Context, cancel context.CancelFunc, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) {
// Create MCP client
client := sdk.NewClient(&sdk.Implementation{
Name: "awmg",
Version: "1.0.0",
}, nil)
client := newMCPClient()

// Create SSE transport
transport := &sdk.SSEClientTransport{
Expand All @@ -302,17 +335,7 @@ func trySSETransport(ctx context.Context, cancel context.CancelFunc, url string,
return nil, fmt.Errorf("SSE transport connect failed: %w", err)
}

conn := &Connection{
client: client,
session: session,
ctx: ctx,
cancel: cancel,
isHTTP: true,
httpURL: url,
headers: headers,
httpClient: httpClient,
httpTransportType: HTTPTransportSSE,
}
conn := newHTTPConnection(ctx, cancel, client, session, url, headers, httpClient, HTTPTransportSSE)

logger.LogInfo("backend", "SSE transport connected successfully")
logConn.Printf("Connected with SSE transport")
Expand Down Expand Up @@ -455,12 +478,7 @@ func (c *Connection) initializeHTTPSession() (string, error) {
},
}

request := map[string]interface{}{
"jsonrpc": "2.0",
"id": requestID,
"method": "initialize",
"params": initParams,
}
request := createJSONRPCRequest(requestID, "initialize", initParams)

requestBody, err := json.Marshal(request)
if err != nil {
Expand All @@ -469,27 +487,18 @@ func (c *Connection) initializeHTTPSession() (string, error) {

logConn.Printf("Sending initialize request: %s", string(requestBody))

// Create HTTP request
httpReq, err := http.NewRequest("POST", c.httpURL, bytes.NewReader(requestBody))
// Create HTTP request with standard headers
httpReq, err := setupHTTPRequest(context.Background(), c.httpURL, requestBody, c.headers)
if err != nil {
return "", fmt.Errorf("failed to create HTTP request: %w", err)
return "", err
}

// Set headers
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "application/json, text/event-stream")

// Generate a temporary session ID for the initialize request
// Some backends may require this header even during initialization
tempSessionID := fmt.Sprintf("awmg-init-%d", requestID)
httpReq.Header.Set("Mcp-Session-Id", tempSessionID)
logConn.Printf("Sending initialize with temporary session ID: %s", tempSessionID)

// Add configured headers (e.g., Authorization)
for key, value := range c.headers {
httpReq.Header.Set(key, value)
}

logConn.Printf("Sending initialize to %s", c.httpURL)

// Send request
Expand Down Expand Up @@ -568,28 +577,19 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params
requestID := atomic.AddUint64(&requestIDCounter, 1)

// Create JSON-RPC request
request := map[string]interface{}{
"jsonrpc": "2.0",
"id": requestID,
"method": method,
"params": params,
}
request := createJSONRPCRequest(requestID, method, params)

requestBody, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}

// Create HTTP request
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.httpURL, bytes.NewReader(requestBody))
// Create HTTP request with standard headers
httpReq, err := setupHTTPRequest(ctx, c.httpURL, requestBody, c.headers)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
return nil, err
}

// Set headers
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "application/json, text/event-stream")

// Add Mcp-Session-Id header with priority:
// 1) Context session ID (if explicitly provided for this request)
// 2) Stored httpSessionID from initialization
Expand All @@ -608,11 +608,6 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params
logConn.Printf("No session ID available (backend may not require session management)")
}

// Add configured headers
for key, value := range c.headers {
httpReq.Header.Set(key, value)
}

logConn.Printf("Sending HTTP request to %s: method=%s, id=%d", c.httpURL, method, requestID)

// Send request using the reusable HTTP client
Expand Down
139 changes: 139 additions & 0 deletions internal/mcp/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -564,3 +564,142 @@ data: {"jsonrpc":"2.0","id":` + idStr + `,"result":{"tools":[]}}

t.Logf("Successfully parsed SSE-formatted response from server")
}

// TestNewMCPClient tests the newMCPClient helper function
func TestNewMCPClient(t *testing.T) {
client := newMCPClient()
require.NotNil(t, client, "newMCPClient should return a non-nil client")
}

// TestCreateJSONRPCRequest tests the createJSONRPCRequest helper function
func TestCreateJSONRPCRequest(t *testing.T) {
tests := []struct {
name string
requestID uint64
method string
params interface{}
}{
{
name: "simple request with nil params",
requestID: 1,
method: "initialize",
params: nil,
},
{
name: "request with map params",
requestID: 42,
method: "tools/list",
params: map[string]interface{}{"filter": "test"},
},
{
name: "request with struct params",
requestID: 100,
method: "tools/call",
params: struct{ Name string }{Name: "test-tool"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
request := createJSONRPCRequest(tt.requestID, tt.method, tt.params)

assert.Equal(t, "2.0", request["jsonrpc"], "jsonrpc version should be 2.0")
assert.Equal(t, tt.requestID, request["id"], "id should match requestID")
assert.Equal(t, tt.method, request["method"], "method should match")
assert.Equal(t, tt.params, request["params"], "params should match")
})
}
}

// TestSetupHTTPRequest tests the setupHTTPRequest helper function
func TestSetupHTTPRequest(t *testing.T) {
tests := []struct {
name string
url string
requestBody []byte
headers map[string]string
expectError bool
expectedMethod string
}{
{
name: "basic request with no custom headers",
url: "http://example.com/mcp",
requestBody: []byte(`{"test": "data"}`),
headers: map[string]string{},
expectError: false,
expectedMethod: "POST",
},
{
name: "request with custom headers",
url: "http://example.com/mcp",
requestBody: []byte(`{"test": "data"}`),
headers: map[string]string{
"Authorization": "Bearer token123",
"X-Custom": "value",
},
expectError: false,
expectedMethod: "POST",
},
{
name: "request with empty body",
url: "http://example.com/mcp",
requestBody: []byte{},
headers: map[string]string{},
expectError: false,
expectedMethod: "POST",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
req, err := setupHTTPRequest(ctx, tt.url, tt.requestBody, tt.headers)

if tt.expectError {
assert.Error(t, err, "Expected error")
return
}

require.NoError(t, err, "setupHTTPRequest should not return error")
require.NotNil(t, req, "Request should not be nil")

// Verify method
assert.Equal(t, tt.expectedMethod, req.Method, "Method should be POST")

// Verify URL
assert.Equal(t, tt.url, req.URL.String(), "URL should match")

// Verify standard headers
assert.Equal(t, "application/json", req.Header.Get("Content-Type"), "Content-Type should be application/json")
assert.Equal(t, "application/json, text/event-stream", req.Header.Get("Accept"), "Accept header should be set")

// Verify custom headers
for key, value := range tt.headers {
assert.Equal(t, value, req.Header.Get(key), "Custom header %s should match", key)
}
})
}
}

// TestNewHTTPConnection tests the newHTTPConnection helper function
func TestNewHTTPConnection(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

client := newMCPClient()
url := "http://example.com/mcp"
headers := map[string]string{"Authorization": "test"}
httpClient := &http.Client{}

conn := newHTTPConnection(ctx, cancel, client, nil, url, headers, httpClient, HTTPTransportStreamable)

require.NotNil(t, conn, "Connection should not be nil")
assert.Equal(t, client, conn.client, "Client should match")
assert.Equal(t, ctx, conn.ctx, "Context should match")
assert.NotNil(t, conn.cancel, "Cancel function should not be nil")
assert.True(t, conn.isHTTP, "isHTTP should be true")
assert.Equal(t, url, conn.httpURL, "URL should match")
assert.Equal(t, headers, conn.headers, "Headers should match")
assert.Equal(t, httpClient, conn.httpClient, "HTTP client should match")
assert.Equal(t, HTTPTransportStreamable, conn.httpTransportType, "Transport type should match")
}