From 074bba41d2683aa714c9257a9b006d1df3d949b4 Mon Sep 17 00:00:00 2001 From: Abdelrahman Shawki Hassan Date: Tue, 20 May 2025 10:48:29 +0200 Subject: [PATCH 1/5] feat(server): persist client info in sessions Add SessionWithClientInfo interface and implementations to store and retrieve client information provided during initialization. This allows servers to access client implementation details throughout the session lifecycle. --- server/server.go | 5 +++ server/session.go | 9 +++++ server/session_test.go | 88 +++++++++++++++++++++++++++++++++++++++++- server/sse.go | 20 +++++++++- server/stdio.go | 16 ++++++++ 5 files changed, 134 insertions(+), 4 deletions(-) diff --git a/server/server.go b/server/server.go index b31b48659..da708d39a 100644 --- a/server/server.go +++ b/server/server.go @@ -537,6 +537,11 @@ func (s *MCPServer) handleInitialize( if session := ClientSessionFromContext(ctx); session != nil { session.Initialize() + + // Store client info if the session supports it + if sessionWithClientInfo, ok := session.(SessionWithClientInfo); ok { + sessionWithClientInfo.SetClientInfo(request.Params.ClientInfo) + } } return &result, nil } diff --git a/server/session.go b/server/session.go index 3a4206a76..95ee0771c 100644 --- a/server/session.go +++ b/server/session.go @@ -30,6 +30,15 @@ type SessionWithTools interface { SetSessionTools(tools map[string]ServerTool) } +// SessionWithClientInfo is an extension of ClientSession that can store client implementation info +type SessionWithClientInfo interface { + ClientSession + // GetClientInfo returns the client information for this session + GetClientInfo() mcp.Implementation + // SetClientInfo sets the client information for this session + SetClientInfo(clientInfo mcp.Implementation) +} + // clientSessionKey is the context key for storing current client notification channel. type clientSessionKey struct{} diff --git a/server/session_test.go b/server/session_test.go index 54a781709..1d03dffe1 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -8,9 +8,10 @@ import ( "testing" "time" - "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/mark3labs/mcp-go/mcp" ) // sessionTestClient implements the basic ClientSession interface for testing @@ -98,9 +99,47 @@ func (f *sessionTestClientWithTools) SetSessionTools(tools map[string]ServerTool f.sessionTools = toolsCopy } -// Verify that both implementations satisfy their respective interfaces +// sessionTestClientWithClientInfo implements the SessionWithClientInfo interface for testing +type sessionTestClientWithClientInfo struct { + sessionID string + notificationChannel chan mcp.JSONRPCNotification + initialized bool + clientInfoMu sync.RWMutex + clientInfo mcp.Implementation +} + +func (f *sessionTestClientWithClientInfo) SessionID() string { + return f.sessionID +} + +func (f *sessionTestClientWithClientInfo) NotificationChannel() chan<- mcp.JSONRPCNotification { + return f.notificationChannel +} + +func (f *sessionTestClientWithClientInfo) Initialize() { + f.initialized = true +} + +func (f *sessionTestClientWithClientInfo) Initialized() bool { + return f.initialized +} + +func (f *sessionTestClientWithClientInfo) GetClientInfo() mcp.Implementation { + f.clientInfoMu.RLock() + defer f.clientInfoMu.RUnlock() + return f.clientInfo +} + +func (f *sessionTestClientWithClientInfo) SetClientInfo(clientInfo mcp.Implementation) { + f.clientInfoMu.Lock() + defer f.clientInfoMu.Unlock() + f.clientInfo = clientInfo +} + +// Verify that all implementations satisfy their respective interfaces var _ ClientSession = &sessionTestClient{} var _ SessionWithTools = &sessionTestClientWithTools{} +var _ SessionWithClientInfo = &sessionTestClientWithClientInfo{} func TestSessionWithTools_Integration(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) @@ -917,3 +956,48 @@ func TestMCPServer_ToolNotificationsDisabled(t *testing.T) { // Verify tool was deleted from session assert.Len(t, session.GetSessionTools(), 0) } + +func TestSessionWithClientInfo_Integration(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0") + + session := &sessionTestClientWithClientInfo{ + sessionID: "session-1", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: false, + } + + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + clientInfo := mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + initRequest := mcp.InitializeRequest{} + initRequest.Params.ClientInfo = clientInfo + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.Capabilities = mcp.ClientCapabilities{} + + sessionCtx := server.WithContext(context.Background(), session) + + // Retrieve the session from context + retrievedSession := ClientSessionFromContext(sessionCtx) + require.NotNil(t, retrievedSession, "Session should be available from context") + assert.Equal(t, session.SessionID(), retrievedSession.SessionID(), "Session ID should match") + + // Check if the session can be cast to SessionWithClientInfo + sessionWithClientInfo, ok := retrievedSession.(SessionWithClientInfo) + require.True(t, ok, "Session should implement SessionWithClientInfo") + + result, reqErr := server.handleInitialize(sessionCtx, 1, initRequest) + require.Nil(t, reqErr) + require.NotNil(t, result) + + assert.True(t, sessionWithClientInfo.Initialized(), "Session should be initialized") + + storedClientInfo := sessionWithClientInfo.GetClientInfo() + + assert.Equal(t, clientInfo.Name, storedClientInfo.Name, "Client name should match") + assert.Equal(t, clientInfo.Version, storedClientInfo.Version, "Client version should match") +} diff --git a/server/sse.go b/server/sse.go index 630927d15..50d32a72d 100644 --- a/server/sse.go +++ b/server/sse.go @@ -15,6 +15,7 @@ import ( "time" "github.com/google/uuid" + "github.com/mark3labs/mcp-go/mcp" ) @@ -29,6 +30,8 @@ type sseSession struct { notificationChannel chan mcp.JSONRPCNotification initialized atomic.Bool tools sync.Map // stores session-specific tools + mu sync.RWMutex + clientInfo mcp.Implementation } // SSEContextFunc is a function that takes an existing context and the current @@ -76,9 +79,22 @@ func (s *sseSession) SetSessionTools(tools map[string]ServerTool) { } } +func (s *sseSession) GetClientInfo() mcp.Implementation { + s.mu.RLock() + defer s.mu.RUnlock() + return s.clientInfo +} + +func (s *sseSession) SetClientInfo(clientInfo mcp.Implementation) { + s.mu.Lock() + defer s.mu.Unlock() + s.clientInfo = clientInfo +} + var ( - _ ClientSession = (*sseSession)(nil) - _ SessionWithTools = (*sseSession)(nil) + _ ClientSession = (*sseSession)(nil) + _ SessionWithTools = (*sseSession)(nil) + _ SessionWithClientInfo = (*sseSession)(nil) ) // SSEServer implements a Server-Sent Events (SSE) based MCP server. diff --git a/server/stdio.go b/server/stdio.go index c4fe1bf6d..0de062e34 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -9,6 +9,7 @@ import ( "log" "os" "os/signal" + "sync" "sync/atomic" "syscall" @@ -53,6 +54,8 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption { type stdioSession struct { notifications chan mcp.JSONRPCNotification initialized atomic.Bool + mu sync.RWMutex + clientInfo mcp.Implementation } func (s *stdioSession) SessionID() string { @@ -71,7 +74,20 @@ func (s *stdioSession) Initialized() bool { return s.initialized.Load() } +func (s *stdioSession) GetClientInfo() mcp.Implementation { + s.mu.RLock() + defer s.mu.RUnlock() + return s.clientInfo +} + +func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) { + s.mu.Lock() + defer s.mu.Unlock() + s.clientInfo = clientInfo +} + var _ ClientSession = (*stdioSession)(nil) +var _ SessionWithClientInfo = (*stdioSession)(nil) var stdioSessionInstance = stdioSession{ notifications: make(chan mcp.JSONRPCNotification, 100), From 07fc3bfa0d67e2408423e9b6f201e162267c1523 Mon Sep 17 00:00:00 2001 From: Abdelrahman Shawki Hassan Date: Tue, 20 May 2025 10:57:42 +0200 Subject: [PATCH 2/5] refactor: use atomic.Value instead of mutex --- server/session_test.go | 29 +++++++++++++---------------- server/sse.go | 24 ++++++++++++------------ server/stdio.go | 34 +++++++++++++++++----------------- 3 files changed, 42 insertions(+), 45 deletions(-) diff --git a/server/session_test.go b/server/session_test.go index b7cb53c7e..aff90f7de 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -105,8 +105,7 @@ type sessionTestClientWithClientInfo struct { sessionID string notificationChannel chan mcp.JSONRPCNotification initialized bool - clientInfoMu sync.RWMutex - clientInfo mcp.Implementation + clientInfo atomic.Value } func (f *sessionTestClientWithClientInfo) SessionID() string { @@ -126,27 +125,24 @@ func (f *sessionTestClientWithClientInfo) Initialized() bool { } func (f *sessionTestClientWithClientInfo) GetClientInfo() mcp.Implementation { - f.clientInfoMu.RLock() - defer f.clientInfoMu.RUnlock() - return f.clientInfo + if value := f.clientInfo.Load(); value != nil { + if clientInfo, ok := value.(mcp.Implementation); ok { + return clientInfo + } + } + return mcp.Implementation{} } func (f *sessionTestClientWithClientInfo) SetClientInfo(clientInfo mcp.Implementation) { - f.clientInfoMu.Lock() - defer f.clientInfoMu.Unlock() - f.clientInfo = clientInfo + f.clientInfo.Store(clientInfo) } -// Verify that all implementations satisfy their respective interfaces -var _ ClientSession = &sessionTestClient{} -var _ SessionWithTools = &sessionTestClientWithTools{} -var _ SessionWithClientInfo = &sessionTestClientWithClientInfo{} // sessionTestClientWithTools implements the SessionWithLogging interface for testing type sessionTestClientWithLogging struct { sessionID string notificationChannel chan mcp.JSONRPCNotification initialized bool - loggingLevel atomic.Value + loggingLevel atomic.Value } func (f *sessionTestClientWithLogging) SessionID() string { @@ -178,9 +174,10 @@ func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel { // Verify that all implementations satisfy their respective interfaces var ( - _ ClientSession = (*sessionTestClient)(nil) - _ SessionWithTools = (*sessionTestClientWithTools)(nil) - _ SessionWithLogging = (*sessionTestClientWithLogging)(nil) + _ ClientSession = (*sessionTestClient)(nil) + _ SessionWithTools = (*sessionTestClientWithTools)(nil) + _ SessionWithLogging = (*sessionTestClientWithLogging)(nil) + _ SessionWithClientInfo = (*sessionTestClientWithClientInfo)(nil) ) func TestSessionWithTools_Integration(t *testing.T) { diff --git a/server/sse.go b/server/sse.go index 5b0723a9c..04aa73ed6 100644 --- a/server/sse.go +++ b/server/sse.go @@ -28,9 +28,8 @@ type sseSession struct { notificationChannel chan mcp.JSONRPCNotification initialized atomic.Bool loggingLevel atomic.Value - tools sync.Map // stores session-specific tools - mu sync.RWMutex - clientInfo mcp.Implementation + tools sync.Map // stores session-specific tools + clientInfo atomic.Value // stores session-specific client info } // SSEContextFunc is a function that takes an existing context and the current @@ -90,21 +89,22 @@ func (s *sseSession) SetSessionTools(tools map[string]ServerTool) { } func (s *sseSession) GetClientInfo() mcp.Implementation { - s.mu.RLock() - defer s.mu.RUnlock() - return s.clientInfo + if value := s.clientInfo.Load(); value != nil { + if clientInfo, ok := value.(mcp.Implementation); ok { + return clientInfo + } + } + return mcp.Implementation{} } func (s *sseSession) SetClientInfo(clientInfo mcp.Implementation) { - s.mu.Lock() - defer s.mu.Unlock() - s.clientInfo = clientInfo + s.clientInfo.Store(clientInfo) } var ( - _ ClientSession = (*sseSession)(nil) - _ SessionWithTools = (*sseSession)(nil) - _ SessionWithLogging = (*sseSession)(nil) + _ ClientSession = (*sseSession)(nil) + _ SessionWithTools = (*sseSession)(nil) + _ SessionWithLogging = (*sseSession)(nil) _ SessionWithClientInfo = (*sseSession)(nil) ) diff --git a/server/stdio.go b/server/stdio.go index 63a8e6bba..e6bf9f2ee 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -9,7 +9,6 @@ import ( "log" "os" "os/signal" - "sync" "sync/atomic" "syscall" @@ -52,11 +51,10 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption { // stdioSession is a static client session, since stdio has only one client. type stdioSession struct { - notifications chan mcp.JSONRPCNotification - initialized atomic.Bool - loggingLevel atomic.Value - mu sync.RWMutex - clientInfo mcp.Implementation + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + clientInfo atomic.Value // stores session-specific client info } func (s *stdioSession) SessionID() string { @@ -78,22 +76,23 @@ func (s *stdioSession) Initialized() bool { } func (s *stdioSession) GetClientInfo() mcp.Implementation { - s.mu.RLock() - defer s.mu.RUnlock() - return s.clientInfo + if value := s.clientInfo.Load(); value != nil { + if clientInfo, ok := value.(mcp.Implementation); ok { + return clientInfo + } + } + return mcp.Implementation{} } func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) { - s.mu.Lock() - defer s.mu.Unlock() - s.clientInfo = clientInfo + s.clientInfo.Store(clientInfo) } -func(s *stdioSession) SetLogLevel(level mcp.LoggingLevel) { +func (s *stdioSession) SetLogLevel(level mcp.LoggingLevel) { s.loggingLevel.Store(level) } -func(s *stdioSession) GetLogLevel() mcp.LoggingLevel { +func (s *stdioSession) GetLogLevel() mcp.LoggingLevel { level := s.loggingLevel.Load() if level == nil { return mcp.LoggingLevelError @@ -102,13 +101,14 @@ func(s *stdioSession) GetLogLevel() mcp.LoggingLevel { } var ( - _ ClientSession = (*stdioSession)(nil) - _ SessionWithLogging = (*stdioSession)(nil) - _ SessionWithClientInfo = (*stdioSession)(nil) + _ ClientSession = (*stdioSession)(nil) + _ SessionWithLogging = (*stdioSession)(nil) + _ SessionWithClientInfo = (*stdioSession)(nil) ) var stdioSessionInstance = stdioSession{ notifications: make(chan mcp.JSONRPCNotification, 100), + clientInfo: atomic.Value{}, } // NewStdioServer creates a new stdio server wrapper around an MCPServer. From 14269f19695e94c2362cae1e190f255ff371fdaf Mon Sep 17 00:00:00 2001 From: Abdelrahman Shawki Hassan Date: Tue, 20 May 2025 11:03:39 +0200 Subject: [PATCH 3/5] chore: cleanup --- server/session.go | 2 +- server/stdio.go | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/server/session.go b/server/session.go index aa3091e62..7b1e12fee 100644 --- a/server/session.go +++ b/server/session.go @@ -39,7 +39,7 @@ type SessionWithTools interface { SetSessionTools(tools map[string]ServerTool) } -// SessionWithClientInfo is an extension of ClientSession that can store client implementation info +// SessionWithClientInfo is an extension of ClientSession that can store client info type SessionWithClientInfo interface { ClientSession // GetClientInfo returns the client information for this session diff --git a/server/stdio.go b/server/stdio.go index e6bf9f2ee..34556cd7b 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -108,7 +108,6 @@ var ( var stdioSessionInstance = stdioSession{ notifications: make(chan mcp.JSONRPCNotification, 100), - clientInfo: atomic.Value{}, } // NewStdioServer creates a new stdio server wrapper around an MCPServer. From 83de439b0715f32d6b104b971af2032cfb64855f Mon Sep 17 00:00:00 2001 From: Abdelrahman Shawki Hassan Date: Tue, 20 May 2025 12:02:51 +0200 Subject: [PATCH 4/5] fix: restore named parameter in handleInitialize method --- server/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/server.go b/server/server.go index e0e9dc6c6..f80f837aa 100644 --- a/server/server.go +++ b/server/server.go @@ -492,7 +492,7 @@ func (s *MCPServer) AddNotificationHandler( func (s *MCPServer) handleInitialize( ctx context.Context, _ any, - _ mcp.InitializeRequest, + request mcp.InitializeRequest, ) (*mcp.InitializeResult, *requestError) { capabilities := mcp.ServerCapabilities{} From 03cb5bf3d78d74e9cd8e0137dc8a8c60f55f26e7 Mon Sep 17 00:00:00 2001 From: Abdelrahman Shawki Hassan Date: Tue, 20 May 2025 12:05:25 +0200 Subject: [PATCH 5/5] chore: test order --- server/session_test.go | 89 +++++++++++++++++++++--------------------- 1 file changed, 45 insertions(+), 44 deletions(-) diff --git a/server/session_test.go b/server/session_test.go index aff90f7de..3067f4e9c 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -996,50 +996,6 @@ func TestMCPServer_ToolNotificationsDisabled(t *testing.T) { assert.Len(t, session.GetSessionTools(), 0) } -func TestSessionWithClientInfo_Integration(t *testing.T) { - server := NewMCPServer("test-server", "1.0.0") - - session := &sessionTestClientWithClientInfo{ - sessionID: "session-1", - notificationChannel: make(chan mcp.JSONRPCNotification, 10), - initialized: false, - } - - err := server.RegisterSession(context.Background(), session) - require.NoError(t, err) - - clientInfo := mcp.Implementation{ - Name: "test-client", - Version: "1.0.0", - } - - initRequest := mcp.InitializeRequest{} - initRequest.Params.ClientInfo = clientInfo - initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION - initRequest.Params.Capabilities = mcp.ClientCapabilities{} - - sessionCtx := server.WithContext(context.Background(), session) - - // Retrieve the session from context - retrievedSession := ClientSessionFromContext(sessionCtx) - require.NotNil(t, retrievedSession, "Session should be available from context") - assert.Equal(t, session.SessionID(), retrievedSession.SessionID(), "Session ID should match") - - // Check if the session can be cast to SessionWithClientInfo - sessionWithClientInfo, ok := retrievedSession.(SessionWithClientInfo) - require.True(t, ok, "Session should implement SessionWithClientInfo") - - result, reqErr := server.handleInitialize(sessionCtx, 1, initRequest) - require.Nil(t, reqErr) - require.NotNil(t, result) - - assert.True(t, sessionWithClientInfo.Initialized(), "Session should be initialized") - - storedClientInfo := sessionWithClientInfo.GetClientInfo() - - assert.Equal(t, clientInfo.Name, storedClientInfo.Name, "Client name should match") - assert.Equal(t, clientInfo.Version, storedClientInfo.Version, "Client version should match") -} func TestMCPServer_SetLevelNotEnabled(t *testing.T) { // Create server without logging capability server := NewMCPServer("test-server", "1.0.0") @@ -1125,3 +1081,48 @@ func TestMCPServer_SetLevel(t *testing.T) { t.Errorf("Expected critical level, got %v", session.GetLogLevel()) } } + +func TestSessionWithClientInfo_Integration(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0") + + session := &sessionTestClientWithClientInfo{ + sessionID: "session-1", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: false, + } + + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + clientInfo := mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + initRequest := mcp.InitializeRequest{} + initRequest.Params.ClientInfo = clientInfo + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.Capabilities = mcp.ClientCapabilities{} + + sessionCtx := server.WithContext(context.Background(), session) + + // Retrieve the session from context + retrievedSession := ClientSessionFromContext(sessionCtx) + require.NotNil(t, retrievedSession, "Session should be available from context") + assert.Equal(t, session.SessionID(), retrievedSession.SessionID(), "Session ID should match") + + result, reqErr := server.handleInitialize(sessionCtx, 1, initRequest) + require.Nil(t, reqErr) + require.NotNil(t, result) + + // Check if the session can be cast to SessionWithClientInfo + sessionWithClientInfo, ok := retrievedSession.(SessionWithClientInfo) + require.True(t, ok, "Session should implement SessionWithClientInfo") + + assert.True(t, sessionWithClientInfo.Initialized(), "Session should be initialized") + + storedClientInfo := sessionWithClientInfo.GetClientInfo() + + assert.Equal(t, clientInfo.Name, storedClientInfo.Name, "Client name should match") + assert.Equal(t, clientInfo.Version, storedClientInfo.Version, "Client version should match") +}