Skip to content

Commit 0a1a9e9

Browse files
committed
add test for continuesly listening
1 parent 6a05fc6 commit 0a1a9e9

File tree

2 files changed

+262
-2
lines changed

2 files changed

+262
-2
lines changed

client/transport/streamable_http.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,11 +456,14 @@ func (c *StreamableHTTP) listenForever() {
456456
if err != nil {
457457
c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err)
458458
}
459-
time.Sleep(1 * time.Second)
459+
time.Sleep(retryInterval)
460460
}
461461
}
462462

463-
var errGetMethodNotAllowed = fmt.Errorf("GET method not allowed")
463+
var (
464+
errGetMethodNotAllowed = fmt.Errorf("GET method not allowed")
465+
retryInterval = 1 * time.Second // a variable is convenient for testing
466+
)
464467

465468
func (c *StreamableHTTP) createGETConnectionToServer() error {
466469

client/transport/streamable_http_test.go

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"net/http"
99
"net/http/httptest"
10+
"strings"
1011
"sync"
1112
"testing"
1213
"time"
@@ -448,3 +449,259 @@ func TestStreamableHTTPErrors(t *testing.T) {
448449
})
449450

450451
}
452+
453+
// ---- continuous listening tests ----
454+
455+
// startMockStreamableWithGETSupport starts a test HTTP server that implements
456+
// a minimal Streamable HTTP server for testing purposes with support for GET requests
457+
// to test the continuous listening feature.
458+
func startMockStreamableWithGETSupport(getSupport bool) (string, func(), chan bool, int) {
459+
var sessionID string
460+
var mu sync.Mutex
461+
disconnectCh := make(chan bool, 1)
462+
notificationCount := 0
463+
var notificationMu sync.Mutex
464+
465+
sendNotification := func() {
466+
notificationMu.Lock()
467+
notificationCount++
468+
notificationMu.Unlock()
469+
}
470+
471+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
472+
// Handle POST requests for initialization
473+
if r.Method == http.MethodPost {
474+
// Parse incoming JSON-RPC request
475+
var request map[string]any
476+
decoder := json.NewDecoder(r.Body)
477+
if err := decoder.Decode(&request); err != nil {
478+
http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest)
479+
return
480+
}
481+
482+
method := request["method"]
483+
if method == "initialize" {
484+
// Generate a new session ID
485+
mu.Lock()
486+
sessionID = fmt.Sprintf("test-session-%d", time.Now().UnixNano())
487+
mu.Unlock()
488+
w.Header().Set("Mcp-Session-Id", sessionID)
489+
w.Header().Set("Content-Type", "application/json")
490+
w.WriteHeader(http.StatusAccepted)
491+
if err := json.NewEncoder(w).Encode(map[string]any{
492+
"jsonrpc": "2.0",
493+
"id": request["id"],
494+
"result": "initialized",
495+
}); err != nil {
496+
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
497+
return
498+
}
499+
}
500+
return
501+
}
502+
503+
// Handle GET requests for continuous listening
504+
if r.Method == http.MethodGet {
505+
if !getSupport {
506+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
507+
return
508+
}
509+
510+
// Check session ID
511+
if recvSessionID := r.Header.Get("Mcp-Session-Id"); recvSessionID != sessionID {
512+
http.Error(w, "Invalid session ID", http.StatusNotFound)
513+
return
514+
}
515+
516+
// Setup SSE connection
517+
w.Header().Set("Content-Type", "text/event-stream")
518+
w.WriteHeader(http.StatusOK)
519+
flusher, ok := w.(http.Flusher)
520+
if !ok {
521+
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
522+
return
523+
}
524+
525+
// Send a notification
526+
notification := map[string]any{
527+
"jsonrpc": "2.0",
528+
"method": "test/notification",
529+
"params": map[string]any{"message": "Hello from server"},
530+
}
531+
notificationData, _ := json.Marshal(notification)
532+
fmt.Fprintf(w, "event: message\ndata: %s\n\n", notificationData)
533+
flusher.Flush()
534+
sendNotification()
535+
536+
// Keep the connection open or disconnect as requested
537+
select {
538+
case <-disconnectCh:
539+
// Force disconnect
540+
return
541+
case <-r.Context().Done():
542+
// Client disconnected
543+
return
544+
case <-time.After(50 * time.Millisecond):
545+
// Send another notification
546+
notification = map[string]any{
547+
"jsonrpc": "2.0",
548+
"method": "test/notification",
549+
"params": map[string]any{"message": "Second notification"},
550+
}
551+
notificationData, _ = json.Marshal(notification)
552+
fmt.Fprintf(w, "event: message\ndata: %s\n\n", notificationData)
553+
flusher.Flush()
554+
sendNotification()
555+
return
556+
}
557+
} else {
558+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
559+
return
560+
}
561+
})
562+
563+
// Start test server
564+
testServer := httptest.NewServer(handler)
565+
566+
notificationMu.Lock()
567+
defer notificationMu.Unlock()
568+
569+
return testServer.URL, testServer.Close, disconnectCh, notificationCount
570+
}
571+
572+
func TestContinuousListening(t *testing.T) {
573+
retryInterval = 10 * time.Millisecond
574+
// Start mock server with GET support
575+
url, closeServer, disconnectCh, _ := startMockStreamableWithGETSupport(true)
576+
577+
// Create transport with continuous listening enabled
578+
trans, err := NewStreamableHTTP(url, WithContinuousListening())
579+
if err != nil {
580+
t.Fatal(err)
581+
}
582+
583+
// Ensure transport is closed before server to avoid connection refused errors
584+
defer func() {
585+
trans.Close()
586+
closeServer()
587+
}()
588+
589+
// Setup notification handler
590+
notificationReceived := make(chan struct{}, 10)
591+
trans.SetNotificationHandler(func(notification mcp.JSONRPCNotification) {
592+
notificationReceived <- struct{}{}
593+
})
594+
595+
// Initialize the transport first
596+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
597+
defer cancel()
598+
599+
initRequest := JSONRPCRequest{
600+
JSONRPC: "2.0",
601+
ID: mcp.NewRequestId(int64(0)),
602+
Method: "initialize",
603+
}
604+
605+
_, err = trans.SendRequest(ctx, initRequest)
606+
if err != nil {
607+
t.Fatal(err)
608+
}
609+
610+
// Start the transport - this will launch listenForever in a goroutine
611+
if err := trans.Start(context.Background()); err != nil {
612+
t.Fatal(err)
613+
}
614+
615+
// Wait for notifications to be received
616+
notificationCount := 0
617+
for notificationCount < 2 {
618+
select {
619+
case <-notificationReceived:
620+
notificationCount++
621+
case <-time.After(3 * time.Second):
622+
t.Fatalf("Timed out waiting for notifications, received %d", notificationCount)
623+
return
624+
}
625+
}
626+
627+
// Test server disconnect and reconnect
628+
disconnectCh <- true
629+
time.Sleep(50 * time.Millisecond) // Allow time for reconnection
630+
631+
// Verify reconnect occurred by receiving more notifications
632+
reconnectNotificationCount := 0
633+
for reconnectNotificationCount < 2 {
634+
select {
635+
case <-notificationReceived:
636+
reconnectNotificationCount++
637+
case <-time.After(3 * time.Second):
638+
t.Fatalf("Timed out waiting for notifications after reconnect")
639+
return
640+
}
641+
}
642+
}
643+
644+
func TestContinuousListeningMethodNotAllowed(t *testing.T) {
645+
646+
// Start a server that doesn't support GET
647+
url, closeServer, _, _ := startMockStreamableWithGETSupport(false)
648+
649+
// Setup logger to capture log messages
650+
logChan := make(chan string, 10)
651+
testLogger := &testLogger{logChan: logChan}
652+
653+
// Create transport with continuous listening enabled and custom logger
654+
trans, err := NewStreamableHTTP(url, WithContinuousListening(), WithLogger(testLogger))
655+
if err != nil {
656+
t.Fatal(err)
657+
}
658+
659+
// Ensure transport is closed before server to avoid connection refused errors
660+
defer func() {
661+
trans.Close()
662+
closeServer()
663+
}()
664+
665+
// Initialize the transport first
666+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
667+
defer cancel()
668+
669+
// Start the transport
670+
if err := trans.Start(context.Background()); err != nil {
671+
t.Fatal(err)
672+
}
673+
674+
initRequest := JSONRPCRequest{
675+
JSONRPC: "2.0",
676+
ID: mcp.NewRequestId(int64(0)),
677+
Method: "initialize",
678+
}
679+
680+
_, err = trans.SendRequest(ctx, initRequest)
681+
if err != nil {
682+
t.Fatal(err)
683+
}
684+
685+
// Wait for the error log message that server doesn't support listening
686+
select {
687+
case logMsg := <-logChan:
688+
if !strings.Contains(logMsg, "server does not support listening") {
689+
t.Errorf("Expected error log about server not supporting listening, got: %s", logMsg)
690+
}
691+
case <-time.After(5 * time.Second):
692+
t.Fatal("Timeout waiting for log message")
693+
}
694+
}
695+
696+
// testLogger is a simple logger for testing
697+
type testLogger struct {
698+
logChan chan string
699+
}
700+
701+
func (l *testLogger) Infof(format string, args ...any) {
702+
// Intentionally left empty
703+
}
704+
705+
func (l *testLogger) Errorf(format string, args ...any) {
706+
l.logChan <- fmt.Sprintf(format, args...)
707+
}

0 commit comments

Comments
 (0)