From 269097f234d01e0c1743fa1a7226e6e5dd5686bb Mon Sep 17 00:00:00 2001 From: bytethm Date: Thu, 18 Sep 2025 14:14:41 +0800 Subject: [PATCH] feat: add connection-level reconnection mechanism --- client.go | 15 ++++ internal/reconnect/reconnect.go | 151 ++++++++++++++++++++++++++++++++ reconnect.go | 73 +++++++++++++++ streamable_client.go | 121 ++++++++++++++++++++++++- 4 files changed, 357 insertions(+), 3 deletions(-) create mode 100644 internal/reconnect/reconnect.go create mode 100644 reconnect.go diff --git a/client.go b/client.go index 0f3d485..e9234a1 100644 --- a/client.go +++ b/client.go @@ -16,6 +16,7 @@ import ( "sync/atomic" "trpc.group/trpc-go/trpc-mcp-go/internal/errors" + "trpc.group/trpc-go/trpc-mcp-go/internal/reconnect" "trpc.group/trpc-go/trpc-mcp-go/internal/retry" ) @@ -125,6 +126,9 @@ type Client struct { // Retry configuration. retryConfig *retry.Config // Configuration for retry behavior (optional). + // Reconnect configuration. + reconnectConfig *reconnect.Config // Configuration for reconnection behavior (optional). + // Roots support. rootsProvider RootsProvider // Provider for roots information. rootsMu sync.RWMutex // Mutex for protecting the rootsProvider. @@ -664,6 +668,17 @@ func (c *Client) SendRootsListChangedNotification(ctx context.Context) error { return c.transport.sendNotification(ctx, notification) } +// setReconnectConfig sets the reconnection configuration for the client and its transport. +func (c *Client) setReconnectConfig(config *reconnect.Config) { + c.reconnectConfig = config + // Set reconnect config on transport if it exists and supports reconnection + if c.transport != nil { + if reconnectableTransport, ok := c.transport.(interface{ setReconnectConfig(*reconnect.Config) }); ok { + reconnectableTransport.setReconnectConfig(c.reconnectConfig) + } + } +} + func isZeroStruct(x interface{}) bool { return reflect.ValueOf(x).IsZero() } diff --git a/internal/reconnect/reconnect.go b/internal/reconnect/reconnect.go new file mode 100644 index 0000000..a6aff1f --- /dev/null +++ b/internal/reconnect/reconnect.go @@ -0,0 +1,151 @@ +// Tencent is pleased to support the open source community by making trpc-mcp-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-mcp-go is licensed under the Apache License Version 2.0. + +// Package reconnect provides connection-level reconnection functionality for MCP transports. +// This package handles stream disconnections and connection recovery, distinct from request-level retries. +package reconnect + +import ( + "math" + "strings" + "time" +) + +// Validation range constants for reconnect configuration parameters. +const ( + // MaxReconnectAttempts validation range + MinMaxReconnectAttempts = 0 + MaxMaxReconnectAttempts = 5 + + // ReconnectDelay validation range + MinReconnectDelay = 100 * time.Millisecond + MaxReconnectDelay = 30 * time.Second + + // ReconnectBackoffFactor validation range + MinReconnectBackoffFactor = 1.0 + MaxReconnectBackoffFactor = 3.0 + + // MaxReconnectDelay validation range + MaxMaxReconnectDelay = 5 * time.Minute +) + +// Config represents the configuration for connection-level reconnection. +// Reconnection differs from retry: retry handles temporary request failures, +// while reconnection handles stream/connection breaks that require re-establishment. +type Config struct { + MaxReconnectAttempts int `json:"max_reconnect_attempts"` // Maximum number of reconnection attempts (default: 2, range: 0-5) + ReconnectDelay time.Duration `json:"reconnect_delay"` // Initial delay before reconnection (default: 1s, range: 100ms-30s) + ReconnectBackoffFactor float64 `json:"reconnect_backoff_factor"` // Exponential backoff factor (default: 1.5, range: 1.0-3.0) + MaxReconnectDelay time.Duration `json:"max_reconnect_delay"` // Maximum delay between attempts (default: 30s, range: up to 5min) +} + +// Validate validates and clamps the reconnect configuration parameters to acceptable ranges. +// Invalid values are automatically corrected to the nearest valid value. +func (c *Config) Validate() { + // Clamp MaxReconnectAttempts to valid range + if c.MaxReconnectAttempts < MinMaxReconnectAttempts { + c.MaxReconnectAttempts = MinMaxReconnectAttempts + } else if c.MaxReconnectAttempts > MaxMaxReconnectAttempts { + c.MaxReconnectAttempts = MaxMaxReconnectAttempts + } + + // Clamp ReconnectDelay to valid range + if c.ReconnectDelay < MinReconnectDelay { + c.ReconnectDelay = MinReconnectDelay + } else if c.ReconnectDelay > MaxReconnectDelay { + c.ReconnectDelay = MaxReconnectDelay + } + + // Clamp ReconnectBackoffFactor to valid range + if c.ReconnectBackoffFactor < MinReconnectBackoffFactor { + c.ReconnectBackoffFactor = MinReconnectBackoffFactor + } else if c.ReconnectBackoffFactor > MaxReconnectBackoffFactor { + c.ReconnectBackoffFactor = MaxReconnectBackoffFactor + } + + // Clamp MaxReconnectDelay to valid range + if c.MaxReconnectDelay > MaxMaxReconnectDelay { + c.MaxReconnectDelay = MaxMaxReconnectDelay + } + + // Ensure MaxReconnectDelay is at least equal to ReconnectDelay + if c.MaxReconnectDelay < c.ReconnectDelay { + c.MaxReconnectDelay = c.ReconnectDelay + } +} + +// CalculateDelay calculates the delay for a specific reconnection attempt using exponential backoff. +// The delay grows exponentially with each attempt but is capped at MaxReconnectDelay. +func (c *Config) CalculateDelay(attempt int) time.Duration { + if attempt <= 1 { + return 0 // No delay for first attempt + } + + // Calculate exponential backoff: delay * factor^(attempt-2) + delay := float64(c.ReconnectDelay) * math.Pow(c.ReconnectBackoffFactor, float64(attempt-2)) + + // Apply maximum delay cap + if time.Duration(delay) > c.MaxReconnectDelay { + return c.MaxReconnectDelay + } + + return time.Duration(delay) +} + +// IsStreamDisconnectedError checks if an error indicates a stream disconnection that can be reconnected. +// These are connection-level issues that don't require session re-initialization. +func IsStreamDisconnectedError(err error) bool { + if err == nil { + return false + } + + errStr := strings.ToLower(err.Error()) + + // Stream-specific disconnection patterns + streamPatterns := []string{ + "stream closed", + "stream disconnected", + "connection lost", + "sse connection", + "broken pipe", + "connection reset", + } + + for _, pattern := range streamPatterns { + if strings.Contains(errStr, pattern) { + return true + } + } + + return false +} + +// IsSessionExpiredError checks if an error indicates session expiration that requires Agent-level handling. +// These errors should be wrapped and propagated to the Agent layer for session recreation. +func IsSessionExpiredError(err error) bool { + if err == nil { + return false + } + + errStr := strings.ToLower(err.Error()) + + // Session expiration patterns + sessionPatterns := []string{ + "404", // HTTP 404 Not Found (session expired) + "unauthorized", // Authentication expired + "session not found", + "invalid session", + "session expired", + } + + for _, pattern := range sessionPatterns { + if strings.Contains(errStr, pattern) { + return true + } + } + + return false +} diff --git a/reconnect.go b/reconnect.go new file mode 100644 index 0000000..8c68680 --- /dev/null +++ b/reconnect.go @@ -0,0 +1,73 @@ +// Tencent is pleased to support the open source community by making trpc-mcp-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-mcp-go is licensed under the Apache License Version 2.0. + +package mcp + +import ( + "time" + + "trpc.group/trpc-go/trpc-mcp-go/internal/reconnect" +) + +// ReconnectConfig defines configuration for MCP client reconnection behavior. +// Reconnection handles connection-level failures such as stream disconnections, +// which are different from request-level retry failures. +type ReconnectConfig struct { + // MaxReconnectAttempts specifies the maximum number of reconnection attempts. + // Valid range: 0-5, default: 2 + MaxReconnectAttempts int `json:"max_reconnect_attempts"` + // ReconnectDelay specifies the initial delay before the first reconnection attempt. + // Valid range: 100ms-30s, default: 1s + ReconnectDelay time.Duration `json:"reconnect_delay"` + // ReconnectBackoffFactor specifies the factor to multiply the delay for each reconnection attempt. + // For example, with factor 1.5: 1s -> 1.5s -> 2.25s -> 3.375s + // Valid range: 1.0-3.0, default: 1.5 + ReconnectBackoffFactor float64 `json:"reconnect_backoff_factor"` + // MaxReconnectDelay specifies the maximum delay between reconnection attempts. + // Valid range: minimum is ReconnectDelay, maximum: 5 minutes, default: 30s + MaxReconnectDelay time.Duration `json:"max_reconnect_delay"` +} + +// defaultReconnectConfig provides sensible defaults for reconnection configuration. +// Uses conservative values optimized for connection stability over speed. +var defaultReconnectConfig = ReconnectConfig{ + MaxReconnectAttempts: 2, // Conservative reconnection count + ReconnectDelay: 1 * time.Second, // 1s initial delay + ReconnectBackoffFactor: 1.5, // Gentle exponential backoff + MaxReconnectDelay: 30 * time.Second, // Maximum delay cap +} + +// WithSimpleReconnect enables reconnection with the specified maximum number of attempts. +// Uses default backoff configuration (1s initial, 1.5 factor, 30s max). +func WithSimpleReconnect(maxAttempts int) ClientOption { + config := defaultReconnectConfig + config.MaxReconnectAttempts = maxAttempts + return func(c *Client) { + internalConfig := reconnect.Config{ + MaxReconnectAttempts: config.MaxReconnectAttempts, + ReconnectDelay: config.ReconnectDelay, + ReconnectBackoffFactor: config.ReconnectBackoffFactor, + MaxReconnectDelay: config.MaxReconnectDelay, + } + internalConfig.Validate() + c.setReconnectConfig(&internalConfig) + } +} + +// WithReconnect enables reconnection with custom configuration. +// All configuration parameters are validated and clamped to acceptable ranges. +func WithReconnect(config ReconnectConfig) ClientOption { + return func(c *Client) { + internalConfig := reconnect.Config{ + MaxReconnectAttempts: config.MaxReconnectAttempts, + ReconnectDelay: config.ReconnectDelay, + ReconnectBackoffFactor: config.ReconnectBackoffFactor, + MaxReconnectDelay: config.MaxReconnectDelay, + } + internalConfig.Validate() + c.setReconnectConfig(&internalConfig) + } +} diff --git a/streamable_client.go b/streamable_client.go index 9aba314..c8cfcec 100644 --- a/streamable_client.go +++ b/streamable_client.go @@ -20,6 +20,7 @@ import ( "time" "trpc.group/trpc-go/trpc-mcp-go/internal/httputil" + "trpc.group/trpc-go/trpc-mcp-go/internal/reconnect" "trpc.group/trpc-go/trpc-mcp-go/internal/retry" ) @@ -54,6 +55,9 @@ type streamableHTTPClientTransport struct { // Retry configuration retryConfig *retry.Config + // Reconnect configuration + reconnectConfig *reconnect.Config + // GET SSE connection getSSEConn struct { active bool @@ -201,21 +205,21 @@ func (t *streamableHTTPClientTransport) setRetryConfig(config *retry.Config) { t.retryConfig = config } -// SendRequest sends a request and waits for a response with retry support +// SendRequest sends a request and waits for a response with retry and reconnect support func (t *streamableHTTPClientTransport) sendRequest( ctx context.Context, req *JSONRPCRequest, ) (*json.RawMessage, error) { // If no retry config, use original implementation if t.retryConfig == nil { - return t.send(ctx, req, nil) + return t.sendWithReconnect(ctx, req) } // Define the operation to be retried var result *json.RawMessage operation := func() error { var err error - result, err = t.send(ctx, req, nil) + result, err = t.sendWithReconnect(ctx, req) return err } @@ -229,6 +233,32 @@ func (t *streamableHTTPClientTransport) sendRequest( return result, nil } +// sendWithReconnect sends a request with reconnection support for stream disconnections +func (t *streamableHTTPClientTransport) sendWithReconnect( + ctx context.Context, + req *JSONRPCRequest, +) (*json.RawMessage, error) { + result, err := t.send(ctx, req, nil) + + // Check if reconnection is needed and configured + if err != nil && t.reconnectConfig != nil { + if reconnect.IsStreamDisconnectedError(err) { + // Attempt stream reconnection + if reconnectErr := t.executeReconnect(ctx); reconnectErr != nil { + t.logger.Debug("Stream reconnection failed", "error", reconnectErr) + return nil, err // Return original error + } + // Retry the request after successful reconnection + return t.send(ctx, req, nil) + } else if reconnect.IsSessionExpiredError(err) { + // Session expired - wrap error for Agent layer handling + return nil, fmt.Errorf("session_expired: %w", err) + } + } + + return result, err +} + // send sends a request and handles the response func (t *streamableHTTPClientTransport) send( ctx context.Context, @@ -972,3 +1002,88 @@ func (t *streamableHTTPClientTransport) establishGetSSEConnection() { t.establishGetSSE() } + +// setReconnectConfig sets the reconnection configuration for this transport. +func (t *streamableHTTPClientTransport) setReconnectConfig(config *reconnect.Config) { + t.reconnectConfig = config +} + +// executeReconnect attempts to reconnect the GET SSE stream using exponential backoff +func (t *streamableHTTPClientTransport) executeReconnect(ctx context.Context) error { + if t.isStateless { + // Stateless mode doesn't support stream reconnection + return fmt.Errorf("reconnection not supported in stateless mode") + } + + if !t.enableGetSSE { + // GET SSE is not enabled, nothing to reconnect + return fmt.Errorf("GET SSE not enabled, cannot reconnect stream") + } + + config := t.reconnectConfig + + for attempt := 1; attempt <= config.MaxReconnectAttempts; attempt++ { + // Calculate delay for this attempt + delay := config.CalculateDelay(attempt) + + if delay > 0 { + t.logger.Debug("Waiting before reconnection attempt", + "attempt", attempt, + "delay", delay) + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(delay): + } + } + + t.logger.Debug("Attempting stream reconnection", + "attempt", attempt, + "last_event_id", t.lastEventID) + + // Attempt to re-establish GET SSE connection + if err := t.reestablishGetSSE(ctx); err != nil { + t.logger.Debug("Reconnection attempt failed", + "attempt", attempt, + "error", err) + + if attempt == config.MaxReconnectAttempts { + return fmt.Errorf("stream reconnection failed after %d attempts: %w", + attempt, err) + } + continue + } + + t.logger.Debug("Stream reconnection successful", "attempt", attempt) + return nil + } + + return fmt.Errorf("stream reconnection failed: max attempts exceeded") +} + +// reestablishGetSSE re-establishes the GET SSE connection with last event ID for resumption +func (t *streamableHTTPClientTransport) reestablishGetSSE(ctx context.Context) error { + // Close existing GET SSE connection if any + t.getSSEConn.mutex.Lock() + if t.getSSEConn.cancel != nil { + t.getSSEConn.cancel() + } + t.getSSEConn.active = false + t.getSSEConn.mutex.Unlock() + + // Re-establish GET SSE connection using the last event ID for resumption + // This leverages the existing establishGetSSE method + t.establishGetSSE() + + // Verify the connection was established + t.getSSEConn.mutex.Lock() + isActive := t.getSSEConn.active + t.getSSEConn.mutex.Unlock() + + if !isActive { + return fmt.Errorf("failed to re-establish GET SSE connection") + } + + return nil +}