Skip to content

Commit 9a1f382

Browse files
add test GetSessionPrompt
1 parent 5f64c16 commit 9a1f382

File tree

2 files changed

+94
-3
lines changed

2 files changed

+94
-3
lines changed

server/server.go

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -891,9 +891,28 @@ func (s *MCPServer) handleGetPrompt(
891891
id any,
892892
request mcp.GetPromptRequest,
893893
) (*mcp.GetPromptResult, *requestError) {
894-
s.promptsMu.RLock()
895-
handler, ok := s.promptHandlers[request.Params.Name]
896-
s.promptsMu.RUnlock()
894+
// First check session-specific prompts
895+
var handler PromptHandlerFunc
896+
var ok bool
897+
898+
session := ClientSessionFromContext(ctx)
899+
if session != nil {
900+
if sessionWithPrompts, typeAssertOk := session.(SessionWithPrompts); typeAssertOk {
901+
if sessionPrompts := sessionWithPrompts.GetSessionPrompts(); sessionPrompts != nil {
902+
if serverPrompt, sessionOk := sessionPrompts[request.Params.Name]; sessionOk {
903+
handler = serverPrompt.Handler
904+
ok = true
905+
}
906+
}
907+
}
908+
}
909+
910+
// If not found in session prompts, check global prompts
911+
if !ok {
912+
s.promptsMu.RLock()
913+
handler, ok = s.promptHandlers[request.Params.Name]
914+
s.promptsMu.RUnlock()
915+
}
897916

898917
if !ok {
899918
return nil, &requestError{

server/session_test.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,78 @@ func TestMCPServer_CallSessionTool(t *testing.T) {
751751
}
752752
}
753753

754+
func TestMCPServer_GetSessionPrompt(t *testing.T) {
755+
server := NewMCPServer("test-server", "1.0.0", WithPromptCapabilities(true))
756+
757+
// Add global prompt
758+
server.AddPrompt(mcp.NewPrompt("test_prompt"), func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
759+
return mcp.NewGetPromptResult("global result", []mcp.PromptMessage{
760+
{
761+
Role: mcp.RoleUser,
762+
Content: mcp.TextContent{Text: "global result"},
763+
},
764+
}), nil
765+
})
766+
767+
// Create a session
768+
sessionChan := make(chan mcp.JSONRPCNotification, 10)
769+
session := &sessionTestClientWithPrompts{
770+
sessionID: "session-1",
771+
notificationChannel: sessionChan,
772+
initialized: true,
773+
}
774+
775+
// Register the session
776+
err := server.RegisterSession(context.Background(), session)
777+
require.NoError(t, err)
778+
779+
// Add session-specific prompt with the same name to override the global prompt
780+
err = server.AddSessionPrompt(
781+
session.SessionID(),
782+
mcp.NewPrompt("test_prompt"),
783+
func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
784+
return mcp.NewGetPromptResult("session result", []mcp.PromptMessage{
785+
{
786+
Role: mcp.RoleUser,
787+
Content: mcp.TextContent{Text: "session result"},
788+
},
789+
}), nil
790+
},
791+
)
792+
require.NoError(t, err)
793+
794+
// Get the prompt using session context
795+
sessionCtx := server.WithContext(context.Background(), session)
796+
toolRequest := map[string]any{
797+
"jsonrpc": "2.0",
798+
"id": 1,
799+
"method": "prompts/get",
800+
"params": map[string]any{
801+
"name": "test_prompt",
802+
},
803+
}
804+
requestBytes, err := json.Marshal(toolRequest)
805+
if err != nil {
806+
t.Fatalf("Failed to marshal prompt request: %v", err)
807+
}
808+
809+
response := server.HandleMessage(sessionCtx, requestBytes)
810+
resp, ok := response.(mcp.JSONRPCResponse)
811+
assert.True(t, ok)
812+
813+
getPromptResult, ok := resp.Result.(mcp.GetPromptResult)
814+
assert.True(t, ok)
815+
816+
// Since we specify a prompt with the same name for current session, the expected text should be "session result"
817+
if textContent, ok := getPromptResult.Messages[0].Content.(mcp.TextContent); ok {
818+
if textContent.Text != "session result" {
819+
t.Errorf("Expected result 'session result', got %q", textContent.Text)
820+
}
821+
} else {
822+
t.Error("Expected TextContent")
823+
}
824+
}
825+
754826
func TestMCPServer_DeleteSessionTools(t *testing.T) {
755827
server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true))
756828
ctx := context.Background()

0 commit comments

Comments
 (0)