diff --git a/mcp/streamable.go b/mcp/streamable.go index 4886f806..20eb13d5 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -236,6 +236,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque SessionID: sessionID, Stateless: h.opts.Stateless, jsonResponse: h.opts.JSONResponse, + logger: h.opts.Logger, } // To support stateless mode, we initialize the session with a default @@ -377,6 +378,10 @@ type StreamableServerTransport struct { // StreamableHTTPOptions.JSONResponse is exported. jsonResponse bool + // optional logger provided through the [StreamableHTTPOptions.Logger]. + // + // TODO(rfindley): logger should be exported, since we want to allow people + // to write their own streamable HTTP handler. logger *slog.Logger // connection is non-nil if and only if the transport has been connected. @@ -393,7 +398,7 @@ func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, er stateless: t.Stateless, eventStore: t.EventStore, jsonResponse: t.jsonResponse, - logger: t.logger, + logger: ensureLogger(t.logger), // see #556: must be non-nil incoming: make(chan jsonrpc.Message, 10), done: make(chan struct{}), streams: make(map[string]*stream), diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index f7cda72c..6ccaebf7 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -18,6 +18,8 @@ import ( "net/http/httptest" "net/http/httputil" "net/url" + "os" + "runtime" "sort" "strings" "sync" @@ -91,7 +93,7 @@ func TestStreamableTransports(t *testing.T) { headerMu sync.Mutex lastHeader http.Header ) - httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + httpServer := httptest.NewServer(mustNotPanic(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { headerMu.Lock() lastHeader = r.Header headerMu.Unlock() @@ -102,7 +104,7 @@ func TestStreamableTransports(t *testing.T) { t.Errorf("got cookie %q, want %q", cookie.Value, "test-value") } handler.ServeHTTP(w, r) - })) + }))) defer httpServer.Close() // Create a client and connect it to the server using our StreamableClientTransport. @@ -315,7 +317,7 @@ func testClientReplay(t *testing.T, test clientReplayTest) { return new(CallToolResult), nil, nil }) - realServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)) + realServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil))) defer realServer.Close() realServerURL, err := url.Parse(realServer.URL) if err != nil { @@ -324,6 +326,7 @@ func testClientReplay(t *testing.T, test clientReplayTest) { // Configure a proxy that sits between the client and the real server. proxyHandler := httputil.NewSingleHostReverseProxy(realServerURL) + // note: don't use mustNotPanic here as the proxy WILL panic when killed. proxy := httptest.NewServer(proxyHandler) proxyAddr := proxy.Listener.Addr().String() // Get the address to restart it later. @@ -434,7 +437,7 @@ func TestServerTransportCleanup(t *testing.T) { chans[sessionID] <- struct{}{} } - httpServer := httptest.NewServer(handler) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) defer httpServer.Close() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -484,7 +487,7 @@ func TestServerInitiatedSSE(t *testing.T) { notifications := make(chan string) server := NewServer(testImpl, nil) - httpServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)) + httpServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil))) defer httpServer.Close() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -857,7 +860,7 @@ func TestStreamableServerTransport(t *testing.T) { } func testStreamableHandler(t *testing.T, handler http.Handler, requests []streamableRequest) { - httpServer := httptest.NewServer(handler) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) defer httpServer.Close() // blocks records request blocks by jsonrpc. ID. @@ -1247,7 +1250,7 @@ func TestStreamableStateless(t *testing.T) { testClientCompatibility := func(t *testing.T, handler http.Handler) { ctx := context.Background() - httpServer := httptest.NewServer(handler) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) defer httpServer.Close() cs, err := NewClient(testImpl, nil).Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil) if err != nil { @@ -1332,7 +1335,7 @@ func TestTokenInfo(t *testing.T) { }, nil } handler := auth.RequireBearerToken(verifier, nil)(streamHandler) - httpServer := httptest.NewServer(handler) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) defer httpServer.Close() transport := &StreamableClientTransport{Endpoint: httpServer.URL} @@ -1366,7 +1369,7 @@ func TestStreamableGET(t *testing.T) { server := NewServer(testImpl, nil) handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) - httpServer := httptest.NewServer(handler) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) defer httpServer.Close() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -1442,7 +1445,7 @@ func TestStreamableClientContextPropagation(t *testing.T) { defer cancel() ctx2 := context.WithValue(ctx, testKey, testValue) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + server := httptest.NewServer(mustNotPanic(t, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { switch req.Method { case "POST": w.Header().Set("Content-Type", "application/json") @@ -1455,7 +1458,7 @@ func TestStreamableClientContextPropagation(t *testing.T) { case "DELETE": w.WriteHeader(http.StatusNoContent) } - })) + }))) defer server.Close() transport := &StreamableClientTransport{Endpoint: server.URL} @@ -1486,3 +1489,19 @@ func TestStreamableClientContextPropagation(t *testing.T) { } } + +// mustNotPanic is a helper to enforce that test handlers do not panic (see +// issue #556). +func mustNotPanic(t *testing.T, h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + defer func() { + if r := recover(); r != nil { + buf := make([]byte, 1<<20) + n := runtime.Stack(buf, false) + fmt.Fprintf(os.Stderr, "handler panic: %v\n\n%s", r, buf[:n]) + t.Errorf("handler panicked: %v", r) + } + }() + h.ServeHTTP(w, req) + }) +}