Skip to content

Commit e69716d

Browse files
andigclaude
andcommitted
test: add concurrent sampling requests test with response association
Add test verifying that concurrent sampling requests are handled correctly when the second request completes faster than the first. The test ensures: - Responses are correctly associated with their request IDs - Server processes requests concurrently without blocking - Completion order follows actual processing time, not submission order 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent bac5dad commit e69716d

File tree

1 file changed

+217
-0
lines changed

1 file changed

+217
-0
lines changed

client/transport/streamable_http_sampling_test.go

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,4 +276,221 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) {
276276
if !handlerSet {
277277
t.Error("Request handler was not properly set or called")
278278
}
279+
}
280+
281+
// TestStreamableHTTP_ConcurrentSamplingRequests tests concurrent sampling requests
282+
// where the second request completes faster than the first request
283+
func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) {
284+
var receivedResponses []map[string]interface{}
285+
var responseMutex sync.Mutex
286+
responseComplete := make(chan struct{}, 2)
287+
288+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
289+
if r.Method == http.MethodPost {
290+
var body map[string]interface{}
291+
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
292+
t.Logf("Failed to decode body: %v", err)
293+
w.WriteHeader(http.StatusBadRequest)
294+
return
295+
}
296+
297+
// Check if this is a response from client (not a request)
298+
if _, ok := body["result"]; ok {
299+
responseMutex.Lock()
300+
receivedResponses = append(receivedResponses, body)
301+
responseMutex.Unlock()
302+
responseComplete <- struct{}{}
303+
}
304+
}
305+
w.WriteHeader(http.StatusOK)
306+
}))
307+
defer server.Close()
308+
309+
client, err := NewStreamableHTTP(server.URL)
310+
if err != nil {
311+
t.Fatalf("Failed to create client: %v", err)
312+
}
313+
defer client.Close()
314+
315+
// Track which requests have been received and their completion order
316+
var requestOrder []int
317+
var orderMutex sync.Mutex
318+
319+
// Set up request handler that simulates different processing times
320+
client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) {
321+
// Extract request ID to determine processing time
322+
requestIDValue := request.ID.Value()
323+
324+
var delay time.Duration
325+
var responseText string
326+
var requestNum int
327+
328+
// First request (ID 1) takes longer, second request (ID 2) completes faster
329+
if requestIDValue == int64(1) {
330+
delay = 100 * time.Millisecond
331+
responseText = "Response from slow request 1"
332+
requestNum = 1
333+
} else if requestIDValue == int64(2) {
334+
delay = 10 * time.Millisecond
335+
responseText = "Response from fast request 2"
336+
requestNum = 2
337+
} else {
338+
t.Errorf("Unexpected request ID: %v", requestIDValue)
339+
return nil, fmt.Errorf("unexpected request ID")
340+
}
341+
342+
// Simulate processing time
343+
time.Sleep(delay)
344+
345+
// Record completion order
346+
orderMutex.Lock()
347+
requestOrder = append(requestOrder, requestNum)
348+
orderMutex.Unlock()
349+
350+
// Return response with correct request ID
351+
result := map[string]interface{}{
352+
"role": "assistant",
353+
"content": map[string]interface{}{
354+
"type": "text",
355+
"text": responseText,
356+
},
357+
"model": "test-model",
358+
"stopReason": "stop_sequence",
359+
}
360+
361+
resultBytes, _ := json.Marshal(result)
362+
363+
return &JSONRPCResponse{
364+
JSONRPC: "2.0",
365+
ID: request.ID,
366+
Result: resultBytes,
367+
}, nil
368+
})
369+
370+
// Start the client
371+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
372+
defer cancel()
373+
374+
err = client.Start(ctx)
375+
if err != nil {
376+
t.Fatalf("Failed to start client: %v", err)
377+
}
378+
379+
// Create two sampling requests with different IDs
380+
request1 := JSONRPCRequest{
381+
JSONRPC: "2.0",
382+
ID: mcp.NewRequestId(int64(1)),
383+
Method: string(mcp.MethodSamplingCreateMessage),
384+
Params: map[string]interface{}{
385+
"messages": []map[string]interface{}{
386+
{
387+
"role": "user",
388+
"content": map[string]interface{}{
389+
"type": "text",
390+
"text": "Slow request 1",
391+
},
392+
},
393+
},
394+
},
395+
}
396+
397+
request2 := JSONRPCRequest{
398+
JSONRPC: "2.0",
399+
ID: mcp.NewRequestId(int64(2)),
400+
Method: string(mcp.MethodSamplingCreateMessage),
401+
Params: map[string]interface{}{
402+
"messages": []map[string]interface{}{
403+
{
404+
"role": "user",
405+
"content": map[string]interface{}{
406+
"type": "text",
407+
"text": "Fast request 2",
408+
},
409+
},
410+
},
411+
},
412+
}
413+
414+
// Send both requests concurrently
415+
go client.handleIncomingRequest(ctx, request1)
416+
go client.handleIncomingRequest(ctx, request2)
417+
418+
// Wait for both responses to complete
419+
for i := 0; i < 2; i++ {
420+
select {
421+
case <-responseComplete:
422+
// Response received
423+
case <-time.After(2 * time.Second):
424+
t.Fatal("Timeout waiting for response")
425+
}
426+
}
427+
428+
// Verify completion order: request 2 should complete first
429+
orderMutex.Lock()
430+
defer orderMutex.Unlock()
431+
432+
if len(requestOrder) != 2 {
433+
t.Fatalf("Expected 2 completed requests, got %d", len(requestOrder))
434+
}
435+
436+
if requestOrder[0] != 2 {
437+
t.Errorf("Expected request 2 to complete first, but request %d completed first", requestOrder[0])
438+
}
439+
440+
if requestOrder[1] != 1 {
441+
t.Errorf("Expected request 1 to complete second, but request %d completed second", requestOrder[1])
442+
}
443+
444+
// Verify responses are correctly associated
445+
responseMutex.Lock()
446+
defer responseMutex.Unlock()
447+
448+
if len(receivedResponses) != 2 {
449+
t.Fatalf("Expected 2 responses, got %d", len(receivedResponses))
450+
}
451+
452+
// Find responses by ID
453+
var response1, response2 map[string]interface{}
454+
for _, resp := range receivedResponses {
455+
if id, ok := resp["id"]; ok {
456+
switch id {
457+
case int64(1), float64(1):
458+
response1 = resp
459+
case int64(2), float64(2):
460+
response2 = resp
461+
}
462+
}
463+
}
464+
465+
if response1 == nil {
466+
t.Error("Response for request 1 not found")
467+
}
468+
if response2 == nil {
469+
t.Error("Response for request 2 not found")
470+
}
471+
472+
// Verify each response contains the correct content
473+
if response1 != nil {
474+
if result, ok := response1["result"].(map[string]interface{}); ok {
475+
if content, ok := result["content"].(map[string]interface{}); ok {
476+
if text, ok := content["text"].(string); ok {
477+
if !strings.Contains(text, "slow request 1") {
478+
t.Errorf("Response 1 should contain 'slow request 1', got: %s", text)
479+
}
480+
}
481+
}
482+
}
483+
}
484+
485+
if response2 != nil {
486+
if result, ok := response2["result"].(map[string]interface{}); ok {
487+
if content, ok := result["content"].(map[string]interface{}); ok {
488+
if text, ok := content["text"].(string); ok {
489+
if !strings.Contains(text, "fast request 2") {
490+
t.Errorf("Response 2 should contain 'fast request 2', got: %s", text)
491+
}
492+
}
493+
}
494+
}
495+
}
279496
}

0 commit comments

Comments
 (0)