diff --git a/internal/server/auth.go b/internal/server/auth.go index d7d450f3..975a98b9 100644 --- a/internal/server/auth.go +++ b/internal/server/auth.go @@ -3,6 +3,7 @@ package server import ( "log" "net/http" + "strings" "time" "github.com/githubnext/gh-aw-mcpg/internal/logger" @@ -41,6 +42,21 @@ func authMiddleware(apiKey string, next http.HandlerFunc) http.HandlerFunc { } } +// extractSessionFromAuth extracts session ID from Authorization header +// Per spec 7.1: When API key is configured, Authorization contains plain API key +// When API key is not configured, supports Bearer token for backward compatibility +func extractSessionFromAuth(authHeader string) string { + if strings.HasPrefix(authHeader, "Bearer ") { + // Bearer token format (for backward compatibility when no API key) + sessionID := strings.TrimPrefix(authHeader, "Bearer ") + return strings.TrimSpace(sessionID) + } else if authHeader != "" { + // Plain format (per spec 7.1 - API key is session ID) + return authHeader + } + return "" +} + // logRuntimeError logs runtime errors to stdout per spec section 9.2 func logRuntimeError(errorType, detail string, r *http.Request, serverName *string) { timestamp := time.Now().UTC().Format(time.RFC3339) diff --git a/internal/server/handlers.go b/internal/server/handlers.go new file mode 100644 index 00000000..e9d5944e --- /dev/null +++ b/internal/server/handlers.go @@ -0,0 +1,71 @@ +package server + +import ( + "encoding/json" + "log" + "net/http" + "os" + "time" + + "github.com/githubnext/gh-aw-mcpg/internal/logger" +) + +// handleOAuthDiscovery returns a handler for OAuth discovery endpoint +// Returns 404 since the gateway doesn't use OAuth +func handleOAuthDiscovery() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("[%s] %s %s - OAuth discovery (not supported)", r.RemoteAddr, r.Method, r.URL.Path) + http.NotFound(w, r) + }) +} + +// handleClose returns a handler for graceful shutdown endpoint (spec 5.1.3) +func handleClose(unifiedServer *UnifiedServer) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("[%s] %s %s", r.RemoteAddr, r.Method, r.URL.Path) + logger.LogInfo("shutdown", "Close endpoint called, remote=%s", r.RemoteAddr) + + // Only accept POST requests + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Check if already closed (idempotency - spec 5.1.3) + if unifiedServer.IsShutdown() { + logger.LogWarn("shutdown", "Close endpoint called but gateway already closed, remote=%s", r.RemoteAddr) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusGone) // 410 Gone + json.NewEncoder(w).Encode(map[string]interface{}{ + "error": "Gateway has already been closed", + }) + return + } + + // Initiate shutdown and get server count + serversTerminated := unifiedServer.InitiateShutdown() + + // Return success response (spec 5.1.3) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + response := map[string]interface{}{ + "status": "closed", + "message": "Gateway shutdown initiated", + "serversTerminated": serversTerminated, + } + json.NewEncoder(w).Encode(response) + + logger.LogInfo("shutdown", "Close endpoint response sent, servers_terminated=%d", serversTerminated) + log.Printf("Gateway shutdown initiated. Terminated %d server(s)", serversTerminated) + + // Exit the process after a brief delay to ensure response is sent + // Skip exit in test mode + if unifiedServer.ShouldExit() { + go func() { + time.Sleep(100 * time.Millisecond) + logger.LogInfo("shutdown", "Gateway process exiting with status 0") + os.Exit(0) + }() + } + }) +} diff --git a/internal/server/response_writer.go b/internal/server/response_writer.go new file mode 100644 index 00000000..f3ddf29b --- /dev/null +++ b/internal/server/response_writer.go @@ -0,0 +1,42 @@ +package server + +import ( + "bytes" + "net/http" +) + +// responseWriter wraps http.ResponseWriter to capture response body and status code +// This unified implementation replaces loggingResponseWriter and sdkLoggingResponseWriter +type responseWriter struct { + http.ResponseWriter + body bytes.Buffer + statusCode int +} + +// newResponseWriter creates a new responseWriter with default status code +func newResponseWriter(w http.ResponseWriter) *responseWriter { + return &responseWriter{ + ResponseWriter: w, + statusCode: http.StatusOK, + } +} + +func (w *responseWriter) WriteHeader(statusCode int) { + w.statusCode = statusCode + w.ResponseWriter.WriteHeader(statusCode) +} + +func (w *responseWriter) Write(b []byte) (int, error) { + w.body.Write(b) + return w.ResponseWriter.Write(b) +} + +// Body returns the captured response body as bytes +func (w *responseWriter) Body() []byte { + return w.body.Bytes() +} + +// StatusCode returns the captured HTTP status code +func (w *responseWriter) StatusCode() int { + return w.statusCode +} diff --git a/internal/server/routed.go b/internal/server/routed.go index 5d55b5da..edbe582d 100644 --- a/internal/server/routed.go +++ b/internal/server/routed.go @@ -3,15 +3,11 @@ package server import ( "bytes" "context" - "encoding/json" "fmt" "io" "log" "net/http" - "os" - "strings" "sync" - "time" "github.com/githubnext/gh-aw-mcpg/internal/logger" "github.com/githubnext/gh-aw-mcpg/internal/mcp" @@ -75,11 +71,7 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap mux := http.NewServeMux() // OAuth discovery endpoint - return 404 since we don't use OAuth - oauthHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Printf("[%s] %s %s - OAuth discovery (not supported)", r.RemoteAddr, r.Method, r.URL.Path) - http.NotFound(w, r) - }) - mux.Handle("/mcp/.well-known/oauth-authorization-server", withResponseLogging(oauthHandler)) + mux.Handle("/mcp/.well-known/oauth-authorization-server", withResponseLogging(handleOAuthDiscovery())) // Create routes for all backends, plus sys only if DIFC is enabled allBackends := unifiedServer.GetServerIDs() @@ -103,19 +95,8 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap // Create StreamableHTTP handler for this route routeHandler := sdk.NewStreamableHTTPHandler(func(r *http.Request) *sdk.Server { // Extract session ID from Authorization header - // Per spec 7.1: When API key is configured, Authorization contains plain API key - // When API key is not configured, supports Bearer token for backward compatibility authHeader := r.Header.Get("Authorization") - var sessionID string - - if strings.HasPrefix(authHeader, "Bearer ") { - // Bearer token format (for backward compatibility when no API key) - sessionID = strings.TrimPrefix(authHeader, "Bearer ") - sessionID = strings.TrimSpace(sessionID) - } else if authHeader != "" { - // Plain format (per spec 7.1 - API key is session ID) - sessionID = authHeader - } + sessionID := extractSessionFromAuth(authHeader) // Reject requests without Authorization header if sessionID == "" { @@ -177,56 +158,10 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap mux.Handle("/health", withResponseLogging(healthHandler)) // Close endpoint for graceful shutdown (spec 5.1.3) - closeHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Printf("[%s] %s %s", r.RemoteAddr, r.Method, r.URL.Path) - logger.LogInfo("shutdown", "Close endpoint called, remote=%s", r.RemoteAddr) - - // Only accept POST requests - if r.Method != http.MethodPost { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - // Check if already closed (idempotency - spec 5.1.3) - if unifiedServer.IsShutdown() { - logger.LogWarn("shutdown", "Close endpoint called but gateway already closed, remote=%s", r.RemoteAddr) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusGone) // 410 Gone - json.NewEncoder(w).Encode(map[string]interface{}{ - "error": "Gateway has already been closed", - }) - return - } - - // Initiate shutdown and get server count - serversTerminated := unifiedServer.InitiateShutdown() - - // Return success response (spec 5.1.3) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - response := map[string]interface{}{ - "status": "closed", - "message": "Gateway shutdown initiated", - "serversTerminated": serversTerminated, - } - json.NewEncoder(w).Encode(response) - - logger.LogInfo("shutdown", "Close endpoint response sent, servers_terminated=%d", serversTerminated) - log.Printf("Gateway shutdown initiated. Terminated %d server(s)", serversTerminated) - - // Exit the process after a brief delay to ensure response is sent - // Skip exit in test mode - if unifiedServer.ShouldExit() { - go func() { - time.Sleep(100 * time.Millisecond) - logger.LogInfo("shutdown", "Gateway process exiting with status 0") - os.Exit(0) - }() - } - }) + closeHandler := handleClose(unifiedServer) // Apply auth middleware if API key is configured (spec 7.1) - var finalCloseHandler http.Handler = closeHandler + finalCloseHandler := closeHandler if apiKey != "" { finalCloseHandler = authMiddleware(apiKey, closeHandler.ServeHTTP) } diff --git a/internal/server/sdk_logging.go b/internal/server/sdk_logging.go index 9332ac36..228407dc 100644 --- a/internal/server/sdk_logging.go +++ b/internal/server/sdk_logging.go @@ -36,23 +36,6 @@ type JSONRPCError struct { Data json.RawMessage `json:"data,omitempty"` } -// sdkLoggingResponseWriter captures response for logging -type sdkLoggingResponseWriter struct { - http.ResponseWriter - body bytes.Buffer - statusCode int -} - -func (w *sdkLoggingResponseWriter) WriteHeader(statusCode int) { - w.statusCode = statusCode - w.ResponseWriter.WriteHeader(statusCode) -} - -func (w *sdkLoggingResponseWriter) Write(b []byte) (int, error) { - w.body.Write(b) - return w.ResponseWriter.Write(b) -} - // WithSDKLogging wraps an SDK StreamableHTTPHandler to log JSON-RPC translation results // This captures the request/response at the HTTP boundary to understand what the SDK // sees and what it returns, particularly for debugging protocol state issues @@ -62,7 +45,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler { // Extract session info for logging context authHeader := r.Header.Get("Authorization") - sessionID := extractSessionID(authHeader) + sessionID := extractSessionFromAuth(authHeader) mcpSessionID := r.Header.Get("Mcp-Session-Id") // Log incoming request @@ -92,10 +75,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler { } // Wrap response writer to capture output - lw := &sdkLoggingResponseWriter{ - ResponseWriter: w, - statusCode: http.StatusOK, - } + lw := newResponseWriter(w) // Call the actual SDK handler handler.ServeHTTP(lw, r) @@ -103,7 +83,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler { duration := time.Since(startTime) // Parse and log response - responseBody := lw.body.Bytes() + responseBody := lw.Body() if len(responseBody) > 0 { // Try to parse as JSON-RPC response var jsonrpcResp JSONRPCResponse @@ -111,7 +91,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler { if jsonrpcResp.Error != nil { // Error response - this is what we're particularly interested in logSDK.Printf("<<< SDK Response [%s] ERROR status=%d duration=%v", - mode, lw.statusCode, duration) + mode, lw.StatusCode(), duration) logSDK.Printf(" JSON-RPC Error: code=%d message=%q", jsonrpcResp.Error.Code, jsonrpcResp.Error.Message) @@ -136,7 +116,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler { } else { // Success response logSDK.Printf("<<< SDK Response [%s] SUCCESS status=%d duration=%v", - mode, lw.statusCode, duration) + mode, lw.StatusCode(), duration) logSDK.Printf(" JSON-RPC Response id=%v has result=%v", jsonrpcResp.ID, jsonrpcResp.Result != nil) @@ -147,7 +127,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler { } else { // Could be SSE stream or other format logSDK.Printf("<<< SDK Response [%s] status=%d duration=%v (non-JSON or stream)", - mode, lw.statusCode, duration) + mode, lw.StatusCode(), duration) if len(responseBody) < 500 { logSDK.Printf(" Raw response: %s", string(responseBody)) } else { @@ -156,19 +136,11 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler { } } else { logSDK.Printf("<<< SDK Response [%s] status=%d duration=%v (empty body)", - mode, lw.statusCode, duration) + mode, lw.StatusCode(), duration) } }) } -// extractSessionID extracts session ID from Authorization header -func extractSessionID(authHeader string) string { - if strings.HasPrefix(authHeader, "Bearer ") { - return strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer ")) - } - return authHeader -} - // truncateSession returns a truncated session ID for logging (first 8 chars) func truncateSession(s string) string { if s == "" { diff --git a/internal/server/transport.go b/internal/server/transport.go index 49595730..eccbd8b5 100644 --- a/internal/server/transport.go +++ b/internal/server/transport.go @@ -3,13 +3,9 @@ package server import ( "bytes" "context" - "encoding/json" "io" "log" "net/http" - "os" - "strings" - "time" "github.com/githubnext/gh-aw-mcpg/internal/logger" sdk "github.com/modelcontextprotocol/go-sdk/mcp" @@ -48,30 +44,13 @@ func (t *HTTPTransport) Close() error { return nil } -// loggingResponseWriter wraps http.ResponseWriter to capture response body -type loggingResponseWriter struct { - http.ResponseWriter - body []byte - statusCode int -} - -func (w *loggingResponseWriter) WriteHeader(statusCode int) { - w.statusCode = statusCode - w.ResponseWriter.WriteHeader(statusCode) -} - -func (w *loggingResponseWriter) Write(b []byte) (int, error) { - w.body = append(w.body, b...) - return w.ResponseWriter.Write(b) -} - // withResponseLogging wraps an http.Handler to log response bodies func withResponseLogging(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - lw := &loggingResponseWriter{ResponseWriter: w, body: []byte{}, statusCode: http.StatusOK} + lw := newResponseWriter(w) handler.ServeHTTP(lw, r) - if len(lw.body) > 0 { - log.Printf("[%s] %s %s - Status: %d, Response: %s", r.RemoteAddr, r.Method, r.URL.Path, lw.statusCode, string(lw.body)) + if len(lw.Body()) > 0 { + log.Printf("[%s] %s %s - Status: %d, Response: %s", r.RemoteAddr, r.Method, r.URL.Path, lw.StatusCode(), string(lw.Body())) } }) } @@ -83,11 +62,7 @@ func CreateHTTPServerForMCP(addr string, unifiedServer *UnifiedServer, apiKey st mux := http.NewServeMux() // OAuth discovery endpoint - return 404 since we don't use OAuth - oauthHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Printf("[%s] %s %s - OAuth discovery (not supported)", r.RemoteAddr, r.Method, r.URL.Path) - http.NotFound(w, r) - }) - mux.Handle("/mcp/.well-known/oauth-authorization-server", withResponseLogging(oauthHandler)) + mux.Handle("/mcp/.well-known/oauth-authorization-server", withResponseLogging(handleOAuthDiscovery())) logTransport.Print("Registering streamable HTTP handler for MCP protocol") // Create StreamableHTTP handler for MCP protocol (supports POST requests) @@ -99,19 +74,8 @@ func CreateHTTPServerForMCP(addr string, unifiedServer *UnifiedServer, apiKey st // This groups all requests from the same agent (same auth value) into one session // Extract session ID from Authorization header - // Per spec 7.1: When API key is configured, Authorization contains plain API key - // When API key is not configured, supports Bearer token for backward compatibility authHeader := r.Header.Get("Authorization") - var sessionID string - - if strings.HasPrefix(authHeader, "Bearer ") { - // Bearer token format (for backward compatibility when no API key) - sessionID = strings.TrimPrefix(authHeader, "Bearer ") - sessionID = strings.TrimSpace(sessionID) - } else if authHeader != "" { - // Plain format (per spec 7.1 - API key is session ID) - sessionID = authHeader - } + sessionID := extractSessionFromAuth(authHeader) // Reject requests without Authorization header if sessionID == "" { @@ -171,58 +135,10 @@ func CreateHTTPServerForMCP(addr string, unifiedServer *UnifiedServer, apiKey st mux.Handle("/health", withResponseLogging(healthHandler)) // Close endpoint for graceful shutdown (spec 5.1.3) - closeHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - logTransport.Printf("Close endpoint called: method=%s, remote=%s", r.Method, r.RemoteAddr) - log.Printf("[%s] %s %s", r.RemoteAddr, r.Method, r.URL.Path) - logger.LogInfo("shutdown", "Close endpoint called, remote=%s", r.RemoteAddr) - - // Only accept POST requests - if r.Method != http.MethodPost { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - // Check if already closed (idempotency - spec 5.1.3) - if unifiedServer.IsShutdown() { - logger.LogWarn("shutdown", "Close endpoint called but gateway already closed, remote=%s", r.RemoteAddr) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusGone) // 410 Gone - json.NewEncoder(w).Encode(map[string]interface{}{ - "error": "Gateway has already been closed", - }) - return - } - - // Initiate shutdown and get server count - serversTerminated := unifiedServer.InitiateShutdown() - logTransport.Printf("Shutdown initiated: servers_terminated=%d", serversTerminated) - - // Return success response (spec 5.1.3) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - response := map[string]interface{}{ - "status": "closed", - "message": "Gateway shutdown initiated", - "serversTerminated": serversTerminated, - } - json.NewEncoder(w).Encode(response) - - logger.LogInfo("shutdown", "Close endpoint response sent, servers_terminated=%d", serversTerminated) - log.Printf("Gateway shutdown initiated. Terminated %d server(s)", serversTerminated) - - // Exit the process after a brief delay to ensure response is sent - // Skip exit in test mode - if unifiedServer.ShouldExit() { - go func() { - time.Sleep(100 * time.Millisecond) - logger.LogInfo("shutdown", "Gateway process exiting with status 0") - os.Exit(0) - }() - } - }) + closeHandler := handleClose(unifiedServer) // Apply auth middleware if API key is configured (spec 7.1) - var finalCloseHandler http.Handler = closeHandler + finalCloseHandler := closeHandler if apiKey != "" { finalCloseHandler = authMiddleware(apiKey, closeHandler.ServeHTTP) }