Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions sse_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"sync/atomic"
"time"

icontext "trpc.group/trpc-go/trpc-mcp-go/internal/context"
"trpc.group/trpc-go/trpc-mcp-go/internal/retry"
)

Expand Down Expand Up @@ -115,10 +116,7 @@ func NewSSEClient(serverURL string, clientInfo Implementation, options ...Client
return nil, err
}

err = c.transport.start(context.Background())
if err != nil {
return nil, err
}
// Transport will auto-start on first request (e.g., Initialize) with the correct context.

return c, nil
}
Expand All @@ -141,7 +139,7 @@ func (t *sseClientTransport) start(ctx context.Context) error {
}

// Create a new context with cancellation for the SSE stream.
sseCtx, cancel := context.WithCancel(context.Background())
sseCtx, cancel := context.WithCancel(icontext.WithoutCancel(ctx))
t.sseConn.mutex.Lock()
t.sseConn.ctx = sseCtx
t.sseConn.cancel = cancel
Expand Down Expand Up @@ -533,7 +531,9 @@ func (t *sseClientTransport) sendRequest(ctx context.Context, req *JSONRPCReques
func (t *sseClientTransport) sendRequestInternal(ctx context.Context, req *JSONRPCRequest) (*json.RawMessage, error) {
// Auto-start the transport if not already started.
if !t.started.Load() {
return nil, errors.New("transport not started")
if err := t.start(ctx); err != nil {
return nil, fmt.Errorf("failed to start transport: %w", err)
}
}

if t.closed.Load() {
Expand Down
6 changes: 4 additions & 2 deletions sse_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"time"

"github.com/google/uuid"

icontext "trpc.group/trpc-go/trpc-mcp-go/internal/context"
)

// SessionIDGenerator defines an interface for generating custom session IDs.
Expand Down Expand Up @@ -597,7 +599,7 @@ func (s *SSEServer) handleNotificationMessage(ctx context.Context, rawMessage js
// Handle notification asynchronously.
go func() {
// Create a context that will not be canceled due to HTTP connection closure.
detachedCtx := context.WithoutCancel(ctx)
detachedCtx := icontext.WithoutCancel(ctx)

// Process notification (currently just log it, but can be extended).
if err := s.handleNotification(detachedCtx, &notification, session); err != nil {
Expand Down Expand Up @@ -773,7 +775,7 @@ func (s *SSEServer) createSessionContext(ctx context.Context, session *sseSessio
// processRequestAsync processes the request asynchronously.
func (s *SSEServer) processRequestAsync(ctx context.Context, request *JSONRPCRequest, session *sseSession) {
// Create a context that will not be canceled due to HTTP connection closure.
detachedCtx := context.WithoutCancel(ctx)
detachedCtx := icontext.WithoutCancel(ctx)

// Check if this is a response to our roots/list request.
if s.isRootsListResponse(request) {
Expand Down