Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions internal/launcher/launcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,41 @@ package launcher

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"testing"

"github.com/githubnext/gh-aw-mcpg/internal/config"
)

func TestHTTPConnection(t *testing.T) {
// Create a mock HTTP server that handles initialize
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{
"jsonrpc": "2.0",
"id": 1,
"result": map[string]interface{}{
"protocolVersion": "2024-11-05",
"capabilities": map[string]interface{}{},
"serverInfo": map[string]interface{}{
"name": "test-server",
"version": "1.0.0",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}))
defer mockServer.Close()

// Create test config with HTTP server
jsonConfig := `{
"mcpServers": {
"safeinputs": {
"type": "http",
"url": "http://host.docker.internal:3000/",
"url": "` + mockServer.URL + `",
"headers": {
"Authorization": "test-auth-secret"
}
Expand Down Expand Up @@ -53,8 +75,8 @@ func TestHTTPConnection(t *testing.T) {
t.Errorf("Expected type 'http', got '%s'", httpServer.Type)
}

if httpServer.URL != "http://host.docker.internal:3000/" {
t.Errorf("Expected URL 'http://host.docker.internal:3000/', got '%s'", httpServer.URL)
if httpServer.URL != mockServer.URL {
t.Errorf("Expected URL '%s', got '%s'", mockServer.URL, httpServer.URL)
}

if httpServer.Headers["Authorization"] != "test-auth-secret" {
Expand All @@ -75,8 +97,8 @@ func TestHTTPConnection(t *testing.T) {
t.Error("Connection should be HTTP")
}

if conn.GetHTTPURL() != "http://host.docker.internal:3000/" {
t.Errorf("Expected URL 'http://host.docker.internal:3000/', got '%s'", conn.GetHTTPURL())
if conn.GetHTTPURL() != mockServer.URL {
t.Errorf("Expected URL '%s', got '%s'", mockServer.URL, conn.GetHTTPURL())
}

if conn.GetHTTPHeaders()["Authorization"] != "test-auth-secret" {
Expand All @@ -89,12 +111,31 @@ func TestHTTPConnectionWithVariableExpansion(t *testing.T) {
os.Setenv("TEST_AUTH_TOKEN", "secret-token-value")
defer os.Unsetenv("TEST_AUTH_TOKEN")

// Create a mock HTTP server that handles initialize
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{
"jsonrpc": "2.0",
"id": 1,
"result": map[string]interface{}{
"protocolVersion": "2024-11-05",
"capabilities": map[string]interface{}{},
"serverInfo": map[string]interface{}{
"name": "test-server",
"version": "1.0.0",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}))
defer mockServer.Close()

// Create test config with variable expansion
jsonConfig := `{
"mcpServers": {
"safeinputs": {
"type": "http",
"url": "http://host.docker.internal:3000/",
"url": "` + mockServer.URL + `",
"headers": {
"Authorization": "${TEST_AUTH_TOKEN}"
}
Expand Down
151 changes: 140 additions & 11 deletions internal/mcp/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ type Connection struct {
ctx context.Context
cancel context.CancelFunc
// HTTP-specific fields
isHTTP bool
httpURL string
headers map[string]string
httpClient *http.Client
isHTTP bool
httpURL string
headers map[string]string
httpClient *http.Client
httpSessionID string // Session ID returned by the HTTP backend
}

// NewConnection creates a new MCP connection using the official SDK
Expand Down Expand Up @@ -126,7 +127,17 @@ func NewConnection(ctx context.Context, command string, args []string, env map[s
}

// NewHTTPConnection creates a new HTTP-based MCP connection
// For HTTP servers that are already running, we just store the connection info
// For HTTP servers that are already running, we connect and initialize a session
//
// NOTE: This currently implements the MCP HTTP protocol manually instead of using
// the SDK's SSEClientTransport. This is because:
// 1. SSEClientTransport requires Server-Sent Events (SSE) format
// 2. Some HTTP MCP servers (like safeinputs) use plain JSON-RPC over HTTP POST
// 3. The MCP spec allows both SSE and plain HTTP transports
//
// TODO: Migrate to sdk.SSEClientTransport once we confirm all target HTTP backends
// support the SSE format, or extend the SDK to support both transport formats.
// This would eliminate manual JSON-RPC handling and improve maintainability.
func NewHTTPConnection(ctx context.Context, url string, headers map[string]string) (*Connection, error) {
logger.LogInfo("backend", "Creating HTTP MCP connection, url=%s", url)
logConn.Printf("Creating HTTP MCP connection: url=%s", url)
Expand All @@ -151,9 +162,20 @@ func NewHTTPConnection(ctx context.Context, url string, headers map[string]strin
httpClient: httpClient,
}

logger.LogInfoMd("backend", "Successfully created HTTP MCP connection, url=%s", url)
logConn.Printf("HTTP connection created: url=%s", url)
log.Printf("Configured HTTP MCP server: %s", url)
// Send initialize request to establish a session with the HTTP backend
// This is critical for backends that require session management
logConn.Printf("Sending initialize request to HTTP backend: url=%s", url)
sessionID, err := conn.initializeHTTPSession()
if err != nil {
cancel()
logger.LogError("backend", "Failed to initialize HTTP session, url=%s, error=%v", url, err)
return nil, fmt.Errorf("failed to initialize HTTP session: %w", err)
}

conn.httpSessionID = sessionID
logger.LogInfo("backend", "Successfully created HTTP MCP connection with session, url=%s, session=%s", url, sessionID)
logConn.Printf("HTTP connection created with session: url=%s, session=%s", url, sessionID)
log.Printf("Configured HTTP MCP server with session: %s (session: %s)", url, sessionID)
return conn, nil
}

Expand Down Expand Up @@ -232,6 +254,101 @@ func (c *Connection) SendRequestWithServerID(ctx context.Context, method string,
return result, err
}

// initializeHTTPSession sends an initialize request to the HTTP backend and captures the session ID
func (c *Connection) initializeHTTPSession() (string, error) {
// Generate unique request ID
requestID := atomic.AddUint64(&requestIDCounter, 1)

// Create initialize request with MCP protocol parameters
initParams := map[string]interface{}{
"protocolVersion": "2024-11-05",
"capabilities": map[string]interface{}{},
"clientInfo": map[string]interface{}{
"name": "awmg",
"version": "1.0.0",
},
}

request := map[string]interface{}{
"jsonrpc": "2.0",
"id": requestID,
"method": "initialize",
"params": initParams,
}

requestBody, err := json.Marshal(request)
if err != nil {
return "", fmt.Errorf("failed to marshal initialize request: %w", err)
}

logConn.Printf("Sending initialize request: %s", string(requestBody))

// Create HTTP request
httpReq, err := http.NewRequest("POST", c.httpURL, bytes.NewReader(requestBody))
if err != nil {
return "", fmt.Errorf("failed to create HTTP request: %w", err)
}

// Set headers
httpReq.Header.Set("Content-Type", "application/json")

// Generate a temporary session ID for the initialize request
// Some backends may require this header even during initialization
tempSessionID := fmt.Sprintf("awmg-init-%d", requestID)
httpReq.Header.Set("Mcp-Session-Id", tempSessionID)
logConn.Printf("Sending initialize with temporary session ID: %s", tempSessionID)

// Add configured headers (e.g., Authorization)
for key, value := range c.headers {
httpReq.Header.Set(key, value)
}

logConn.Printf("Sending initialize to %s", c.httpURL)

// Send request
httpResp, err := c.httpClient.Do(httpReq)
if err != nil {
return "", fmt.Errorf("failed to send initialize request: %w", err)
}
defer httpResp.Body.Close()

// Capture the Mcp-Session-Id from response headers
sessionID := httpResp.Header.Get("Mcp-Session-Id")
if sessionID != "" {
logConn.Printf("Captured Mcp-Session-Id from response: %s", sessionID)
} else {
// If no session ID in response, use the temporary one
// This handles backends that don't return a session ID
sessionID = tempSessionID
logConn.Printf("No Mcp-Session-Id in response, using temporary session ID: %s", sessionID)
}

// Read response body
responseBody, err := io.ReadAll(httpResp.Body)
if err != nil {
return "", fmt.Errorf("failed to read initialize response: %w", err)
}

logConn.Printf("Initialize response: status=%d, body_len=%d, session=%s", httpResp.StatusCode, len(responseBody), sessionID)

// Check for HTTP errors
if httpResp.StatusCode != http.StatusOK {
return "", fmt.Errorf("initialize failed: status=%d, body=%s", httpResp.StatusCode, string(responseBody))
}

// Parse JSON-RPC response to check for errors
var rpcResponse Response
if err := json.Unmarshal(responseBody, &rpcResponse); err != nil {
return "", fmt.Errorf("failed to parse initialize response: %w", err)
}

if rpcResponse.Error != nil {
return "", fmt.Errorf("initialize error: code=%d, message=%s", rpcResponse.Error.Code, rpcResponse.Error.Message)
}

return sessionID, nil
}

// sendHTTPRequest sends a JSON-RPC request to an HTTP MCP server
// The ctx parameter is used to extract session ID for the Mcp-Session-Id header
func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params interface{}) (*Response, error) {
Expand Down Expand Up @@ -260,10 +377,22 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params
// Set headers
httpReq.Header.Set("Content-Type", "application/json")

// Extract session ID from context and add Mcp-Session-Id header
if sessionID, ok := ctx.Value(SessionIDContextKey).(string); ok && sessionID != "" {
// Add Mcp-Session-Id header with priority:
// 1) Context session ID (if explicitly provided for this request)
// 2) Stored httpSessionID from initialization
var sessionID string
if ctxSessionID, ok := ctx.Value(SessionIDContextKey).(string); ok && ctxSessionID != "" {
sessionID = ctxSessionID
logConn.Printf("Using session ID from context: %s", sessionID)
} else if c.httpSessionID != "" {
sessionID = c.httpSessionID
logConn.Printf("Using stored session ID from initialization: %s", sessionID)
}

if sessionID != "" {
httpReq.Header.Set("Mcp-Session-Id", sessionID)
logConn.Printf("Added Mcp-Session-Id header: %s", sessionID)
} else {
logConn.Printf("No session ID available (backend may not require session management)")
}

// Add configured headers
Expand Down
26 changes: 25 additions & 1 deletion internal/server/unified_http_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ func TestHTTPBackendInitializationWithSessionIDRequirement(t *testing.T) {
// TestHTTPBackend_SessionIDPropagation tests that session ID is propagated through tool calls
func TestHTTPBackend_SessionIDPropagation(t *testing.T) {
// Track session IDs received at different stages
initializeSessionID := ""
initSessionID := ""
toolCallSessionID := ""

Expand All @@ -186,6 +187,23 @@ func TestHTTPBackend_SessionIDPropagation(t *testing.T) {
json.NewDecoder(r.Body).Decode(&req)

switch req.Method {
case "initialize":
initializeSessionID = sessionID
// Return initialize response
response := map[string]interface{}{
"jsonrpc": "2.0",
"id": 1,
"result": map[string]interface{}{
"protocolVersion": "2024-11-05",
"capabilities": map[string]interface{}{},
"serverInfo": map[string]interface{}{
"name": "test-http-server",
"version": "1.0.0",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
case "tools/list":
initSessionID = sessionID
// Return tools list
Expand Down Expand Up @@ -260,8 +278,14 @@ func TestHTTPBackend_SessionIDPropagation(t *testing.T) {
}

// Verify session IDs were received
if initializeSessionID == "" {
t.Errorf("No session ID received during initialize")
} else {
t.Logf("Initialize session ID: %s", initializeSessionID)
}

if initSessionID == "" {
t.Errorf("No session ID received during initialization")
t.Errorf("No session ID received during tools/list (initialization)")
} else {
t.Logf("Init session ID: %s", initSessionID)
}
Expand Down
16 changes: 15 additions & 1 deletion test/integration/safeinputs_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@ func TestSafeinputsHTTPBackend(t *testing.T) {
// Return appropriate response based on method
var response map[string]interface{}
switch rpcReq.Method {
case "initialize":
response = map[string]interface{}{
"jsonrpc": "2.0",
"id": rpcReq.ID,
"result": map[string]interface{}{
"protocolVersion": "2024-11-05",
"capabilities": map[string]interface{}{},
"serverInfo": map[string]interface{}{
"name": "safeinputs-server",
"version": "1.0.0",
},
},
}
case "tools/list":
response = map[string]interface{}{
"jsonrpc": "2.0",
Expand Down Expand Up @@ -231,7 +244,8 @@ func TestSafeinputsHTTPBackend(t *testing.T) {
t.Logf("Request #%d session ID: %s", i+1, headers["Mcp-Session-Id"])

// Verify the session ID follows the expected pattern for initialization
if strings.HasPrefix(headers["Mcp-Session-Id"], "gateway-init-") {
if strings.HasPrefix(headers["Mcp-Session-Id"], "awmg-init-") ||
strings.HasPrefix(headers["Mcp-Session-Id"], "gateway-init-") {
t.Logf("✓ Request #%d has correct gateway initialization session ID pattern", i+1)
}
}
Expand Down