Skip to content

Commit 9de3423

Browse files
committed
feat: add SessionWithResourceTemplates for session-specific resource templates
Implements session-specific resource templates to achieve parity with SessionWithTools and SessionWithResources. This allows sessions to have their own resource templates that override global templates with the same URI pattern. Key changes: - Add SessionWithResourceTemplates interface to ClientSession hierarchy - Implement interface in both SSE and StreamableHTTP transports - Add AddSessionResourceTemplate(s) and DeleteSessionResourceTemplates methods - Update handleListResourceTemplates to merge session and global templates - Update handleReadResource to check session templates before global ones - Session templates trigger notifications/resources/list_changed when modified Closes #622
1 parent 8b7d60c commit 9de3423

File tree

6 files changed

+299
-46
lines changed

6 files changed

+299
-46
lines changed

server/errors.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@ var (
1313
ErrToolNotFound = errors.New("tool not found")
1414

1515
// Session-related errors
16-
ErrSessionNotFound = errors.New("session not found")
17-
ErrSessionExists = errors.New("session already exists")
18-
ErrSessionNotInitialized = errors.New("session not properly initialized")
19-
ErrSessionDoesNotSupportTools = errors.New("session does not support per-session tools")
20-
ErrSessionDoesNotSupportResources = errors.New("session does not support per-session resources")
21-
ErrSessionDoesNotSupportLogging = errors.New("session does not support setting logging level")
16+
ErrSessionNotFound = errors.New("session not found")
17+
ErrSessionExists = errors.New("session already exists")
18+
ErrSessionNotInitialized = errors.New("session not properly initialized")
19+
ErrSessionDoesNotSupportTools = errors.New("session does not support per-session tools")
20+
ErrSessionDoesNotSupportResources = errors.New("session does not support per-session resources")
21+
ErrSessionDoesNotSupportResourceTemplates = errors.New("session does not support resource templates")
22+
ErrSessionDoesNotSupportLogging = errors.New("session does not support setting logging level")
2223

2324
// Notification-related errors
2425
ErrNotificationNotInitialized = errors.New("notification channel not initialized")

server/server.go

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -880,12 +880,34 @@ func (s *MCPServer) handleListResourceTemplates(
880880
id any,
881881
request mcp.ListResourceTemplatesRequest,
882882
) (*mcp.ListResourceTemplatesResult, *requestError) {
883+
// Get global templates
883884
s.resourcesMu.RLock()
884-
templates := make([]mcp.ResourceTemplate, 0, len(s.resourceTemplates))
885-
for _, entry := range s.resourceTemplates {
886-
templates = append(templates, entry.template)
885+
templateMap := make(map[string]mcp.ResourceTemplate, len(s.resourceTemplates))
886+
for uri, entry := range s.resourceTemplates {
887+
templateMap[uri] = entry.template
887888
}
888889
s.resourcesMu.RUnlock()
890+
891+
// Check if there are session-specific resource templates
892+
session := ClientSessionFromContext(ctx)
893+
if session != nil {
894+
if sessionWithTemplates, ok := session.(SessionWithResourceTemplates); ok {
895+
if sessionTemplates := sessionWithTemplates.GetSessionResourceTemplates(); sessionTemplates != nil {
896+
// Merge session-specific templates with global templates
897+
// Session templates override global ones
898+
for uriTemplate, serverTemplate := range sessionTemplates {
899+
templateMap[uriTemplate] = serverTemplate.Template
900+
}
901+
}
902+
}
903+
}
904+
905+
// Convert map to slice for sorting and pagination
906+
templates := make([]mcp.ResourceTemplate, 0, len(templateMap))
907+
for _, template := range templateMap {
908+
templates = append(templates, template)
909+
}
910+
889911
sort.Slice(templates, func(i, j int) bool {
890912
return templates[i].Name < templates[j].Name
891913
})
@@ -971,18 +993,42 @@ func (s *MCPServer) handleReadResource(
971993
// If no direct handler found, try matching against templates
972994
var matchedHandler ResourceTemplateHandlerFunc
973995
var matched bool
974-
for _, entry := range s.resourceTemplates {
975-
template := entry.template
976-
if matchesTemplate(request.Params.URI, template.URITemplate) {
977-
matchedHandler = entry.handler
978-
matched = true
979-
matchedVars := template.URITemplate.Match(request.Params.URI)
980-
// Convert matched variables to a map
981-
request.Params.Arguments = make(map[string]any, len(matchedVars))
982-
for name, value := range matchedVars {
983-
request.Params.Arguments[name] = value.V
996+
997+
// First check session templates if available
998+
if session != nil {
999+
if sessionWithTemplates, ok := session.(SessionWithResourceTemplates); ok {
1000+
sessionTemplates := sessionWithTemplates.GetSessionResourceTemplates()
1001+
for _, serverTemplate := range sessionTemplates {
1002+
if matchesTemplate(request.Params.URI, serverTemplate.Template.URITemplate) {
1003+
matchedHandler = serverTemplate.Handler
1004+
matched = true
1005+
matchedVars := serverTemplate.Template.URITemplate.Match(request.Params.URI)
1006+
// Convert matched variables to a map
1007+
request.Params.Arguments = make(map[string]any, len(matchedVars))
1008+
for name, value := range matchedVars {
1009+
request.Params.Arguments[name] = value.V
1010+
}
1011+
break
1012+
}
1013+
}
1014+
}
1015+
}
1016+
1017+
// If not found in session templates, check global templates
1018+
if !matched {
1019+
for _, entry := range s.resourceTemplates {
1020+
template := entry.template
1021+
if matchesTemplate(request.Params.URI, template.URITemplate) {
1022+
matchedHandler = entry.handler
1023+
matched = true
1024+
matchedVars := template.URITemplate.Match(request.Params.URI)
1025+
// Convert matched variables to a map
1026+
request.Params.Arguments = make(map[string]any, len(matchedVars))
1027+
for name, value := range matchedVars {
1028+
request.Params.Arguments[name] = value.V
1029+
}
1030+
break
9841031
}
985-
break
9861032
}
9871033
}
9881034
s.resourcesMu.RUnlock()

server/session.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,17 @@ type SessionWithResources interface {
5151
SetSessionResources(resources map[string]ServerResource)
5252
}
5353

54+
// SessionWithResourceTemplates is an extension of ClientSession that can store session-specific resource template data
55+
type SessionWithResourceTemplates interface {
56+
ClientSession
57+
// GetSessionResourceTemplates returns the resource templates specific to this session, if any
58+
// This method must be thread-safe for concurrent access
59+
GetSessionResourceTemplates() map[string]ServerResourceTemplate
60+
// SetSessionResourceTemplates sets resource templates specific to this session
61+
// This method must be thread-safe for concurrent access
62+
SetSessionResourceTemplates(templates map[string]ServerResourceTemplate)
63+
}
64+
5465
// SessionWithClientInfo is an extension of ClientSession that can store client info
5566
type SessionWithClientInfo interface {
5667
ClientSession
@@ -613,3 +624,128 @@ func (s *MCPServer) DeleteSessionResources(sessionID string, uris ...string) err
613624

614625
return nil
615626
}
627+
628+
// AddSessionResourceTemplate adds a resource template for a specific session
629+
func (s *MCPServer) AddSessionResourceTemplate(sessionID string, template mcp.ResourceTemplate, handler ResourceTemplateHandlerFunc) error {
630+
return s.AddSessionResourceTemplates(sessionID, ServerResourceTemplate{
631+
Template: template,
632+
Handler: handler,
633+
})
634+
}
635+
636+
// AddSessionResourceTemplates adds resource templates for a specific session
637+
func (s *MCPServer) AddSessionResourceTemplates(sessionID string, templates ...ServerResourceTemplate) error {
638+
sessionValue, ok := s.sessions.Load(sessionID)
639+
if !ok {
640+
return ErrSessionNotFound
641+
}
642+
643+
session, ok := sessionValue.(SessionWithResourceTemplates)
644+
if !ok {
645+
return ErrSessionDoesNotSupportResourceTemplates
646+
}
647+
648+
// For session resource templates, enable listChanged by default
649+
// This is the same behavior as session resources
650+
s.implicitlyRegisterCapabilities(
651+
func() bool { return s.capabilities.resources != nil },
652+
func() { s.capabilities.resources = &resourceCapabilities{listChanged: true} },
653+
)
654+
655+
// Get existing templates (this returns a thread-safe copy)
656+
sessionTemplates := session.GetSessionResourceTemplates()
657+
658+
// Create a new map to avoid modifying the returned copy
659+
newTemplates := make(map[string]ServerResourceTemplate, len(sessionTemplates)+len(templates))
660+
661+
// Copy existing templates
662+
for k, v := range sessionTemplates {
663+
newTemplates[k] = v
664+
}
665+
666+
// Add new templates
667+
for _, template := range templates {
668+
key := template.Template.URITemplate.Raw()
669+
newTemplates[key] = template
670+
}
671+
672+
// Set the new templates (this method must handle thread-safety)
673+
session.SetSessionResourceTemplates(newTemplates)
674+
675+
// Send notification if the session is initialized and listChanged is enabled
676+
if session.Initialized() && s.capabilities.resources != nil && s.capabilities.resources.listChanged {
677+
if err := s.SendNotificationToSpecificClient(sessionID, "notifications/resources/list_changed", nil); err != nil {
678+
// Log the error but don't fail the operation
679+
if s.hooks != nil && len(s.hooks.OnError) > 0 {
680+
hooks := s.hooks
681+
go func(sID string, hooks *Hooks) {
682+
ctx := context.Background()
683+
hooks.onError(ctx, nil, "notification", map[string]any{
684+
"method": "notifications/resources/list_changed",
685+
"sessionID": sID,
686+
}, fmt.Errorf("failed to send notification after adding resource templates: %w", err))
687+
}(sessionID, hooks)
688+
}
689+
}
690+
}
691+
692+
return nil
693+
}
694+
695+
// DeleteSessionResourceTemplates removes resource templates from a specific session
696+
func (s *MCPServer) DeleteSessionResourceTemplates(sessionID string, uriTemplates ...string) error {
697+
sessionValue, ok := s.sessions.Load(sessionID)
698+
if !ok {
699+
return ErrSessionNotFound
700+
}
701+
702+
session, ok := sessionValue.(SessionWithResourceTemplates)
703+
if !ok {
704+
return ErrSessionDoesNotSupportResourceTemplates
705+
}
706+
707+
// Get existing templates (this returns a thread-safe copy)
708+
sessionTemplates := session.GetSessionResourceTemplates()
709+
710+
// Track if any were actually deleted
711+
deletedAny := false
712+
713+
// Create a new map without the deleted templates
714+
newTemplates := make(map[string]ServerResourceTemplate, len(sessionTemplates))
715+
for k, v := range sessionTemplates {
716+
newTemplates[k] = v
717+
}
718+
719+
// Delete specified templates
720+
for _, uriTemplate := range uriTemplates {
721+
if _, exists := newTemplates[uriTemplate]; exists {
722+
delete(newTemplates, uriTemplate)
723+
deletedAny = true
724+
}
725+
}
726+
727+
// Only update if something was actually deleted
728+
if deletedAny {
729+
// Set the new templates (this method must handle thread-safety)
730+
session.SetSessionResourceTemplates(newTemplates)
731+
732+
// Send notification if the session is initialized and listChanged is enabled
733+
if session.Initialized() && s.capabilities.resources != nil && s.capabilities.resources.listChanged {
734+
if err := s.SendNotificationToSpecificClient(sessionID, "notifications/resources/list_changed", nil); err != nil {
735+
// Log the error but don't fail the operation
736+
if s.hooks != nil && len(s.hooks.OnError) > 0 {
737+
hooks := s.hooks
738+
go func(sID string, hooks *Hooks) {
739+
ctx := context.Background()
740+
hooks.onError(ctx, nil, "notification", map[string]any{
741+
"method": "notifications/resources/list_changed",
742+
"sessionID": sID,
743+
}, fmt.Errorf("failed to send notification after deleting resource templates: %w", err))
744+
}(sessionID, hooks)
745+
}
746+
}
747+
}
748+
}
749+
750+
return nil
751+
}

server/sse.go

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ type sseSession struct {
3030
loggingLevel atomic.Value
3131
tools sync.Map // stores session-specific tools
3232
resources sync.Map // stores session-specific resources
33+
resourceTemplates sync.Map // stores session-specific resource templates
3334
clientInfo atomic.Value // stores session-specific client info
3435
clientCapabilities atomic.Value // stores session-specific client capabilities
3536
}
@@ -97,6 +98,27 @@ func (s *sseSession) SetSessionResources(resources map[string]ServerResource) {
9798
}
9899
}
99100

101+
func (s *sseSession) GetSessionResourceTemplates() map[string]ServerResourceTemplate {
102+
templates := make(map[string]ServerResourceTemplate)
103+
s.resourceTemplates.Range(func(key, value any) bool {
104+
if template, ok := value.(ServerResourceTemplate); ok {
105+
templates[key.(string)] = template
106+
}
107+
return true
108+
})
109+
return templates
110+
}
111+
112+
func (s *sseSession) SetSessionResourceTemplates(templates map[string]ServerResourceTemplate) {
113+
// Clear existing templates
114+
s.resourceTemplates.Clear()
115+
116+
// Set new templates
117+
for uriTemplate, template := range templates {
118+
s.resourceTemplates.Store(uriTemplate, template)
119+
}
120+
}
121+
100122
func (s *sseSession) GetSessionTools() map[string]ServerTool {
101123
tools := make(map[string]ServerTool)
102124
s.tools.Range(func(key, value any) bool {
@@ -145,11 +167,12 @@ func (s *sseSession) GetClientCapabilities() mcp.ClientCapabilities {
145167
}
146168

147169
var (
148-
_ ClientSession = (*sseSession)(nil)
149-
_ SessionWithTools = (*sseSession)(nil)
150-
_ SessionWithResources = (*sseSession)(nil)
151-
_ SessionWithLogging = (*sseSession)(nil)
152-
_ SessionWithClientInfo = (*sseSession)(nil)
170+
_ ClientSession = (*sseSession)(nil)
171+
_ SessionWithTools = (*sseSession)(nil)
172+
_ SessionWithResources = (*sseSession)(nil)
173+
_ SessionWithResourceTemplates = (*sseSession)(nil)
174+
_ SessionWithLogging = (*sseSession)(nil)
175+
_ SessionWithClientInfo = (*sseSession)(nil)
153176
)
154177

155178
// SSEServer implements a Server-Sent Events (SSE) based MCP server.

0 commit comments

Comments
 (0)