Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ func (s *MCPServer) AddNotificationHandler(
func (s *MCPServer) handleInitialize(
ctx context.Context,
_ any,
_ mcp.InitializeRequest,
request mcp.InitializeRequest,
) (*mcp.InitializeResult, *requestError) {
capabilities := mcp.ServerCapabilities{}

Expand Down Expand Up @@ -541,6 +541,11 @@ func (s *MCPServer) handleInitialize(

if session := ClientSessionFromContext(ctx); session != nil {
session.Initialize()

// Store client info if the session supports it
if sessionWithClientInfo, ok := session.(SessionWithClientInfo); ok {
sessionWithClientInfo.SetClientInfo(request.Params.ClientInfo)
}
}
return &result, nil
}
Expand Down
9 changes: 9 additions & 0 deletions server/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ type SessionWithTools interface {
SetSessionTools(tools map[string]ServerTool)
}

// SessionWithClientInfo is an extension of ClientSession that can store client info
type SessionWithClientInfo interface {
ClientSession
// GetClientInfo returns the client information for this session
GetClientInfo() mcp.Implementation
// SetClientInfo sets the client information for this session
SetClientInfo(clientInfo mcp.Implementation)
}

// clientSessionKey is the context key for storing current client notification channel.
type clientSessionKey struct{}

Expand Down
96 changes: 90 additions & 6 deletions server/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ import (
"testing"
"time"

"github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/mark3labs/mcp-go/mcp"
)

// sessionTestClient implements the basic ClientSession interface for testing
Expand Down Expand Up @@ -99,12 +100,49 @@ func (f *sessionTestClientWithTools) SetSessionTools(tools map[string]ServerTool
f.sessionTools = toolsCopy
}

// sessionTestClientWithClientInfo implements the SessionWithClientInfo interface for testing
type sessionTestClientWithClientInfo struct {
sessionID string
notificationChannel chan mcp.JSONRPCNotification
initialized bool
clientInfo atomic.Value
}

func (f *sessionTestClientWithClientInfo) SessionID() string {
return f.sessionID
}

func (f *sessionTestClientWithClientInfo) NotificationChannel() chan<- mcp.JSONRPCNotification {
return f.notificationChannel
}

func (f *sessionTestClientWithClientInfo) Initialize() {
f.initialized = true
}

func (f *sessionTestClientWithClientInfo) Initialized() bool {
return f.initialized
}

func (f *sessionTestClientWithClientInfo) GetClientInfo() mcp.Implementation {
if value := f.clientInfo.Load(); value != nil {
if clientInfo, ok := value.(mcp.Implementation); ok {
return clientInfo
}
}
return mcp.Implementation{}
}

func (f *sessionTestClientWithClientInfo) SetClientInfo(clientInfo mcp.Implementation) {
f.clientInfo.Store(clientInfo)
}

// sessionTestClientWithTools implements the SessionWithLogging interface for testing
type sessionTestClientWithLogging struct {
sessionID string
notificationChannel chan mcp.JSONRPCNotification
initialized bool
loggingLevel atomic.Value
loggingLevel atomic.Value
}

func (f *sessionTestClientWithLogging) SessionID() string {
Expand Down Expand Up @@ -136,9 +174,10 @@ func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel {

// Verify that all implementations satisfy their respective interfaces
var (
_ ClientSession = (*sessionTestClient)(nil)
_ SessionWithTools = (*sessionTestClientWithTools)(nil)
_ SessionWithLogging = (*sessionTestClientWithLogging)(nil)
_ ClientSession = (*sessionTestClient)(nil)
_ SessionWithTools = (*sessionTestClientWithTools)(nil)
_ SessionWithLogging = (*sessionTestClientWithLogging)(nil)
_ SessionWithClientInfo = (*sessionTestClientWithClientInfo)(nil)
)

func TestSessionWithTools_Integration(t *testing.T) {
Expand Down Expand Up @@ -1041,4 +1080,49 @@ func TestMCPServer_SetLevel(t *testing.T) {
if session.GetLogLevel() != mcp.LoggingLevelCritical {
t.Errorf("Expected critical level, got %v", session.GetLogLevel())
}
}
}

func TestSessionWithClientInfo_Integration(t *testing.T) {
server := NewMCPServer("test-server", "1.0.0")

session := &sessionTestClientWithClientInfo{
sessionID: "session-1",
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
initialized: false,
}

err := server.RegisterSession(context.Background(), session)
require.NoError(t, err)

clientInfo := mcp.Implementation{
Name: "test-client",
Version: "1.0.0",
}

initRequest := mcp.InitializeRequest{}
initRequest.Params.ClientInfo = clientInfo
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.Capabilities = mcp.ClientCapabilities{}

sessionCtx := server.WithContext(context.Background(), session)

// Retrieve the session from context
retrievedSession := ClientSessionFromContext(sessionCtx)
require.NotNil(t, retrievedSession, "Session should be available from context")
assert.Equal(t, session.SessionID(), retrievedSession.SessionID(), "Session ID should match")

result, reqErr := server.handleInitialize(sessionCtx, 1, initRequest)
require.Nil(t, reqErr)
require.NotNil(t, result)

// Check if the session can be cast to SessionWithClientInfo
sessionWithClientInfo, ok := retrievedSession.(SessionWithClientInfo)
require.True(t, ok, "Session should implement SessionWithClientInfo")

assert.True(t, sessionWithClientInfo.Initialized(), "Session should be initialized")

storedClientInfo := sessionWithClientInfo.GetClientInfo()

assert.Equal(t, clientInfo.Name, storedClientInfo.Name, "Client name should match")
assert.Equal(t, clientInfo.Version, storedClientInfo.Version, "Client version should match")
}
24 changes: 20 additions & 4 deletions server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"time"

"github.com/google/uuid"

"github.com/mark3labs/mcp-go/mcp"
)

Expand All @@ -27,7 +28,8 @@ type sseSession struct {
notificationChannel chan mcp.JSONRPCNotification
initialized atomic.Bool
loggingLevel atomic.Value
tools sync.Map // stores session-specific tools
tools sync.Map // stores session-specific tools
clientInfo atomic.Value // stores session-specific client info
}

// SSEContextFunc is a function that takes an existing context and the current
Expand Down Expand Up @@ -86,10 +88,24 @@ func (s *sseSession) SetSessionTools(tools map[string]ServerTool) {
}
}

func (s *sseSession) GetClientInfo() mcp.Implementation {
if value := s.clientInfo.Load(); value != nil {
if clientInfo, ok := value.(mcp.Implementation); ok {
return clientInfo
}
}
return mcp.Implementation{}
}

func (s *sseSession) SetClientInfo(clientInfo mcp.Implementation) {
s.clientInfo.Store(clientInfo)
}

var (
_ ClientSession = (*sseSession)(nil)
_ SessionWithTools = (*sseSession)(nil)
_ SessionWithLogging = (*sseSession)(nil)
_ ClientSession = (*sseSession)(nil)
_ SessionWithTools = (*sseSession)(nil)
_ SessionWithLogging = (*sseSession)(nil)
_ SessionWithClientInfo = (*sseSession)(nil)
)

// SSEServer implements a Server-Sent Events (SSE) based MCP server.
Expand Down
29 changes: 22 additions & 7 deletions server/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption {

// stdioSession is a static client session, since stdio has only one client.
type stdioSession struct {
notifications chan mcp.JSONRPCNotification
initialized atomic.Bool
loggingLevel atomic.Value
notifications chan mcp.JSONRPCNotification
initialized atomic.Bool
loggingLevel atomic.Value
clientInfo atomic.Value // stores session-specific client info
}

func (s *stdioSession) SessionID() string {
Expand All @@ -74,11 +75,24 @@ func (s *stdioSession) Initialized() bool {
return s.initialized.Load()
}

func(s *stdioSession) SetLogLevel(level mcp.LoggingLevel) {
func (s *stdioSession) GetClientInfo() mcp.Implementation {
if value := s.clientInfo.Load(); value != nil {
if clientInfo, ok := value.(mcp.Implementation); ok {
return clientInfo
}
}
return mcp.Implementation{}
}

func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) {
s.clientInfo.Store(clientInfo)
}

func (s *stdioSession) SetLogLevel(level mcp.LoggingLevel) {
s.loggingLevel.Store(level)
}

func(s *stdioSession) GetLogLevel() mcp.LoggingLevel {
func (s *stdioSession) GetLogLevel() mcp.LoggingLevel {
level := s.loggingLevel.Load()
if level == nil {
return mcp.LoggingLevelError
Expand All @@ -87,8 +101,9 @@ func(s *stdioSession) GetLogLevel() mcp.LoggingLevel {
}

var (
_ ClientSession = (*stdioSession)(nil)
_ SessionWithLogging = (*stdioSession)(nil)
_ ClientSession = (*stdioSession)(nil)
_ SessionWithLogging = (*stdioSession)(nil)
_ SessionWithClientInfo = (*stdioSession)(nil)
)

var stdioSessionInstance = stdioSession{
Expand Down