diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 16c1740..7319c84 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -27,28 +27,30 @@ const ( // DefaultListenIPv4 is the default interface used by the HTTP server. DefaultListenIPv4 = "127.0.0.1" // DefaultListenPort is the default port used by the HTTP server. - DefaultListenPort = "3000" - defaultListenAddr = DefaultListenIPv4 + ":" + DefaultListenPort - defaultRoutedMode = false - defaultUnifiedMode = false - defaultEnvFile = "" - defaultEnableDIFC = false - defaultLogDir = "/tmp/gh-aw/mcp-logs" + DefaultListenPort = "3000" + defaultListenAddr = DefaultListenIPv4 + ":" + DefaultListenPort + defaultRoutedMode = false + defaultUnifiedMode = false + defaultEnvFile = "" + defaultEnableDIFC = false + defaultLogDir = "/tmp/gh-aw/mcp-logs" + defaultParallelLaunch = true ) var ( - configFile string - configStdin bool - listenAddr string - routedMode bool - unifiedMode bool - envFile string - enableDIFC bool - logDir string - validateEnv bool - verbosity int // Verbosity level: 0 (default), 1 (-v info), 2 (-vv debug), 3 (-vvv trace) - debugLog = logger.New("cmd:root") - version = "dev" // Default version, overridden by SetVersion + configFile string + configStdin bool + listenAddr string + routedMode bool + unifiedMode bool + envFile string + enableDIFC bool + logDir string + validateEnv bool + parallelLaunch bool + verbosity int // Verbosity level: 0 (default), 1 (-v info), 2 (-vv debug), 3 (-vvv trace) + debugLog = logger.New("cmd:root") + version = "dev" // Default version, overridden by SetVersion ) var rootCmd = &cobra.Command{ @@ -75,6 +77,7 @@ func init() { rootCmd.Flags().BoolVar(&enableDIFC, "enable-difc", defaultEnableDIFC, "Enable DIFC enforcement and session requirement (requires sys___init call before tool access)") rootCmd.Flags().StringVar(&logDir, "log-dir", getDefaultLogDir(), "Directory for log files (falls back to stdout if directory cannot be created)") rootCmd.Flags().BoolVar(&validateEnv, "validate-env", false, "Validate execution environment (Docker, env vars) before starting") + rootCmd.Flags().BoolVar(¶llelLaunch, "parallel-launch", defaultParallelLaunch, "Launch MCP servers in parallel during startup (enabled by default)") rootCmd.Flags().CountVarP(&verbosity, "verbose", "v", "Increase verbosity level (use -v for info, -vv for debug, -vvv for trace)") // Mark mutually exclusive flags @@ -238,12 +241,19 @@ func run(cmd *cobra.Command, args []string) error { // Apply command-line flags to config cfg.EnableDIFC = enableDIFC + cfg.ParallelLaunch = parallelLaunch if enableDIFC { log.Println("DIFC enforcement and session requirement enabled") } else { log.Println("DIFC enforcement disabled (sessions auto-created for standard MCP client compatibility)") } + if parallelLaunch { + log.Println("Parallel server launching enabled") + } else { + log.Println("Sequential server launching enabled") + } + // Determine mode (default to unified if neither flag is set) mode := "unified" if routedMode { diff --git a/internal/config/config.go b/internal/config/config.go index f7ae0e5..b5a0bbb 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -24,9 +24,10 @@ const ( // Config represents the MCPG configuration type Config struct { - Servers map[string]*ServerConfig `toml:"servers"` - EnableDIFC bool `toml:"enable_difc"` // When true, enables DIFC enforcement and requires sys___init call before tool access. Default is false for standard MCP client compatibility. - Gateway *GatewayConfig `toml:"gateway"` // Gateway configuration (port, API key, etc.) + Servers map[string]*ServerConfig `toml:"servers"` + EnableDIFC bool `toml:"enable_difc"` // When true, enables DIFC enforcement and requires sys___init call before tool access. Default is false for standard MCP client compatibility. + ParallelLaunch bool `toml:"parallel_launch"` // When true (default), launches MCP servers in parallel during startup. + Gateway *GatewayConfig `toml:"gateway"` // Gateway configuration (port, API key, etc.) } // GatewayConfig represents gateway-level configuration diff --git a/internal/server/unified.go b/internal/server/unified.go index cab380c..9002dce 100644 --- a/internal/server/unified.go +++ b/internal/server/unified.go @@ -74,14 +74,15 @@ type ToolInfo struct { // UnifiedServer implements a unified MCP server that aggregates multiple backend servers type UnifiedServer struct { - launcher *launcher.Launcher - sysServer *sys.SysServer - ctx context.Context - server *sdk.Server - sessions map[string]*Session // mcp-session-id -> Session - sessionMu sync.RWMutex - tools map[string]*ToolInfo // prefixed tool name -> tool info - toolsMu sync.RWMutex + launcher *launcher.Launcher + sysServer *sys.SysServer + ctx context.Context + server *sdk.Server + sessions map[string]*Session // mcp-session-id -> Session + sessionMu sync.RWMutex + tools map[string]*ToolInfo // prefixed tool name -> tool info + toolsMu sync.RWMutex + parallelLaunch bool // When true (default), launches MCP servers in parallel during startup // DIFC components guardRegistry *guard.Registry @@ -101,15 +102,16 @@ type UnifiedServer struct { // NewUnified creates a new unified MCP server func NewUnified(ctx context.Context, cfg *config.Config) (*UnifiedServer, error) { - logUnified.Printf("Creating new unified server: enableDIFC=%v, servers=%d", cfg.EnableDIFC, len(cfg.Servers)) + logUnified.Printf("Creating new unified server: enableDIFC=%v, parallelLaunch=%v, servers=%d", cfg.EnableDIFC, cfg.ParallelLaunch, len(cfg.Servers)) l := launcher.New(ctx, cfg) us := &UnifiedServer{ - launcher: l, - sysServer: sys.NewSysServer(l.ServerIDs()), - ctx: ctx, - sessions: make(map[string]*Session), - tools: make(map[string]*ToolInfo), + launcher: l, + sysServer: sys.NewSysServer(l.ServerIDs()), + ctx: ctx, + sessions: make(map[string]*Session), + tools: make(map[string]*ToolInfo), + parallelLaunch: cfg.ParallelLaunch, // Initialize DIFC components guardRegistry: guard.NewRegistry(), @@ -141,6 +143,13 @@ func NewUnified(ctx context.Context, cfg *config.Config) (*UnifiedServer, error) return us, nil } +// launchResult stores the result of a backend server launch +type launchResult struct { + serverID string + err error + duration time.Duration +} + // registerAllTools fetches and registers tools from all backend servers func (us *UnifiedServer) registerAllTools() error { log.Println("Registering tools from all backends...") @@ -157,8 +166,22 @@ func (us *UnifiedServer) registerAllTools() error { log.Println("DIFC disabled: skipping sys tools registration") } - // Register tools from each backend server - for _, serverID := range us.launcher.ServerIDs() { + serverIDs := us.launcher.ServerIDs() + + if us.parallelLaunch { + // Launch servers in parallel + return us.registerAllToolsParallel(serverIDs) + } else { + // Launch servers sequentially (original behavior) + return us.registerAllToolsSequential(serverIDs) + } +} + +// registerAllToolsSequential registers tools from backend servers sequentially +func (us *UnifiedServer) registerAllToolsSequential(serverIDs []string) error { + logUnified.Printf("Registering tools sequentially from %d backends", len(serverIDs)) + + for _, serverID := range serverIDs { logUnified.Printf("Registering tools from backend: %s", serverID) if err := us.registerToolsFromBackend(serverID); err != nil { log.Printf("Warning: failed to register tools from %s: %v", serverID, err) @@ -170,6 +193,55 @@ func (us *UnifiedServer) registerAllTools() error { return nil } +// registerAllToolsParallel registers tools from backend servers in parallel +func (us *UnifiedServer) registerAllToolsParallel(serverIDs []string) error { + logUnified.Printf("Registering tools in parallel from %d backends", len(serverIDs)) + + var wg sync.WaitGroup + results := make(chan launchResult, len(serverIDs)) + + // Launch each server in its own goroutine + for _, serverID := range serverIDs { + wg.Add(1) + go func(sid string) { + defer wg.Done() + + startTime := time.Now() + err := us.registerToolsFromBackend(sid) + duration := time.Since(startTime) + + results <- launchResult{ + serverID: sid, + err: err, + duration: duration, + } + }(serverID) + } + + // Wait for all goroutines to complete + wg.Wait() + close(results) + + // Collect and log results + successCount := 0 + failureCount := 0 + for result := range results { + if result.err != nil { + log.Printf("Warning: failed to register tools from %s (took %v): %v", result.serverID, result.duration, result.err) + logger.LogWarn("backend", "Failed to register tools from %s (took %v): %v", result.serverID, result.duration, result.err) + failureCount++ + } else { + logUnified.Printf("Successfully registered tools from %s (took %v)", result.serverID, result.duration) + logger.LogInfo("backend", "Successfully registered tools from %s (took %v)", result.serverID, result.duration) + successCount++ + } + } + + log.Printf("Parallel tool registration complete: %d succeeded, %d failed, total tools=%d", successCount, failureCount, len(us.tools)) + logUnified.Printf("Tool registration complete: %d succeeded, %d failed, total tools=%d", successCount, failureCount, len(us.tools)) + return nil +} + // registerToolsFromBackend registers tools from a specific backend with ___ naming func (us *UnifiedServer) registerToolsFromBackend(serverID string) error { log.Printf("Registering tools from backend: %s", serverID) diff --git a/internal/server/unified_test.go b/internal/server/unified_test.go index a00299a..8124c74 100644 --- a/internal/server/unified_test.go +++ b/internal/server/unified_test.go @@ -492,3 +492,31 @@ func TestRequireSession_EdgeCases(t *testing.T) { }) } } + +func TestUnifiedServer_ParallelLaunch_Enabled(t *testing.T) { + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{}, + ParallelLaunch: true, + } + + ctx := context.Background() + us, err := NewUnified(ctx, cfg) + require.NoError(t, err, "NewUnified() failed") + defer us.Close() + + assert.True(t, us.parallelLaunch, "ParallelLaunch should be enabled when configured") +} + +func TestUnifiedServer_ParallelLaunch_Disabled(t *testing.T) { + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{}, + ParallelLaunch: false, + } + + ctx := context.Background() + us, err := NewUnified(ctx, cfg) + require.NoError(t, err, "NewUnified() failed") + defer us.Close() + + assert.False(t, us.parallelLaunch, "ParallelLaunch should be disabled when configured") +}