Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions internal/auth/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,42 @@ func ExtractAgentID(authHeader string) string {

return agentID
}

// ExtractSessionID extracts session ID from Authorization header.
// Per spec 7.1: When API key is configured, Authorization contains plain API key.
// When API key is not configured, supports Bearer token for backward compatibility.
//
// This function is specifically designed for server connection handling where:
// - Empty auth headers should return "" (to allow rejection of unauthenticated requests)
// - Bearer tokens should have whitespace trimmed (for backward compatibility)
//
// Returns:
// - Empty string if authHeader is empty
// - Trimmed token value if Bearer format
// - Plain authHeader value otherwise
func ExtractSessionID(authHeader string) string {
log.Printf("Extracting session ID from auth header: sanitized=%s", sanitize.TruncateSecret(authHeader))

if authHeader == "" {
log.Print("Auth header empty, returning empty session ID")
return ""
}

// Handle "Bearer <token>" format (backward compatibility)
// Trim spaces for backward compatibility with older clients
if strings.HasPrefix(authHeader, "Bearer ") {
log.Print("Detected Bearer format, trimming spaces for backward compatibility")
sessionID := strings.TrimPrefix(authHeader, "Bearer ")
return strings.TrimSpace(sessionID)
}

// Handle "Agent <agent-id>" format
if strings.HasPrefix(authHeader, "Agent ") {
log.Print("Detected Agent format")
return strings.TrimPrefix(authHeader, "Agent ")
}

// Plain format (per spec 7.1 - API key is session ID)
log.Print("Using plain API key as session ID")
return authHeader
}
61 changes: 61 additions & 0 deletions internal/auth/header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,64 @@ func TestExtractAgentID(t *testing.T) {
})
}
}

func TestExtractSessionID(t *testing.T) {
tests := []struct {
name string
authHeader string
want string
}{
{
name: "Empty header returns empty string",
authHeader: "",
want: "",
},
{
name: "Plain API key",
authHeader: "my-api-key",
want: "my-api-key",
},
{
name: "Bearer token",
authHeader: "Bearer my-token-123",
want: "my-token-123",
},
{
name: "Bearer token with trailing space (trimmed)",
authHeader: "Bearer my-token-123 ",
want: "my-token-123",
},
{
name: "Bearer token with leading and trailing spaces (trimmed)",
authHeader: "Bearer my-token-123 ",
want: "my-token-123",
},
{
name: "Agent format",
authHeader: "Agent agent-abc",
want: "agent-abc",
},
{
name: "Long API key",
authHeader: "my-super-long-api-key-with-many-characters",
want: "my-super-long-api-key-with-many-characters",
},
{
name: "API key with special characters",
authHeader: "key!@#$%^&*()",
want: "key!@#$%^&*()",
},
{
name: "Whitespace only header",
authHeader: " ",
want: " ",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ExtractSessionID(tt.authHeader)
assert.Equal(t, tt.want, got)
})
}
}
86 changes: 49 additions & 37 deletions internal/mcp/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,61 +313,73 @@ func NewHTTPConnection(ctx context.Context, url string, headers map[string]strin
return nil, fmt.Errorf("failed to connect using any HTTP transport (tried streamable, SSE, and plain JSON-RPC): last error: %w", err)
}

// tryStreamableHTTPTransport attempts to connect using the streamable HTTP transport (2025-03-26 spec)
func tryStreamableHTTPTransport(ctx context.Context, cancel context.CancelFunc, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) {
// transportConnector is a function that creates an SDK transport for a given URL and HTTP client
type transportConnector func(url string, httpClient *http.Client) sdk.Transport

// trySDKTransport is a generic function to attempt connection with any SDK-based transport
// It handles the common logic of creating a client, connecting with timeout, and returning a connection
func trySDKTransport(
ctx context.Context,
cancel context.CancelFunc,
url string,
headers map[string]string,
httpClient *http.Client,
transportType HTTPTransportType,
transportName string,
createTransport transportConnector,
) (*Connection, error) {
// Create MCP client
client := newMCPClient()

// Create streamable HTTP transport
transport := &sdk.StreamableClientTransport{
Endpoint: url,
HTTPClient: httpClient,
MaxRetries: 0, // Don't retry on failure - we'll try other transports
}
// Create transport using the provided connector
transport := createTransport(url, httpClient)

// Try to connect with a timeout - this will fail if the server doesn't support streamable HTTP
// Try to connect with a timeout - this will fail if the server doesn't support this transport
// Use a short timeout to fail fast and try other transports
connectCtx, connectCancel := context.WithTimeout(ctx, 5*time.Second)
defer connectCancel()

session, err := client.Connect(connectCtx, transport, nil)
if err != nil {
return nil, fmt.Errorf("streamable HTTP transport connect failed: %w", err)
return nil, fmt.Errorf("%s transport connect failed: %w", transportName, err)
}

conn := newHTTPConnection(ctx, cancel, client, session, url, headers, httpClient, HTTPTransportStreamable)
conn := newHTTPConnection(ctx, cancel, client, session, url, headers, httpClient, transportType)

logger.LogInfo("backend", "Streamable HTTP transport connected successfully")
logConn.Printf("Connected with streamable HTTP transport")
logger.LogInfo("backend", "%s transport connected successfully", transportName)
logConn.Printf("Connected with %s transport", transportName)
return conn, nil
}

// tryStreamableHTTPTransport attempts to connect using the streamable HTTP transport (2025-03-26 spec)
func tryStreamableHTTPTransport(ctx context.Context, cancel context.CancelFunc, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) {
return trySDKTransport(
ctx, cancel, url, headers, httpClient,
HTTPTransportStreamable,
"streamable HTTP",
func(url string, httpClient *http.Client) sdk.Transport {
return &sdk.StreamableClientTransport{
Endpoint: url,
HTTPClient: httpClient,
MaxRetries: 0, // Don't retry on failure - we'll try other transports
}
},
)
}

// trySSETransport attempts to connect using the SSE transport (2024-11-05 spec)
func trySSETransport(ctx context.Context, cancel context.CancelFunc, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) {
// Create MCP client
client := newMCPClient()

// Create SSE transport
transport := &sdk.SSEClientTransport{
Endpoint: url,
HTTPClient: httpClient,
}

// Try to connect with a timeout - this will fail if the server doesn't support SSE
// Use a short timeout to fail fast and try other transports
connectCtx, connectCancel := context.WithTimeout(ctx, 5*time.Second)
defer connectCancel()

session, err := client.Connect(connectCtx, transport, nil)
if err != nil {
return nil, fmt.Errorf("SSE transport connect failed: %w", err)
}

conn := newHTTPConnection(ctx, cancel, client, session, url, headers, httpClient, HTTPTransportSSE)

logger.LogInfo("backend", "SSE transport connected successfully")
logConn.Printf("Connected with SSE transport")
return conn, nil
return trySDKTransport(
ctx, cancel, url, headers, httpClient,
HTTPTransportSSE,
"SSE",
func(url string, httpClient *http.Client) sdk.Transport {
return &sdk.SSEClientTransport{
Endpoint: url,
HTTPClient: httpClient,
}
},
)
}

// tryPlainJSONTransport attempts to connect using plain JSON-RPC 2.0 over HTTP POST (non-standard)
Expand Down
21 changes: 8 additions & 13 deletions internal/server/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package server
import (
"log"
"net/http"
"strings"
"time"

"github.com/githubnext/gh-aw-mcpg/internal/auth"
"github.com/githubnext/gh-aw-mcpg/internal/logger"
)

Expand Down Expand Up @@ -42,19 +42,14 @@ func authMiddleware(apiKey string, next http.HandlerFunc) http.HandlerFunc {
}
}

// extractSessionFromAuth extracts session ID from Authorization header
// Per spec 7.1: When API key is configured, Authorization contains plain API key
// When API key is not configured, supports Bearer token for backward compatibility
// extractSessionFromAuth extracts session ID from Authorization header.
// This function delegates to auth.ExtractSessionID for consistent session ID extraction.
// Per spec 7.1: When API key is configured, Authorization contains plain API key.
// When API key is not configured, supports Bearer token for backward compatibility.
//
// Deprecated: Use auth.ExtractSessionID directly instead.
func extractSessionFromAuth(authHeader string) string {
if strings.HasPrefix(authHeader, "Bearer ") {
// Bearer token format (for backward compatibility when no API key)
sessionID := strings.TrimPrefix(authHeader, "Bearer ")
return strings.TrimSpace(sessionID)
} else if authHeader != "" {
// Plain format (per spec 7.1 - API key is session ID)
return authHeader
}
return ""
return auth.ExtractSessionID(authHeader)
}

// logRuntimeError logs runtime errors to stdout per spec section 9.2
Expand Down