diff --git a/internal/config/config.go b/internal/config/config.go index e9e00222..f7ae0e5f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,6 +13,15 @@ import ( var logConfig = logger.New("config:config") +const ( + // DefaultPort is the default port for the gateway HTTP server + DefaultPort = 3000 + // DefaultStartupTimeout is the default timeout for backend server startup (seconds) + DefaultStartupTimeout = 60 + // DefaultToolTimeout is the default timeout for tool execution (seconds) + DefaultToolTimeout = 120 +) + // Config represents the MCPG configuration type Config struct { Servers map[string]*ServerConfig `toml:"servers"` @@ -94,6 +103,19 @@ func LoadFromFile(path string) (*Config, error) { // return nil, fmt.Errorf("unknown configuration keys: %v", meta.Undecoded()) } + // Set default gateway config values if gateway section exists but fields are unset + if cfg.Gateway != nil { + if cfg.Gateway.StartupTimeout == 0 { + cfg.Gateway.StartupTimeout = DefaultStartupTimeout + } + if cfg.Gateway.ToolTimeout == 0 { + cfg.Gateway.ToolTimeout = DefaultToolTimeout + } + if cfg.Gateway.Port == 0 { + cfg.Gateway.Port = DefaultPort + } + } + logConfig.Printf("Successfully loaded %d servers from TOML file", len(cfg.Servers)) return &cfg, nil } @@ -157,11 +179,11 @@ func LoadFromStdin() (*Config, error) { // Store gateway config with defaults if stdinCfg.Gateway != nil { cfg.Gateway = &GatewayConfig{ - Port: 3000, + Port: DefaultPort, APIKey: stdinCfg.Gateway.APIKey, Domain: stdinCfg.Gateway.Domain, - StartupTimeout: 60, - ToolTimeout: 120, + StartupTimeout: DefaultStartupTimeout, + ToolTimeout: DefaultToolTimeout, } if stdinCfg.Gateway.Port != nil { cfg.Gateway.Port = *stdinCfg.Gateway.Port diff --git a/internal/launcher/launcher.go b/internal/launcher/launcher.go index 2efae91c..14fac508 100644 --- a/internal/launcher/launcher.go +++ b/internal/launcher/launcher.go @@ -7,6 +7,7 @@ import ( "os" "strings" "sync" + "time" "github.com/githubnext/gh-aw-mcpg/internal/config" "github.com/githubnext/gh-aw-mcpg/internal/logger" @@ -17,6 +18,12 @@ import ( var logLauncher = logger.New("launcher:launcher") +// connectionResult is used to return the result of a connection attempt from a goroutine +type connectionResult struct { + conn *mcp.Connection + err error +} + // Launcher manages backend MCP server connections type Launcher struct { ctx context.Context @@ -25,6 +32,7 @@ type Launcher struct { sessionPool *SessionConnectionPool // Session-aware connections (stateful/stdio) mu sync.RWMutex runningInContainer bool + startupTimeout time.Duration // Timeout for backend server startup } // New creates a new Launcher @@ -36,12 +44,22 @@ func New(ctx context.Context, cfg *config.Config) *Launcher { log.Println("[LAUNCHER] Detected running inside a container") } + // Get startup timeout from config, default to config.DefaultStartupTimeout seconds + startupTimeout := time.Duration(config.DefaultStartupTimeout) * time.Second + if cfg.Gateway != nil && cfg.Gateway.StartupTimeout > 0 { + startupTimeout = time.Duration(cfg.Gateway.StartupTimeout) * time.Second + logLauncher.Printf("Using configured startup timeout: %v", startupTimeout) + } else { + logLauncher.Printf("Using default startup timeout: %v", startupTimeout) + } + return &Launcher{ ctx: ctx, config: cfg, connections: make(map[string]*mcp.Connection), sessionPool: NewSessionConnectionPool(ctx), runningInContainer: inContainer, + startupTimeout: startupTimeout, } } @@ -147,41 +165,66 @@ func GetOrLaunch(l *Launcher, serverID string) (*mcp.Connection, error) { log.Printf("[LAUNCHER] Additional env vars: %v", sanitize.TruncateSecretMap(serverCfg.Env)) } - // Create connection - conn, err := mcp.NewConnection(l.ctx, serverCfg.Command, serverCfg.Args, serverCfg.Env) - if err != nil { - // Enhanced error logging for command-based servers - logger.LogError("backend", "Failed to launch MCP backend server: %s, error=%v", serverID, err) - log.Printf("[LAUNCHER] ❌ FAILED to launch server '%s'", serverID) - log.Printf("[LAUNCHER] Error: %v", err) - log.Printf("[LAUNCHER] Debug Information:") - log.Printf("[LAUNCHER] - Command: %s", serverCfg.Command) - log.Printf("[LAUNCHER] - Args: %v", serverCfg.Args) - log.Printf("[LAUNCHER] - Env vars: %v", sanitize.TruncateSecretMap(serverCfg.Env)) - log.Printf("[LAUNCHER] - Running in container: %v", l.runningInContainer) - log.Printf("[LAUNCHER] - Is direct command: %v", isDirectCommand) - - if isDirectCommand && l.runningInContainer { - log.Printf("[LAUNCHER] ⚠️ Possible causes:") - log.Printf("[LAUNCHER] - Command '%s' may not be installed in the gateway container", serverCfg.Command) - log.Printf("[LAUNCHER] - Consider using 'container' config instead of 'command'") - log.Printf("[LAUNCHER] - Or add '%s' to the gateway's Dockerfile", serverCfg.Command) - } else if isDirectCommand { - log.Printf("[LAUNCHER] ⚠️ Possible causes:") - log.Printf("[LAUNCHER] - Command '%s' may not be in PATH", serverCfg.Command) - log.Printf("[LAUNCHER] - Check if '%s' is installed: which %s", serverCfg.Command, serverCfg.Command) - log.Printf("[LAUNCHER] - Verify file permissions and execute bit") + log.Printf("[LAUNCHER] Starting server with %v timeout", l.startupTimeout) + logLauncher.Printf("Starting server with timeout: serverID=%s, timeout=%v", serverID, l.startupTimeout) + + // Create a buffered channel to receive connection result + // Buffer size of 1 prevents goroutine leak if timeout occurs before connection completes + resultChan := make(chan connectionResult, 1) + + // Launch connection in a goroutine + go func() { + conn, err := mcp.NewConnection(l.ctx, serverCfg.Command, serverCfg.Args, serverCfg.Env) + resultChan <- connectionResult{conn, err} + }() + + // Wait for connection with timeout + select { + case result := <-resultChan: + conn, err := result.conn, result.err + if err != nil { + // Enhanced error logging for command-based servers + logger.LogError("backend", "Failed to launch MCP backend server: %s, error=%v", serverID, err) + log.Printf("[LAUNCHER] ❌ FAILED to launch server '%s'", serverID) + log.Printf("[LAUNCHER] Error: %v", err) + log.Printf("[LAUNCHER] Debug Information:") + log.Printf("[LAUNCHER] - Command: %s", serverCfg.Command) + log.Printf("[LAUNCHER] - Args: %v", serverCfg.Args) + log.Printf("[LAUNCHER] - Env vars: %v", sanitize.TruncateSecretMap(serverCfg.Env)) + log.Printf("[LAUNCHER] - Running in container: %v", l.runningInContainer) + log.Printf("[LAUNCHER] - Is direct command: %v", isDirectCommand) + log.Printf("[LAUNCHER] - Startup timeout: %v", l.startupTimeout) + + if isDirectCommand && l.runningInContainer { + log.Printf("[LAUNCHER] ⚠️ Possible causes:") + log.Printf("[LAUNCHER] - Command '%s' may not be installed in the gateway container", serverCfg.Command) + log.Printf("[LAUNCHER] - Consider using 'container' config instead of 'command'") + log.Printf("[LAUNCHER] - Or add '%s' to the gateway's Dockerfile", serverCfg.Command) + } else if isDirectCommand { + log.Printf("[LAUNCHER] ⚠️ Possible causes:") + log.Printf("[LAUNCHER] - Command '%s' may not be in PATH", serverCfg.Command) + log.Printf("[LAUNCHER] - Check if '%s' is installed: which %s", serverCfg.Command, serverCfg.Command) + log.Printf("[LAUNCHER] - Verify file permissions and execute bit") + } + + return nil, fmt.Errorf("failed to create connection: %w", err) } - return nil, fmt.Errorf("failed to create connection: %w", err) - } + logger.LogInfo("backend", "Successfully launched MCP backend server: %s", serverID) + log.Printf("[LAUNCHER] Successfully launched: %s", serverID) + logLauncher.Printf("Connection established: serverID=%s", serverID) - logger.LogInfo("backend", "Successfully launched MCP backend server: %s", serverID) - log.Printf("[LAUNCHER] Successfully launched: %s", serverID) - logLauncher.Printf("Connection established: serverID=%s", serverID) + l.connections[serverID] = conn + return conn, nil - l.connections[serverID] = conn - return conn, nil + case <-time.After(l.startupTimeout): + // Timeout occurred + logger.LogError("backend", "MCP backend server startup timeout: %s, timeout=%v", serverID, l.startupTimeout) + log.Printf("[LAUNCHER] ❌ Server startup timed out after %v", l.startupTimeout) + log.Printf("[LAUNCHER] ⚠️ The server may be hanging or taking too long to initialize") + log.Printf("[LAUNCHER] ⚠️ Consider increasing 'startupTimeout' in gateway config if server needs more time") + return nil, fmt.Errorf("server startup timeout after %v", l.startupTimeout) + } } // GetOrLaunchForSession returns a session-aware connection or launches a new one @@ -267,32 +310,59 @@ func GetOrLaunchForSession(l *Launcher, serverID, sessionID string) (*mcp.Connec log.Printf("[LAUNCHER] Additional env vars: %v", sanitize.TruncateSecretMap(serverCfg.Env)) } - // Create connection - conn, err := mcp.NewConnection(l.ctx, serverCfg.Command, serverCfg.Args, serverCfg.Env) - if err != nil { - logger.LogError("backend", "Failed to launch MCP backend server for session: server=%s, session=%s, error=%v", serverID, sessionID, err) - log.Printf("[LAUNCHER] ❌ FAILED to launch server '%s' for session '%s'", serverID, sessionID) - log.Printf("[LAUNCHER] Error: %v", err) - log.Printf("[LAUNCHER] Debug Information:") - log.Printf("[LAUNCHER] - Command: %s", serverCfg.Command) - log.Printf("[LAUNCHER] - Args: %v", serverCfg.Args) - log.Printf("[LAUNCHER] - Env vars: %v", sanitize.TruncateSecretMap(serverCfg.Env)) - log.Printf("[LAUNCHER] - Running in container: %v", l.runningInContainer) - log.Printf("[LAUNCHER] - Is direct command: %v", isDirectCommand) - - // Record error in session pool - l.sessionPool.RecordError(serverID, sessionID) + log.Printf("[LAUNCHER] Starting server for session with %v timeout", l.startupTimeout) + logLauncher.Printf("Starting session server with timeout: serverID=%s, sessionID=%s, timeout=%v", serverID, sessionID, l.startupTimeout) - return nil, fmt.Errorf("failed to create connection: %w", err) - } + // Create a buffered channel to receive connection result + // Buffer size of 1 prevents goroutine leak if timeout occurs before connection completes + resultChan := make(chan connectionResult, 1) - logger.LogInfo("backend", "Successfully launched MCP backend server for session: server=%s, session=%s", serverID, sessionID) - log.Printf("[LAUNCHER] Successfully launched: %s (session: %s)", serverID, sessionID) - logLauncher.Printf("Session connection established: serverID=%s, sessionID=%s", serverID, sessionID) + // Launch connection in a goroutine + go func() { + conn, err := mcp.NewConnection(l.ctx, serverCfg.Command, serverCfg.Args, serverCfg.Env) + resultChan <- connectionResult{conn, err} + }() - // Add to session pool - l.sessionPool.Set(serverID, sessionID, conn) - return conn, nil + // Wait for connection with timeout + select { + case result := <-resultChan: + conn, err := result.conn, result.err + if err != nil { + logger.LogError("backend", "Failed to launch MCP backend server for session: server=%s, session=%s, error=%v", serverID, sessionID, err) + log.Printf("[LAUNCHER] ❌ FAILED to launch server '%s' for session '%s'", serverID, sessionID) + log.Printf("[LAUNCHER] Error: %v", err) + log.Printf("[LAUNCHER] Debug Information:") + log.Printf("[LAUNCHER] - Command: %s", serverCfg.Command) + log.Printf("[LAUNCHER] - Args: %v", serverCfg.Args) + log.Printf("[LAUNCHER] - Env vars: %v", sanitize.TruncateSecretMap(serverCfg.Env)) + log.Printf("[LAUNCHER] - Running in container: %v", l.runningInContainer) + log.Printf("[LAUNCHER] - Is direct command: %v", isDirectCommand) + log.Printf("[LAUNCHER] - Startup timeout: %v", l.startupTimeout) + + // Record error in session pool + l.sessionPool.RecordError(serverID, sessionID) + + return nil, fmt.Errorf("failed to create connection: %w", err) + } + + logger.LogInfo("backend", "Successfully launched MCP backend server for session: server=%s, session=%s", serverID, sessionID) + log.Printf("[LAUNCHER] Successfully launched: %s (session: %s)", serverID, sessionID) + logLauncher.Printf("Session connection established: serverID=%s, sessionID=%s", serverID, sessionID) + + // Add to session pool + l.sessionPool.Set(serverID, sessionID, conn) + return conn, nil + + case <-time.After(l.startupTimeout): + // Timeout occurred + logger.LogError("backend", "MCP backend server startup timeout for session: server=%s, session=%s, timeout=%v", serverID, sessionID, l.startupTimeout) + log.Printf("[LAUNCHER] ❌ Server startup timed out after %v", l.startupTimeout) + log.Printf("[LAUNCHER] ⚠️ The server may be hanging or taking too long to initialize") + log.Printf("[LAUNCHER] ⚠️ Consider increasing 'startupTimeout' in gateway config if server needs more time") + // Record error in session pool before returning + l.sessionPool.RecordError(serverID, sessionID) + return nil, fmt.Errorf("server startup timeout after %v", l.startupTimeout) + } } // ServerIDs returns all configured server IDs diff --git a/internal/launcher/launcher_test.go b/internal/launcher/launcher_test.go index bccf8517..5ccba0f3 100644 --- a/internal/launcher/launcher_test.go +++ b/internal/launcher/launcher_test.go @@ -549,3 +549,78 @@ func TestGetOrLaunchForSession_InvalidServer(t *testing.T) { assert.Nil(t, conn) assert.Contains(t, err.Error(), "not found in config") } + +func TestLauncher_StartupTimeout(t *testing.T) { + // Test that launcher respects the startup timeout from config + tests := []struct { + name string + configTimeout int + expectedTimeout string + }{ + { + name: "default timeout (60 seconds)", + configTimeout: 0, // 0 means use default + expectedTimeout: "1m0s", + }, + { + name: "custom timeout (30 seconds)", + configTimeout: 30, + expectedTimeout: "30s", + }, + { + name: "custom timeout (120 seconds)", + configTimeout: 120, + expectedTimeout: "2m0s", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{ + "test-server": { + Type: "http", + URL: "http://example.com", + }, + }, + Gateway: &config.GatewayConfig{ + Port: 3000, + StartupTimeout: tt.configTimeout, + ToolTimeout: 120, + }, + } + + // If timeout is 0, set it to default to match LoadFromFile behavior + if cfg.Gateway.StartupTimeout == 0 { + cfg.Gateway.StartupTimeout = config.DefaultStartupTimeout + } + + l := New(ctx, cfg) + defer l.Close() + + // Verify the timeout was set correctly + assert.Equal(t, tt.expectedTimeout, l.startupTimeout.String()) + }) + } +} + +func TestLauncher_TimeoutWithNilGateway(t *testing.T) { + // Test that launcher uses default timeout when Gateway config is nil + ctx := context.Background() + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{ + "test-server": { + Type: "http", + URL: "http://example.com", + }, + }, + Gateway: nil, // No gateway config + } + + l := New(ctx, cfg) + defer l.Close() + + // Should use default timeout (60 seconds) + assert.Equal(t, "1m0s", l.startupTimeout.String()) +} diff --git a/internal/logger/common_test.go b/internal/logger/common_test.go index ae1a085b..8c9a44e1 100644 --- a/internal/logger/common_test.go +++ b/internal/logger/common_test.go @@ -4,7 +4,6 @@ import ( "fmt" "os" "path/filepath" - "strings" "sync" "testing" @@ -607,7 +606,7 @@ func TestInitLogger_MarkdownLoggerFallback(t *testing.T) { } func TestInitLogger_SetupError(t *testing.T) { - assert := assert.New(t) + a := assert.New(t) tmpDir := t.TempDir() logDir := filepath.Join(tmpDir, "logs") fileName := "test.log" @@ -625,14 +624,14 @@ func TestInitLogger_SetupError(t *testing.T) { }, ) - assert.Error(err, "initLogger should return error on setup failure") - assert.Equal(assert.AnError, err, "Error should match setup error") - assert.Nil(logger, "logger should be nil on setup error") + a.Error(err, "initLogger should return error on setup failure") + a.Equal(assert.AnError, err, "Error should match setup error") + a.Nil(logger, "logger should be nil on setup error") // Verify the log file was created but then closed logPath := filepath.Join(logDir, fileName) _, err = os.Stat(logPath) - assert.NoError(err, "Log file should exist even after setup error") + a.NoError(err, "Log file should exist even after setup error") } func TestInitLogger_FileFlags(t *testing.T) {