Skip to content

Commit d0fa06e

Browse files
authored
Add support for MCP host session management (#466)
* add session persistance * Address nitpick * Expose SessionID of the client * Address feedback, add docs about preconfigured sessions
1 parent 251da13 commit d0fa06e

File tree

9 files changed

+95
-8
lines changed

9 files changed

+95
-8
lines changed

client/client.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption {
3333
}
3434
}
3535

36+
// WithSession assumes a MCP Session has already been initialized
37+
func WithSession() ClientOption {
38+
return func(c *Client) {
39+
c.initialized = true
40+
}
41+
}
42+
3643
// NewClient creates a new MCP client with the given transport.
3744
// Usage:
3845
//
@@ -432,3 +439,17 @@ func (c *Client) GetServerCapabilities() mcp.ServerCapabilities {
432439
func (c *Client) GetClientCapabilities() mcp.ClientCapabilities {
433440
return c.clientCapabilities
434441
}
442+
443+
// GetSessionId returns the session ID of the transport.
444+
// If the transport does not support sessions, it returns an empty string.
445+
func (c *Client) GetSessionId() string {
446+
if c.transport == nil {
447+
return ""
448+
}
449+
return c.transport.GetSessionId()
450+
}
451+
452+
// IsInitialized returns true if the client has been initialized.
453+
func (c *Client) IsInitialized() bool {
454+
return c.initialized
455+
}

client/http.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,10 @@ func NewStreamableHttpClient(baseURL string, options ...transport.StreamableHTTP
1313
if err != nil {
1414
return nil, fmt.Errorf("failed to create SSE transport: %w", err)
1515
}
16-
return NewClient(trans), nil
16+
clientOptions := make([]ClientOption, 0)
17+
sessionID := trans.GetSessionId()
18+
if sessionID != "" {
19+
clientOptions = append(clientOptions, WithSession())
20+
}
21+
return NewClient(trans, clientOptions...), nil
1722
}

client/http_test.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ import (
77
"testing"
88
"time"
99

10+
"github.com/google/uuid"
1011
"github.com/mark3labs/mcp-go/client/transport"
1112
"github.com/mark3labs/mcp-go/mcp"
1213
"github.com/mark3labs/mcp-go/server"
1314
)
1415

15-
1616
func TestHTTPClient(t *testing.T) {
1717
hooks := &server.Hooks{}
1818
hooks.AddAfterCallTool(func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) {
@@ -81,6 +81,17 @@ func TestHTTPClient(t *testing.T) {
8181
},
8282
}
8383

84+
t.Run("Can Configure a server with a pre-existing session", func(t *testing.T) {
85+
sessionID := uuid.NewString()
86+
client, err := NewStreamableHttpClient(testServer.URL, transport.WithSession(sessionID))
87+
if err != nil {
88+
t.Fatalf("create client failed %v", err)
89+
}
90+
if client.IsInitialized() != true {
91+
t.Fatalf("Client is not initialized")
92+
}
93+
})
94+
8495
t.Run("Can receive notification from server", func(t *testing.T) {
8596
client, err := NewStreamableHttpClient(testServer.URL)
8697
if err != nil {

client/transport/inprocess.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,7 @@ func (c *InProcessTransport) SetNotificationHandler(handler func(notification mc
6868
func (*InProcessTransport) Close() error {
6969
return nil
7070
}
71+
72+
func (c *InProcessTransport) GetSessionId() string {
73+
return ""
74+
}

client/transport/interface.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ type Interface interface {
2929

3030
// Close the connection.
3131
Close() error
32+
33+
// GetSessionId returns the session ID of the transport.
34+
GetSessionId() string
3235
}
3336

3437
type JSONRPCRequest struct {

client/transport/sse.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,12 @@ func (c *SSE) Close() error {
428428
return nil
429429
}
430430

431+
// GetSessionId returns the session ID of the transport.
432+
// Since SSE does not maintain a session ID, it returns an empty string.
433+
func (c *SSE) GetSessionId() string {
434+
return ""
435+
}
436+
431437
// SendNotification sends a JSON-RPC notification to the server without expecting a response.
432438
func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
433439
if c.endpoint == nil {

client/transport/stdio.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,12 @@ func (c *Stdio) Close() error {
148148
return nil
149149
}
150150

151+
// GetSessionId returns the session ID of the transport.
152+
// Since stdio does not maintain a session ID, it returns an empty string.
153+
func (c *Stdio) GetSessionId() string {
154+
return ""
155+
}
156+
151157
// SetNotificationHandler sets the handler function to be called when a notification is received.
152158
// Only one handler can be set at a time; setting a new one replaces the previous handler.
153159
func (c *Stdio) SetNotificationHandler(

client/transport/streamable_http.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,13 @@ func WithLogger(logger util.Logger) StreamableHTTPCOption {
7474
}
7575
}
7676

77+
// WithSession creates a client with a pre-configured session
78+
func WithSession(sessionID string) StreamableHTTPCOption {
79+
return func(sc *StreamableHTTP) {
80+
sc.sessionID.Store(sessionID)
81+
}
82+
}
83+
7784
// StreamableHTTP implements Streamable HTTP transport.
7885
//
7986
// It transmits JSON-RPC messages over individual HTTP requests. One message per request.
@@ -236,7 +243,7 @@ func (c *StreamableHTTP) SendRequest(
236243

237244
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
238245
if err != nil {
239-
if errors.Is(err, errSessionTerminated) && request.Method == string(mcp.MethodInitialize) {
246+
if errors.Is(err, ErrSessionTerminated) && request.Method == string(mcp.MethodInitialize) {
240247
// If the request is initialize, should not return a SessionTerminated error
241248
// It should be a genuine endpoint-routing issue.
242249
// ( Fall through to return StatusCode checking. )
@@ -357,7 +364,7 @@ func (c *StreamableHTTP) sendHTTP(
357364
// universal handling for session terminated
358365
if resp.StatusCode == http.StatusNotFound {
359366
c.sessionID.CompareAndSwap(sessionID, "")
360-
return nil, errSessionTerminated
367+
return nil, ErrSessionTerminated
361368
}
362369

363370
return resp, nil
@@ -543,7 +550,7 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) {
543550
c.logger.Infof("listening to server forever")
544551
for {
545552
err := c.createGETConnectionToServer(ctx)
546-
if errors.Is(err, errGetMethodNotAllowed) {
553+
if errors.Is(err, ErrGetMethodNotAllowed) {
547554
// server does not support listening
548555
c.logger.Errorf("server does not support listening")
549556
return
@@ -563,8 +570,8 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) {
563570
}
564571

565572
var (
566-
errSessionTerminated = fmt.Errorf("session terminated (404). need to re-initialize")
567-
errGetMethodNotAllowed = fmt.Errorf("GET method not allowed")
573+
ErrSessionTerminated = fmt.Errorf("session terminated (404). need to re-initialize")
574+
ErrGetMethodNotAllowed = fmt.Errorf("GET method not allowed")
568575

569576
retryInterval = 1 * time.Second // a variable is convenient for testing
570577
)
@@ -579,7 +586,7 @@ func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error
579586

580587
// Check if we got an error response
581588
if resp.StatusCode == http.StatusMethodNotAllowed {
582-
return errGetMethodNotAllowed
589+
return ErrGetMethodNotAllowed
583590
}
584591

585592
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {

www/docs/pages/clients/transports.mdx

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,30 @@ func (pool *StreamableHTTPClientPool) CallTool(ctx context.Context, req mcp.Call
389389
}
390390
```
391391

392+
### StreamableHTTP With Preconfigured Session
393+
You can also create a StreamableHTTP client with a preconfigured session, which allows you to reuse the same session across multiple requests
394+
395+
```go
396+
func createStreamableHTTPClientWithSession() {
397+
// Create StreamableHTTP client with options
398+
sessionID := // fetch existing session ID
399+
c := client.NewStreamableHttpClient("https://api.example.com/mcp",
400+
transport.WithSession(sessionID),
401+
)
402+
defer c.Close()
403+
404+
ctx := context.Background()
405+
// Use client...
406+
_, err := c.ListTools(ctx)
407+
// If the session is terminated, you must reinitialize the client
408+
if errors.Is(err, transport.ErrSessionTerminated) {
409+
c.Initialize(ctx) // Reinitialize if session is terminated
410+
// The session ID should change after reinitialization
411+
sessionID = c.GetSessionId() // Update session ID
412+
}
413+
}
414+
```
415+
392416
## SSE Client
393417

394418
SSE (Server-Sent Events) clients provide real-time communication with servers.

0 commit comments

Comments
 (0)