@@ -74,6 +74,13 @@ func WithLogger(logger util.Logger) StreamableHTTPCOption {
7474 }
7575}
7676
77+ // WithSession creates a client with a pre-configured session
78+ func WithSession (sessionID string ) StreamableHTTPCOption {
79+ return func (sc * StreamableHTTP ) {
80+ sc .sessionID .Store (sessionID )
81+ }
82+ }
83+
7784// StreamableHTTP implements Streamable HTTP transport.
7885//
7986// It transmits JSON-RPC messages over individual HTTP requests. One message per request.
@@ -236,7 +243,7 @@ func (c *StreamableHTTP) SendRequest(
236243
237244 resp , err := c .sendHTTP (ctx , http .MethodPost , bytes .NewReader (requestBody ), "application/json, text/event-stream" )
238245 if err != nil {
239- if errors .Is (err , errSessionTerminated ) && request .Method == string (mcp .MethodInitialize ) {
246+ if errors .Is (err , ErrSessionTerminated ) && request .Method == string (mcp .MethodInitialize ) {
240247 // If the request is initialize, should not return a SessionTerminated error
241248 // It should be a genuine endpoint-routing issue.
242249 // ( Fall through to return StatusCode checking. )
@@ -357,7 +364,7 @@ func (c *StreamableHTTP) sendHTTP(
357364 // universal handling for session terminated
358365 if resp .StatusCode == http .StatusNotFound {
359366 c .sessionID .CompareAndSwap (sessionID , "" )
360- return nil , errSessionTerminated
367+ return nil , ErrSessionTerminated
361368 }
362369
363370 return resp , nil
@@ -543,7 +550,7 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) {
543550 c .logger .Infof ("listening to server forever" )
544551 for {
545552 err := c .createGETConnectionToServer (ctx )
546- if errors .Is (err , errGetMethodNotAllowed ) {
553+ if errors .Is (err , ErrGetMethodNotAllowed ) {
547554 // server does not support listening
548555 c .logger .Errorf ("server does not support listening" )
549556 return
@@ -563,8 +570,8 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) {
563570}
564571
565572var (
566- errSessionTerminated = fmt .Errorf ("session terminated (404). need to re-initialize" )
567- errGetMethodNotAllowed = fmt .Errorf ("GET method not allowed" )
573+ ErrSessionTerminated = fmt .Errorf ("session terminated (404). need to re-initialize" )
574+ ErrGetMethodNotAllowed = fmt .Errorf ("GET method not allowed" )
568575
569576 retryInterval = 1 * time .Second // a variable is convenient for testing
570577)
@@ -579,7 +586,7 @@ func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error
579586
580587 // Check if we got an error response
581588 if resp .StatusCode == http .StatusMethodNotAllowed {
582- return errGetMethodNotAllowed
589+ return ErrGetMethodNotAllowed
583590 }
584591
585592 if resp .StatusCode != http .StatusOK && resp .StatusCode != http .StatusAccepted {
0 commit comments