diff --git a/mcp/streamable.go b/mcp/streamable.go index b4b2fa31..ece319d7 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1389,6 +1389,26 @@ type StreamableClientTransport struct { // It defaults to 5. To disable retries, use a negative number. MaxRetries int + // DisableStandaloneSSE controls whether the client establishes a standalone SSE stream + // for receiving server-initiated messages. + // + // When false (the default), after initialization the client sends an HTTP GET request + // to establish a persistent server-sent events (SSE) connection. This allows the server + // to send messages to the client at any time, such as ToolListChangedNotification or + // other server-initiated requests and notifications. The connection persists for the + // lifetime of the session and automatically reconnects if interrupted. + // + // When true, the client does not establish the standalone SSE stream. The client will + // only receive responses to its own POST requests. Server-initiated messages will not + // be received. + // + // According to the MCP specification, the standalone SSE stream is optional. + // Setting DisableStandaloneSSE to true is useful when: + // - You only need request-response communication and don't need server-initiated notifications + // - The server doesn't properly handle GET requests for SSE streams + // - You want to avoid maintaining a persistent connection + DisableStandaloneSSE bool + // TODO(rfindley): propose exporting these. // If strict is set, the transport is in 'strict mode', where any violation // of the MCP spec causes a failure. @@ -1453,16 +1473,17 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er // middleware), yet only cancel the standalone stream when the connection is closed. connCtx, cancel := context.WithCancel(xcontext.Detach(ctx)) conn := &streamableClientConn{ - url: t.Endpoint, - client: client, - incoming: make(chan jsonrpc.Message, 10), - done: make(chan struct{}), - maxRetries: maxRetries, - strict: t.strict, - logger: ensureLogger(t.logger), // must be non-nil for safe logging - ctx: connCtx, - cancel: cancel, - failed: make(chan struct{}), + url: t.Endpoint, + client: client, + incoming: make(chan jsonrpc.Message, 10), + done: make(chan struct{}), + maxRetries: maxRetries, + strict: t.strict, + logger: ensureLogger(t.logger), // must be non-nil for safe logging + ctx: connCtx, + cancel: cancel, + failed: make(chan struct{}), + disableStandaloneSSE: t.DisableStandaloneSSE, } return conn, nil } @@ -1477,6 +1498,10 @@ type streamableClientConn struct { strict bool // from [StreamableClientTransport.strict] logger *slog.Logger // from [StreamableClientTransport.logger] + // disableStandaloneSSE controls whether to disable the standalone SSE stream + // for receiving server-to-client notifications when no request is in flight. + disableStandaloneSSE bool // from [StreamableClientTransport.DisableStandaloneSSE] + // Guard calls to Close, as it may be called multiple times. closeOnce sync.Once closeErr error @@ -1518,7 +1543,7 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { c.mu.Unlock() // Start the standalone SSE stream as soon as we have the initialized - // result. + // result, if continuous listening is enabled. // // § 2.2: The client MAY issue an HTTP GET to the MCP endpoint. This can be // used to open an SSE stream, allowing the server to communicate to the @@ -1528,9 +1553,11 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { // initialized, we don't know whether the server requires a sessionID. // // § 2.5: A server using the Streamable HTTP transport MAY assign a session - // ID at initialization time, by including it in an Mcp-Session-Id header + // ID at initialization time, by including it in a Mcp-Session-Id header // on the HTTP response containing the InitializeResult. - c.connectStandaloneSSE() + if !c.disableStandaloneSSE { + c.connectStandaloneSSE() + } } func (c *streamableClientConn) connectStandaloneSSE() { diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index e2923325..0674171d 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -17,6 +17,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -693,3 +694,121 @@ func TestStreamableClientTransientErrors(t *testing.T) { }) } } + +func TestStreamableClientDisableStandaloneSSE(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + disableStandaloneSSE bool + expectGETRequest bool + }{ + { + name: "default behavior (standalone SSE enabled)", + disableStandaloneSSE: false, + expectGETRequest: true, + }, + { + name: "standalone SSE disabled", + disableStandaloneSSE: true, + expectGETRequest: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + getRequestKey := streamableRequestKey{"GET", "123", "", ""} + + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized, ""}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + getRequestKey: { + header: header{ + "Content-Type": "text/event-stream", + }, + wantProtocolVersion: latestProtocolVersion, + optional: !test.expectGETRequest, + }, + {"DELETE", "123", "", ""}: { + optional: true, + }, + }, + } + + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + DisableStandaloneSSE: test.disableStandaloneSSE, + } + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + + // Give some time for the standalone SSE connection to be established (if enabled) + time.Sleep(100 * time.Millisecond) + + // Verify the connection state + streamableConn, ok := session.mcpConn.(*streamableClientConn) + if !ok { + t.Fatalf("Expected *streamableClientConn, got %T", session.mcpConn) + } + + if got, want := streamableConn.disableStandaloneSSE, test.disableStandaloneSSE; got != want { + t.Errorf("disableStandaloneSSE field: got %v, want %v", got, want) + } + + // Clean up + if err := session.Close(); err != nil { + t.Errorf("closing session: %v", err) + } + + // Check if GET request was received + fake.calledMu.Lock() + getRequestReceived := false + if fake.called != nil { + getRequestReceived = fake.called[getRequestKey] + } + fake.calledMu.Unlock() + + if got, want := getRequestReceived, test.expectGETRequest; got != want { + t.Errorf("GET request received: got %v, want %v", got, want) + } + + // If we expected a GET request, verify it was actually received + if test.expectGETRequest { + if missing := fake.missingRequests(); len(missing) > 0 { + // Filter out optional requests + var requiredMissing []streamableRequestKey + for _, key := range missing { + if resp, ok := fake.responses[key]; ok && !resp.optional { + requiredMissing = append(requiredMissing, key) + } + } + if len(requiredMissing) > 0 { + t.Errorf("did not receive expected requests: %v", requiredMissing) + } + } + } else { + // If we didn't expect a GET request, verify it wasn't sent + if getRequestReceived { + t.Error("GET request was sent unexpectedly when DisableStandaloneSSE is true") + } + } + }) + } +}