Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 53 additions & 42 deletions internal/launcher/connection_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ type ConnectionKey struct {

// ConnectionMetadata tracks information about a pooled connection
type ConnectionMetadata struct {
Connection *mcp.Connection
CreatedAt time.Time
LastUsedAt time.Time
Connection *mcp.Connection
CreatedAt time.Time
LastUsedAt time.Time
RequestCount int
ErrorCount int
State ConnectionState
ErrorCount int
State ConnectionState
}

// ConnectionState represents the state of a pooled connection
Expand All @@ -39,21 +39,22 @@ const (

// Default configuration values
const (
DefaultIdleTimeout = 30 * time.Minute
DefaultIdleTimeout = 30 * time.Minute
DefaultCleanupInterval = 5 * time.Minute
DefaultMaxErrorCount = 10
)

// SessionConnectionPool manages connections keyed by (backend, session)
type SessionConnectionPool struct {
connections map[ConnectionKey]*ConnectionMetadata
mu sync.RWMutex
ctx context.Context
idleTimeout time.Duration
cleanupInterval time.Duration
maxErrorCount int
cleanupTicker *time.Ticker
cleanupDone chan bool
connections map[ConnectionKey]*ConnectionMetadata
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc // cancel function to stop cleanup goroutine
idleTimeout time.Duration
cleanupInterval time.Duration
maxErrorCount int
cleanupTicker *time.Ticker
cleanupDone chan bool
}

// PoolConfig configures the connection pool
Expand All @@ -76,26 +77,31 @@ func NewSessionConnectionPool(ctx context.Context) *SessionConnectionPool {
func NewSessionConnectionPoolWithConfig(ctx context.Context, config PoolConfig) *SessionConnectionPool {
logPool.Printf("Creating new session connection pool: idleTimeout=%v, cleanupInterval=%v, maxErrors=%d",
config.IdleTimeout, config.CleanupInterval, config.MaxErrorCount)


// Create a cancellable context derived from the parent context
// This allows Stop() to signal the cleanup goroutine to exit
poolCtx, cancel := context.WithCancel(ctx)

pool := &SessionConnectionPool{
connections: make(map[ConnectionKey]*ConnectionMetadata),
ctx: ctx,
ctx: poolCtx,
cancel: cancel,
idleTimeout: config.IdleTimeout,
cleanupInterval: config.CleanupInterval,
maxErrorCount: config.MaxErrorCount,
cleanupDone: make(chan bool),
}

// Start cleanup goroutine
pool.startCleanup()

return pool
}

// startCleanup starts the periodic cleanup goroutine
func (p *SessionConnectionPool) startCleanup() {
p.cleanupTicker = time.NewTicker(p.cleanupInterval)

go func() {
logPool.Print("Cleanup goroutine started")
for {
Expand All @@ -116,48 +122,48 @@ func (p *SessionConnectionPool) startCleanup() {
func (p *SessionConnectionPool) cleanupIdleConnections() {
p.mu.Lock()
defer p.mu.Unlock()

now := time.Now()
removed := 0

for key, metadata := range p.connections {
shouldRemove := false
reason := ""

// Check if idle for too long
if now.Sub(metadata.LastUsedAt) > p.idleTimeout {
shouldRemove = true
reason = "idle timeout"
}

// Check if too many errors
if metadata.ErrorCount >= p.maxErrorCount {
shouldRemove = true
reason = "too many errors"
}

// Check if already closed
if metadata.State == ConnectionStateClosed {
shouldRemove = true
reason = "already closed"
}

if shouldRemove {
logPool.Printf("Cleaning up connection: backend=%s, session=%s, reason=%s, idle=%v, errors=%d",
key.BackendID, key.SessionID, reason, now.Sub(metadata.LastUsedAt), metadata.ErrorCount)

// Close the connection if still active
if metadata.Connection != nil && metadata.State != ConnectionStateClosed {
// Note: mcp.Connection doesn't have a Close method in current implementation
// but we mark it as closed
metadata.State = ConnectionStateClosed
}

delete(p.connections, key)
removed++
}
}

if removed > 0 {
logPool.Printf("Cleanup complete: removed %d idle/failed connections, active=%d", removed, len(p.connections))
}
Expand All @@ -166,30 +172,35 @@ func (p *SessionConnectionPool) cleanupIdleConnections() {
// Stop gracefully shuts down the connection pool
func (p *SessionConnectionPool) Stop() {
logPool.Print("Stopping connection pool")

// Stop cleanup goroutine

// Stop cleanup goroutine by cancelling the context
if p.cancel != nil {
p.cancel()
}

// Stop cleanup ticker
if p.cleanupTicker != nil {
p.cleanupTicker.Stop()
}
// Wait for cleanup to finish

// Wait for cleanup goroutine to finish (should be immediate now that context is cancelled)
select {
case <-p.cleanupDone:
logPool.Print("Cleanup goroutine stopped")
case <-time.After(5 * time.Second):
case <-time.After(1 * time.Second):
logPool.Print("Cleanup goroutine stop timeout")
}

// Close all connections
p.mu.Lock()
defer p.mu.Unlock()

for key, metadata := range p.connections {
logPool.Printf("Closing connection: backend=%s, session=%s", key.BackendID, key.SessionID)
metadata.State = ConnectionStateClosed
delete(p.connections, key)
}

logPool.Print("Connection pool stopped")
}

Expand All @@ -200,7 +211,7 @@ func (p *SessionConnectionPool) Get(backendID, sessionID string) (*mcp.Connectio

key := ConnectionKey{BackendID: backendID, SessionID: sessionID}
metadata, exists := p.connections[key]

if !exists {
logPool.Printf("Connection not found: backend=%s, session=%s", backendID, sessionID)
return nil, false
Expand All @@ -211,9 +222,9 @@ func (p *SessionConnectionPool) Get(backendID, sessionID string) (*mcp.Connectio
return nil, false
}

logPool.Printf("Reusing connection: backend=%s, session=%s, requests=%d",
logPool.Printf("Reusing connection: backend=%s, session=%s, requests=%d",
backendID, sessionID, metadata.RequestCount)

// Update last used time and state (need write lock for this)
p.mu.RUnlock()
p.mu.Lock()
Expand All @@ -232,7 +243,7 @@ func (p *SessionConnectionPool) Set(backendID, sessionID string, conn *mcp.Conne
defer p.mu.Unlock()

key := ConnectionKey{BackendID: backendID, SessionID: sessionID}

// Check if connection already exists
if existing, exists := p.connections[key]; exists {
logPool.Printf("Updating existing connection: backend=%s, session=%s", backendID, sessionID)
Expand Down Expand Up @@ -262,7 +273,7 @@ func (p *SessionConnectionPool) Delete(backendID, sessionID string) {
defer p.mu.Unlock()

key := ConnectionKey{BackendID: backendID, SessionID: sessionID}

if metadata, exists := p.connections[key]; exists {
metadata.State = ConnectionStateClosed
delete(p.connections, key)
Expand Down Expand Up @@ -295,7 +306,7 @@ func (p *SessionConnectionPool) RecordError(backendID, sessionID string) {
key := ConnectionKey{BackendID: backendID, SessionID: sessionID}
if metadata, exists := p.connections[key]; exists {
metadata.ErrorCount++
logPool.Printf("Recorded error for connection: backend=%s, session=%s, errors=%d",
logPool.Printf("Recorded error for connection: backend=%s, session=%s, errors=%d",
backendID, sessionID, metadata.ErrorCount)
}
}
Expand Down
Loading