diff --git a/server/errors.go b/server/errors.go index 4668e4591..5e65f0760 100644 --- a/server/errors.go +++ b/server/errors.go @@ -13,12 +13,13 @@ var ( ErrToolNotFound = errors.New("tool not found") // Session-related errors - ErrSessionNotFound = errors.New("session not found") - ErrSessionExists = errors.New("session already exists") - ErrSessionNotInitialized = errors.New("session not properly initialized") - ErrSessionDoesNotSupportTools = errors.New("session does not support per-session tools") - ErrSessionDoesNotSupportResources = errors.New("session does not support per-session resources") - ErrSessionDoesNotSupportLogging = errors.New("session does not support setting logging level") + ErrSessionNotFound = errors.New("session not found") + ErrSessionExists = errors.New("session already exists") + ErrSessionNotInitialized = errors.New("session not properly initialized") + ErrSessionDoesNotSupportTools = errors.New("session does not support per-session tools") + ErrSessionDoesNotSupportResources = errors.New("session does not support per-session resources") + ErrSessionDoesNotSupportResourceTemplates = errors.New("session does not support resource templates") + ErrSessionDoesNotSupportLogging = errors.New("session does not support setting logging level") // Notification-related errors ErrNotificationNotInitialized = errors.New("notification channel not initialized") diff --git a/server/server.go b/server/server.go index f45c03536..480e0e9f6 100644 --- a/server/server.go +++ b/server/server.go @@ -880,12 +880,34 @@ func (s *MCPServer) handleListResourceTemplates( id any, request mcp.ListResourceTemplatesRequest, ) (*mcp.ListResourceTemplatesResult, *requestError) { + // Get global templates s.resourcesMu.RLock() - templates := make([]mcp.ResourceTemplate, 0, len(s.resourceTemplates)) - for _, entry := range s.resourceTemplates { - templates = append(templates, entry.template) + templateMap := make(map[string]mcp.ResourceTemplate, len(s.resourceTemplates)) + for uri, entry := range s.resourceTemplates { + templateMap[uri] = entry.template } s.resourcesMu.RUnlock() + + // Check if there are session-specific resource templates + session := ClientSessionFromContext(ctx) + if session != nil { + if sessionWithTemplates, ok := session.(SessionWithResourceTemplates); ok { + if sessionTemplates := sessionWithTemplates.GetSessionResourceTemplates(); sessionTemplates != nil { + // Merge session-specific templates with global templates + // Session templates override global ones + for uriTemplate, serverTemplate := range sessionTemplates { + templateMap[uriTemplate] = serverTemplate.Template + } + } + } + } + + // Convert map to slice for sorting and pagination + templates := make([]mcp.ResourceTemplate, 0, len(templateMap)) + for _, template := range templateMap { + templates = append(templates, template) + } + sort.Slice(templates, func(i, j int) bool { return templates[i].Name < templates[j].Name }) @@ -971,18 +993,48 @@ func (s *MCPServer) handleReadResource( // If no direct handler found, try matching against templates var matchedHandler ResourceTemplateHandlerFunc var matched bool - for _, entry := range s.resourceTemplates { - template := entry.template - if matchesTemplate(request.Params.URI, template.URITemplate) { - matchedHandler = entry.handler - matched = true - matchedVars := template.URITemplate.Match(request.Params.URI) - // Convert matched variables to a map - request.Params.Arguments = make(map[string]any, len(matchedVars)) - for name, value := range matchedVars { - request.Params.Arguments[name] = value.V + + // First check session templates if available + if session != nil { + if sessionWithTemplates, ok := session.(SessionWithResourceTemplates); ok { + sessionTemplates := sessionWithTemplates.GetSessionResourceTemplates() + for _, serverTemplate := range sessionTemplates { + if serverTemplate.Template.URITemplate == nil { + continue + } + if matchesTemplate(request.Params.URI, serverTemplate.Template.URITemplate) { + matchedHandler = serverTemplate.Handler + matched = true + matchedVars := serverTemplate.Template.URITemplate.Match(request.Params.URI) + // Convert matched variables to a map + request.Params.Arguments = make(map[string]any, len(matchedVars)) + for name, value := range matchedVars { + request.Params.Arguments[name] = value.V + } + break + } + } + } + } + + // If not found in session templates, check global templates + if !matched { + for _, entry := range s.resourceTemplates { + template := entry.template + if template.URITemplate == nil { + continue + } + if matchesTemplate(request.Params.URI, template.URITemplate) { + matchedHandler = entry.handler + matched = true + matchedVars := template.URITemplate.Match(request.Params.URI) + // Convert matched variables to a map + request.Params.Arguments = make(map[string]any, len(matchedVars)) + for name, value := range matchedVars { + request.Params.Arguments[name] = value.V + } + break } - break } } s.resourcesMu.RUnlock() diff --git a/server/session.go b/server/session.go index 99d6db8d4..aa392ba75 100644 --- a/server/session.go +++ b/server/session.go @@ -51,6 +51,17 @@ type SessionWithResources interface { SetSessionResources(resources map[string]ServerResource) } +// SessionWithResourceTemplates is an extension of ClientSession that can store session-specific resource template data +type SessionWithResourceTemplates interface { + ClientSession + // GetSessionResourceTemplates returns the resource templates specific to this session, if any + // This method must be thread-safe for concurrent access + GetSessionResourceTemplates() map[string]ServerResourceTemplate + // SetSessionResourceTemplates sets resource templates specific to this session + // This method must be thread-safe for concurrent access + SetSessionResourceTemplates(templates map[string]ServerResourceTemplate) +} + // SessionWithClientInfo is an extension of ClientSession that can store client info type SessionWithClientInfo interface { ClientSession @@ -613,3 +624,137 @@ func (s *MCPServer) DeleteSessionResources(sessionID string, uris ...string) err return nil } + +// AddSessionResourceTemplate adds a resource template for a specific session +func (s *MCPServer) AddSessionResourceTemplate(sessionID string, template mcp.ResourceTemplate, handler ResourceTemplateHandlerFunc) error { + return s.AddSessionResourceTemplates(sessionID, ServerResourceTemplate{ + Template: template, + Handler: handler, + }) +} + +// AddSessionResourceTemplates adds resource templates for a specific session +func (s *MCPServer) AddSessionResourceTemplates(sessionID string, templates ...ServerResourceTemplate) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(SessionWithResourceTemplates) + if !ok { + return ErrSessionDoesNotSupportResourceTemplates + } + + // For session resource templates, enable listChanged by default + // This is the same behavior as session resources + s.implicitlyRegisterCapabilities( + func() bool { return s.capabilities.resources != nil }, + func() { s.capabilities.resources = &resourceCapabilities{listChanged: true} }, + ) + + // Get existing templates (this returns a thread-safe copy) + sessionTemplates := session.GetSessionResourceTemplates() + + // Create a new map to avoid modifying the returned copy + newTemplates := make(map[string]ServerResourceTemplate, len(sessionTemplates)+len(templates)) + + // Copy existing templates + for k, v := range sessionTemplates { + newTemplates[k] = v + } + + // Validate and add new templates + for _, t := range templates { + if t.Template.URITemplate == nil { + return fmt.Errorf("resource template URITemplate cannot be nil") + } + raw := t.Template.URITemplate.Raw() + if raw == "" { + return fmt.Errorf("resource template URITemplate cannot be empty") + } + if t.Template.Name == "" { + return fmt.Errorf("resource template name cannot be empty") + } + newTemplates[raw] = t + } + + // Set the new templates (this method must handle thread-safety) + session.SetSessionResourceTemplates(newTemplates) + + // Send notification if the session is initialized and listChanged is enabled + if session.Initialized() && s.capabilities.resources != nil && s.capabilities.resources.listChanged { + if err := s.SendNotificationToSpecificClient(sessionID, "notifications/resources/list_changed", nil); err != nil { + // Log the error but don't fail the operation + if s.hooks != nil && len(s.hooks.OnError) > 0 { + hooks := s.hooks + go func(sID string, hooks *Hooks) { + ctx := context.Background() + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": "notifications/resources/list_changed", + "sessionID": sID, + }, fmt.Errorf("failed to send notification after adding resource templates: %w", err)) + }(sessionID, hooks) + } + } + } + + return nil +} + +// DeleteSessionResourceTemplates removes resource templates from a specific session +func (s *MCPServer) DeleteSessionResourceTemplates(sessionID string, uriTemplates ...string) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(SessionWithResourceTemplates) + if !ok { + return ErrSessionDoesNotSupportResourceTemplates + } + + // Get existing templates (this returns a thread-safe copy) + sessionTemplates := session.GetSessionResourceTemplates() + + // Track if any were actually deleted + deletedAny := false + + // Create a new map without the deleted templates + newTemplates := make(map[string]ServerResourceTemplate, len(sessionTemplates)) + for k, v := range sessionTemplates { + newTemplates[k] = v + } + + // Delete specified templates + for _, uriTemplate := range uriTemplates { + if _, exists := newTemplates[uriTemplate]; exists { + delete(newTemplates, uriTemplate) + deletedAny = true + } + } + + // Only update if something was actually deleted + if deletedAny { + // Set the new templates (this method must handle thread-safety) + session.SetSessionResourceTemplates(newTemplates) + + // Send notification if the session is initialized and listChanged is enabled + if session.Initialized() && s.capabilities.resources != nil && s.capabilities.resources.listChanged { + if err := s.SendNotificationToSpecificClient(sessionID, "notifications/resources/list_changed", nil); err != nil { + // Log the error but don't fail the operation + if s.hooks != nil && len(s.hooks.OnError) > 0 { + hooks := s.hooks + go func(sID string, hooks *Hooks) { + ctx := context.Background() + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": "notifications/resources/list_changed", + "sessionID": sID, + }, fmt.Errorf("failed to send notification after deleting resource templates: %w", err)) + }(sessionID, hooks) + } + } + } + } + + return nil +} diff --git a/server/session_resource_templates_test.go b/server/session_resource_templates_test.go new file mode 100644 index 000000000..b425d39f0 --- /dev/null +++ b/server/session_resource_templates_test.go @@ -0,0 +1,563 @@ +package server + +import ( + "context" + "encoding/json" + "maps" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mark3labs/mcp-go/mcp" +) + +type sessionTestClientWithResourceTemplates struct { + sessionID string + notificationChannel chan mcp.JSONRPCNotification + initialized atomic.Bool + sessionResourceTemplates map[string]ServerResourceTemplate + mu sync.RWMutex +} + +func (f *sessionTestClientWithResourceTemplates) SessionID() string { + return f.sessionID +} + +func (f *sessionTestClientWithResourceTemplates) NotificationChannel() chan<- mcp.JSONRPCNotification { + return f.notificationChannel +} + +func (f *sessionTestClientWithResourceTemplates) Initialize() { + f.initialized.Store(true) +} + +func (f *sessionTestClientWithResourceTemplates) Initialized() bool { + return f.initialized.Load() +} + +func (f *sessionTestClientWithResourceTemplates) GetSessionResourceTemplates() map[string]ServerResourceTemplate { + f.mu.RLock() + defer f.mu.RUnlock() + return maps.Clone(f.sessionResourceTemplates) +} + +func (f *sessionTestClientWithResourceTemplates) SetSessionResourceTemplates(templates map[string]ServerResourceTemplate) { + f.mu.Lock() + defer f.mu.Unlock() + f.sessionResourceTemplates = maps.Clone(templates) +} + +var _ SessionWithResourceTemplates = (*sessionTestClientWithResourceTemplates)(nil) + +func TestSessionWithResourceTemplates_Integration(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0") + + sessionTemplate := ServerResourceTemplate{ + Template: mcp.NewResourceTemplate("test://session/{id}", "session-template"), + Handler: func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{mcp.TextResourceContents{ + URI: request.Params.URI, + Text: "session-template result", + }}, nil + }, + } + + session := &sessionTestClientWithResourceTemplates{ + sessionID: "session-1", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + sessionResourceTemplates: map[string]ServerResourceTemplate{ + "test://session/{id}": sessionTemplate, + }, + } + session.initialized.Store(true) + + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + testReq := mcp.ReadResourceRequest{} + testReq.Params.URI = "test://session/123" + + sessionCtx := server.WithContext(context.Background(), session) + + s := ClientSessionFromContext(sessionCtx) + require.NotNil(t, s, "Session should be available from context") + assert.Equal(t, session.SessionID(), s.SessionID(), "Session ID should match") + + swrt, ok := s.(SessionWithResourceTemplates) + require.True(t, ok, "Session should implement SessionWithResourceTemplates") + + templates := swrt.GetSessionResourceTemplates() + require.NotNil(t, templates, "Session resource templates should be available") + require.Contains(t, templates, "test://session/{id}", "Session should have test://session/{id}") + + t.Run("test session resource template access", func(t *testing.T) { + template, exists := templates["test://session/{id}"] + require.True(t, exists, "Session resource template should exist in the map") + require.NotNil(t, template, "Session resource template should not be nil") + + result, err := template.Handler(sessionCtx, testReq) + require.NoError(t, err, "No error calling session resource template handler directly") + require.NotNil(t, result, "Result should not be nil") + require.Len(t, result, 1, "Result should have one content item") + + textContent, ok := result[0].(mcp.TextResourceContents) + require.True(t, ok, "Content should be TextResourceContents") + assert.Equal(t, "session-template result", textContent.Text, "Result text should match") + }) +} + +func TestMCPServer_ResourceTemplatesWithSessionResourceTemplates(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithResourceCapabilities(false, true)) + + server.AddResourceTemplates( + ServerResourceTemplate{ + Template: mcp.NewResourceTemplate("test://global/{id}", "global-template-1"), + Handler: func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{mcp.TextResourceContents{ + URI: request.Params.URI, + Text: "global-template-1 result", + }}, nil + }, + }, + ServerResourceTemplate{ + Template: mcp.NewResourceTemplate("test://another/{id}", "global-template-2"), + Handler: func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{mcp.TextResourceContents{ + URI: request.Params.URI, + Text: "global-template-2 result", + }}, nil + }, + }, + ) + + session := &sessionTestClientWithResourceTemplates{ + sessionID: "session-1", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + sessionResourceTemplates: map[string]ServerResourceTemplate{ + "test://global/{id}": { + Template: mcp.NewResourceTemplate("test://global/{id}", "global-template-1-overridden"), + Handler: func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{mcp.TextResourceContents{ + URI: request.Params.URI, + Text: "session-overridden result", + }}, nil + }, + }, + }, + } + session.initialized.Store(true) + + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + sessionCtx := server.WithContext(context.Background(), session) + resp := server.HandleMessage(sessionCtx, []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "resources/templates/list" + }`)) + + jsonResp, ok := resp.(mcp.JSONRPCResponse) + require.True(t, ok, "Response should be a JSONRPCResponse") + + result, ok := jsonResp.Result.(mcp.ListResourceTemplatesResult) + require.True(t, ok, "Result should be a ListResourceTemplatesResult") + + assert.Len(t, result.ResourceTemplates, 2, "Should have 2 resource templates") + + templateMap := make(map[string]mcp.ResourceTemplate) + for _, template := range result.ResourceTemplates { + templateMap[template.URITemplate.Raw()] = template + } + + require.Contains(t, templateMap, "test://another/{id}", "Should have non-overridden global template") + assert.Equal(t, "global-template-2", templateMap["test://another/{id}"].Name, "Global template name should match") + + require.Contains(t, templateMap, "test://global/{id}", "Should have overridden global template") + assert.Equal(t, "global-template-1-overridden", templateMap["test://global/{id}"].Name, "Overridden template name should match session version") + + t.Run("read overridden resource via HandleMessage", func(t *testing.T) { + readResp := server.HandleMessage(sessionCtx, []byte(`{ + "jsonrpc": "2.0", + "id": 2, + "method": "resources/read", + "params": { + "uri": "test://global/123" + } + }`)) + + readJSONResp, ok := readResp.(mcp.JSONRPCResponse) + require.True(t, ok, "Read response should be a JSONRPCResponse") + + readResult, ok := readJSONResp.Result.(mcp.ReadResourceResult) + require.True(t, ok, "Result should be a ReadResourceResult") + + require.Len(t, readResult.Contents, 1, "Should have one content item") + textContent, ok := readResult.Contents[0].(mcp.TextResourceContents) + require.True(t, ok, "Content should be TextResourceContents") + assert.Equal(t, "session-overridden result", textContent.Text, "Should return session handler's content") + }) +} + +func TestMCPServer_AddSessionResourceTemplates(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithResourceCapabilities(false, true)) + ctx := context.Background() + + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithResourceTemplates{ + sessionID: "session-1", + notificationChannel: sessionChan, + } + session.initialized.Store(true) + + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + err = server.AddSessionResourceTemplates(session.SessionID(), + ServerResourceTemplate{Template: mcp.NewResourceTemplate("test://session/{id}", "session-template")}, + ) + require.NoError(t, err) + + select { + case notification := <-sessionChan: + assert.Equal(t, "notifications/resources/list_changed", notification.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Expected notification not received") + } + + assert.Len(t, session.GetSessionResourceTemplates(), 1) + assert.Contains(t, session.GetSessionResourceTemplates(), "test://session/{id}") +} + +func TestMCPServer_AddSessionResourceTemplate(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithResourceCapabilities(false, true)) + ctx := context.Background() + + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithResourceTemplates{ + sessionID: "session-1", + notificationChannel: sessionChan, + } + session.initialized.Store(true) + + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + err = server.AddSessionResourceTemplate( + session.SessionID(), + mcp.NewResourceTemplate("test://helper/{id}", "helper-template"), + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{mcp.TextResourceContents{ + URI: request.Params.URI, + Text: "helper result", + }}, nil + }, + ) + require.NoError(t, err) + + select { + case notification := <-sessionChan: + assert.Equal(t, "notifications/resources/list_changed", notification.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Expected notification not received") + } + + assert.Len(t, session.GetSessionResourceTemplates(), 1) + assert.Contains(t, session.GetSessionResourceTemplates(), "test://helper/{id}") +} + +func TestMCPServer_AddSessionResourceTemplatesUninitialized(t *testing.T) { + errorChan := make(chan error) + hooks := &Hooks{} + hooks.AddOnError( + func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { + errorChan <- err + }, + ) + + server := NewMCPServer("test-server", "1.0.0", + WithResourceCapabilities(false, true), + WithHooks(hooks), + ) + ctx := context.Background() + + sessionChan := make(chan mcp.JSONRPCNotification, 1) + session := &sessionTestClientWithResourceTemplates{ + sessionID: "session-1", + notificationChannel: sessionChan, + } + + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + err = server.AddSessionResourceTemplates(session.SessionID(), + ServerResourceTemplate{Template: mcp.NewResourceTemplate("test://uninit/{id}", "uninitialized-template")}, + ) + require.NoError(t, err) + + select { + case err := <-errorChan: + t.Error("Expected no errors, but OnError called with: ", err) + case <-time.After(25 * time.Millisecond): + } + + select { + case <-sessionChan: + t.Error("Expected no notification to be sent for uninitialized session") + default: + } + + assert.Len(t, session.GetSessionResourceTemplates(), 1) + assert.Contains(t, session.GetSessionResourceTemplates(), "test://uninit/{id}") + + session.Initialize() + + err = server.AddSessionResourceTemplates(session.SessionID(), + ServerResourceTemplate{Template: mcp.NewResourceTemplate("test://initialized/{id}", "initialized-template")}, + ) + require.NoError(t, err) + + select { + case err := <-errorChan: + t.Error("Expected no errors, but OnError called with:", err) + case <-time.After(200 * time.Millisecond): + } + + select { + case notification := <-sessionChan: + assert.Equal(t, "notifications/resources/list_changed", notification.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Timeout waiting for expected notifications/resources/list_changed notification") + } + + assert.Len(t, session.GetSessionResourceTemplates(), 2) + assert.Contains(t, session.GetSessionResourceTemplates(), "test://uninit/{id}") + assert.Contains(t, session.GetSessionResourceTemplates(), "test://initialized/{id}") +} + +func TestMCPServer_DeleteSessionResourceTemplatesUninitialized(t *testing.T) { + errorChan := make(chan error) + hooks := &Hooks{} + hooks.AddOnError( + func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { + errorChan <- err + }, + ) + + server := NewMCPServer("test-server", "1.0.0", + WithResourceCapabilities(false, true), + WithHooks(hooks), + ) + ctx := context.Background() + + sessionChan := make(chan mcp.JSONRPCNotification, 1) + session := &sessionTestClientWithResourceTemplates{ + sessionID: "uninitialized-session", + notificationChannel: sessionChan, + sessionResourceTemplates: map[string]ServerResourceTemplate{ + "test://delete/{id}": {Template: mcp.NewResourceTemplate("test://delete/{id}", "template-to-delete")}, + "test://keep/{id}": {Template: mcp.NewResourceTemplate("test://keep/{id}", "template-to-keep")}, + }, + } + + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + err = server.DeleteSessionResourceTemplates(session.SessionID(), "test://delete/{id}") + require.NoError(t, err) + + select { + case err := <-errorChan: + t.Errorf("Expected error hooks not to be called, got error: %v", err) + case <-time.After(25 * time.Millisecond): + } + + select { + case <-sessionChan: + t.Error("Expected no notification to be sent for uninitialized session") + default: + } + + assert.Len(t, session.GetSessionResourceTemplates(), 1) + assert.NotContains(t, session.GetSessionResourceTemplates(), "test://delete/{id}") + assert.Contains(t, session.GetSessionResourceTemplates(), "test://keep/{id}") + + session.Initialize() + + err = server.DeleteSessionResourceTemplates(session.SessionID(), "test://keep/{id}") + require.NoError(t, err) + + select { + case err := <-errorChan: + t.Errorf("Expected error hooks not to be called, got error: %v", err) + case <-time.After(200 * time.Millisecond): + } + + select { + case notification := <-sessionChan: + assert.Equal(t, "notifications/resources/list_changed", notification.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Expected notification not received for initialized session") + } + + assert.Len(t, session.GetSessionResourceTemplates(), 0) +} + +func TestMCPServer_CallSessionResourceTemplate(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithResourceCapabilities(false, true)) + + server.AddResourceTemplate( + mcp.NewResourceTemplate("test://resource/{id}", "test_template"), + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{mcp.TextResourceContents{ + URI: request.Params.URI, + Text: "global result", + }}, nil + }, + ) + + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithResourceTemplates{ + sessionID: "session-1", + notificationChannel: sessionChan, + } + + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + err = server.AddSessionResourceTemplate( + session.SessionID(), + mcp.NewResourceTemplate("test://resource/{id}", "test_template"), + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{mcp.TextResourceContents{ + URI: request.Params.URI, + Text: "session result", + }}, nil + }, + ) + require.NoError(t, err) + + sessionCtx := server.WithContext(context.Background(), session) + resourceRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "resources/read", + "params": map[string]any{ + "uri": "test://resource/123", + }, + } + requestBytes, err := json.Marshal(resourceRequest) + if err != nil { + t.Fatalf("Failed to marshal resource request: %v", err) + } + + response := server.HandleMessage(sessionCtx, requestBytes) + resp, ok := response.(mcp.JSONRPCResponse) + assert.True(t, ok) + + readResourceResult, ok := resp.Result.(mcp.ReadResourceResult) + assert.True(t, ok) + + if text := readResourceResult.Contents[0].(mcp.TextResourceContents).Text; text != "session result" { + t.Errorf("Expected result 'session result', got %q", text) + } +} + +func TestMCPServer_DeleteSessionResourceTemplates(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithResourceCapabilities(false, true)) + ctx := context.Background() + + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithResourceTemplates{ + sessionID: "session-1", + notificationChannel: sessionChan, + sessionResourceTemplates: map[string]ServerResourceTemplate{ + "test://template1/{id}": { + Template: mcp.NewResourceTemplate("test://template1/{id}", "session-template-1"), + }, + "test://template2/{id}": { + Template: mcp.NewResourceTemplate("test://template2/{id}", "session-template-2"), + }, + }, + } + session.initialized.Store(true) + + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + err = server.DeleteSessionResourceTemplates(session.SessionID(), "test://template1/{id}") + require.NoError(t, err) + + select { + case notification := <-sessionChan: + assert.Equal(t, "notifications/resources/list_changed", notification.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Expected notification not received") + } + + assert.Len(t, session.GetSessionResourceTemplates(), 1) + assert.NotContains(t, session.GetSessionResourceTemplates(), "test://template1/{id}") + assert.Contains(t, session.GetSessionResourceTemplates(), "test://template2/{id}") +} + +func TestMCPServer_SessionResourceTemplateError(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0") + ctx := context.Background() + + session := &sessionTestClient{ + sessionID: "session-1", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + } + + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + err = server.AddSessionResourceTemplates(session.SessionID(), + ServerResourceTemplate{Template: mcp.NewResourceTemplate("test://template/{id}", "test-template")}, + ) + require.Error(t, err) + assert.Equal(t, ErrSessionDoesNotSupportResourceTemplates, err) +} + +func TestMCPServer_ResourceTemplatesNotificationsDisabled(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithResourceCapabilities(false, false)) + ctx := context.Background() + + sessionChan := make(chan mcp.JSONRPCNotification, 1) + session := &sessionTestClientWithResourceTemplates{ + sessionID: "session-1", + notificationChannel: sessionChan, + } + + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + err = server.AddSessionResourceTemplates(session.SessionID(), + ServerResourceTemplate{Template: mcp.NewResourceTemplate("test://template/{id}", "test-template")}, + ) + require.NoError(t, err) + + select { + case <-sessionChan: + t.Error("Expected no notification to be sent when capabilities.resources.listChanged is false") + default: + } + + assert.Len(t, session.GetSessionResourceTemplates(), 1) + assert.Contains(t, session.GetSessionResourceTemplates(), "test://template/{id}") + + err = server.DeleteSessionResourceTemplates(session.SessionID(), "test://template/{id}") + require.NoError(t, err) + + select { + case <-sessionChan: + t.Error("Expected no notification to be sent when capabilities.resources.listChanged is false") + default: + } + + assert.Len(t, session.GetSessionResourceTemplates(), 0) +} diff --git a/server/sse.go b/server/sse.go index 250141ce4..97c765cc7 100644 --- a/server/sse.go +++ b/server/sse.go @@ -30,6 +30,7 @@ type sseSession struct { loggingLevel atomic.Value tools sync.Map // stores session-specific tools resources sync.Map // stores session-specific resources + resourceTemplates sync.Map // stores session-specific resource templates clientInfo atomic.Value // stores session-specific client info clientCapabilities atomic.Value // stores session-specific client capabilities } @@ -97,6 +98,27 @@ func (s *sseSession) SetSessionResources(resources map[string]ServerResource) { } } +func (s *sseSession) GetSessionResourceTemplates() map[string]ServerResourceTemplate { + templates := make(map[string]ServerResourceTemplate) + s.resourceTemplates.Range(func(key, value any) bool { + if template, ok := value.(ServerResourceTemplate); ok { + templates[key.(string)] = template + } + return true + }) + return templates +} + +func (s *sseSession) SetSessionResourceTemplates(templates map[string]ServerResourceTemplate) { + // Clear existing templates + s.resourceTemplates.Clear() + + // Set new templates + for uriTemplate, template := range templates { + s.resourceTemplates.Store(uriTemplate, template) + } +} + func (s *sseSession) GetSessionTools() map[string]ServerTool { tools := make(map[string]ServerTool) s.tools.Range(func(key, value any) bool { @@ -145,11 +167,12 @@ func (s *sseSession) GetClientCapabilities() mcp.ClientCapabilities { } var ( - _ ClientSession = (*sseSession)(nil) - _ SessionWithTools = (*sseSession)(nil) - _ SessionWithResources = (*sseSession)(nil) - _ SessionWithLogging = (*sseSession)(nil) - _ SessionWithClientInfo = (*sseSession)(nil) + _ ClientSession = (*sseSession)(nil) + _ SessionWithTools = (*sseSession)(nil) + _ SessionWithResources = (*sseSession)(nil) + _ SessionWithResourceTemplates = (*sseSession)(nil) + _ SessionWithLogging = (*sseSession)(nil) + _ SessionWithClientInfo = (*sseSession)(nil) ) // SSEServer implements a Server-Sent Events (SSE) based MCP server. diff --git a/server/streamable_http.go b/server/streamable_http.go index 8af6f1478..385f69e64 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -140,11 +140,12 @@ func WithTLSCert(certFile, keyFile string) StreamableHTTPOption { // The current implementation does not support the following features from the specification: // - Stream Resumability type StreamableHTTPServer struct { - server *MCPServer - sessionTools *sessionToolsStore - sessionResources *sessionResourcesStore - sessionRequestIDs sync.Map // sessionId --> last requestID(*atomic.Int64) - activeSessions sync.Map // sessionId --> *streamableHttpSession (for sampling responses) + server *MCPServer + sessionTools *sessionToolsStore + sessionResources *sessionResourcesStore + sessionResourceTemplates *sessionResourceTemplatesStore + sessionRequestIDs sync.Map // sessionId --> last requestID(*atomic.Int64) + activeSessions sync.Map // sessionId --> *streamableHttpSession (for sampling responses) httpServer *http.Server mu sync.RWMutex @@ -164,13 +165,14 @@ type StreamableHTTPServer struct { // NewStreamableHTTPServer creates a new streamable-http server instance func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *StreamableHTTPServer { s := &StreamableHTTPServer{ - server: server, - sessionTools: newSessionToolsStore(), - sessionLogLevels: newSessionLogLevelsStore(), - endpointPath: "/mcp", - sessionIdManager: &InsecureStatefulSessionIdManager{}, - logger: util.DefaultLogger(), - sessionResources: newSessionResourcesStore(), + server: server, + sessionTools: newSessionToolsStore(), + sessionLogLevels: newSessionLogLevelsStore(), + endpointPath: "/mcp", + sessionIdManager: &InsecureStatefulSessionIdManager{}, + logger: util.DefaultLogger(), + sessionResources: newSessionResourcesStore(), + sessionResourceTemplates: newSessionResourceTemplatesStore(), } // Apply all options @@ -345,7 +347,7 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request // Create ephemeral session if no persistent session exists if session == nil { - session = newStreamableHttpSession(sessionID, s.sessionTools, s.sessionResources, s.sessionLogLevels) + session = newStreamableHttpSession(sessionID, s.sessionTools, s.sessionResources, s.sessionResourceTemplates, s.sessionLogLevels) } // Set the client context before handling the message @@ -480,7 +482,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) // Get or create session atomically to prevent TOCTOU races // where concurrent GETs could both create and register duplicate sessions var session *streamableHttpSession - newSession := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionResources, s.sessionLogLevels) + newSession := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionResources, s.sessionResourceTemplates, s.sessionLogLevels) actual, loaded := s.activeSessions.LoadOrStore(sessionID, newSession) session = actual.(*streamableHttpSession) @@ -622,6 +624,7 @@ func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Reque // remove the session relateddata from the sessionToolsStore s.sessionTools.delete(sessionID) s.sessionResources.delete(sessionID) + s.sessionResourceTemplates.delete(sessionID) s.sessionLogLevels.delete(sessionID) // remove current session's requstID information s.sessionRequestIDs.Delete(sessionID) @@ -834,6 +837,39 @@ func (s *sessionResourcesStore) delete(sessionID string) { delete(s.resources, sessionID) } +type sessionResourceTemplatesStore struct { + mu sync.RWMutex + templates map[string]map[string]ServerResourceTemplate // sessionID -> uriTemplate -> template +} + +func newSessionResourceTemplatesStore() *sessionResourceTemplatesStore { + return &sessionResourceTemplatesStore{ + templates: make(map[string]map[string]ServerResourceTemplate), + } +} + +func (s *sessionResourceTemplatesStore) get(sessionID string) map[string]ServerResourceTemplate { + s.mu.RLock() + defer s.mu.RUnlock() + cloned := make(map[string]ServerResourceTemplate, len(s.templates[sessionID])) + maps.Copy(cloned, s.templates[sessionID]) + return cloned +} + +func (s *sessionResourceTemplatesStore) set(sessionID string, templates map[string]ServerResourceTemplate) { + s.mu.Lock() + defer s.mu.Unlock() + cloned := make(map[string]ServerResourceTemplate, len(templates)) + maps.Copy(cloned, templates) + s.templates[sessionID] = cloned +} + +func (s *sessionResourceTemplatesStore) delete(sessionID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.templates, sessionID) +} + type sessionToolsStore struct { mu sync.RWMutex tools map[string]map[string]ServerTool // sessionID -> toolName -> tool @@ -895,6 +931,7 @@ type streamableHttpSession struct { notificationChannel chan mcp.JSONRPCNotification // server -> client notifications tools *sessionToolsStore resources *sessionResourcesStore + resourceTemplates *sessionResourceTemplatesStore upgradeToSSE atomic.Bool logLevels *sessionLogLevelsStore @@ -906,12 +943,13 @@ type streamableHttpSession struct { requestIDCounter atomic.Int64 // for generating unique request IDs } -func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, resourcesStore *sessionResourcesStore, levels *sessionLogLevelsStore) *streamableHttpSession { +func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, resourcesStore *sessionResourcesStore, templatesStore *sessionResourceTemplatesStore, levels *sessionLogLevelsStore) *streamableHttpSession { s := &streamableHttpSession{ sessionID: sessionID, notificationChannel: make(chan mcp.JSONRPCNotification, 100), tools: toolStore, resources: resourcesStore, + resourceTemplates: templatesStore, logLevels: levels, samplingRequestChan: make(chan samplingRequestItem, 10), elicitationRequestChan: make(chan elicitationRequestItem, 10), @@ -963,10 +1001,19 @@ func (s *streamableHttpSession) SetSessionResources(resources map[string]ServerR s.resources.set(s.sessionID, resources) } +func (s *streamableHttpSession) GetSessionResourceTemplates() map[string]ServerResourceTemplate { + return s.resourceTemplates.get(s.sessionID) +} + +func (s *streamableHttpSession) SetSessionResourceTemplates(templates map[string]ServerResourceTemplate) { + s.resourceTemplates.set(s.sessionID, templates) +} + var ( - _ SessionWithTools = (*streamableHttpSession)(nil) - _ SessionWithResources = (*streamableHttpSession)(nil) - _ SessionWithLogging = (*streamableHttpSession)(nil) + _ SessionWithTools = (*streamableHttpSession)(nil) + _ SessionWithResources = (*streamableHttpSession)(nil) + _ SessionWithResourceTemplates = (*streamableHttpSession)(nil) + _ SessionWithLogging = (*streamableHttpSession)(nil) ) func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() { diff --git a/server/streamable_http_sampling_test.go b/server/streamable_http_sampling_test.go index 5c9ed1e24..6f7af6833 100644 --- a/server/streamable_http_sampling_test.go +++ b/server/streamable_http_sampling_test.go @@ -26,7 +26,7 @@ func TestStreamableHTTPServer_SamplingBasic(t *testing.T) { // Test session creation and interface implementation sessionID := "test-session" - session := newStreamableHttpSession(sessionID, httpServer.sessionTools, httpServer.sessionResources, httpServer.sessionLogLevels) + session := newStreamableHttpSession(sessionID, httpServer.sessionTools, httpServer.sessionResources, httpServer.sessionResourceTemplates, httpServer.sessionLogLevels) // Verify it implements SessionWithSampling _, ok := any(session).(SessionWithSampling) @@ -139,7 +139,7 @@ func TestStreamableHTTPServer_SamplingInterface(t *testing.T) { // Create a session sessionID := "test-session" - session := newStreamableHttpSession(sessionID, httpServer.sessionTools, httpServer.sessionResources, httpServer.sessionLogLevels) + session := newStreamableHttpSession(sessionID, httpServer.sessionTools, httpServer.sessionResources, httpServer.sessionResourceTemplates, httpServer.sessionLogLevels) // Verify it implements SessionWithSampling _, ok := any(session).(SessionWithSampling) @@ -178,7 +178,7 @@ func TestStreamableHTTPServer_SamplingInterface(t *testing.T) { // TestStreamableHTTPServer_SamplingQueueFull tests queue overflow scenarios func TestStreamableHTTPServer_SamplingQueueFull(t *testing.T) { sessionID := "test-session" - session := newStreamableHttpSession(sessionID, nil, nil, nil) + session := newStreamableHttpSession(sessionID, nil, nil, nil, nil) // Fill the sampling request queue for i := 0; i < cap(session.samplingRequestChan); i++ {