@@ -1141,6 +1141,112 @@ func TestSSEServer(t *testing.T) {
11411141 }
11421142 })
11431143
1144+ t .Run ("TestSessionWithPrompts" , func (t * testing.T ) {
1145+ // Create hooks to track sessions
1146+ hooks := & Hooks {}
1147+ var registeredSession * sseSession
1148+ hooks .AddOnRegisterSession (func (ctx context.Context , session ClientSession ) {
1149+ if s , ok := session .(* sseSession ); ok {
1150+ registeredSession = s
1151+ }
1152+ })
1153+
1154+ mcpServer := NewMCPServer ("test" , "1.0.0" , WithHooks (hooks ))
1155+ testServer := NewTestServer (mcpServer )
1156+ defer testServer .Close ()
1157+
1158+ // Connect to SSE endpoint
1159+ sseResp , err := http .Get (fmt .Sprintf ("%s/sse" , testServer .URL ))
1160+ if err != nil {
1161+ t .Fatalf ("Failed to connect to SSE endpoint: %v" , err )
1162+ }
1163+ defer sseResp .Body .Close ()
1164+
1165+ // Read the endpoint event to ensure session is established
1166+ _ , err = readSSEEvent (sseResp )
1167+ if err != nil {
1168+ t .Fatalf ("Failed to read SSE response: %v" , err )
1169+ }
1170+
1171+ // Verify we got a session
1172+ if registeredSession == nil {
1173+ t .Fatal ("Session was not registered via hook" )
1174+ }
1175+
1176+ // Test setting and getting prompts
1177+ prompts := map [string ]ServerPrompt {
1178+ "test_prompt" : {
1179+ Prompt : mcp.Prompt {
1180+ Name : "test_prompt" ,
1181+ Description : "A test prompt" ,
1182+ },
1183+ Handler : func (ctx context.Context , request mcp.GetPromptRequest ) (* mcp.GetPromptResult , error ) {
1184+ return mcp .NewGetPromptResult ("test" , []mcp.PromptMessage {
1185+ {
1186+ Role : mcp .RoleUser ,
1187+ Content : mcp.TextContent {Text : "test" },
1188+ },
1189+ }), nil
1190+ },
1191+ },
1192+ }
1193+
1194+ // Test SetSessionPrompts
1195+ registeredSession .SetSessionPrompts (prompts )
1196+
1197+ // Test GetSessionPrompts
1198+ retrievedPrompts := registeredSession .GetSessionPrompts ()
1199+ if len (retrievedPrompts ) != 1 {
1200+ t .Errorf ("Expected 1 prompt, got %d" , len (retrievedPrompts ))
1201+ }
1202+ if prompt , exists := retrievedPrompts ["test_prompt" ]; ! exists {
1203+ t .Error ("Expected test_prompt to exist" )
1204+ } else if prompt .Prompt .Name != "test_prompt" {
1205+ t .Errorf ("Expected prompt name test_prompt, got %s" , prompt .Prompt .Name )
1206+ }
1207+
1208+ // Test concurrent access
1209+ var wg sync.WaitGroup
1210+ for i := 0 ; i < 10 ; i ++ {
1211+ wg .Add (2 )
1212+ go func (i int ) {
1213+ defer wg .Done ()
1214+ prompts := map [string ]ServerPrompt {
1215+ fmt .Sprintf ("prompt_%d" , i ): {
1216+ Prompt : mcp.Prompt {
1217+ Name : fmt .Sprintf ("prompt_%d" , i ),
1218+ Description : fmt .Sprintf ("Prompt %d" , i ),
1219+ },
1220+ },
1221+ }
1222+ registeredSession .SetSessionPrompts (prompts )
1223+ }(i )
1224+ go func () {
1225+ defer wg .Done ()
1226+ _ = registeredSession .GetSessionTools ()
1227+ }()
1228+ }
1229+ wg .Wait ()
1230+
1231+ // Verify we can still get and set tools after concurrent access
1232+ finalPrompts := map [string ]ServerPrompt {
1233+ "final_prompt" : {
1234+ Prompt : mcp.Prompt {
1235+ Name : "final_prompt" ,
1236+ Description : "Final Prompt" ,
1237+ },
1238+ },
1239+ }
1240+ registeredSession .SetSessionPrompts (finalPrompts )
1241+ retrievedPrompts = registeredSession .GetSessionPrompts ()
1242+ if len (retrievedPrompts ) != 1 {
1243+ t .Errorf ("Expected 1 prompt, got %d" , len (retrievedPrompts ))
1244+ }
1245+ if _ , exists := retrievedPrompts ["final_prompt" ]; ! exists {
1246+ t .Error ("Expected final_prompt to exist" )
1247+ }
1248+ })
1249+
11441250 t .Run ("SessionWithTools implementation" , func (t * testing.T ) {
11451251 // Create hooks to track sessions
11461252 hooks := & Hooks {}
0 commit comments