diff --git a/sse_client.go b/sse_client.go index 1f779b4..76f5891 100644 --- a/sse_client.go +++ b/sse_client.go @@ -21,6 +21,7 @@ import ( "sync/atomic" "time" + icontext "trpc.group/trpc-go/trpc-mcp-go/internal/context" "trpc.group/trpc-go/trpc-mcp-go/internal/retry" ) @@ -115,10 +116,7 @@ func NewSSEClient(serverURL string, clientInfo Implementation, options ...Client return nil, err } - err = c.transport.start(context.Background()) - if err != nil { - return nil, err - } + // Transport will auto-start on first request (e.g., Initialize) with the correct context. return c, nil } @@ -141,7 +139,7 @@ func (t *sseClientTransport) start(ctx context.Context) error { } // Create a new context with cancellation for the SSE stream. - sseCtx, cancel := context.WithCancel(context.Background()) + sseCtx, cancel := context.WithCancel(icontext.WithoutCancel(ctx)) t.sseConn.mutex.Lock() t.sseConn.ctx = sseCtx t.sseConn.cancel = cancel @@ -533,7 +531,9 @@ func (t *sseClientTransport) sendRequest(ctx context.Context, req *JSONRPCReques func (t *sseClientTransport) sendRequestInternal(ctx context.Context, req *JSONRPCRequest) (*json.RawMessage, error) { // Auto-start the transport if not already started. if !t.started.Load() { - return nil, errors.New("transport not started") + if err := t.start(ctx); err != nil { + return nil, fmt.Errorf("failed to start transport: %w", err) + } } if t.closed.Load() { diff --git a/sse_server.go b/sse_server.go index c20bd0a..447ec24 100644 --- a/sse_server.go +++ b/sse_server.go @@ -19,6 +19,8 @@ import ( "time" "github.com/google/uuid" + + icontext "trpc.group/trpc-go/trpc-mcp-go/internal/context" ) // SessionIDGenerator defines an interface for generating custom session IDs. @@ -597,7 +599,7 @@ func (s *SSEServer) handleNotificationMessage(ctx context.Context, rawMessage js // Handle notification asynchronously. go func() { // Create a context that will not be canceled due to HTTP connection closure. - detachedCtx := context.WithoutCancel(ctx) + detachedCtx := icontext.WithoutCancel(ctx) // Process notification (currently just log it, but can be extended). if err := s.handleNotification(detachedCtx, ¬ification, session); err != nil { @@ -773,7 +775,7 @@ func (s *SSEServer) createSessionContext(ctx context.Context, session *sseSessio // processRequestAsync processes the request asynchronously. func (s *SSEServer) processRequestAsync(ctx context.Context, request *JSONRPCRequest, session *sseSession) { // Create a context that will not be canceled due to HTTP connection closure. - detachedCtx := context.WithoutCancel(ctx) + detachedCtx := icontext.WithoutCancel(ctx) // Check if this is a response to our roots/list request. if s.isRootsListResponse(request) {