Skip to content

Commit c977042

Browse files
committed
fix: reuse sessions correctly in streamable HTTP transport
Fixes session management in the streamable HTTP server to properly reuse registered sessions for POST requests instead of always creating ephemeral sessions. This enables SendNotificationToSpecificClient and session-aware features to work correctly with POST-based interactions. Changes: - Check s.server.sessions for existing sessions before creating ephemeral ones - Register sessions after successful initialization from POST requests - Store sessions in both s.server.sessions and s.activeSessions for consistency - Add comprehensive tests for session reuse and notification delivery Fixes #614
1 parent 74a600b commit c977042

File tree

2 files changed

+197
-4
lines changed

2 files changed

+197
-4
lines changed

server/streamable_http.go

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,23 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
309309
}
310310
}
311311

312+
// For non-initialize requests, try to reuse existing registered session
313+
var session *streamableHttpSession
314+
if !isInitializeRequest {
315+
if sessionValue, ok := s.server.sessions.Load(sessionID); ok {
316+
if existingSession, ok := sessionValue.(*streamableHttpSession); ok {
317+
session = existingSession
318+
}
319+
}
320+
}
321+
312322
// Check if a persistent session exists (for sampling support), otherwise create ephemeral session
313323
// Persistent sessions are created by GET (continuous listening) connections
314-
var session *streamableHttpSession
315-
if sessionInterface, exists := s.activeSessions.Load(sessionID); exists {
316-
if persistentSession, ok := sessionInterface.(*streamableHttpSession); ok {
317-
session = persistentSession
324+
if session == nil {
325+
if sessionInterface, exists := s.activeSessions.Load(sessionID); exists {
326+
if persistentSession, ok := sessionInterface.(*streamableHttpSession); ok {
327+
session = persistentSession
328+
}
318329
}
319330
}
320331

@@ -417,6 +428,21 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
417428
s.logger.Errorf("Failed to write response: %v", err)
418429
}
419430
}
431+
432+
// Register session after successful initialization
433+
// Only register if not already registered (e.g., by a GET connection)
434+
if isInitializeRequest && sessionID != "" {
435+
if _, exists := s.server.sessions.Load(sessionID); !exists {
436+
// Store in activeSessions to prevent duplicate registration from GET
437+
s.activeSessions.Store(sessionID, session)
438+
// Register the session with the MCPServer for notification support
439+
if err := s.server.RegisterSession(ctx, session); err != nil {
440+
s.logger.Errorf("Failed to register POST session: %v", err)
441+
s.activeSessions.Delete(sessionID)
442+
// Don't fail the request, just log the error
443+
}
444+
}
445+
}
420446
}
421447

422448
func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) {

server/streamable_http_test.go

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,3 +1314,170 @@ func TestInsecureStatefulSessionIdManager(t *testing.T) {
13141314
}
13151315
})
13161316
}
1317+
1318+
func TestStreamableHTTP_SendNotificationToSpecificClient(t *testing.T) {
1319+
t.Run("POST session registration enables SendNotificationToSpecificClient", func(t *testing.T) {
1320+
hooks := &Hooks{}
1321+
var registeredSessionID string
1322+
var mu sync.Mutex
1323+
var sessionRegistered sync.WaitGroup
1324+
sessionRegistered.Add(1)
1325+
1326+
hooks.AddOnRegisterSession(func(ctx context.Context, session ClientSession) {
1327+
mu.Lock()
1328+
registeredSessionID = session.SessionID()
1329+
mu.Unlock()
1330+
sessionRegistered.Done()
1331+
})
1332+
1333+
mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks))
1334+
testServer := NewTestStreamableHTTPServer(mcpServer)
1335+
defer testServer.Close()
1336+
1337+
// Send initialize request to register session
1338+
resp, err := postJSON(testServer.URL, initRequest)
1339+
if err != nil {
1340+
t.Fatalf("Failed to send initialize request: %v", err)
1341+
}
1342+
defer resp.Body.Close()
1343+
1344+
if resp.StatusCode != http.StatusOK {
1345+
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
1346+
}
1347+
1348+
// Get session ID from response header
1349+
sessionID := resp.Header.Get(HeaderKeySessionID)
1350+
if sessionID == "" {
1351+
t.Fatal("Expected session ID in response header")
1352+
}
1353+
1354+
// Wait for session registration
1355+
done := make(chan struct{})
1356+
go func() {
1357+
sessionRegistered.Wait()
1358+
close(done)
1359+
}()
1360+
1361+
select {
1362+
case <-done:
1363+
// Session registered successfully
1364+
case <-time.After(2 * time.Second):
1365+
t.Fatal("Timeout waiting for session registration")
1366+
}
1367+
1368+
mu.Lock()
1369+
if registeredSessionID != sessionID {
1370+
t.Errorf("Expected registered session ID %s, got %s", sessionID, registeredSessionID)
1371+
}
1372+
mu.Unlock()
1373+
1374+
// Now test SendNotificationToSpecificClient
1375+
err = mcpServer.SendNotificationToSpecificClient(sessionID, "test/notification", map[string]any{
1376+
"message": "test notification",
1377+
})
1378+
if err != nil {
1379+
t.Errorf("SendNotificationToSpecificClient failed: %v", err)
1380+
}
1381+
})
1382+
1383+
t.Run("Session reuse for non-initialize requests", func(t *testing.T) {
1384+
mcpServer := NewMCPServer("test", "1.0.0")
1385+
1386+
// Add a tool that sends a notification
1387+
mcpServer.AddTool(mcp.NewTool("notify_tool"), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
1388+
session := ClientSessionFromContext(ctx)
1389+
if session == nil {
1390+
return mcp.NewToolResultError("no session in context"), nil
1391+
}
1392+
1393+
// Try to send notification to specific client
1394+
server := ServerFromContext(ctx)
1395+
err := server.SendNotificationToSpecificClient(session.SessionID(), "tool/notification", map[string]any{
1396+
"from": "tool",
1397+
})
1398+
if err != nil {
1399+
return mcp.NewToolResultError(fmt.Sprintf("notification failed: %v", err)), nil
1400+
}
1401+
1402+
return mcp.NewToolResultText("notification sent"), nil
1403+
})
1404+
1405+
testServer := NewTestStreamableHTTPServer(mcpServer)
1406+
defer testServer.Close()
1407+
1408+
// Initialize session
1409+
resp, err := postJSON(testServer.URL, initRequest)
1410+
if err != nil {
1411+
t.Fatalf("Failed to send initialize request: %v", err)
1412+
}
1413+
sessionID := resp.Header.Get(HeaderKeySessionID)
1414+
resp.Body.Close()
1415+
1416+
if sessionID == "" {
1417+
t.Fatal("Expected session ID in response header")
1418+
}
1419+
1420+
// Give time for registration to complete
1421+
time.Sleep(100 * time.Millisecond)
1422+
1423+
// Call tool with the session ID
1424+
toolCallRequest := map[string]any{
1425+
"jsonrpc": "2.0",
1426+
"id": 2,
1427+
"method": "tools/call",
1428+
"params": map[string]any{
1429+
"name": "notify_tool",
1430+
},
1431+
}
1432+
1433+
jsonBody, _ := json.Marshal(toolCallRequest)
1434+
req, _ := http.NewRequest(http.MethodPost, testServer.URL, bytes.NewBuffer(jsonBody))
1435+
req.Header.Set("Content-Type", "application/json")
1436+
req.Header.Set(HeaderKeySessionID, sessionID)
1437+
1438+
resp, err = http.DefaultClient.Do(req)
1439+
if err != nil {
1440+
t.Fatalf("Failed to call tool: %v", err)
1441+
}
1442+
defer resp.Body.Close()
1443+
1444+
bodyBytes, _ := io.ReadAll(resp.Body)
1445+
bodyStr := string(bodyBytes)
1446+
1447+
// Response might be SSE format if notification was sent
1448+
var toolResponse jsonRPCResponse
1449+
if strings.HasPrefix(bodyStr, "event: message") {
1450+
// Parse SSE format
1451+
lines := strings.Split(bodyStr, "\n")
1452+
for _, line := range lines {
1453+
if strings.HasPrefix(line, "data: ") {
1454+
jsonData := strings.TrimPrefix(line, "data: ")
1455+
if err := json.Unmarshal([]byte(jsonData), &toolResponse); err == nil {
1456+
break
1457+
}
1458+
}
1459+
}
1460+
} else {
1461+
if err := json.Unmarshal(bodyBytes, &toolResponse); err != nil {
1462+
t.Fatalf("Failed to unmarshal response: %v. Body: %s", err, bodyStr)
1463+
}
1464+
}
1465+
1466+
if toolResponse.Error != nil {
1467+
t.Errorf("Tool call failed: %v", toolResponse.Error)
1468+
}
1469+
1470+
// Verify the tool result indicates success
1471+
if result, ok := toolResponse.Result["content"].([]any); ok {
1472+
if len(result) > 0 {
1473+
if content, ok := result[0].(map[string]any); ok {
1474+
if text, ok := content["text"].(string); ok {
1475+
if text != "notification sent" {
1476+
t.Errorf("Expected 'notification sent', got %s", text)
1477+
}
1478+
}
1479+
}
1480+
}
1481+
}
1482+
})
1483+
}

0 commit comments

Comments
 (0)