Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -1389,6 +1389,26 @@ type StreamableClientTransport struct {
// It defaults to 5. To disable retries, use a negative number.
MaxRetries int

// DisableStandaloneSSE controls whether the client establishes a standalone SSE stream
// for receiving server-initiated messages.
//
// When false (the default), after initialization the client sends an HTTP GET request
// to establish a persistent server-sent events (SSE) connection. This allows the server
// to send messages to the client at any time, such as ToolListChangedNotification or
// other server-initiated requests and notifications. The connection persists for the
// lifetime of the session and automatically reconnects if interrupted.
//
// When true, the client does not establish the standalone SSE stream. The client will
// only receive responses to its own POST requests. Server-initiated messages will not
// be received.
//
// According to the MCP specification, the standalone SSE stream is optional.
// Setting DisableStandaloneSSE to true is useful when:
// - You only need request-response communication and don't need server-initiated notifications
// - The server doesn't properly handle GET requests for SSE streams
// - You want to avoid maintaining a persistent connection
DisableStandaloneSSE bool

// TODO(rfindley): propose exporting these.
// If strict is set, the transport is in 'strict mode', where any violation
// of the MCP spec causes a failure.
Expand Down Expand Up @@ -1453,16 +1473,17 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
// middleware), yet only cancel the standalone stream when the connection is closed.
connCtx, cancel := context.WithCancel(xcontext.Detach(ctx))
conn := &streamableClientConn{
url: t.Endpoint,
client: client,
incoming: make(chan jsonrpc.Message, 10),
done: make(chan struct{}),
maxRetries: maxRetries,
strict: t.strict,
logger: ensureLogger(t.logger), // must be non-nil for safe logging
ctx: connCtx,
cancel: cancel,
failed: make(chan struct{}),
url: t.Endpoint,
client: client,
incoming: make(chan jsonrpc.Message, 10),
done: make(chan struct{}),
maxRetries: maxRetries,
strict: t.strict,
logger: ensureLogger(t.logger), // must be non-nil for safe logging
ctx: connCtx,
cancel: cancel,
failed: make(chan struct{}),
disableStandaloneSSE: t.DisableStandaloneSSE,
}
return conn, nil
}
Expand All @@ -1477,6 +1498,10 @@ type streamableClientConn struct {
strict bool // from [StreamableClientTransport.strict]
logger *slog.Logger // from [StreamableClientTransport.logger]

// disableStandaloneSSE controls whether to disable the standalone SSE stream
// for receiving server-to-client notifications when no request is in flight.
disableStandaloneSSE bool // from [StreamableClientTransport.DisableStandaloneSSE]

// Guard calls to Close, as it may be called multiple times.
closeOnce sync.Once
closeErr error
Expand Down Expand Up @@ -1518,7 +1543,7 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) {
c.mu.Unlock()

// Start the standalone SSE stream as soon as we have the initialized
// result.
// result, if continuous listening is enabled.
//
// § 2.2: The client MAY issue an HTTP GET to the MCP endpoint. This can be
// used to open an SSE stream, allowing the server to communicate to the
Expand All @@ -1528,9 +1553,11 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) {
// initialized, we don't know whether the server requires a sessionID.
//
// § 2.5: A server using the Streamable HTTP transport MAY assign a session
// ID at initialization time, by including it in an Mcp-Session-Id header
// ID at initialization time, by including it in a Mcp-Session-Id header
// on the HTTP response containing the InitializeResult.
c.connectStandaloneSSE()
if !c.disableStandaloneSSE {
c.connectStandaloneSSE()
}
}

func (c *streamableClientConn) connectStandaloneSSE() {
Expand Down
119 changes: 119 additions & 0 deletions mcp/streamable_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"time"

"github.com/google/go-cmp/cmp"

"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
)
Expand Down Expand Up @@ -693,3 +694,121 @@ func TestStreamableClientTransientErrors(t *testing.T) {
})
}
}

func TestStreamableClientDisableStandaloneSSE(t *testing.T) {
ctx := context.Background()

tests := []struct {
name string
disableStandaloneSSE bool
expectGETRequest bool
}{
{
name: "default behavior (standalone SSE enabled)",
disableStandaloneSSE: false,
expectGETRequest: true,
},
{
name: "standalone SSE disabled",
disableStandaloneSSE: true,
expectGETRequest: false,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
getRequestKey := streamableRequestKey{"GET", "123", "", ""}

fake := &fakeStreamableServer{
t: t,
responses: fakeResponses{
{"POST", "", methodInitialize, ""}: {
header: header{
"Content-Type": "application/json",
sessionIDHeader: "123",
},
body: jsonBody(t, initResp),
},
{"POST", "123", notificationInitialized, ""}: {
status: http.StatusAccepted,
wantProtocolVersion: latestProtocolVersion,
},
getRequestKey: {
header: header{
"Content-Type": "text/event-stream",
},
wantProtocolVersion: latestProtocolVersion,
optional: !test.expectGETRequest,
},
{"DELETE", "123", "", ""}: {
optional: true,
},
},
}

httpServer := httptest.NewServer(fake)
defer httpServer.Close()

transport := &StreamableClientTransport{
Endpoint: httpServer.URL,
DisableStandaloneSSE: test.disableStandaloneSSE,
}
client := NewClient(testImpl, nil)
session, err := client.Connect(ctx, transport, nil)
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}

// Give some time for the standalone SSE connection to be established (if enabled)
time.Sleep(100 * time.Millisecond)

// Verify the connection state
streamableConn, ok := session.mcpConn.(*streamableClientConn)
if !ok {
t.Fatalf("Expected *streamableClientConn, got %T", session.mcpConn)
}

if got, want := streamableConn.disableStandaloneSSE, test.disableStandaloneSSE; got != want {
t.Errorf("disableStandaloneSSE field: got %v, want %v", got, want)
}

// Clean up
if err := session.Close(); err != nil {
t.Errorf("closing session: %v", err)
}

// Check if GET request was received
fake.calledMu.Lock()
getRequestReceived := false
if fake.called != nil {
getRequestReceived = fake.called[getRequestKey]
}
fake.calledMu.Unlock()

if got, want := getRequestReceived, test.expectGETRequest; got != want {
t.Errorf("GET request received: got %v, want %v", got, want)
}

// If we expected a GET request, verify it was actually received
if test.expectGETRequest {
if missing := fake.missingRequests(); len(missing) > 0 {
// Filter out optional requests
var requiredMissing []streamableRequestKey
for _, key := range missing {
if resp, ok := fake.responses[key]; ok && !resp.optional {
requiredMissing = append(requiredMissing, key)
}
}
if len(requiredMissing) > 0 {
t.Errorf("did not receive expected requests: %v", requiredMissing)
}
}
} else {
// If we didn't expect a GET request, verify it wasn't sent
if getRequestReceived {
t.Error("GET request was sent unexpectedly when DisableStandaloneSSE is true")
}
}
})
}
}