Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption {
}
}

// WithSession assumes a MCP Session has already been initialized
func WithSession() ClientOption {
return func(c *Client) {
c.initialized = true
}
}

// NewClient creates a new MCP client with the given transport.
// Usage:
//
Expand Down Expand Up @@ -432,3 +439,17 @@ func (c *Client) GetServerCapabilities() mcp.ServerCapabilities {
func (c *Client) GetClientCapabilities() mcp.ClientCapabilities {
return c.clientCapabilities
}

// GetSessionId returns the session ID of the transport.
// If the transport does not support sessions, it returns an empty string.
func (c *Client) GetSessionId() string {
if c.transport == nil {
return ""
}
return c.transport.GetSessionId()
}

// IsInitialized returns true if the client has been initialized.
func (c *Client) IsInitialized() bool {
return c.initialized
}
7 changes: 6 additions & 1 deletion client/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,10 @@ func NewStreamableHttpClient(baseURL string, options ...transport.StreamableHTTP
if err != nil {
return nil, fmt.Errorf("failed to create SSE transport: %w", err)
}
return NewClient(trans), nil
clientOptions := make([]ClientOption, 0)
sessionID := trans.GetSessionId()
if sessionID != "" {
clientOptions = append(clientOptions, WithSession())
}
return NewClient(trans, clientOptions...), nil
}
13 changes: 12 additions & 1 deletion client/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ import (
"testing"
"time"

"github.com/google/uuid"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)


func TestHTTPClient(t *testing.T) {
hooks := &server.Hooks{}
hooks.AddAfterCallTool(func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) {
Expand Down Expand Up @@ -81,6 +81,17 @@ func TestHTTPClient(t *testing.T) {
},
}

t.Run("Can Configure a server with a pre-existing session", func(t *testing.T) {
sessionID := uuid.NewString()
client, err := NewStreamableHttpClient(testServer.URL, transport.WithSession(sessionID))
if err != nil {
t.Fatalf("create client failed %v", err)
}
if client.IsInitialized() != true {
t.Fatalf("Client is not initialized")
}
})

t.Run("Can receive notification from server", func(t *testing.T) {
client, err := NewStreamableHttpClient(testServer.URL)
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions client/transport/inprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,7 @@ func (c *InProcessTransport) SetNotificationHandler(handler func(notification mc
func (*InProcessTransport) Close() error {
return nil
}

func (c *InProcessTransport) GetSessionId() string {
return ""
}
3 changes: 3 additions & 0 deletions client/transport/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ type Interface interface {

// Close the connection.
Close() error

// GetSessionId returns the session ID of the transport.
GetSessionId() string
}

type JSONRPCRequest struct {
Expand Down
6 changes: 6 additions & 0 deletions client/transport/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,12 @@ func (c *SSE) Close() error {
return nil
}

// GetSessionId returns the session ID of the transport.
// Since SSE does not maintain a session ID, it returns an empty string.
func (c *SSE) GetSessionId() string {
return ""
}

// SendNotification sends a JSON-RPC notification to the server without expecting a response.
func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
if c.endpoint == nil {
Expand Down
6 changes: 6 additions & 0 deletions client/transport/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ func (c *Stdio) Close() error {
return nil
}

// GetSessionId returns the session ID of the transport.
// Since stdio does not maintain a session ID, it returns an empty string.
func (c *Stdio) GetSessionId() string {
return ""
}

// SetNotificationHandler sets the handler function to be called when a notification is received.
// Only one handler can be set at a time; setting a new one replaces the previous handler.
func (c *Stdio) SetNotificationHandler(
Expand Down
19 changes: 13 additions & 6 deletions client/transport/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ func WithLogger(logger util.Logger) StreamableHTTPCOption {
}
}

// WithSession creates a client with a pre-configured session
func WithSession(sessionID string) StreamableHTTPCOption {
return func(sc *StreamableHTTP) {
sc.sessionID.Store(sessionID)
}
}
Comment on lines +78 to +82
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New constructor option to provide an existing session


// StreamableHTTP implements Streamable HTTP transport.
//
// It transmits JSON-RPC messages over individual HTTP requests. One message per request.
Expand Down Expand Up @@ -236,7 +243,7 @@ func (c *StreamableHTTP) SendRequest(

resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
if err != nil {
if errors.Is(err, errSessionTerminated) && request.Method == string(mcp.MethodInitialize) {
if errors.Is(err, ErrSessionTerminated) && request.Method == string(mcp.MethodInitialize) {
// If the request is initialize, should not return a SessionTerminated error
// It should be a genuine endpoint-routing issue.
// ( Fall through to return StatusCode checking. )
Expand Down Expand Up @@ -357,7 +364,7 @@ func (c *StreamableHTTP) sendHTTP(
// universal handling for session terminated
if resp.StatusCode == http.StatusNotFound {
c.sessionID.CompareAndSwap(sessionID, "")
return nil, errSessionTerminated
return nil, ErrSessionTerminated
}

return resp, nil
Expand Down Expand Up @@ -543,7 +550,7 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) {
c.logger.Infof("listening to server forever")
for {
err := c.createGETConnectionToServer(ctx)
if errors.Is(err, errGetMethodNotAllowed) {
if errors.Is(err, ErrGetMethodNotAllowed) {
// server does not support listening
c.logger.Errorf("server does not support listening")
return
Expand All @@ -563,8 +570,8 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) {
}

var (
errSessionTerminated = fmt.Errorf("session terminated (404). need to re-initialize")
errGetMethodNotAllowed = fmt.Errorf("GET method not allowed")
ErrSessionTerminated = fmt.Errorf("session terminated (404). need to re-initialize")
ErrGetMethodNotAllowed = fmt.Errorf("GET method not allowed")
Comment on lines +573 to +574
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These need to be public to allow MCP hosts to know when the errors are returned, that way a MCP host can re-initialize on a 404


retryInterval = 1 * time.Second // a variable is convenient for testing
)
Expand All @@ -579,7 +586,7 @@ func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error

// Check if we got an error response
if resp.StatusCode == http.StatusMethodNotAllowed {
return errGetMethodNotAllowed
return ErrGetMethodNotAllowed
}

if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
Expand Down
24 changes: 24 additions & 0 deletions www/docs/pages/clients/transports.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,30 @@ func (pool *StreamableHTTPClientPool) CallTool(ctx context.Context, req mcp.Call
}
```

### StreamableHTTP With Preconfigured Session
You can also create a StreamableHTTP client with a preconfigured session, which allows you to reuse the same session across multiple requests

```go
func createStreamableHTTPClientWithSession() {
// Create StreamableHTTP client with options
sessionID := // fetch existing session ID
c := client.NewStreamableHttpClient("https://api.example.com/mcp",
transport.WithSession(sessionID),
)
defer c.Close()

ctx := context.Background()
// Use client...
_, err := c.ListTools(ctx)
// If the session is terminated, you must reinitialize the client
if errors.Is(err, transport.ErrSessionTerminated) {
c.Initialize(ctx) // Reinitialize if session is terminated
// The session ID should change after reinitialization
sessionID = c.GetSessionId() // Update session ID
}
}
```

## SSE Client

SSE (Server-Sent Events) clients provide real-time communication with servers.
Expand Down