Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ import (
type TokenInfo struct {
Scopes []string
Expiration time.Time
// UserID is an optional identifier for the authenticated user.
// If set by a TokenVerifier, it can be used by transports to prevent
// session hijacking by ensuring that all requests for a given session
// come from the same user.
UserID string
// TODO: add standard JWT fields
Extra map[string]any
}
Expand Down
20 changes: 20 additions & 0 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ type StreamableHTTPHandler struct {
type sessionInfo struct {
session *ServerSession
transport *StreamableServerTransport
// userID is the user ID from the TokenInfo when the session was created.
// If non-empty, subsequent requests must have the same user ID to prevent
// session hijacking.
userID string

// If timeout is set, automatically close the session after an idle period.
timeout time.Duration
Expand Down Expand Up @@ -238,6 +242,15 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
http.Error(w, "session not found", http.StatusNotFound)
return
}
// Prevent session hijacking: if the session was created with a user ID,
// verify that subsequent requests come from the same user.
if sessInfo != nil && sessInfo.userID != "" {
tokenInfo := auth.TokenInfoFromContext(req.Context())
if tokenInfo == nil || tokenInfo.UserID != sessInfo.userID {
http.Error(w, "session user mismatch", http.StatusForbidden)
return
}
}
}

if req.Method == http.MethodDelete {
Expand Down Expand Up @@ -404,9 +417,16 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
http.Error(w, "failed connection", http.StatusInternalServerError)
return
}
// Capture the user ID from the token info to enable session hijacking
// prevention on subsequent requests.
var userID string
if tokenInfo := auth.TokenInfoFromContext(req.Context()); tokenInfo != nil {
userID = tokenInfo.UserID
}
sessInfo = &sessionInfo{
session: session,
transport: transport,
userID: userID,
}

if stateless {
Expand Down
78 changes: 77 additions & 1 deletion mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1649,11 +1649,87 @@ func TestTokenInfo(t *testing.T) {
if !ok {
t.Fatal("not TextContent")
}
if g, w := tc.Text, "&{[scope] 5000-01-02 03:04:05 +0000 UTC map[]}"; g != w {
if g, w := tc.Text, "&{[scope] 5000-01-02 03:04:05 +0000 UTC map[]}"; g != w {
t.Errorf("got %q, want %q", g, w)
}
}

func TestSessionHijackingPrevention(t *testing.T) {
// This test verifies that sessions bound to a user ID cannot be accessed
// by a different user (session hijacking prevention).
ctx := context.Background()

server := NewServer(testImpl, nil)
streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)

// Use the bearer token directly as the user ID. This simulates how a real
// verifier might extract a user ID from a JWT "sub" claim or introspection.
verifier := func(_ context.Context, token string, _ *http.Request) (*auth.TokenInfo, error) {
return &auth.TokenInfo{
Scopes: []string{"scope"},
UserID: token,
Expiration: time.Date(5000, 1, 2, 3, 4, 5, 0, time.UTC),
}, nil
}
handler := auth.RequireBearerToken(verifier, nil)(streamHandler)
httpServer := httptest.NewServer(mustNotPanic(t, handler))
defer httpServer.Close()

// Helper to send a JSON-RPC request as a given user.
doRequest := func(msg jsonrpc.Message, sessionID, userID string) *http.Response {
t.Helper()
data, _ := jsonrpc2.EncodeMessage(msg)
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, httpServer.URL, bytes.NewReader(data))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")
req.Header.Set("Authorization", "Bearer "+userID)
if sessionID != "" {
req.Header.Set("Mcp-Session-Id", sessionID)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
return resp
}

// Create a session as user1.
initReq := &jsonrpc.Request{Method: "initialize", ID: jsonrpc2.Int64ID(1)}
initReq.Params, _ = json.Marshal(&InitializeParams{
ProtocolVersion: protocolVersion20250618,
ClientInfo: &Implementation{Name: "test", Version: "1.0"},
})
resp := doRequest(initReq, "", "user1")
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("initialize failed with status %d: %s", resp.StatusCode, body)
}
sessionID := resp.Header.Get("Mcp-Session-Id")
if sessionID == "" {
t.Fatal("no session ID in response")
}

pingReq := &jsonrpc.Request{Method: "ping", ID: jsonrpc2.Int64ID(2)}
pingReq.Params, _ = json.Marshal(&PingParams{})

// Try to access the session as user2 - should fail.
resp2 := doRequest(pingReq, sessionID, "user2")
defer resp2.Body.Close()
if resp2.StatusCode != http.StatusForbidden {
body, _ := io.ReadAll(resp2.Body)
t.Errorf("expected status %d for user mismatch, got %d: %s", http.StatusForbidden, resp2.StatusCode, body)
}

// Access as original user1 should succeed.
resp3 := doRequest(pingReq, sessionID, "user1")
defer resp3.Body.Close()
if resp3.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp3.Body)
t.Errorf("expected status %d for matching user, got %d: %s", http.StatusOK, resp3.StatusCode, body)
}
}

func TestStreamableGET(t *testing.T) {
// This test checks the fix for problematic behavior described in #410:
// Hanging GET headers should be written immediately, even if there are no
Expand Down