diff --git a/internal/auth/header.go b/internal/auth/header.go index ee90657e..eb4f8e52 100644 --- a/internal/auth/header.go +++ b/internal/auth/header.go @@ -126,3 +126,42 @@ func ExtractAgentID(authHeader string) string { return agentID } + +// ExtractSessionID extracts session ID from Authorization header. +// Per spec 7.1: When API key is configured, Authorization contains plain API key. +// When API key is not configured, supports Bearer token for backward compatibility. +// +// This function is specifically designed for server connection handling where: +// - Empty auth headers should return "" (to allow rejection of unauthenticated requests) +// - Bearer tokens should have whitespace trimmed (for backward compatibility) +// +// Returns: +// - Empty string if authHeader is empty +// - Trimmed token value if Bearer format +// - Plain authHeader value otherwise +func ExtractSessionID(authHeader string) string { + log.Printf("Extracting session ID from auth header: sanitized=%s", sanitize.TruncateSecret(authHeader)) + + if authHeader == "" { + log.Print("Auth header empty, returning empty session ID") + return "" + } + + // Handle "Bearer " format (backward compatibility) + // Trim spaces for backward compatibility with older clients + if strings.HasPrefix(authHeader, "Bearer ") { + log.Print("Detected Bearer format, trimming spaces for backward compatibility") + sessionID := strings.TrimPrefix(authHeader, "Bearer ") + return strings.TrimSpace(sessionID) + } + + // Handle "Agent " format + if strings.HasPrefix(authHeader, "Agent ") { + log.Print("Detected Agent format") + return strings.TrimPrefix(authHeader, "Agent ") + } + + // Plain format (per spec 7.1 - API key is session ID) + log.Print("Using plain API key as session ID") + return authHeader +} diff --git a/internal/auth/header_test.go b/internal/auth/header_test.go index 7eb4b619..863a26a2 100644 --- a/internal/auth/header_test.go +++ b/internal/auth/header_test.go @@ -283,3 +283,64 @@ func TestExtractAgentID(t *testing.T) { }) } } + +func TestExtractSessionID(t *testing.T) { + tests := []struct { + name string + authHeader string + want string + }{ + { + name: "Empty header returns empty string", + authHeader: "", + want: "", + }, + { + name: "Plain API key", + authHeader: "my-api-key", + want: "my-api-key", + }, + { + name: "Bearer token", + authHeader: "Bearer my-token-123", + want: "my-token-123", + }, + { + name: "Bearer token with trailing space (trimmed)", + authHeader: "Bearer my-token-123 ", + want: "my-token-123", + }, + { + name: "Bearer token with leading and trailing spaces (trimmed)", + authHeader: "Bearer my-token-123 ", + want: "my-token-123", + }, + { + name: "Agent format", + authHeader: "Agent agent-abc", + want: "agent-abc", + }, + { + name: "Long API key", + authHeader: "my-super-long-api-key-with-many-characters", + want: "my-super-long-api-key-with-many-characters", + }, + { + name: "API key with special characters", + authHeader: "key!@#$%^&*()", + want: "key!@#$%^&*()", + }, + { + name: "Whitespace only header", + authHeader: " ", + want: " ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ExtractSessionID(tt.authHeader) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/mcp/connection.go b/internal/mcp/connection.go index a8c85875..69b51d42 100644 --- a/internal/mcp/connection.go +++ b/internal/mcp/connection.go @@ -313,61 +313,73 @@ func NewHTTPConnection(ctx context.Context, url string, headers map[string]strin return nil, fmt.Errorf("failed to connect using any HTTP transport (tried streamable, SSE, and plain JSON-RPC): last error: %w", err) } -// tryStreamableHTTPTransport attempts to connect using the streamable HTTP transport (2025-03-26 spec) -func tryStreamableHTTPTransport(ctx context.Context, cancel context.CancelFunc, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) { +// transportConnector is a function that creates an SDK transport for a given URL and HTTP client +type transportConnector func(url string, httpClient *http.Client) sdk.Transport + +// trySDKTransport is a generic function to attempt connection with any SDK-based transport +// It handles the common logic of creating a client, connecting with timeout, and returning a connection +func trySDKTransport( + ctx context.Context, + cancel context.CancelFunc, + url string, + headers map[string]string, + httpClient *http.Client, + transportType HTTPTransportType, + transportName string, + createTransport transportConnector, +) (*Connection, error) { // Create MCP client client := newMCPClient() - // Create streamable HTTP transport - transport := &sdk.StreamableClientTransport{ - Endpoint: url, - HTTPClient: httpClient, - MaxRetries: 0, // Don't retry on failure - we'll try other transports - } + // Create transport using the provided connector + transport := createTransport(url, httpClient) - // Try to connect with a timeout - this will fail if the server doesn't support streamable HTTP + // Try to connect with a timeout - this will fail if the server doesn't support this transport // Use a short timeout to fail fast and try other transports connectCtx, connectCancel := context.WithTimeout(ctx, 5*time.Second) defer connectCancel() session, err := client.Connect(connectCtx, transport, nil) if err != nil { - return nil, fmt.Errorf("streamable HTTP transport connect failed: %w", err) + return nil, fmt.Errorf("%s transport connect failed: %w", transportName, err) } - conn := newHTTPConnection(ctx, cancel, client, session, url, headers, httpClient, HTTPTransportStreamable) + conn := newHTTPConnection(ctx, cancel, client, session, url, headers, httpClient, transportType) - logger.LogInfo("backend", "Streamable HTTP transport connected successfully") - logConn.Printf("Connected with streamable HTTP transport") + logger.LogInfo("backend", "%s transport connected successfully", transportName) + logConn.Printf("Connected with %s transport", transportName) return conn, nil } +// tryStreamableHTTPTransport attempts to connect using the streamable HTTP transport (2025-03-26 spec) +func tryStreamableHTTPTransport(ctx context.Context, cancel context.CancelFunc, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) { + return trySDKTransport( + ctx, cancel, url, headers, httpClient, + HTTPTransportStreamable, + "streamable HTTP", + func(url string, httpClient *http.Client) sdk.Transport { + return &sdk.StreamableClientTransport{ + Endpoint: url, + HTTPClient: httpClient, + MaxRetries: 0, // Don't retry on failure - we'll try other transports + } + }, + ) +} + // trySSETransport attempts to connect using the SSE transport (2024-11-05 spec) func trySSETransport(ctx context.Context, cancel context.CancelFunc, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) { - // Create MCP client - client := newMCPClient() - - // Create SSE transport - transport := &sdk.SSEClientTransport{ - Endpoint: url, - HTTPClient: httpClient, - } - - // Try to connect with a timeout - this will fail if the server doesn't support SSE - // Use a short timeout to fail fast and try other transports - connectCtx, connectCancel := context.WithTimeout(ctx, 5*time.Second) - defer connectCancel() - - session, err := client.Connect(connectCtx, transport, nil) - if err != nil { - return nil, fmt.Errorf("SSE transport connect failed: %w", err) - } - - conn := newHTTPConnection(ctx, cancel, client, session, url, headers, httpClient, HTTPTransportSSE) - - logger.LogInfo("backend", "SSE transport connected successfully") - logConn.Printf("Connected with SSE transport") - return conn, nil + return trySDKTransport( + ctx, cancel, url, headers, httpClient, + HTTPTransportSSE, + "SSE", + func(url string, httpClient *http.Client) sdk.Transport { + return &sdk.SSEClientTransport{ + Endpoint: url, + HTTPClient: httpClient, + } + }, + ) } // tryPlainJSONTransport attempts to connect using plain JSON-RPC 2.0 over HTTP POST (non-standard) diff --git a/internal/server/auth.go b/internal/server/auth.go index 975a98b9..6e3fb7a5 100644 --- a/internal/server/auth.go +++ b/internal/server/auth.go @@ -3,9 +3,9 @@ package server import ( "log" "net/http" - "strings" "time" + "github.com/githubnext/gh-aw-mcpg/internal/auth" "github.com/githubnext/gh-aw-mcpg/internal/logger" ) @@ -42,19 +42,14 @@ func authMiddleware(apiKey string, next http.HandlerFunc) http.HandlerFunc { } } -// extractSessionFromAuth extracts session ID from Authorization header -// Per spec 7.1: When API key is configured, Authorization contains plain API key -// When API key is not configured, supports Bearer token for backward compatibility +// extractSessionFromAuth extracts session ID from Authorization header. +// This function delegates to auth.ExtractSessionID for consistent session ID extraction. +// Per spec 7.1: When API key is configured, Authorization contains plain API key. +// When API key is not configured, supports Bearer token for backward compatibility. +// +// Deprecated: Use auth.ExtractSessionID directly instead. func extractSessionFromAuth(authHeader string) string { - if strings.HasPrefix(authHeader, "Bearer ") { - // Bearer token format (for backward compatibility when no API key) - sessionID := strings.TrimPrefix(authHeader, "Bearer ") - return strings.TrimSpace(sessionID) - } else if authHeader != "" { - // Plain format (per spec 7.1 - API key is session ID) - return authHeader - } - return "" + return auth.ExtractSessionID(authHeader) } // logRuntimeError logs runtime errors to stdout per spec section 9.2