Skip to content

Commit 6a05fc6

Browse files
committed
impl WithContinuousListening
1 parent 6449b15 commit 6a05fc6

File tree

1 file changed

+95
-11
lines changed

1 file changed

+95
-11
lines changed

client/transport/streamable_http.go

Lines changed: 95 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"bytes"
66
"context"
77
"encoding/json"
8+
"errors"
89
"fmt"
910
"io"
1011
"mime"
@@ -16,10 +17,23 @@ import (
1617
"time"
1718

1819
"github.com/mark3labs/mcp-go/mcp"
20+
"github.com/mark3labs/mcp-go/util"
1921
)
2022

2123
type StreamableHTTPCOption func(*StreamableHTTP)
2224

25+
// WithContinuousListening enables receiving server-to-client notifications when no request is in flight.
26+
// In particular, if you want to receive global notifications (like ToolListChangedNotification)
27+
// from the server, you should enable this option.
28+
// It will establish a standalone long-live GET HTTP connection to the server.
29+
// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
30+
// NOTICE: Even enabled, the server may not support this feature.
31+
func WithContinuousListening() StreamableHTTPCOption {
32+
return func(sc *StreamableHTTP) {
33+
sc.getListeningEnabled = true
34+
}
35+
}
36+
2337
func WithHTTPHeaders(headers map[string]string) StreamableHTTPCOption {
2438
return func(sc *StreamableHTTP) {
2539
sc.headers = headers
@@ -61,13 +75,15 @@ func WithLogger(logger util.Logger) StreamableHTTPCOption {
6175
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery)
6276
// - server -> client request
6377
type StreamableHTTP struct {
64-
baseURL *url.URL
65-
httpClient *http.Client
66-
headers map[string]string
67-
headerFunc HTTPHeaderFunc
78+
baseURL *url.URL
79+
httpClient *http.Client
80+
headers map[string]string
81+
headerFunc HTTPHeaderFunc
6882
logger util.Logger
83+
getListeningEnabled bool
6984

70-
sessionID atomic.Value // string
85+
initialized chan struct{}
86+
sessionID atomic.Value // string
7187

7288
notificationHandler func(mcp.JSONRPCNotification)
7389
notifyMu sync.RWMutex
@@ -84,11 +100,12 @@ func NewStreamableHTTP(baseURL string, options ...StreamableHTTPCOption) (*Strea
84100
}
85101

86102
smc := &StreamableHTTP{
87-
baseURL: parsedURL,
88-
httpClient: &http.Client{},
89-
headers: make(map[string]string),
90-
closed: make(chan struct{}),
91-
logger: util.DefaultLogger(),
103+
baseURL: parsedURL,
104+
httpClient: &http.Client{},
105+
headers: make(map[string]string),
106+
closed: make(chan struct{}),
107+
logger: util.DefaultLogger(),
108+
initialized: make(chan struct{}),
92109
}
93110
smc.sessionID.Store("") // set initial value to simplify later usage
94111

@@ -101,7 +118,14 @@ func NewStreamableHTTP(baseURL string, options ...StreamableHTTPCOption) (*Strea
101118

102119
// Start initiates the HTTP connection to the server.
103120
func (c *StreamableHTTP) Start(ctx context.Context) error {
104-
// For Streamable HTTP, we don't need to establish a persistent connection
121+
// For Streamable HTTP, we don't need to establish a persistent connection by default
122+
if c.getListeningEnabled {
123+
go func() {
124+
<-c.initialized
125+
c.listenForever()
126+
}()
127+
}
128+
105129
return nil
106130
}
107131

@@ -182,6 +206,8 @@ func (c *StreamableHTTP) SendRequest(
182206
if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" {
183207
c.sessionID.Store(sessionID)
184208
}
209+
210+
close(c.initialized)
185211
}
186212

187213
// Handle different response types
@@ -410,3 +436,61 @@ func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotifica
410436
func (c *StreamableHTTP) GetSessionId() string {
411437
return c.sessionID.Load().(string)
412438
}
439+
440+
func (c *StreamableHTTP) listenForever() {
441+
c.logger.Infof("listening to server forever")
442+
for {
443+
err := c.createGETConnectionToServer()
444+
if errors.Is(err, errGetMethodNotAllowed) {
445+
// server does not support listening
446+
c.logger.Errorf("server does not support listening")
447+
return
448+
}
449+
450+
select {
451+
case <-c.closed:
452+
return
453+
default:
454+
}
455+
456+
if err != nil {
457+
c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err)
458+
}
459+
time.Sleep(1 * time.Second)
460+
}
461+
}
462+
463+
var errGetMethodNotAllowed = fmt.Errorf("GET method not allowed")
464+
465+
func (c *StreamableHTTP) createGETConnectionToServer() error {
466+
467+
ctx := context.Background() // the sendHTTP will be automatically canceled when the client is closed
468+
resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream")
469+
if err != nil {
470+
return fmt.Errorf("failed to send request: %w", err)
471+
}
472+
defer resp.Body.Close()
473+
474+
// Check if we got an error response
475+
if resp.StatusCode == http.StatusMethodNotAllowed {
476+
return errGetMethodNotAllowed
477+
}
478+
479+
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
480+
body, _ := io.ReadAll(resp.Body)
481+
return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
482+
}
483+
484+
// handle SSE response
485+
contentType := resp.Header.Get("Content-Type")
486+
if contentType != "text/event-stream" {
487+
return fmt.Errorf("unexpected content type: %s", contentType)
488+
}
489+
490+
_, err = c.handleSSEResponse(ctx, resp.Body)
491+
if err != nil {
492+
return fmt.Errorf("failed to handle SSE response: %w", err)
493+
}
494+
495+
return nil
496+
}

0 commit comments

Comments
 (0)