Skip to content

Commit f96dacc

Browse files
committed
refactor(server): replace sync.Map with Sessionizer interface for session management
1 parent 051cda5 commit f96dacc

File tree

5 files changed

+59
-15
lines changed

5 files changed

+59
-15
lines changed

server/server.go

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ type MCPServer struct {
147147
tools map[string]ServerTool
148148
notificationHandlers map[string]NotificationHandlerFunc
149149
capabilities serverCapabilities
150-
sessions sync.Map
150+
sessionizer Sessionizer
151151
hooks *Hooks
152152
}
153153

@@ -175,7 +175,7 @@ func (s *MCPServer) RegisterSession(
175175
session ClientSession,
176176
) error {
177177
sessionID := session.SessionID()
178-
if _, exists := s.sessions.LoadOrStore(sessionID, session); exists {
178+
if _, exists := s.sessionizer.LoadOrStore(sessionID, session); exists {
179179
return fmt.Errorf("session %s is already registered", sessionID)
180180
}
181181
return nil
@@ -185,7 +185,7 @@ func (s *MCPServer) RegisterSession(
185185
func (s *MCPServer) UnregisterSession(
186186
sessionID string,
187187
) {
188-
s.sessions.Delete(sessionID)
188+
s.sessionizer.Delete(sessionID)
189189
}
190190

191191
// sendNotificationToAllClients sends a notification to all the currently active clients.
@@ -203,16 +203,15 @@ func (s *MCPServer) sendNotificationToAllClients(
203203
},
204204
}
205205

206-
s.sessions.Range(func(k, v any) bool {
207-
if session, ok := v.(ClientSession); ok && session.Initialized() {
206+
for _, session := range s.sessionizer.All() {
207+
if session.Initialized() {
208208
select {
209209
case session.NotificationChannel() <- notification:
210210
default:
211211
// TODO: log blocked channel in the future versions
212212
}
213213
}
214-
return true
215-
})
214+
}
216215
}
217216

218217
// SendNotificationToClient sends a notification to the current client
@@ -322,6 +321,12 @@ func WithInstructions(instructions string) ServerOption {
322321
}
323322
}
324323

324+
func WithSessionizer(sessionizer Sessionizer) ServerOption {
325+
return func(s *MCPServer) {
326+
s.sessionizer = sessionizer
327+
}
328+
}
329+
325330
// NewMCPServer creates a new MCP server instance with the given name, version and options
326331
func NewMCPServer(
327332
name, version string,
@@ -342,6 +347,7 @@ func NewMCPServer(
342347
prompts: nil,
343348
logging: false,
344349
},
350+
sessionizer: &SyncMapSessionizer{},
345351
}
346352

347353
for _, opt := range opts {
@@ -410,7 +416,7 @@ func (s *MCPServer) AddTools(tools ...ServerTool) {
410416
}
411417
s.mu.Unlock()
412418

413-
// Send notification to all initialized sessions
419+
// Send notification to all initialized sessionizer
414420
s.sendNotificationToAllClients("notifications/tools/list_changed", nil)
415421
}
416422

@@ -430,7 +436,7 @@ func (s *MCPServer) DeleteTools(names ...string) {
430436
}
431437
s.mu.Unlock()
432438

433-
// Send notification to all initialized sessions
439+
// Send notification to all initialized sessionizer
434440
s.sendNotificationToAllClients("notifications/tools/list_changed", nil)
435441
}
436442

server/server_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ func TestMCPServer_Tools(t *testing.T) {
153153
validate func(*testing.T, []mcp.JSONRPCNotification, mcp.JSONRPCMessage)
154154
}{
155155
{
156-
name: "SetTools sends no notifications/tools/list_changed without active sessions",
156+
name: "SetTools sends no notifications/tools/list_changed without active sessionizer",
157157
action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) {
158158
server.SetTools(ServerTool{
159159
Tool: mcp.NewTool("test-tool-1"),
@@ -216,7 +216,7 @@ func TestMCPServer_Tools(t *testing.T) {
216216
})
217217
require.NoError(t, err)
218218
}
219-
// also let's register inactive sessions
219+
// also let's register inactive sessionizer
220220
for i := range 5 {
221221
err := server.RegisterSession(&fakeSession{
222222
sessionID: fmt.Sprintf("test%d", i+5),

server/sessionizer.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package server
2+
3+
import "sync"
4+
5+
type Sessionizer interface {
6+
LoadOrStore(sessionID string, session ClientSession) (ClientSession, bool)
7+
8+
Delete(sessionID string)
9+
10+
All() []ClientSession
11+
}
12+
13+
type SyncMapSessionizer struct {
14+
sessions sync.Map
15+
}
16+
17+
var _ Sessionizer = (*SyncMapSessionizer)(nil)
18+
19+
func (s *SyncMapSessionizer) LoadOrStore(sessionID string, session ClientSession) (ClientSession, bool) {
20+
actual, ok := s.sessions.LoadOrStore(sessionID, session)
21+
if ok {
22+
return actual.(ClientSession), true
23+
}
24+
return session, false
25+
}
26+
27+
func (s *SyncMapSessionizer) Delete(sessionID string) {
28+
s.sessions.Delete(sessionID)
29+
}
30+
31+
func (s *SyncMapSessionizer) All() []ClientSession {
32+
var sessions []ClientSession
33+
s.sessions.Range(func(key, value any) bool {
34+
sessions = append(sessions, value.(ClientSession))
35+
return true
36+
})
37+
return sessions
38+
}

server/sse.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ func (s *SSEServer) Start(addr string) error {
167167
return s.srv.ListenAndServe()
168168
}
169169

170-
// Shutdown gracefully stops the SSE server, closing all active sessions
170+
// Shutdown gracefully stops the SSE server, closing all active sessionizer
171171
// and shutting down the HTTP server.
172172
func (s *SSEServer) Shutdown(ctx context.Context) error {
173173
if s.srv != nil {

server/sse_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ func TestSSEServer(t *testing.T) {
122122
}
123123
})
124124

125-
t.Run("Can handle multiple sessions", func(t *testing.T) {
125+
t.Run("Can handle multiple sessionizer", func(t *testing.T) {
126126
mcpServer := NewMCPServer("test", "1.0.0",
127127
WithResourceCapabilities(true, true),
128128
)
@@ -238,9 +238,9 @@ func TestSSEServer(t *testing.T) {
238238

239239
select {
240240
case <-done:
241-
// All sessions completed successfully
241+
// All sessionizer completed successfully
242242
case <-time.After(5 * time.Second):
243-
t.Fatal("Timeout waiting for sessions to complete")
243+
t.Fatal("Timeout waiting for sessionizer to complete")
244244
}
245245
})
246246

0 commit comments

Comments
 (0)