diff --git a/server/streamable_http.go b/server/streamable_http.go index 8c31d1762..10e8e7262 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -309,12 +309,23 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request } } + // For non-initialize requests, try to reuse existing registered session + var session *streamableHttpSession + if !isInitializeRequest { + if sessionValue, ok := s.server.sessions.Load(sessionID); ok { + if existingSession, ok := sessionValue.(*streamableHttpSession); ok { + session = existingSession + } + } + } + // Check if a persistent session exists (for sampling support), otherwise create ephemeral session // Persistent sessions are created by GET (continuous listening) connections - var session *streamableHttpSession - if sessionInterface, exists := s.activeSessions.Load(sessionID); exists { - if persistentSession, ok := sessionInterface.(*streamableHttpSession); ok { - session = persistentSession + if session == nil { + if sessionInterface, exists := s.activeSessions.Load(sessionID); exists { + if persistentSession, ok := sessionInterface.(*streamableHttpSession); ok { + session = persistentSession + } } } @@ -417,6 +428,21 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request s.logger.Errorf("Failed to write response: %v", err) } } + + // Register session after successful initialization + // Only register if not already registered (e.g., by a GET connection) + if isInitializeRequest && sessionID != "" { + if _, exists := s.server.sessions.Load(sessionID); !exists { + // Store in activeSessions to prevent duplicate registration from GET + s.activeSessions.Store(sessionID, session) + // Register the session with the MCPServer for notification support + if err := s.server.RegisterSession(ctx, session); err != nil { + s.logger.Errorf("Failed to register POST session: %v", err) + s.activeSessions.Delete(sessionID) + // Don't fail the request, just log the error + } + } + } } func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) { diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index c2647f8a1..175ec7dd8 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -1314,3 +1314,170 @@ func TestInsecureStatefulSessionIdManager(t *testing.T) { } }) } + +func TestStreamableHTTP_SendNotificationToSpecificClient(t *testing.T) { + t.Run("POST session registration enables SendNotificationToSpecificClient", func(t *testing.T) { + hooks := &Hooks{} + var registeredSessionID string + var mu sync.Mutex + var sessionRegistered sync.WaitGroup + sessionRegistered.Add(1) + + hooks.AddOnRegisterSession(func(ctx context.Context, session ClientSession) { + mu.Lock() + registeredSessionID = session.SessionID() + mu.Unlock() + sessionRegistered.Done() + }) + + mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks)) + testServer := NewTestStreamableHTTPServer(mcpServer) + defer testServer.Close() + + // Send initialize request to register session + resp, err := postJSON(testServer.URL, initRequest) + if err != nil { + t.Fatalf("Failed to send initialize request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Expected status 200, got %d", resp.StatusCode) + } + + // Get session ID from response header + sessionID := resp.Header.Get(HeaderKeySessionID) + if sessionID == "" { + t.Fatal("Expected session ID in response header") + } + + // Wait for session registration + done := make(chan struct{}) + go func() { + sessionRegistered.Wait() + close(done) + }() + + select { + case <-done: + // Session registered successfully + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for session registration") + } + + mu.Lock() + if registeredSessionID != sessionID { + t.Errorf("Expected registered session ID %s, got %s", sessionID, registeredSessionID) + } + mu.Unlock() + + // Now test SendNotificationToSpecificClient + err = mcpServer.SendNotificationToSpecificClient(sessionID, "test/notification", map[string]any{ + "message": "test notification", + }) + if err != nil { + t.Errorf("SendNotificationToSpecificClient failed: %v", err) + } + }) + + t.Run("Session reuse for non-initialize requests", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0") + + // Add a tool that sends a notification + mcpServer.AddTool(mcp.NewTool("notify_tool"), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + session := ClientSessionFromContext(ctx) + if session == nil { + return mcp.NewToolResultError("no session in context"), nil + } + + // Try to send notification to specific client + server := ServerFromContext(ctx) + err := server.SendNotificationToSpecificClient(session.SessionID(), "tool/notification", map[string]any{ + "from": "tool", + }) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("notification failed: %v", err)), nil + } + + return mcp.NewToolResultText("notification sent"), nil + }) + + testServer := NewTestStreamableHTTPServer(mcpServer) + defer testServer.Close() + + // Initialize session + resp, err := postJSON(testServer.URL, initRequest) + if err != nil { + t.Fatalf("Failed to send initialize request: %v", err) + } + sessionID := resp.Header.Get(HeaderKeySessionID) + resp.Body.Close() + + if sessionID == "" { + t.Fatal("Expected session ID in response header") + } + + // Give time for registration to complete + time.Sleep(100 * time.Millisecond) + + // Call tool with the session ID + toolCallRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": map[string]any{ + "name": "notify_tool", + }, + } + + jsonBody, _ := json.Marshal(toolCallRequest) + req, _ := http.NewRequest(http.MethodPost, testServer.URL, bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set(HeaderKeySessionID, sessionID) + + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Failed to call tool: %v", err) + } + defer resp.Body.Close() + + bodyBytes, _ := io.ReadAll(resp.Body) + bodyStr := string(bodyBytes) + + // Response might be SSE format if notification was sent + var toolResponse jsonRPCResponse + if strings.HasPrefix(bodyStr, "event: message") { + // Parse SSE format + lines := strings.Split(bodyStr, "\n") + for _, line := range lines { + if strings.HasPrefix(line, "data: ") { + jsonData := strings.TrimPrefix(line, "data: ") + if err := json.Unmarshal([]byte(jsonData), &toolResponse); err == nil { + break + } + } + } + } else { + if err := json.Unmarshal(bodyBytes, &toolResponse); err != nil { + t.Fatalf("Failed to unmarshal response: %v. Body: %s", err, bodyStr) + } + } + + if toolResponse.Error != nil { + t.Errorf("Tool call failed: %v", toolResponse.Error) + } + + // Verify the tool result indicates success + if result, ok := toolResponse.Result["content"].([]any); ok { + if len(result) > 0 { + if content, ok := result[0].(map[string]any); ok { + if text, ok := content["text"].(string); ok { + if text != "notification sent" { + t.Errorf("Expected 'notification sent', got %s", text) + } + } + } + } + } + }) +}