diff --git a/server/server.go b/server/server.go index 6005738b3..f80f837aa 100644 --- a/server/server.go +++ b/server/server.go @@ -492,7 +492,7 @@ func (s *MCPServer) AddNotificationHandler( func (s *MCPServer) handleInitialize( ctx context.Context, _ any, - _ mcp.InitializeRequest, + request mcp.InitializeRequest, ) (*mcp.InitializeResult, *requestError) { capabilities := mcp.ServerCapabilities{} @@ -541,6 +541,11 @@ func (s *MCPServer) handleInitialize( if session := ClientSessionFromContext(ctx); session != nil { session.Initialize() + + // Store client info if the session supports it + if sessionWithClientInfo, ok := session.(SessionWithClientInfo); ok { + sessionWithClientInfo.SetClientInfo(request.Params.ClientInfo) + } } return &result, nil } diff --git a/server/session.go b/server/session.go index 0c50a260e..7b1e12fee 100644 --- a/server/session.go +++ b/server/session.go @@ -39,6 +39,15 @@ type SessionWithTools interface { SetSessionTools(tools map[string]ServerTool) } +// SessionWithClientInfo is an extension of ClientSession that can store client info +type SessionWithClientInfo interface { + ClientSession + // GetClientInfo returns the client information for this session + GetClientInfo() mcp.Implementation + // SetClientInfo sets the client information for this session + SetClientInfo(clientInfo mcp.Implementation) +} + // clientSessionKey is the context key for storing current client notification channel. type clientSessionKey struct{} diff --git a/server/session_test.go b/server/session_test.go index 8f2cfa763..3067f4e9c 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -9,9 +9,10 @@ import ( "testing" "time" - "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/mark3labs/mcp-go/mcp" ) // sessionTestClient implements the basic ClientSession interface for testing @@ -99,12 +100,49 @@ func (f *sessionTestClientWithTools) SetSessionTools(tools map[string]ServerTool f.sessionTools = toolsCopy } +// sessionTestClientWithClientInfo implements the SessionWithClientInfo interface for testing +type sessionTestClientWithClientInfo struct { + sessionID string + notificationChannel chan mcp.JSONRPCNotification + initialized bool + clientInfo atomic.Value +} + +func (f *sessionTestClientWithClientInfo) SessionID() string { + return f.sessionID +} + +func (f *sessionTestClientWithClientInfo) NotificationChannel() chan<- mcp.JSONRPCNotification { + return f.notificationChannel +} + +func (f *sessionTestClientWithClientInfo) Initialize() { + f.initialized = true +} + +func (f *sessionTestClientWithClientInfo) Initialized() bool { + return f.initialized +} + +func (f *sessionTestClientWithClientInfo) GetClientInfo() mcp.Implementation { + if value := f.clientInfo.Load(); value != nil { + if clientInfo, ok := value.(mcp.Implementation); ok { + return clientInfo + } + } + return mcp.Implementation{} +} + +func (f *sessionTestClientWithClientInfo) SetClientInfo(clientInfo mcp.Implementation) { + f.clientInfo.Store(clientInfo) +} + // sessionTestClientWithTools implements the SessionWithLogging interface for testing type sessionTestClientWithLogging struct { sessionID string notificationChannel chan mcp.JSONRPCNotification initialized bool - loggingLevel atomic.Value + loggingLevel atomic.Value } func (f *sessionTestClientWithLogging) SessionID() string { @@ -136,9 +174,10 @@ func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel { // Verify that all implementations satisfy their respective interfaces var ( - _ ClientSession = (*sessionTestClient)(nil) - _ SessionWithTools = (*sessionTestClientWithTools)(nil) - _ SessionWithLogging = (*sessionTestClientWithLogging)(nil) + _ ClientSession = (*sessionTestClient)(nil) + _ SessionWithTools = (*sessionTestClientWithTools)(nil) + _ SessionWithLogging = (*sessionTestClientWithLogging)(nil) + _ SessionWithClientInfo = (*sessionTestClientWithClientInfo)(nil) ) func TestSessionWithTools_Integration(t *testing.T) { @@ -1041,4 +1080,49 @@ func TestMCPServer_SetLevel(t *testing.T) { if session.GetLogLevel() != mcp.LoggingLevelCritical { t.Errorf("Expected critical level, got %v", session.GetLogLevel()) } -} \ No newline at end of file +} + +func TestSessionWithClientInfo_Integration(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0") + + session := &sessionTestClientWithClientInfo{ + sessionID: "session-1", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: false, + } + + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + clientInfo := mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + initRequest := mcp.InitializeRequest{} + initRequest.Params.ClientInfo = clientInfo + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.Capabilities = mcp.ClientCapabilities{} + + sessionCtx := server.WithContext(context.Background(), session) + + // Retrieve the session from context + retrievedSession := ClientSessionFromContext(sessionCtx) + require.NotNil(t, retrievedSession, "Session should be available from context") + assert.Equal(t, session.SessionID(), retrievedSession.SessionID(), "Session ID should match") + + result, reqErr := server.handleInitialize(sessionCtx, 1, initRequest) + require.Nil(t, reqErr) + require.NotNil(t, result) + + // Check if the session can be cast to SessionWithClientInfo + sessionWithClientInfo, ok := retrievedSession.(SessionWithClientInfo) + require.True(t, ok, "Session should implement SessionWithClientInfo") + + assert.True(t, sessionWithClientInfo.Initialized(), "Session should be initialized") + + storedClientInfo := sessionWithClientInfo.GetClientInfo() + + assert.Equal(t, clientInfo.Name, storedClientInfo.Name, "Client name should match") + assert.Equal(t, clientInfo.Version, storedClientInfo.Version, "Client version should match") +} diff --git a/server/sse.go b/server/sse.go index c7aa72986..04aa73ed6 100644 --- a/server/sse.go +++ b/server/sse.go @@ -15,6 +15,7 @@ import ( "time" "github.com/google/uuid" + "github.com/mark3labs/mcp-go/mcp" ) @@ -27,7 +28,8 @@ type sseSession struct { notificationChannel chan mcp.JSONRPCNotification initialized atomic.Bool loggingLevel atomic.Value - tools sync.Map // stores session-specific tools + tools sync.Map // stores session-specific tools + clientInfo atomic.Value // stores session-specific client info } // SSEContextFunc is a function that takes an existing context and the current @@ -86,10 +88,24 @@ func (s *sseSession) SetSessionTools(tools map[string]ServerTool) { } } +func (s *sseSession) GetClientInfo() mcp.Implementation { + if value := s.clientInfo.Load(); value != nil { + if clientInfo, ok := value.(mcp.Implementation); ok { + return clientInfo + } + } + return mcp.Implementation{} +} + +func (s *sseSession) SetClientInfo(clientInfo mcp.Implementation) { + s.clientInfo.Store(clientInfo) +} + var ( - _ ClientSession = (*sseSession)(nil) - _ SessionWithTools = (*sseSession)(nil) - _ SessionWithLogging = (*sseSession)(nil) + _ ClientSession = (*sseSession)(nil) + _ SessionWithTools = (*sseSession)(nil) + _ SessionWithLogging = (*sseSession)(nil) + _ SessionWithClientInfo = (*sseSession)(nil) ) // SSEServer implements a Server-Sent Events (SSE) based MCP server. diff --git a/server/stdio.go b/server/stdio.go index 0ebed6a9a..34556cd7b 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -51,9 +51,10 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption { // stdioSession is a static client session, since stdio has only one client. type stdioSession struct { - notifications chan mcp.JSONRPCNotification - initialized atomic.Bool - loggingLevel atomic.Value + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + clientInfo atomic.Value // stores session-specific client info } func (s *stdioSession) SessionID() string { @@ -74,11 +75,24 @@ func (s *stdioSession) Initialized() bool { return s.initialized.Load() } -func(s *stdioSession) SetLogLevel(level mcp.LoggingLevel) { +func (s *stdioSession) GetClientInfo() mcp.Implementation { + if value := s.clientInfo.Load(); value != nil { + if clientInfo, ok := value.(mcp.Implementation); ok { + return clientInfo + } + } + return mcp.Implementation{} +} + +func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) { + s.clientInfo.Store(clientInfo) +} + +func (s *stdioSession) SetLogLevel(level mcp.LoggingLevel) { s.loggingLevel.Store(level) } -func(s *stdioSession) GetLogLevel() mcp.LoggingLevel { +func (s *stdioSession) GetLogLevel() mcp.LoggingLevel { level := s.loggingLevel.Load() if level == nil { return mcp.LoggingLevelError @@ -87,8 +101,9 @@ func(s *stdioSession) GetLogLevel() mcp.LoggingLevel { } var ( - _ ClientSession = (*stdioSession)(nil) - _ SessionWithLogging = (*stdioSession)(nil) + _ ClientSession = (*stdioSession)(nil) + _ SessionWithLogging = (*stdioSession)(nil) + _ SessionWithClientInfo = (*stdioSession)(nil) ) var stdioSessionInstance = stdioSession{