diff --git a/internal/server/http_helpers.go b/internal/server/http_helpers.go new file mode 100644 index 00000000..74b5b7a6 --- /dev/null +++ b/internal/server/http_helpers.go @@ -0,0 +1,66 @@ +package server + +import ( + "bytes" + "context" + "io" + "log" + "net/http" + + "github.com/githubnext/gh-aw-mcpg/internal/auth" + "github.com/githubnext/gh-aw-mcpg/internal/logger" + "github.com/githubnext/gh-aw-mcpg/internal/mcp" +) + +// extractAndValidateSession extracts the session ID from the Authorization header +// and logs connection details. Returns empty string if validation fails. +func extractAndValidateSession(r *http.Request) string { + authHeader := r.Header.Get("Authorization") + sessionID := auth.ExtractSessionID(authHeader) + + if sessionID == "" { + logger.LogError("client", "Rejected MCP client connection: no Authorization header, remote=%s, path=%s", r.RemoteAddr, r.URL.Path) + log.Printf("[%s] %s %s - REJECTED: No Authorization header", r.RemoteAddr, r.Method, r.URL.Path) + return "" + } + + return sessionID +} + +// logHTTPRequestBody logs the request body for debugging purposes. +// It reads the body, logs it, and restores it so it can be read again. +// The backendID parameter is optional and can be empty for unified mode. +func logHTTPRequestBody(r *http.Request, sessionID, backendID string) { + if r.Method != "POST" || r.Body == nil { + return + } + + bodyBytes, err := io.ReadAll(r.Body) + if err != nil || len(bodyBytes) == 0 { + return + } + + // Log with backend context if provided (routed mode) + if backendID != "" { + logger.LogDebug("client", "MCP client request body, backend=%s, body=%s", backendID, string(bodyBytes)) + } else { + logger.LogDebug("client", "MCP request body, session=%s, body=%s", sessionID, string(bodyBytes)) + } + log.Printf("Request body: %s", string(bodyBytes)) + + // Restore body for subsequent reads + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) +} + +// injectSessionContext stores the session ID and optional backend ID into the request context. +// If backendID is empty, only session ID is injected (unified mode). +// Returns the modified request with updated context. +func injectSessionContext(r *http.Request, sessionID, backendID string) *http.Request { + ctx := context.WithValue(r.Context(), SessionIDContextKey, sessionID) + + if backendID != "" { + ctx = context.WithValue(ctx, mcp.ContextKey("backend-id"), backendID) + } + + return r.WithContext(ctx) +} diff --git a/internal/server/http_helpers_test.go b/internal/server/http_helpers_test.go new file mode 100644 index 00000000..c424cb06 --- /dev/null +++ b/internal/server/http_helpers_test.go @@ -0,0 +1,225 @@ +package server + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/githubnext/gh-aw-mcpg/internal/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExtractAndValidateSession(t *testing.T) { + tests := []struct { + name string + authHeader string + expectedID string + shouldBeEmpty bool + }{ + { + name: "Valid plain API key", + authHeader: "test-session-123", + expectedID: "test-session-123", + shouldBeEmpty: false, + }, + { + name: "Valid Bearer token", + authHeader: "Bearer my-token-456", + expectedID: "my-token-456", + shouldBeEmpty: false, + }, + { + name: "Empty Authorization header", + authHeader: "", + expectedID: "", + shouldBeEmpty: true, + }, + { + name: "Whitespace only header", + authHeader: " ", + expectedID: " ", + shouldBeEmpty: false, + }, + { + name: "Long session ID", + authHeader: "very-long-session-id-with-many-characters-1234567890", + expectedID: "very-long-session-id-with-many-characters-1234567890", + shouldBeEmpty: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/mcp", nil) + if tt.authHeader != "" { + req.Header.Set("Authorization", tt.authHeader) + } + + sessionID := extractAndValidateSession(req) + + if tt.shouldBeEmpty { + assert.Empty(t, sessionID, "Expected empty session ID") + } else { + assert.Equal(t, tt.expectedID, sessionID, "Session ID mismatch") + } + }) + } +} + +func TestLogHTTPRequestBody(t *testing.T) { + tests := []struct { + name string + method string + body string + sessionID string + backendID string + shouldLog bool + }{ + { + name: "POST request with body and backend", + method: "POST", + body: `{"method":"initialize"}`, + sessionID: "session-123", + backendID: "backend-1", + shouldLog: true, + }, + { + name: "POST request with body without backend", + method: "POST", + body: `{"method":"tools/call"}`, + sessionID: "session-456", + backendID: "", + shouldLog: true, + }, + { + name: "GET request (no body logging)", + method: "GET", + body: "", + sessionID: "session-789", + backendID: "backend-2", + shouldLog: false, + }, + { + name: "POST request with empty body", + method: "POST", + body: "", + sessionID: "session-abc", + backendID: "backend-3", + shouldLog: false, + }, + { + name: "POST request with nil body", + method: "POST", + body: "", + sessionID: "session-def", + backendID: "", + shouldLog: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var req *http.Request + if tt.body != "" { + req = httptest.NewRequest(tt.method, "/mcp", bytes.NewBufferString(tt.body)) + } else if tt.method == "POST" { + req = httptest.NewRequest(tt.method, "/mcp", nil) + } else { + req = httptest.NewRequest(tt.method, "/mcp", nil) + } + + // Call the function + logHTTPRequestBody(req, tt.sessionID, tt.backendID) + + // Verify body can still be read after logging + if tt.body != "" { + bodyBytes, err := io.ReadAll(req.Body) + require.NoError(t, err, "Should be able to read body after logging") + assert.Equal(t, tt.body, string(bodyBytes), "Body content should be preserved") + } + }) + } +} + +func TestInjectSessionContext(t *testing.T) { + tests := []struct { + name string + sessionID string + backendID string + expectBackendID bool + }{ + { + name: "Inject session and backend ID (routed mode)", + sessionID: "session-123", + backendID: "github", + expectBackendID: true, + }, + { + name: "Inject session ID only (unified mode)", + sessionID: "session-456", + backendID: "", + expectBackendID: false, + }, + { + name: "Long session ID with backend", + sessionID: "very-long-session-id-1234567890", + backendID: "slack", + expectBackendID: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/mcp", nil) + + // Inject context + modifiedReq := injectSessionContext(req, tt.sessionID, tt.backendID) + + // Verify session ID is in context + sessionIDFromCtx := modifiedReq.Context().Value(SessionIDContextKey) + require.NotNil(t, sessionIDFromCtx, "Session ID should be in context") + assert.Equal(t, tt.sessionID, sessionIDFromCtx, "Session ID mismatch") + + // Verify backend ID if expected + if tt.expectBackendID { + backendIDFromCtx := modifiedReq.Context().Value(mcp.ContextKey("backend-id")) + require.NotNil(t, backendIDFromCtx, "Backend ID should be in context") + assert.Equal(t, tt.backendID, backendIDFromCtx, "Backend ID mismatch") + } else { + backendIDFromCtx := modifiedReq.Context().Value(mcp.ContextKey("backend-id")) + assert.Nil(t, backendIDFromCtx, "Backend ID should not be in context for unified mode") + } + + // Verify original request is not modified + originalSessionID := req.Context().Value(SessionIDContextKey) + assert.Nil(t, originalSessionID, "Original request context should not be modified") + }) + } +} + +// testContextKey is a custom type for context keys to avoid collisions +type testContextKey string + +func TestInjectSessionContext_PreservesExistingContext(t *testing.T) { + // Create a request with existing context values + req := httptest.NewRequest("POST", "/mcp", nil) + ctx := context.WithValue(req.Context(), testContextKey("existing-key"), "existing-value") + req = req.WithContext(ctx) + + // Inject session context + modifiedReq := injectSessionContext(req, "session-123", "backend-1") + + // Verify both values are present + sessionID := modifiedReq.Context().Value(SessionIDContextKey) + assert.Equal(t, "session-123", sessionID, "Session ID should be present") + + backendID := modifiedReq.Context().Value(mcp.ContextKey("backend-id")) + assert.Equal(t, "backend-1", backendID, "Backend ID should be present") + + existingValue := modifiedReq.Context().Value(testContextKey("existing-key")) + assert.Equal(t, "existing-value", existingValue, "Existing context value should be preserved") +} diff --git a/internal/server/routed.go b/internal/server/routed.go index 39b4e682..010fd7c9 100644 --- a/internal/server/routed.go +++ b/internal/server/routed.go @@ -1,17 +1,13 @@ package server import ( - "bytes" "context" "fmt" - "io" "log" "net/http" "sync" - "github.com/githubnext/gh-aw-mcpg/internal/auth" "github.com/githubnext/gh-aw-mcpg/internal/logger" - "github.com/githubnext/gh-aw-mcpg/internal/mcp" sdk "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -110,14 +106,9 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap // Create StreamableHTTP handler for this route routeHandler := sdk.NewStreamableHTTPHandler(func(r *http.Request) *sdk.Server { - // Extract session ID from Authorization header - authHeader := r.Header.Get("Authorization") - sessionID := auth.ExtractSessionID(authHeader) - - // Reject requests without Authorization header + // Extract and validate session ID from Authorization header + sessionID := extractAndValidateSession(r) if sessionID == "" { - logger.LogError("client", "Rejected MCP client connection: no Authorization header, remote=%s, path=%s", r.RemoteAddr, r.URL.Path) - log.Printf("[%s] %s %s - REJECTED: No Authorization header", r.RemoteAddr, r.Method, r.URL.Path) return nil } @@ -129,19 +120,10 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap log.Printf("Authorization (Session ID): %s", sessionID) // Log request body for debugging - if r.Method == "POST" && r.Body != nil { - bodyBytes, err := io.ReadAll(r.Body) - if err == nil && len(bodyBytes) > 0 { - logger.LogDebug("client", "MCP client request body, backend=%s, body=%s", backendID, string(bodyBytes)) - log.Printf("Request body: %s", string(bodyBytes)) - r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - } - } + logHTTPRequestBody(r, sessionID, backendID) // Store session ID and backend ID in request context - ctx := context.WithValue(r.Context(), SessionIDContextKey, sessionID) - ctx = context.WithValue(ctx, mcp.ContextKey("backend-id"), backendID) - *r = *r.WithContext(ctx) + *r = *injectSessionContext(r, sessionID, backendID) log.Printf("✓ Injected session ID and backend ID into context") log.Printf("===================================\n") diff --git a/internal/server/transport.go b/internal/server/transport.go index 41f5734a..97c2b547 100644 --- a/internal/server/transport.go +++ b/internal/server/transport.go @@ -1,13 +1,10 @@ package server import ( - "bytes" "context" - "io" "log" "net/http" - "github.com/githubnext/gh-aw-mcpg/internal/auth" "github.com/githubnext/gh-aw-mcpg/internal/logger" sdk "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -74,15 +71,9 @@ func CreateHTTPServerForMCP(addr string, unifiedServer *UnifiedServer, apiKey st // We use the Authorization header value as the session ID // This groups all requests from the same agent (same auth value) into one session - // Extract session ID from Authorization header - authHeader := r.Header.Get("Authorization") - sessionID := auth.ExtractSessionID(authHeader) - - // Reject requests without Authorization header + // Extract and validate session ID from Authorization header + sessionID := extractAndValidateSession(r) if sessionID == "" { - logTransport.Printf("Rejecting connection: no Authorization header, remote=%s, path=%s", r.RemoteAddr, r.URL.Path) - logger.LogErrorMd("client", "MCP connection rejected: no Authorization header, remote=%s, path=%s", r.RemoteAddr, r.URL.Path) - log.Printf("[%s] %s %s - REJECTED: No Authorization header", r.RemoteAddr, r.Method, r.URL.Path) // Return nil to reject the connection // The SDK will handle sending an appropriate error response return nil @@ -96,20 +87,11 @@ func CreateHTTPServerForMCP(addr string, unifiedServer *UnifiedServer, apiKey st log.Printf("DEBUG: About to check request body, Method=%s, Body!=nil: %v", r.Method, r.Body != nil) // Log request body for debugging (typically the 'initialize' request) - if r.Method == "POST" && r.Body != nil { - bodyBytes, err := io.ReadAll(r.Body) - if err == nil && len(bodyBytes) > 0 { - logger.LogDebug("client", "MCP initialize request body, session=%s, body=%s", sessionID, string(bodyBytes)) - log.Printf("Request body: %s", string(bodyBytes)) - // Restore body - r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - } - } + logHTTPRequestBody(r, sessionID, "") // Store session ID in request context // This context will be passed to all tool handlers for this connection - ctx := context.WithValue(r.Context(), SessionIDContextKey, sessionID) - *r = *r.WithContext(ctx) + *r = *injectSessionContext(r, sessionID, "") log.Printf("✓ Injected session ID into context") log.Printf("==========================\n")