|
7 | 7 | "fmt" |
8 | 8 | "net/http" |
9 | 9 | "net/http/httptest" |
| 10 | + "strings" |
10 | 11 | "sync" |
11 | 12 | "testing" |
12 | 13 | "time" |
@@ -448,3 +449,259 @@ func TestStreamableHTTPErrors(t *testing.T) { |
448 | 449 | }) |
449 | 450 |
|
450 | 451 | } |
| 452 | + |
| 453 | +// ---- continuous listening tests ---- |
| 454 | + |
| 455 | +// startMockStreamableWithGETSupport starts a test HTTP server that implements |
| 456 | +// a minimal Streamable HTTP server for testing purposes with support for GET requests |
| 457 | +// to test the continuous listening feature. |
| 458 | +func startMockStreamableWithGETSupport(getSupport bool) (string, func(), chan bool, int) { |
| 459 | + var sessionID string |
| 460 | + var mu sync.Mutex |
| 461 | + disconnectCh := make(chan bool, 1) |
| 462 | + notificationCount := 0 |
| 463 | + var notificationMu sync.Mutex |
| 464 | + |
| 465 | + sendNotification := func() { |
| 466 | + notificationMu.Lock() |
| 467 | + notificationCount++ |
| 468 | + notificationMu.Unlock() |
| 469 | + } |
| 470 | + |
| 471 | + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 472 | + // Handle POST requests for initialization |
| 473 | + if r.Method == http.MethodPost { |
| 474 | + // Parse incoming JSON-RPC request |
| 475 | + var request map[string]any |
| 476 | + decoder := json.NewDecoder(r.Body) |
| 477 | + if err := decoder.Decode(&request); err != nil { |
| 478 | + http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) |
| 479 | + return |
| 480 | + } |
| 481 | + |
| 482 | + method := request["method"] |
| 483 | + if method == "initialize" { |
| 484 | + // Generate a new session ID |
| 485 | + mu.Lock() |
| 486 | + sessionID = fmt.Sprintf("test-session-%d", time.Now().UnixNano()) |
| 487 | + mu.Unlock() |
| 488 | + w.Header().Set("Mcp-Session-Id", sessionID) |
| 489 | + w.Header().Set("Content-Type", "application/json") |
| 490 | + w.WriteHeader(http.StatusAccepted) |
| 491 | + if err := json.NewEncoder(w).Encode(map[string]any{ |
| 492 | + "jsonrpc": "2.0", |
| 493 | + "id": request["id"], |
| 494 | + "result": "initialized", |
| 495 | + }); err != nil { |
| 496 | + http.Error(w, "Failed to encode response", http.StatusInternalServerError) |
| 497 | + return |
| 498 | + } |
| 499 | + } |
| 500 | + return |
| 501 | + } |
| 502 | + |
| 503 | + // Handle GET requests for continuous listening |
| 504 | + if r.Method == http.MethodGet { |
| 505 | + if !getSupport { |
| 506 | + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) |
| 507 | + return |
| 508 | + } |
| 509 | + |
| 510 | + // Check session ID |
| 511 | + if recvSessionID := r.Header.Get("Mcp-Session-Id"); recvSessionID != sessionID { |
| 512 | + http.Error(w, "Invalid session ID", http.StatusNotFound) |
| 513 | + return |
| 514 | + } |
| 515 | + |
| 516 | + // Setup SSE connection |
| 517 | + w.Header().Set("Content-Type", "text/event-stream") |
| 518 | + w.WriteHeader(http.StatusOK) |
| 519 | + flusher, ok := w.(http.Flusher) |
| 520 | + if !ok { |
| 521 | + http.Error(w, "Streaming not supported", http.StatusInternalServerError) |
| 522 | + return |
| 523 | + } |
| 524 | + |
| 525 | + // Send a notification |
| 526 | + notification := map[string]any{ |
| 527 | + "jsonrpc": "2.0", |
| 528 | + "method": "test/notification", |
| 529 | + "params": map[string]any{"message": "Hello from server"}, |
| 530 | + } |
| 531 | + notificationData, _ := json.Marshal(notification) |
| 532 | + fmt.Fprintf(w, "event: message\ndata: %s\n\n", notificationData) |
| 533 | + flusher.Flush() |
| 534 | + sendNotification() |
| 535 | + |
| 536 | + // Keep the connection open or disconnect as requested |
| 537 | + select { |
| 538 | + case <-disconnectCh: |
| 539 | + // Force disconnect |
| 540 | + return |
| 541 | + case <-r.Context().Done(): |
| 542 | + // Client disconnected |
| 543 | + return |
| 544 | + case <-time.After(50 * time.Millisecond): |
| 545 | + // Send another notification |
| 546 | + notification = map[string]any{ |
| 547 | + "jsonrpc": "2.0", |
| 548 | + "method": "test/notification", |
| 549 | + "params": map[string]any{"message": "Second notification"}, |
| 550 | + } |
| 551 | + notificationData, _ = json.Marshal(notification) |
| 552 | + fmt.Fprintf(w, "event: message\ndata: %s\n\n", notificationData) |
| 553 | + flusher.Flush() |
| 554 | + sendNotification() |
| 555 | + return |
| 556 | + } |
| 557 | + } else { |
| 558 | + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) |
| 559 | + return |
| 560 | + } |
| 561 | + }) |
| 562 | + |
| 563 | + // Start test server |
| 564 | + testServer := httptest.NewServer(handler) |
| 565 | + |
| 566 | + notificationMu.Lock() |
| 567 | + defer notificationMu.Unlock() |
| 568 | + |
| 569 | + return testServer.URL, testServer.Close, disconnectCh, notificationCount |
| 570 | +} |
| 571 | + |
| 572 | +func TestContinuousListening(t *testing.T) { |
| 573 | + retryInterval = 10 * time.Millisecond |
| 574 | + // Start mock server with GET support |
| 575 | + url, closeServer, disconnectCh, _ := startMockStreamableWithGETSupport(true) |
| 576 | + |
| 577 | + // Create transport with continuous listening enabled |
| 578 | + trans, err := NewStreamableHTTP(url, WithContinuousListening()) |
| 579 | + if err != nil { |
| 580 | + t.Fatal(err) |
| 581 | + } |
| 582 | + |
| 583 | + // Ensure transport is closed before server to avoid connection refused errors |
| 584 | + defer func() { |
| 585 | + trans.Close() |
| 586 | + closeServer() |
| 587 | + }() |
| 588 | + |
| 589 | + // Setup notification handler |
| 590 | + notificationReceived := make(chan struct{}, 10) |
| 591 | + trans.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { |
| 592 | + notificationReceived <- struct{}{} |
| 593 | + }) |
| 594 | + |
| 595 | + // Initialize the transport first |
| 596 | + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) |
| 597 | + defer cancel() |
| 598 | + |
| 599 | + initRequest := JSONRPCRequest{ |
| 600 | + JSONRPC: "2.0", |
| 601 | + ID: mcp.NewRequestId(int64(0)), |
| 602 | + Method: "initialize", |
| 603 | + } |
| 604 | + |
| 605 | + _, err = trans.SendRequest(ctx, initRequest) |
| 606 | + if err != nil { |
| 607 | + t.Fatal(err) |
| 608 | + } |
| 609 | + |
| 610 | + // Start the transport - this will launch listenForever in a goroutine |
| 611 | + if err := trans.Start(context.Background()); err != nil { |
| 612 | + t.Fatal(err) |
| 613 | + } |
| 614 | + |
| 615 | + // Wait for notifications to be received |
| 616 | + notificationCount := 0 |
| 617 | + for notificationCount < 2 { |
| 618 | + select { |
| 619 | + case <-notificationReceived: |
| 620 | + notificationCount++ |
| 621 | + case <-time.After(3 * time.Second): |
| 622 | + t.Fatalf("Timed out waiting for notifications, received %d", notificationCount) |
| 623 | + return |
| 624 | + } |
| 625 | + } |
| 626 | + |
| 627 | + // Test server disconnect and reconnect |
| 628 | + disconnectCh <- true |
| 629 | + time.Sleep(50 * time.Millisecond) // Allow time for reconnection |
| 630 | + |
| 631 | + // Verify reconnect occurred by receiving more notifications |
| 632 | + reconnectNotificationCount := 0 |
| 633 | + for reconnectNotificationCount < 2 { |
| 634 | + select { |
| 635 | + case <-notificationReceived: |
| 636 | + reconnectNotificationCount++ |
| 637 | + case <-time.After(3 * time.Second): |
| 638 | + t.Fatalf("Timed out waiting for notifications after reconnect") |
| 639 | + return |
| 640 | + } |
| 641 | + } |
| 642 | +} |
| 643 | + |
| 644 | +func TestContinuousListeningMethodNotAllowed(t *testing.T) { |
| 645 | + |
| 646 | + // Start a server that doesn't support GET |
| 647 | + url, closeServer, _, _ := startMockStreamableWithGETSupport(false) |
| 648 | + |
| 649 | + // Setup logger to capture log messages |
| 650 | + logChan := make(chan string, 10) |
| 651 | + testLogger := &testLogger{logChan: logChan} |
| 652 | + |
| 653 | + // Create transport with continuous listening enabled and custom logger |
| 654 | + trans, err := NewStreamableHTTP(url, WithContinuousListening(), WithLogger(testLogger)) |
| 655 | + if err != nil { |
| 656 | + t.Fatal(err) |
| 657 | + } |
| 658 | + |
| 659 | + // Ensure transport is closed before server to avoid connection refused errors |
| 660 | + defer func() { |
| 661 | + trans.Close() |
| 662 | + closeServer() |
| 663 | + }() |
| 664 | + |
| 665 | + // Initialize the transport first |
| 666 | + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) |
| 667 | + defer cancel() |
| 668 | + |
| 669 | + // Start the transport |
| 670 | + if err := trans.Start(context.Background()); err != nil { |
| 671 | + t.Fatal(err) |
| 672 | + } |
| 673 | + |
| 674 | + initRequest := JSONRPCRequest{ |
| 675 | + JSONRPC: "2.0", |
| 676 | + ID: mcp.NewRequestId(int64(0)), |
| 677 | + Method: "initialize", |
| 678 | + } |
| 679 | + |
| 680 | + _, err = trans.SendRequest(ctx, initRequest) |
| 681 | + if err != nil { |
| 682 | + t.Fatal(err) |
| 683 | + } |
| 684 | + |
| 685 | + // Wait for the error log message that server doesn't support listening |
| 686 | + select { |
| 687 | + case logMsg := <-logChan: |
| 688 | + if !strings.Contains(logMsg, "server does not support listening") { |
| 689 | + t.Errorf("Expected error log about server not supporting listening, got: %s", logMsg) |
| 690 | + } |
| 691 | + case <-time.After(5 * time.Second): |
| 692 | + t.Fatal("Timeout waiting for log message") |
| 693 | + } |
| 694 | +} |
| 695 | + |
| 696 | +// testLogger is a simple logger for testing |
| 697 | +type testLogger struct { |
| 698 | + logChan chan string |
| 699 | +} |
| 700 | + |
| 701 | +func (l *testLogger) Infof(format string, args ...any) { |
| 702 | + // Intentionally left empty |
| 703 | +} |
| 704 | + |
| 705 | +func (l *testLogger) Errorf(format string, args ...any) { |
| 706 | + l.logChan <- fmt.Sprintf(format, args...) |
| 707 | +} |
0 commit comments