Skip to content

Commit 178bd07

Browse files
add SessionWithPromps on sse server
1 parent 0e24d31 commit 178bd07

File tree

2 files changed

+128
-0
lines changed

2 files changed

+128
-0
lines changed

server/sse.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ type sseSession struct {
2929
initialized atomic.Bool
3030
loggingLevel atomic.Value
3131
tools sync.Map // stores session-specific tools
32+
prompts sync.Map // stores session-specific prompts
3233
clientInfo atomic.Value // stores session-specific client info
3334
}
3435

@@ -74,6 +75,17 @@ func (s *sseSession) GetLogLevel() mcp.LoggingLevel {
7475
return level.(mcp.LoggingLevel)
7576
}
7677

78+
func (s *sseSession) GetSessionPrompts() map[string]ServerPrompt {
79+
prompts := make(map[string]ServerPrompt)
80+
s.prompts.Range(func(key, value any) bool {
81+
if prompt, ok := value.(ServerPrompt); ok {
82+
prompts[key.(string)] = prompt
83+
}
84+
return true
85+
})
86+
return prompts
87+
}
88+
7789
func (s *sseSession) GetSessionTools() map[string]ServerTool {
7890
tools := make(map[string]ServerTool)
7991
s.tools.Range(func(key, value any) bool {
@@ -85,6 +97,16 @@ func (s *sseSession) GetSessionTools() map[string]ServerTool {
8597
return tools
8698
}
8799

100+
func (s *sseSession) SetSessionPrompts(prompts map[string]ServerPrompt) {
101+
// Clear existing prompts
102+
s.prompts.Clear()
103+
104+
// Set new prompts
105+
for name, prompt := range prompts {
106+
s.prompts.Store(name, prompt)
107+
}
108+
}
109+
88110
func (s *sseSession) SetSessionTools(tools map[string]ServerTool) {
89111
// Clear existing tools
90112
s.tools.Clear()

server/sse_test.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,6 +1141,112 @@ func TestSSEServer(t *testing.T) {
11411141
}
11421142
})
11431143

1144+
t.Run("TestSessionWithPrompts", func(t *testing.T) {
1145+
// Create hooks to track sessions
1146+
hooks := &Hooks{}
1147+
var registeredSession *sseSession
1148+
hooks.AddOnRegisterSession(func(ctx context.Context, session ClientSession) {
1149+
if s, ok := session.(*sseSession); ok {
1150+
registeredSession = s
1151+
}
1152+
})
1153+
1154+
mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks))
1155+
testServer := NewTestServer(mcpServer)
1156+
defer testServer.Close()
1157+
1158+
// Connect to SSE endpoint
1159+
sseResp, err := http.Get(fmt.Sprintf("%s/sse", testServer.URL))
1160+
if err != nil {
1161+
t.Fatalf("Failed to connect to SSE endpoint: %v", err)
1162+
}
1163+
defer sseResp.Body.Close()
1164+
1165+
// Read the endpoint event to ensure session is established
1166+
_, err = readSSEEvent(sseResp)
1167+
if err != nil {
1168+
t.Fatalf("Failed to read SSE response: %v", err)
1169+
}
1170+
1171+
// Verify we got a session
1172+
if registeredSession == nil {
1173+
t.Fatal("Session was not registered via hook")
1174+
}
1175+
1176+
// Test setting and getting prompts
1177+
prompts := map[string]ServerPrompt{
1178+
"test_prompt": {
1179+
Prompt: mcp.Prompt{
1180+
Name: "test_prompt",
1181+
Description: "A test prompt",
1182+
},
1183+
Handler: func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
1184+
return mcp.NewGetPromptResult("test", []mcp.PromptMessage{
1185+
{
1186+
Role: mcp.RoleUser,
1187+
Content: mcp.TextContent{Text: "test"},
1188+
},
1189+
}), nil
1190+
},
1191+
},
1192+
}
1193+
1194+
// Test SetSessionPrompts
1195+
registeredSession.SetSessionPrompts(prompts)
1196+
1197+
// Test GetSessionPrompts
1198+
retrievedPrompts := registeredSession.GetSessionPrompts()
1199+
if len(retrievedPrompts) != 1 {
1200+
t.Errorf("Expected 1 prompt, got %d", len(retrievedPrompts))
1201+
}
1202+
if prompt, exists := retrievedPrompts["test_prompt"]; !exists {
1203+
t.Error("Expected test_prompt to exist")
1204+
} else if prompt.Prompt.Name != "test_prompt" {
1205+
t.Errorf("Expected prompt name test_prompt, got %s", prompt.Prompt.Name)
1206+
}
1207+
1208+
// Test concurrent access
1209+
var wg sync.WaitGroup
1210+
for i := 0; i < 10; i++ {
1211+
wg.Add(2)
1212+
go func(i int) {
1213+
defer wg.Done()
1214+
prompts := map[string]ServerPrompt{
1215+
fmt.Sprintf("prompt_%d", i): {
1216+
Prompt: mcp.Prompt{
1217+
Name: fmt.Sprintf("prompt_%d", i),
1218+
Description: fmt.Sprintf("Prompt %d", i),
1219+
},
1220+
},
1221+
}
1222+
registeredSession.SetSessionPrompts(prompts)
1223+
}(i)
1224+
go func() {
1225+
defer wg.Done()
1226+
_ = registeredSession.GetSessionTools()
1227+
}()
1228+
}
1229+
wg.Wait()
1230+
1231+
// Verify we can still get and set tools after concurrent access
1232+
finalPrompts := map[string]ServerPrompt{
1233+
"final_prompt": {
1234+
Prompt: mcp.Prompt{
1235+
Name: "final_prompt",
1236+
Description: "Final Prompt",
1237+
},
1238+
},
1239+
}
1240+
registeredSession.SetSessionPrompts(finalPrompts)
1241+
retrievedPrompts = registeredSession.GetSessionPrompts()
1242+
if len(retrievedPrompts) != 1 {
1243+
t.Errorf("Expected 1 prompt, got %d", len(retrievedPrompts))
1244+
}
1245+
if _, exists := retrievedPrompts["final_prompt"]; !exists {
1246+
t.Error("Expected final_prompt to exist")
1247+
}
1248+
})
1249+
11441250
t.Run("SessionWithTools implementation", func(t *testing.T) {
11451251
// Create hooks to track sessions
11461252
hooks := &Hooks{}

0 commit comments

Comments
 (0)