Skip to content

Commit 3cddf41

Browse files
committed
Add per-session notifications handling
1 parent 258aee9 commit 3cddf41

File tree

5 files changed

+321
-102
lines changed

5 files changed

+321
-102
lines changed

examples/everything/main.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ func handleSendNotification(
300300
server := server.ServerFromContext(ctx)
301301

302302
err := server.SendNotificationToClient(
303+
ctx,
303304
"notifications/progress",
304305
map[string]interface{}{
305306
"progress": 10,
@@ -336,6 +337,7 @@ func handleLongRunningOperationTool(
336337
time.Sleep(time.Duration(stepDuration * float64(time.Second)))
337338
if progressToken != nil {
338339
server.SendNotificationToClient(
340+
ctx,
339341
"notifications/progress",
340342
map[string]interface{}{
341343
"progress": i,

server/server.go

Lines changed: 69 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,23 @@ type ServerTool struct {
4646
Handler ToolHandlerFunc
4747
}
4848

49-
// NotificationContext provides client identification for notifications
50-
type NotificationContext struct {
51-
ClientID string
52-
SessionID string
49+
// ClientSession represents an active session that can be used by MCPServer to interact with client.
50+
type ClientSession interface {
51+
// NotificationChannel provides a channel suitable for sending notifications to client.
52+
NotificationChannel() chan<- mcp.JSONRPCNotification
53+
// SessionID is a unique identifier used to track user session.
54+
SessionID() string
5355
}
5456

55-
// ServerNotification combines the notification with client context
56-
type ServerNotification struct {
57-
Context NotificationContext
58-
Notification mcp.JSONRPCNotification
57+
// clientSessionKey is the context key for storing current client notification channel.
58+
type clientSessionKey struct{}
59+
60+
// ClientSessionFromContext retrieves current client notification context from context.
61+
func ClientSessionFromContext(ctx context.Context) ClientSession {
62+
if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok {
63+
return session
64+
}
65+
return nil
5966
}
6067

6168
// NotificationHandlerFunc handles incoming notifications.
@@ -74,9 +81,7 @@ type MCPServer struct {
7481
tools map[string]ServerTool
7582
notificationHandlers map[string]NotificationHandlerFunc
7683
capabilities serverCapabilities
77-
notifications chan ServerNotification
78-
clientMu sync.Mutex // Separate mutex for client context
79-
currentClient NotificationContext
84+
sessions sync.Map
8085
initialized atomic.Bool // Use atomic for the initialized flag
8186
}
8287

@@ -91,30 +96,69 @@ func ServerFromContext(ctx context.Context) *MCPServer {
9196
return nil
9297
}
9398

94-
// WithContext sets the current client context and returns the provided context
99+
// WithContext sets the current client session and returns the provided context
95100
func (s *MCPServer) WithContext(
96101
ctx context.Context,
97-
notifCtx NotificationContext,
102+
session ClientSession,
98103
) context.Context {
99-
s.clientMu.Lock()
100-
s.currentClient = notifCtx
101-
s.clientMu.Unlock()
102-
return ctx
104+
return context.WithValue(ctx, clientSessionKey{}, session)
105+
}
106+
107+
// RegisterSession saves session that should be notified in case if some server attributes changed.
108+
func (s *MCPServer) RegisterSession(
109+
session ClientSession,
110+
) error {
111+
sessionID := session.SessionID()
112+
if _, exists := s.sessions.LoadOrStore(sessionID, session); exists {
113+
return fmt.Errorf("session %s is already registered", sessionID)
114+
}
115+
return nil
116+
}
117+
118+
// UnregisterSession removes from storage session that is shut down.
119+
func (s *MCPServer) UnregisterSession(
120+
sessionID string,
121+
) {
122+
s.sessions.Delete(sessionID)
123+
}
124+
125+
// sendNotificationToAllClients sends a notification to all the currently active clients.
126+
func (s *MCPServer) sendNotificationToAllClients(
127+
method string,
128+
params map[string]interface{},
129+
) {
130+
notification := mcp.JSONRPCNotification{
131+
JSONRPC: mcp.JSONRPC_VERSION,
132+
Notification: mcp.Notification{
133+
Method: method,
134+
Params: mcp.NotificationParams{
135+
AdditionalFields: params,
136+
},
137+
},
138+
}
139+
140+
s.sessions.Range(func(k, v any) bool {
141+
if session, ok := v.(ClientSession); ok {
142+
select {
143+
case session.NotificationChannel() <- notification:
144+
default:
145+
}
146+
}
147+
return true
148+
})
103149
}
104150

105151
// SendNotificationToClient sends a notification to the current client
106152
func (s *MCPServer) SendNotificationToClient(
153+
ctx context.Context,
107154
method string,
108155
params map[string]interface{},
109156
) error {
110-
if s.notifications == nil {
157+
session := ClientSessionFromContext(ctx)
158+
if session == nil {
111159
return fmt.Errorf("notification channel not initialized")
112160
}
113161

114-
s.clientMu.Lock()
115-
clientContext := s.currentClient
116-
s.clientMu.Unlock()
117-
118162
notification := mcp.JSONRPCNotification{
119163
JSONRPC: mcp.JSONRPC_VERSION,
120164
Notification: mcp.Notification{
@@ -126,10 +170,7 @@ func (s *MCPServer) SendNotificationToClient(
126170
}
127171

128172
select {
129-
case s.notifications <- ServerNotification{
130-
Context: clientContext,
131-
Notification: notification,
132-
}:
173+
case session.NotificationChannel() <- notification:
133174
return nil
134175
default:
135176
return fmt.Errorf("notification channel full or blocked")
@@ -212,7 +253,6 @@ func NewMCPServer(
212253
name: name,
213254
version: version,
214255
notificationHandlers: make(map[string]NotificationHandlerFunc),
215-
notifications: make(chan ServerNotification, 100),
216256
capabilities: serverCapabilities{
217257
tools: nil,
218258
resources: nil,
@@ -483,9 +523,7 @@ func (s *MCPServer) AddTools(tools ...ServerTool) {
483523

484524
// Send notification if server is already initialized
485525
if initialized {
486-
if err := s.SendNotificationToClient("notifications/tools/list_changed", nil); err != nil {
487-
// We can't return the error, but in a future version we could log it
488-
}
526+
s.sendNotificationToAllClients("notifications/tools/list_changed", nil)
489527
}
490528
}
491529

@@ -508,9 +546,7 @@ func (s *MCPServer) DeleteTools(names ...string) {
508546

509547
// Send notification if server is already initialized
510548
if initialized {
511-
if err := s.SendNotificationToClient("notifications/tools/list_changed", nil); err != nil {
512-
// We can't return the error, but in a future version we could log it
513-
}
549+
s.sendNotificationToAllClients("notifications/tools/list_changed", nil)
514550
}
515551
}
516552

0 commit comments

Comments
 (0)