diff --git a/auth/auth.go b/auth/auth.go index 0eea1d87..6c345714 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -17,6 +17,11 @@ import ( type TokenInfo struct { Scopes []string Expiration time.Time + // UserID is an optional identifier for the authenticated user. + // If set by a TokenVerifier, it can be used by transports to prevent + // session hijacking by ensuring that all requests for a given session + // come from the same user. + UserID string // TODO: add standard JWT fields Extra map[string]any } diff --git a/mcp/streamable.go b/mcp/streamable.go index 9e210abb..9ab3303b 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -51,6 +51,10 @@ type StreamableHTTPHandler struct { type sessionInfo struct { session *ServerSession transport *StreamableServerTransport + // userID is the user ID from the TokenInfo when the session was created. + // If non-empty, subsequent requests must have the same user ID to prevent + // session hijacking. + userID string // If timeout is set, automatically close the session after an idle period. timeout time.Duration @@ -238,6 +242,15 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque http.Error(w, "session not found", http.StatusNotFound) return } + // Prevent session hijacking: if the session was created with a user ID, + // verify that subsequent requests come from the same user. + if sessInfo != nil && sessInfo.userID != "" { + tokenInfo := auth.TokenInfoFromContext(req.Context()) + if tokenInfo == nil || tokenInfo.UserID != sessInfo.userID { + http.Error(w, "session user mismatch", http.StatusForbidden) + return + } + } } if req.Method == http.MethodDelete { @@ -404,9 +417,16 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque http.Error(w, "failed connection", http.StatusInternalServerError) return } + // Capture the user ID from the token info to enable session hijacking + // prevention on subsequent requests. + var userID string + if tokenInfo := auth.TokenInfoFromContext(req.Context()); tokenInfo != nil { + userID = tokenInfo.UserID + } sessInfo = &sessionInfo{ session: session, transport: transport, + userID: userID, } if stateless { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index d2234b8e..9a90d24f 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1649,11 +1649,87 @@ func TestTokenInfo(t *testing.T) { if !ok { t.Fatal("not TextContent") } - if g, w := tc.Text, "&{[scope] 5000-01-02 03:04:05 +0000 UTC map[]}"; g != w { + if g, w := tc.Text, "&{[scope] 5000-01-02 03:04:05 +0000 UTC map[]}"; g != w { t.Errorf("got %q, want %q", g, w) } } +func TestSessionHijackingPrevention(t *testing.T) { + // This test verifies that sessions bound to a user ID cannot be accessed + // by a different user (session hijacking prevention). + ctx := context.Background() + + server := NewServer(testImpl, nil) + streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + + // Use the bearer token directly as the user ID. This simulates how a real + // verifier might extract a user ID from a JWT "sub" claim or introspection. + verifier := func(_ context.Context, token string, _ *http.Request) (*auth.TokenInfo, error) { + return &auth.TokenInfo{ + Scopes: []string{"scope"}, + UserID: token, + Expiration: time.Date(5000, 1, 2, 3, 4, 5, 0, time.UTC), + }, nil + } + handler := auth.RequireBearerToken(verifier, nil)(streamHandler) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + // Helper to send a JSON-RPC request as a given user. + doRequest := func(msg jsonrpc.Message, sessionID, userID string) *http.Response { + t.Helper() + data, _ := jsonrpc2.EncodeMessage(msg) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, httpServer.URL, bytes.NewReader(data)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Authorization", "Bearer "+userID) + if sessionID != "" { + req.Header.Set("Mcp-Session-Id", sessionID) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + return resp + } + + // Create a session as user1. + initReq := &jsonrpc.Request{Method: "initialize", ID: jsonrpc2.Int64ID(1)} + initReq.Params, _ = json.Marshal(&InitializeParams{ + ProtocolVersion: protocolVersion20250618, + ClientInfo: &Implementation{Name: "test", Version: "1.0"}, + }) + resp := doRequest(initReq, "", "user1") + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("initialize failed with status %d: %s", resp.StatusCode, body) + } + sessionID := resp.Header.Get("Mcp-Session-Id") + if sessionID == "" { + t.Fatal("no session ID in response") + } + + pingReq := &jsonrpc.Request{Method: "ping", ID: jsonrpc2.Int64ID(2)} + pingReq.Params, _ = json.Marshal(&PingParams{}) + + // Try to access the session as user2 - should fail. + resp2 := doRequest(pingReq, sessionID, "user2") + defer resp2.Body.Close() + if resp2.StatusCode != http.StatusForbidden { + body, _ := io.ReadAll(resp2.Body) + t.Errorf("expected status %d for user mismatch, got %d: %s", http.StatusForbidden, resp2.StatusCode, body) + } + + // Access as original user1 should succeed. + resp3 := doRequest(pingReq, sessionID, "user1") + defer resp3.Body.Close() + if resp3.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp3.Body) + t.Errorf("expected status %d for matching user, got %d: %s", http.StatusOK, resp3.StatusCode, body) + } +} + func TestStreamableGET(t *testing.T) { // This test checks the fix for problematic behavior described in #410: // Hanging GET headers should be written immediately, even if there are no