Skip to content

Commit 881e095

Browse files
andigclaude
andcommitted
feat: implement sampling support for Streamable HTTP transport
Implements sampling capability for HTTP transport, resolving issue #419. Enables servers to send sampling requests to HTTP clients via SSE and receive LLM-generated responses. ## Key Changes ### Core Implementation - Add `BidirectionalInterface` support to `StreamableHTTP` - Implement `SetRequestHandler` for server-to-client requests - Enhance SSE parsing to handle requests alongside responses/notifications - Add `handleIncomingRequest` and `sendResponseToServer` methods ### HTTP-Specific Features - Leverage existing MCP headers (`Mcp-Session-Id`, `Mcp-Protocol-Version`) - Bidirectional communication via HTTP POST for responses - Proper JSON-RPC request/response handling over HTTP ### Error Handling - Add specific JSON-RPC error codes for different failure scenarios: - `-32601` (Method not found) when no handler configured - `-32603` (Internal error) for sampling failures - `-32800` (Request cancelled/timeout) for context errors - Enhanced error messages with sampling-specific context ### Testing & Examples - Comprehensive test suite in `streamable_http_sampling_test.go` - Complete working example in `examples/sampling_http_client/` - Tests cover success flows, error scenarios, and interface compliance ## Technical Details The implementation maintains full backward compatibility while adding bidirectional communication support. Server requests are processed asynchronously to avoid blocking the SSE stream reader. HTTP transport now supports the complete sampling flow that was previously only available in stdio and inprocess transports. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent a43b104 commit 881e095

File tree

4 files changed

+572
-1
lines changed

4 files changed

+572
-1
lines changed

client/transport/streamable_http.go

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ func WithSession(sessionID string) StreamableHTTPCOption {
9292
// The current implementation does not support the following features:
9393
// - resuming stream
9494
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery)
95-
// - server -> client request
9695
type StreamableHTTP struct {
9796
serverURL *url.URL
9897
httpClient *http.Client
@@ -110,6 +109,10 @@ type StreamableHTTP struct {
110109
notificationHandler func(mcp.JSONRPCNotification)
111110
notifyMu sync.RWMutex
112111

112+
// Request handler for incoming server-to-client requests (like sampling)
113+
requestHandler RequestHandler
114+
requestMu sync.RWMutex
115+
113116
closed chan struct{}
114117

115118
// OAuth support
@@ -406,6 +409,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
406409
defer close(responseChan)
407410

408411
c.readSSE(ctx, reader, func(event, data string) {
412+
// Try to unmarshal as a response first
409413
var message JSONRPCResponse
410414
if err := json.Unmarshal([]byte(data), &message); err != nil {
411415
c.logger.Errorf("failed to unmarshal message: %v", err)
@@ -427,6 +431,17 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
427431
return
428432
}
429433

434+
// Check if this is actually a request from the server
435+
// If Result and Error are nil, it might be a request
436+
if message.Result == nil && message.Error == nil {
437+
var request JSONRPCRequest
438+
if err := json.Unmarshal([]byte(data), &request); err == nil {
439+
// This is a request from the server
440+
c.handleIncomingRequest(ctx, request)
441+
return
442+
}
443+
}
444+
430445
if !ignoreResponse {
431446
responseChan <- &message
432447
}
@@ -547,6 +562,13 @@ func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotifica
547562
c.notificationHandler = handler
548563
}
549564

565+
// SetRequestHandler sets the handler for incoming requests from the server.
566+
func (c *StreamableHTTP) SetRequestHandler(handler RequestHandler) {
567+
c.requestMu.Lock()
568+
defer c.requestMu.Unlock()
569+
c.requestHandler = handler
570+
}
571+
550572
func (c *StreamableHTTP) GetSessionId() string {
551573
return c.sessionID.Load().(string)
552574
}
@@ -627,6 +649,107 @@ func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error
627649
return nil
628650
}
629651

652+
// handleIncomingRequest processes requests from the server (like sampling requests)
653+
func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSONRPCRequest) {
654+
c.requestMu.RLock()
655+
handler := c.requestHandler
656+
c.requestMu.RUnlock()
657+
658+
if handler == nil {
659+
c.logger.Errorf("received request from server but no handler set: %s", request.Method)
660+
// Send method not found error
661+
errorResponse := &JSONRPCResponse{
662+
JSONRPC: "2.0",
663+
ID: request.ID,
664+
Error: &struct {
665+
Code int `json:"code"`
666+
Message string `json:"message"`
667+
Data json.RawMessage `json:"data"`
668+
}{
669+
Code: -32601, // Method not found
670+
Message: fmt.Sprintf("no handler configured for method: %s", request.Method),
671+
},
672+
}
673+
c.sendResponseToServer(ctx, errorResponse)
674+
return
675+
}
676+
677+
// Handle the request in a goroutine to avoid blocking the SSE reader
678+
go func() {
679+
response, err := handler(ctx, request)
680+
if err != nil {
681+
c.logger.Errorf("error handling request %s: %v", request.Method, err)
682+
683+
// Determine appropriate JSON-RPC error code based on error type
684+
var errorCode int
685+
var errorMessage string
686+
687+
// Check for specific sampling-related errors
688+
if errors.Is(err, context.Canceled) {
689+
errorCode = -32800 // Request cancelled
690+
errorMessage = "request was cancelled"
691+
} else if errors.Is(err, context.DeadlineExceeded) {
692+
errorCode = -32800 // Request timeout
693+
errorMessage = "request timed out"
694+
} else {
695+
// Generic error cases
696+
switch request.Method {
697+
case string(mcp.MethodSamplingCreateMessage):
698+
errorCode = -32603 // Internal error
699+
errorMessage = fmt.Sprintf("sampling request failed: %v", err)
700+
default:
701+
errorCode = -32603 // Internal error
702+
errorMessage = err.Error()
703+
}
704+
}
705+
706+
// Send error response
707+
errorResponse := &JSONRPCResponse{
708+
JSONRPC: "2.0",
709+
ID: request.ID,
710+
Error: &struct {
711+
Code int `json:"code"`
712+
Message string `json:"message"`
713+
Data json.RawMessage `json:"data"`
714+
}{
715+
Code: errorCode,
716+
Message: errorMessage,
717+
},
718+
}
719+
c.sendResponseToServer(ctx, errorResponse)
720+
return
721+
}
722+
723+
if response != nil {
724+
c.sendResponseToServer(ctx, response)
725+
}
726+
}()
727+
}
728+
729+
// sendResponseToServer sends a response back to the server via HTTP POST
730+
func (c *StreamableHTTP) sendResponseToServer(ctx context.Context, response *JSONRPCResponse) {
731+
responseBody, err := json.Marshal(response)
732+
if err != nil {
733+
c.logger.Errorf("failed to marshal response: %v", err)
734+
return
735+
}
736+
737+
ctx, cancel := c.contextAwareOfClientClose(ctx)
738+
defer cancel()
739+
740+
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(responseBody), "application/json")
741+
if err != nil {
742+
c.logger.Errorf("failed to send response to server: %v", err)
743+
return
744+
}
745+
defer resp.Body.Close()
746+
747+
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
748+
body, _ := io.ReadAll(resp.Body)
749+
c.logger.Errorf("server rejected response with status %d: %s", resp.StatusCode, body)
750+
}
751+
}
752+
630753
func (c *StreamableHTTP) contextAwareOfClientClose(ctx context.Context) (context.Context, context.CancelFunc) {
631754
newCtx, cancel := context.WithCancel(ctx)
632755
go func() {

0 commit comments

Comments
 (0)