@@ -276,4 +276,221 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) {
276276 if ! handlerSet {
277277 t .Error ("Request handler was not properly set or called" )
278278 }
279+ }
280+
281+ // TestStreamableHTTP_ConcurrentSamplingRequests tests concurrent sampling requests
282+ // where the second request completes faster than the first request
283+ func TestStreamableHTTP_ConcurrentSamplingRequests (t * testing.T ) {
284+ var receivedResponses []map [string ]interface {}
285+ var responseMutex sync.Mutex
286+ responseComplete := make (chan struct {}, 2 )
287+
288+ server := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
289+ if r .Method == http .MethodPost {
290+ var body map [string ]interface {}
291+ if err := json .NewDecoder (r .Body ).Decode (& body ); err != nil {
292+ t .Logf ("Failed to decode body: %v" , err )
293+ w .WriteHeader (http .StatusBadRequest )
294+ return
295+ }
296+
297+ // Check if this is a response from client (not a request)
298+ if _ , ok := body ["result" ]; ok {
299+ responseMutex .Lock ()
300+ receivedResponses = append (receivedResponses , body )
301+ responseMutex .Unlock ()
302+ responseComplete <- struct {}{}
303+ }
304+ }
305+ w .WriteHeader (http .StatusOK )
306+ }))
307+ defer server .Close ()
308+
309+ client , err := NewStreamableHTTP (server .URL )
310+ if err != nil {
311+ t .Fatalf ("Failed to create client: %v" , err )
312+ }
313+ defer client .Close ()
314+
315+ // Track which requests have been received and their completion order
316+ var requestOrder []int
317+ var orderMutex sync.Mutex
318+
319+ // Set up request handler that simulates different processing times
320+ client .SetRequestHandler (func (ctx context.Context , request JSONRPCRequest ) (* JSONRPCResponse , error ) {
321+ // Extract request ID to determine processing time
322+ requestIDValue := request .ID .Value ()
323+
324+ var delay time.Duration
325+ var responseText string
326+ var requestNum int
327+
328+ // First request (ID 1) takes longer, second request (ID 2) completes faster
329+ if requestIDValue == int64 (1 ) {
330+ delay = 100 * time .Millisecond
331+ responseText = "Response from slow request 1"
332+ requestNum = 1
333+ } else if requestIDValue == int64 (2 ) {
334+ delay = 10 * time .Millisecond
335+ responseText = "Response from fast request 2"
336+ requestNum = 2
337+ } else {
338+ t .Errorf ("Unexpected request ID: %v" , requestIDValue )
339+ return nil , fmt .Errorf ("unexpected request ID" )
340+ }
341+
342+ // Simulate processing time
343+ time .Sleep (delay )
344+
345+ // Record completion order
346+ orderMutex .Lock ()
347+ requestOrder = append (requestOrder , requestNum )
348+ orderMutex .Unlock ()
349+
350+ // Return response with correct request ID
351+ result := map [string ]interface {}{
352+ "role" : "assistant" ,
353+ "content" : map [string ]interface {}{
354+ "type" : "text" ,
355+ "text" : responseText ,
356+ },
357+ "model" : "test-model" ,
358+ "stopReason" : "stop_sequence" ,
359+ }
360+
361+ resultBytes , _ := json .Marshal (result )
362+
363+ return & JSONRPCResponse {
364+ JSONRPC : "2.0" ,
365+ ID : request .ID ,
366+ Result : resultBytes ,
367+ }, nil
368+ })
369+
370+ // Start the client
371+ ctx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
372+ defer cancel ()
373+
374+ err = client .Start (ctx )
375+ if err != nil {
376+ t .Fatalf ("Failed to start client: %v" , err )
377+ }
378+
379+ // Create two sampling requests with different IDs
380+ request1 := JSONRPCRequest {
381+ JSONRPC : "2.0" ,
382+ ID : mcp .NewRequestId (int64 (1 )),
383+ Method : string (mcp .MethodSamplingCreateMessage ),
384+ Params : map [string ]interface {}{
385+ "messages" : []map [string ]interface {}{
386+ {
387+ "role" : "user" ,
388+ "content" : map [string ]interface {}{
389+ "type" : "text" ,
390+ "text" : "Slow request 1" ,
391+ },
392+ },
393+ },
394+ },
395+ }
396+
397+ request2 := JSONRPCRequest {
398+ JSONRPC : "2.0" ,
399+ ID : mcp .NewRequestId (int64 (2 )),
400+ Method : string (mcp .MethodSamplingCreateMessage ),
401+ Params : map [string ]interface {}{
402+ "messages" : []map [string ]interface {}{
403+ {
404+ "role" : "user" ,
405+ "content" : map [string ]interface {}{
406+ "type" : "text" ,
407+ "text" : "Fast request 2" ,
408+ },
409+ },
410+ },
411+ },
412+ }
413+
414+ // Send both requests concurrently
415+ go client .handleIncomingRequest (ctx , request1 )
416+ go client .handleIncomingRequest (ctx , request2 )
417+
418+ // Wait for both responses to complete
419+ for i := 0 ; i < 2 ; i ++ {
420+ select {
421+ case <- responseComplete :
422+ // Response received
423+ case <- time .After (2 * time .Second ):
424+ t .Fatal ("Timeout waiting for response" )
425+ }
426+ }
427+
428+ // Verify completion order: request 2 should complete first
429+ orderMutex .Lock ()
430+ defer orderMutex .Unlock ()
431+
432+ if len (requestOrder ) != 2 {
433+ t .Fatalf ("Expected 2 completed requests, got %d" , len (requestOrder ))
434+ }
435+
436+ if requestOrder [0 ] != 2 {
437+ t .Errorf ("Expected request 2 to complete first, but request %d completed first" , requestOrder [0 ])
438+ }
439+
440+ if requestOrder [1 ] != 1 {
441+ t .Errorf ("Expected request 1 to complete second, but request %d completed second" , requestOrder [1 ])
442+ }
443+
444+ // Verify responses are correctly associated
445+ responseMutex .Lock ()
446+ defer responseMutex .Unlock ()
447+
448+ if len (receivedResponses ) != 2 {
449+ t .Fatalf ("Expected 2 responses, got %d" , len (receivedResponses ))
450+ }
451+
452+ // Find responses by ID
453+ var response1 , response2 map [string ]interface {}
454+ for _ , resp := range receivedResponses {
455+ if id , ok := resp ["id" ]; ok {
456+ switch id {
457+ case int64 (1 ), float64 (1 ):
458+ response1 = resp
459+ case int64 (2 ), float64 (2 ):
460+ response2 = resp
461+ }
462+ }
463+ }
464+
465+ if response1 == nil {
466+ t .Error ("Response for request 1 not found" )
467+ }
468+ if response2 == nil {
469+ t .Error ("Response for request 2 not found" )
470+ }
471+
472+ // Verify each response contains the correct content
473+ if response1 != nil {
474+ if result , ok := response1 ["result" ].(map [string ]interface {}); ok {
475+ if content , ok := result ["content" ].(map [string ]interface {}); ok {
476+ if text , ok := content ["text" ].(string ); ok {
477+ if ! strings .Contains (text , "slow request 1" ) {
478+ t .Errorf ("Response 1 should contain 'slow request 1', got: %s" , text )
479+ }
480+ }
481+ }
482+ }
483+ }
484+
485+ if response2 != nil {
486+ if result , ok := response2 ["result" ].(map [string ]interface {}); ok {
487+ if content , ok := result ["content" ].(map [string ]interface {}); ok {
488+ if text , ok := content ["text" ].(string ); ok {
489+ if ! strings .Contains (text , "fast request 2" ) {
490+ t .Errorf ("Response 2 should contain 'fast request 2', got: %s" , text )
491+ }
492+ }
493+ }
494+ }
495+ }
279496}
0 commit comments