diff --git a/examples/everything/main.go b/examples/everything/main.go index a59e8c8fa..c2fcc7b23 100644 --- a/examples/everything/main.go +++ b/examples/everything/main.go @@ -300,6 +300,7 @@ func handleSendNotification( server := server.ServerFromContext(ctx) err := server.SendNotificationToClient( + ctx, "notifications/progress", map[string]interface{}{ "progress": 10, @@ -336,6 +337,7 @@ func handleLongRunningOperationTool( time.Sleep(time.Duration(stepDuration * float64(time.Second))) if progressToken != nil { server.SendNotificationToClient( + ctx, "notifications/progress", map[string]interface{}{ "progress": i, diff --git a/server/server.go b/server/server.go index 1af019263..f63e7e796 100644 --- a/server/server.go +++ b/server/server.go @@ -46,16 +46,23 @@ type ServerTool struct { Handler ToolHandlerFunc } -// NotificationContext provides client identification for notifications -type NotificationContext struct { - ClientID string - SessionID string +// ClientSession represents an active session that can be used by MCPServer to interact with client. +type ClientSession interface { + // NotificationChannel provides a channel suitable for sending notifications to client. + NotificationChannel() chan<- mcp.JSONRPCNotification + // SessionID is a unique identifier used to track user session. + SessionID() string } -// ServerNotification combines the notification with client context -type ServerNotification struct { - Context NotificationContext - Notification mcp.JSONRPCNotification +// clientSessionKey is the context key for storing current client notification channel. +type clientSessionKey struct{} + +// ClientSessionFromContext retrieves current client notification context from context. +func ClientSessionFromContext(ctx context.Context) ClientSession { + if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok { + return session + } + return nil } // NotificationHandlerFunc handles incoming notifications. @@ -75,9 +82,7 @@ type MCPServer struct { tools map[string]ServerTool notificationHandlers map[string]NotificationHandlerFunc capabilities serverCapabilities - notifications chan ServerNotification - clientMu sync.Mutex // Separate mutex for client context - currentClient NotificationContext + sessions sync.Map initialized atomic.Bool // Use atomic for the initialized flag } @@ -92,30 +97,70 @@ func ServerFromContext(ctx context.Context) *MCPServer { return nil } -// WithContext sets the current client context and returns the provided context +// WithContext sets the current client session and returns the provided context func (s *MCPServer) WithContext( ctx context.Context, - notifCtx NotificationContext, + session ClientSession, ) context.Context { - s.clientMu.Lock() - s.currentClient = notifCtx - s.clientMu.Unlock() - return ctx + return context.WithValue(ctx, clientSessionKey{}, session) +} + +// RegisterSession saves session that should be notified in case if some server attributes changed. +func (s *MCPServer) RegisterSession( + session ClientSession, +) error { + sessionID := session.SessionID() + if _, exists := s.sessions.LoadOrStore(sessionID, session); exists { + return fmt.Errorf("session %s is already registered", sessionID) + } + return nil +} + +// UnregisterSession removes from storage session that is shut down. +func (s *MCPServer) UnregisterSession( + sessionID string, +) { + s.sessions.Delete(sessionID) +} + +// sendNotificationToAllClients sends a notification to all the currently active clients. +func (s *MCPServer) sendNotificationToAllClients( + method string, + params map[string]any, +) { + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + + s.sessions.Range(func(k, v any) bool { + if session, ok := v.(ClientSession); ok { + select { + case session.NotificationChannel() <- notification: + default: + // TODO: log blocked channel in the future versions + } + } + return true + }) } // SendNotificationToClient sends a notification to the current client func (s *MCPServer) SendNotificationToClient( + ctx context.Context, method string, - params map[string]interface{}, + params map[string]any, ) error { - if s.notifications == nil { + session := ClientSessionFromContext(ctx) + if session == nil { return fmt.Errorf("notification channel not initialized") } - s.clientMu.Lock() - clientContext := s.currentClient - s.clientMu.Unlock() - notification := mcp.JSONRPCNotification{ JSONRPC: mcp.JSONRPC_VERSION, Notification: mcp.Notification{ @@ -127,10 +172,7 @@ func (s *MCPServer) SendNotificationToClient( } select { - case s.notifications <- ServerNotification{ - Context: clientContext, - Notification: notification, - }: + case session.NotificationChannel() <- notification: return nil default: return fmt.Errorf("notification channel full or blocked") @@ -220,7 +262,6 @@ func NewMCPServer( name: name, version: version, notificationHandlers: make(map[string]NotificationHandlerFunc), - notifications: make(chan ServerNotification, 100), capabilities: serverCapabilities{ tools: nil, resources: nil, @@ -491,9 +532,7 @@ func (s *MCPServer) AddTools(tools ...ServerTool) { // Send notification if server is already initialized if initialized { - if err := s.SendNotificationToClient("notifications/tools/list_changed", nil); err != nil { - // We can't return the error, but in a future version we could log it - } + s.sendNotificationToAllClients("notifications/tools/list_changed", nil) } } @@ -516,9 +555,7 @@ func (s *MCPServer) DeleteTools(names ...string) { // Send notification if server is already initialized if initialized { - if err := s.SendNotificationToClient("notifications/tools/list_changed", nil); err != nil { - // We can't return the error, but in a future version we could log it - } + s.sendNotificationToAllClients("notifications/tools/list_changed", nil) } } diff --git a/server/server_test.go b/server/server_test.go index 82085037d..3ffa2ccf6 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -3,11 +3,13 @@ package server import ( "context" "encoding/json" + "fmt" "testing" "time" "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMCPServer_NewMCPServer(t *testing.T) { @@ -145,13 +147,41 @@ func TestMCPServer_Capabilities(t *testing.T) { func TestMCPServer_Tools(t *testing.T) { tests := []struct { name string - action func(*MCPServer) + action func(*testing.T, *MCPServer, chan mcp.JSONRPCNotification) expectedNotifications int - validate func(*testing.T, []ServerNotification, mcp.JSONRPCMessage) + validate func(*testing.T, []mcp.JSONRPCNotification, mcp.JSONRPCMessage) }{ { - name: "SetTools sends single notifications/tools/list_changed", - action: func(server *MCPServer) { + name: "SetTools sends no notifications/tools/list_changed without active sessions", + action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { + server.SetTools(ServerTool{ + Tool: mcp.NewTool("test-tool-1"), + Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{}, nil + }, + }, ServerTool{ + Tool: mcp.NewTool("test-tool-2"), + Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{}, nil + }, + }) + }, + expectedNotifications: 0, + validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { + tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools + assert.Len(t, tools, 2) + assert.Equal(t, "test-tool-1", tools[0].Name) + assert.Equal(t, "test-tool-2", tools[1].Name) + }, + }, + { + name: "SetTools sends single notifications/tools/list_changed with one active session", + action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { + err := server.RegisterSession(&fakeSession{ + sessionID: "test", + notificationChannel: notificationChannel, + }) + require.NoError(t, err) server.SetTools(ServerTool{ Tool: mcp.NewTool("test-tool-1"), Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -165,8 +195,41 @@ func TestMCPServer_Tools(t *testing.T) { }) }, expectedNotifications: 1, - validate: func(t *testing.T, notifications []ServerNotification, toolsList mcp.JSONRPCMessage) { - assert.Equal(t, "notifications/tools/list_changed", notifications[0].Notification.Method) + validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { + assert.Equal(t, "notifications/tools/list_changed", notifications[0].Method) + tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools + assert.Len(t, tools, 2) + assert.Equal(t, "test-tool-1", tools[0].Name) + assert.Equal(t, "test-tool-2", tools[1].Name) + }, + }, + { + name: "SetTools sends single notifications/tools/list_changed per each active session", + action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { + for i := range 5 { + err := server.RegisterSession(&fakeSession{ + sessionID: fmt.Sprintf("test%d", i), + notificationChannel: notificationChannel, + }) + require.NoError(t, err) + } + server.SetTools(ServerTool{ + Tool: mcp.NewTool("test-tool-1"), + Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{}, nil + }, + }, ServerTool{ + Tool: mcp.NewTool("test-tool-2"), + Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{}, nil + }, + }) + }, + expectedNotifications: 5, + validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { + for _, notification := range notifications { + assert.Equal(t, "notifications/tools/list_changed", notification.Method) + } tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools assert.Len(t, tools, 2) assert.Equal(t, "test-tool-1", tools[0].Name) @@ -175,7 +238,12 @@ func TestMCPServer_Tools(t *testing.T) { }, { name: "AddTool sends multiple notifications/tools/list_changed", - action: func(server *MCPServer) { + action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { + err := server.RegisterSession(&fakeSession{ + sessionID: "test", + notificationChannel: notificationChannel, + }) + require.NoError(t, err) server.AddTool(mcp.NewTool("test-tool-1"), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil @@ -186,8 +254,9 @@ func TestMCPServer_Tools(t *testing.T) { }) }, expectedNotifications: 2, - validate: func(t *testing.T, notifications []ServerNotification, toolsList mcp.JSONRPCMessage) { - assert.Equal(t, "notifications/tools/list_changed", notifications[0].Notification.Method) + validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { + assert.Equal(t, "notifications/tools/list_changed", notifications[0].Method) + assert.Equal(t, "notifications/tools/list_changed", notifications[1].Method) tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools assert.Len(t, tools, 2) assert.Equal(t, "test-tool-1", tools[0].Name) @@ -196,18 +265,23 @@ func TestMCPServer_Tools(t *testing.T) { }, { name: "DeleteTools sends single notifications/tools/list_changed", - action: func(server *MCPServer) { + action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { + err := server.RegisterSession(&fakeSession{ + sessionID: "test", + notificationChannel: notificationChannel, + }) + require.NoError(t, err) server.SetTools( ServerTool{Tool: mcp.NewTool("test-tool-1")}, ServerTool{Tool: mcp.NewTool("test-tool-2")}) server.DeleteTools("test-tool-1", "test-tool-2") }, expectedNotifications: 2, - validate: func(t *testing.T, notifications []ServerNotification, toolsList mcp.JSONRPCMessage) { + validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { // One for SetTools - assert.Equal(t, "notifications/tools/list_changed", notifications[0].Notification.Method) + assert.Equal(t, "notifications/tools/list_changed", notifications[0].Method) // One for DeleteTools - assert.Equal(t, "notifications/tools/list_changed", notifications[1].Notification.Method) + assert.Equal(t, "notifications/tools/list_changed", notifications[1].Method) // Expect a successful response with an empty list of tools resp, ok := toolsList.(mcp.JSONRPCResponse) @@ -225,15 +299,16 @@ func TestMCPServer_Tools(t *testing.T) { ctx := context.Background() server := NewMCPServer("test-server", "1.0.0") _ = server.HandleMessage(ctx, []byte(`{ - "jsonrpc": "2.0", - "id": 1, - "method": "initialize" - }`)) - notifications := make([]ServerNotification, 0) - tt.action(server) + "jsonrpc": "2.0", + "id": 1, + "method": "initialize" + }`)) + notificationChannel := make(chan mcp.JSONRPCNotification, 100) + notifications := make([]mcp.JSONRPCNotification, 0) + tt.action(t, server, notificationChannel) for done := false; !done; { select { - case serverNotification := <-server.notifications: + case serverNotification := <-notificationChannel: notifications = append(notifications, serverNotification) if len(notifications) == tt.expectedNotifications { done = true @@ -244,10 +319,10 @@ func TestMCPServer_Tools(t *testing.T) { } assert.Len(t, notifications, tt.expectedNotifications) toolsList := server.HandleMessage(ctx, []byte(`{ - "jsonrpc": "2.0", - "id": 1, - "method": "tools/list" - }`)) + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list" + }`)) tt.validate(t, notifications, toolsList.(mcp.JSONRPCMessage)) }) @@ -398,6 +473,74 @@ func TestMCPServer_HandleNotifications(t *testing.T) { assert.True(t, notificationReceived) } +func TestMCPServer_SendNotificationToClient(t *testing.T) { + tests := []struct { + name string + contextPrepare func(context.Context, *MCPServer) context.Context + validate func(*testing.T, context.Context, *MCPServer) + }{ + { + name: "no active session", + contextPrepare: func(ctx context.Context, srv *MCPServer) context.Context { + return ctx + }, + validate: func(t *testing.T, ctx context.Context, srv *MCPServer) { + require.Error(t, srv.SendNotificationToClient(ctx, "method", nil)) + }, + }, + { + name: "active session", + contextPrepare: func(ctx context.Context, srv *MCPServer) context.Context { + return srv.WithContext(ctx, fakeSession{ + sessionID: "test", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + }) + }, + validate: func(t *testing.T, ctx context.Context, srv *MCPServer) { + for range 10 { + require.NoError(t, srv.SendNotificationToClient(ctx, "method", nil)) + } + session, ok := ClientSessionFromContext(ctx).(fakeSession) + require.True(t, ok, "session not found or of incorrect type") + for range 10 { + select { + case record := <-session.notificationChannel: + assert.Equal(t, "method", record.Method) + default: + t.Errorf("notification not sent") + } + } + }, + }, + { + name: "session with blocked channel", + contextPrepare: func(ctx context.Context, srv *MCPServer) context.Context { + return srv.WithContext(ctx, fakeSession{ + sessionID: "test", + notificationChannel: make(chan mcp.JSONRPCNotification, 1), + }) + }, + validate: func(t *testing.T, ctx context.Context, srv *MCPServer) { + require.NoError(t, srv.SendNotificationToClient(ctx, "method", nil)) + require.Error(t, srv.SendNotificationToClient(ctx, "method", nil)) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0") + ctx := tt.contextPrepare(context.Background(), server) + _ = server.HandleMessage(ctx, []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize" + }`)) + + tt.validate(t, ctx, server) + }) + } +} + func TestMCPServer_PromptHandling(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", WithPromptCapabilities(true), @@ -818,3 +961,18 @@ func createTestServer() *MCPServer { return server } + +type fakeSession struct { + sessionID string + notificationChannel chan mcp.JSONRPCNotification +} + +func (f fakeSession) SessionID() string { + return f.sessionID +} + +func (f fakeSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return f.notificationChannel +} + +var _ ClientSession = fakeSession{} diff --git a/server/sse.go b/server/sse.go index 0e2f8b893..2e1dceb76 100644 --- a/server/sse.go +++ b/server/sse.go @@ -15,10 +15,12 @@ import ( // sseSession represents an active SSE connection. type sseSession struct { - writer http.ResponseWriter - flusher http.Flusher - done chan struct{} - eventQueue chan string // Channel for queuing events + writer http.ResponseWriter + flusher http.Flusher + done chan struct{} + eventQueue chan string // Channel for queuing events + sessionID string + notificationChannel chan mcp.JSONRPCNotification } // SSEContextFunc is a function that takes an existing context and the current @@ -26,6 +28,16 @@ type sseSession struct { // content. This can be used to inject context values from headers, for example. type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context +func (s *sseSession) SessionID() string { + return s.sessionID +} + +func (s *sseSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return s.notificationChannel +} + +var _ ClientSession = (*sseSession)(nil) + // SSEServer implements a Server-Sent Events (SSE) based MCP server. // It provides real-time communication capabilities over HTTP using the SSE protocol. type SSEServer struct { @@ -168,30 +180,35 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { sessionID := uuid.New().String() session := &sseSession{ - writer: w, - flusher: flusher, - done: make(chan struct{}), - eventQueue: make(chan string, 100), // Buffer for events + writer: w, + flusher: flusher, + done: make(chan struct{}), + eventQueue: make(chan string, 100), // Buffer for events + sessionID: sessionID, + notificationChannel: make(chan mcp.JSONRPCNotification, 100), } s.sessions.Store(sessionID, session) defer s.sessions.Delete(sessionID) + if err := s.server.RegisterSession(session); err != nil { + http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusInternalServerError) + return + } + defer s.server.UnregisterSession(sessionID) + // Start notification handler for this session go func() { for { select { - case serverNotification := <-s.server.notifications: - // Only forward notifications meant for this session - if serverNotification.Context.SessionID == sessionID { - eventData, err := json.Marshal(serverNotification.Notification) - if err == nil { - select { - case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): - // Event queued successfully - case <-session.done: - return - } + case notification := <-session.notificationChannel: + eventData, err := json.Marshal(notification) + if err == nil { + select { + case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): + // Event queued successfully + case <-session.done: + return } } case <-session.done: @@ -241,16 +258,6 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { return } - // Set the client context in the server before handling the message - ctx := s.server.WithContext(r.Context(), NotificationContext{ - ClientID: sessionID, - SessionID: sessionID, - }) - - if s.contextFunc != nil { - ctx = s.contextFunc(ctx, r) - } - sessionI, ok := s.sessions.Load(sessionID) if !ok { s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID") @@ -258,6 +265,12 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { } session := sessionI.(*sseSession) + // Set the client context before handling the message + ctx := s.server.WithContext(r.Context(), session) + if s.contextFunc != nil { + ctx = s.contextFunc(ctx, r) + } + // Parse message as raw JSON var rawMessage json.RawMessage if err := json.NewDecoder(r.Body).Decode(&rawMessage); err != nil { diff --git a/server/stdio.go b/server/stdio.go index 5f9221c9e..441c50b99 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -48,6 +48,25 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption { } } +// stdioSession is a static client session, since stdio has only one client. +type stdioSession struct { + notifications chan mcp.JSONRPCNotification +} + +func (s *stdioSession) SessionID() string { + return "stdio" +} + +func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return s.notifications +} + +var _ ClientSession = (*stdioSession)(nil) + +var stdioSessionInstance = stdioSession{ + notifications: make(chan mcp.JSONRPCNotification, 100), +} + // NewStdioServer creates a new stdio server wrapper around an MCPServer. // It initializes the server with a default error logger that discards all output. func NewStdioServer(server *MCPServer) *StdioServer { @@ -83,10 +102,11 @@ func (s *StdioServer) Listen( stdout io.Writer, ) error { // Set a static client context since stdio only has one client - ctx = s.server.WithContext(ctx, NotificationContext{ - ClientID: "stdio", - SessionID: "stdio", - }) + if err := s.server.RegisterSession(&stdioSessionInstance); err != nil { + return fmt.Errorf("register session: %w", err) + } + defer s.server.UnregisterSession(stdioSessionInstance.SessionID()) + ctx = s.server.WithContext(ctx, &stdioSessionInstance) // Add in any custom context. if s.contextFunc != nil { @@ -99,19 +119,16 @@ func (s *StdioServer) Listen( go func() { for { select { - case serverNotification := <-s.server.notifications: - // Only handle notifications for stdio client - if serverNotification.Context.ClientID == "stdio" { - err := s.writeResponse( - serverNotification.Notification, - stdout, + case notification := <-stdioSessionInstance.notifications: + err := s.writeResponse( + notification, + stdout, + ) + if err != nil { + s.errLogger.Printf( + "Error writing notification: %v", + err, ) - if err != nil { - s.errLogger.Printf( - "Error writing notification: %v", - err, - ) - } } case <-ctx.Done(): return