Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions internal/server/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package server
import (
"log"
"net/http"
"strings"
"time"

"github.com/githubnext/gh-aw-mcpg/internal/logger"
Expand Down Expand Up @@ -41,6 +42,21 @@ func authMiddleware(apiKey string, next http.HandlerFunc) http.HandlerFunc {
}
}

// extractSessionFromAuth extracts session ID from Authorization header
// Per spec 7.1: When API key is configured, Authorization contains plain API key
// When API key is not configured, supports Bearer token for backward compatibility
func extractSessionFromAuth(authHeader string) string {
if strings.HasPrefix(authHeader, "Bearer ") {
// Bearer token format (for backward compatibility when no API key)
sessionID := strings.TrimPrefix(authHeader, "Bearer ")
return strings.TrimSpace(sessionID)
} else if authHeader != "" {
// Plain format (per spec 7.1 - API key is session ID)
return authHeader
}
return ""
}

// logRuntimeError logs runtime errors to stdout per spec section 9.2
func logRuntimeError(errorType, detail string, r *http.Request, serverName *string) {
timestamp := time.Now().UTC().Format(time.RFC3339)
Expand Down
71 changes: 71 additions & 0 deletions internal/server/handlers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package server

import (
"encoding/json"
"log"
"net/http"
"os"
"time"

"github.com/githubnext/gh-aw-mcpg/internal/logger"
)

// handleOAuthDiscovery returns a handler for OAuth discovery endpoint
// Returns 404 since the gateway doesn't use OAuth
func handleOAuthDiscovery() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("[%s] %s %s - OAuth discovery (not supported)", r.RemoteAddr, r.Method, r.URL.Path)
http.NotFound(w, r)
})
}

// handleClose returns a handler for graceful shutdown endpoint (spec 5.1.3)
func handleClose(unifiedServer *UnifiedServer) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("[%s] %s %s", r.RemoteAddr, r.Method, r.URL.Path)
logger.LogInfo("shutdown", "Close endpoint called, remote=%s", r.RemoteAddr)

// Only accept POST requests
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}

// Check if already closed (idempotency - spec 5.1.3)
if unifiedServer.IsShutdown() {
logger.LogWarn("shutdown", "Close endpoint called but gateway already closed, remote=%s", r.RemoteAddr)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusGone) // 410 Gone
json.NewEncoder(w).Encode(map[string]interface{}{
"error": "Gateway has already been closed",
})
return
}

// Initiate shutdown and get server count
serversTerminated := unifiedServer.InitiateShutdown()

// Return success response (spec 5.1.3)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
response := map[string]interface{}{
"status": "closed",
"message": "Gateway shutdown initiated",
"serversTerminated": serversTerminated,
}
json.NewEncoder(w).Encode(response)

logger.LogInfo("shutdown", "Close endpoint response sent, servers_terminated=%d", serversTerminated)
log.Printf("Gateway shutdown initiated. Terminated %d server(s)", serversTerminated)

// Exit the process after a brief delay to ensure response is sent
// Skip exit in test mode
if unifiedServer.ShouldExit() {
go func() {
time.Sleep(100 * time.Millisecond)
logger.LogInfo("shutdown", "Gateway process exiting with status 0")
os.Exit(0)
}()
}
})
}
42 changes: 42 additions & 0 deletions internal/server/response_writer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package server

import (
"bytes"
"net/http"
)

// responseWriter wraps http.ResponseWriter to capture response body and status code
// This unified implementation replaces loggingResponseWriter and sdkLoggingResponseWriter
type responseWriter struct {
http.ResponseWriter
body bytes.Buffer
statusCode int
}

// newResponseWriter creates a new responseWriter with default status code
func newResponseWriter(w http.ResponseWriter) *responseWriter {
return &responseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
}
}

func (w *responseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}

func (w *responseWriter) Write(b []byte) (int, error) {
w.body.Write(b)
return w.ResponseWriter.Write(b)
}

// Body returns the captured response body as bytes
func (w *responseWriter) Body() []byte {
return w.body.Bytes()
}

// StatusCode returns the captured HTTP status code
func (w *responseWriter) StatusCode() int {
return w.statusCode
}
73 changes: 4 additions & 69 deletions internal/server/routed.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,11 @@ package server
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"strings"
"sync"
"time"

"github.com/githubnext/gh-aw-mcpg/internal/logger"
"github.com/githubnext/gh-aw-mcpg/internal/mcp"
Expand Down Expand Up @@ -75,11 +71,7 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap
mux := http.NewServeMux()

// OAuth discovery endpoint - return 404 since we don't use OAuth
oauthHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("[%s] %s %s - OAuth discovery (not supported)", r.RemoteAddr, r.Method, r.URL.Path)
http.NotFound(w, r)
})
mux.Handle("/mcp/.well-known/oauth-authorization-server", withResponseLogging(oauthHandler))
mux.Handle("/mcp/.well-known/oauth-authorization-server", withResponseLogging(handleOAuthDiscovery()))

// Create routes for all backends, plus sys only if DIFC is enabled
allBackends := unifiedServer.GetServerIDs()
Expand All @@ -103,19 +95,8 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap
// Create StreamableHTTP handler for this route
routeHandler := sdk.NewStreamableHTTPHandler(func(r *http.Request) *sdk.Server {
// Extract session ID from Authorization header
// Per spec 7.1: When API key is configured, Authorization contains plain API key
// When API key is not configured, supports Bearer token for backward compatibility
authHeader := r.Header.Get("Authorization")
var sessionID string

if strings.HasPrefix(authHeader, "Bearer ") {
// Bearer token format (for backward compatibility when no API key)
sessionID = strings.TrimPrefix(authHeader, "Bearer ")
sessionID = strings.TrimSpace(sessionID)
} else if authHeader != "" {
// Plain format (per spec 7.1 - API key is session ID)
sessionID = authHeader
}
sessionID := extractSessionFromAuth(authHeader)

// Reject requests without Authorization header
if sessionID == "" {
Expand Down Expand Up @@ -177,56 +158,10 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap
mux.Handle("/health", withResponseLogging(healthHandler))

// Close endpoint for graceful shutdown (spec 5.1.3)
closeHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("[%s] %s %s", r.RemoteAddr, r.Method, r.URL.Path)
logger.LogInfo("shutdown", "Close endpoint called, remote=%s", r.RemoteAddr)

// Only accept POST requests
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}

// Check if already closed (idempotency - spec 5.1.3)
if unifiedServer.IsShutdown() {
logger.LogWarn("shutdown", "Close endpoint called but gateway already closed, remote=%s", r.RemoteAddr)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusGone) // 410 Gone
json.NewEncoder(w).Encode(map[string]interface{}{
"error": "Gateway has already been closed",
})
return
}

// Initiate shutdown and get server count
serversTerminated := unifiedServer.InitiateShutdown()

// Return success response (spec 5.1.3)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
response := map[string]interface{}{
"status": "closed",
"message": "Gateway shutdown initiated",
"serversTerminated": serversTerminated,
}
json.NewEncoder(w).Encode(response)

logger.LogInfo("shutdown", "Close endpoint response sent, servers_terminated=%d", serversTerminated)
log.Printf("Gateway shutdown initiated. Terminated %d server(s)", serversTerminated)

// Exit the process after a brief delay to ensure response is sent
// Skip exit in test mode
if unifiedServer.ShouldExit() {
go func() {
time.Sleep(100 * time.Millisecond)
logger.LogInfo("shutdown", "Gateway process exiting with status 0")
os.Exit(0)
}()
}
})
closeHandler := handleClose(unifiedServer)

// Apply auth middleware if API key is configured (spec 7.1)
var finalCloseHandler http.Handler = closeHandler
finalCloseHandler := closeHandler
if apiKey != "" {
finalCloseHandler = authMiddleware(apiKey, closeHandler.ServeHTTP)
}
Expand Down
42 changes: 7 additions & 35 deletions internal/server/sdk_logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,6 @@ type JSONRPCError struct {
Data json.RawMessage `json:"data,omitempty"`
}

// sdkLoggingResponseWriter captures response for logging
type sdkLoggingResponseWriter struct {
http.ResponseWriter
body bytes.Buffer
statusCode int
}

func (w *sdkLoggingResponseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}

func (w *sdkLoggingResponseWriter) Write(b []byte) (int, error) {
w.body.Write(b)
return w.ResponseWriter.Write(b)
}

// WithSDKLogging wraps an SDK StreamableHTTPHandler to log JSON-RPC translation results
// This captures the request/response at the HTTP boundary to understand what the SDK
// sees and what it returns, particularly for debugging protocol state issues
Expand All @@ -62,7 +45,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler {

// Extract session info for logging context
authHeader := r.Header.Get("Authorization")
sessionID := extractSessionID(authHeader)
sessionID := extractSessionFromAuth(authHeader)
mcpSessionID := r.Header.Get("Mcp-Session-Id")

// Log incoming request
Expand Down Expand Up @@ -92,26 +75,23 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler {
}

// Wrap response writer to capture output
lw := &sdkLoggingResponseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
}
lw := newResponseWriter(w)

// Call the actual SDK handler
handler.ServeHTTP(lw, r)

duration := time.Since(startTime)

// Parse and log response
responseBody := lw.body.Bytes()
responseBody := lw.Body()
if len(responseBody) > 0 {
// Try to parse as JSON-RPC response
var jsonrpcResp JSONRPCResponse
if err := json.Unmarshal(responseBody, &jsonrpcResp); err == nil {
if jsonrpcResp.Error != nil {
// Error response - this is what we're particularly interested in
logSDK.Printf("<<< SDK Response [%s] ERROR status=%d duration=%v",
mode, lw.statusCode, duration)
mode, lw.StatusCode(), duration)
logSDK.Printf(" JSON-RPC Error: code=%d message=%q",
jsonrpcResp.Error.Code, jsonrpcResp.Error.Message)

Expand All @@ -136,7 +116,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler {
} else {
// Success response
logSDK.Printf("<<< SDK Response [%s] SUCCESS status=%d duration=%v",
mode, lw.statusCode, duration)
mode, lw.StatusCode(), duration)
logSDK.Printf(" JSON-RPC Response id=%v has result=%v",
jsonrpcResp.ID, jsonrpcResp.Result != nil)

Expand All @@ -147,7 +127,7 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler {
} else {
// Could be SSE stream or other format
logSDK.Printf("<<< SDK Response [%s] status=%d duration=%v (non-JSON or stream)",
mode, lw.statusCode, duration)
mode, lw.StatusCode(), duration)
if len(responseBody) < 500 {
logSDK.Printf(" Raw response: %s", string(responseBody))
} else {
Expand All @@ -156,19 +136,11 @@ func WithSDKLogging(handler http.Handler, mode string) http.Handler {
}
} else {
logSDK.Printf("<<< SDK Response [%s] status=%d duration=%v (empty body)",
mode, lw.statusCode, duration)
mode, lw.StatusCode(), duration)
}
})
}

// extractSessionID extracts session ID from Authorization header
func extractSessionID(authHeader string) string {
if strings.HasPrefix(authHeader, "Bearer ") {
return strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer "))
}
return authHeader
}

// truncateSession returns a truncated session ID for logging (first 8 chars)
func truncateSession(s string) string {
if s == "" {
Expand Down
Loading