Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion internal/cmd/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

// newCompletionCmd creates a completion command for generating shell completion scripts
func newCompletionCmd() *cobra.Command {
return &cobra.Command{
cmd := &cobra.Command{
Use: "completion [bash|zsh|fish|powershell]",
Short: "Generate completion script",
Long: `To load completions:
Expand Down Expand Up @@ -66,4 +66,11 @@ PowerShell:
}
},
}

// Override the parent's PersistentPreRunE to skip validation for completion command
cmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error {
return nil
}

return cmd
}
70 changes: 65 additions & 5 deletions internal/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ var (
enableDIFC bool
logDir string
validateEnv bool
verbosity int // Verbosity level: 0 (default), 1 (-v info), 2 (-vv debug), 3 (-vvv trace)
debugLog = logger.New("cmd:root")
version = "dev" // Default version, overridden by SetVersion
)
Expand All @@ -56,11 +57,15 @@ var rootCmd = &cobra.Command{
Version: version,
Long: `MCPG is a proxy server for Model Context Protocol (MCP) servers.
It provides routing, aggregation, and management of multiple MCP backend servers.`,
SilenceUsage: true, // Don't show help on runtime errors
RunE: run,
SilenceUsage: true, // Don't show help on runtime errors
PersistentPreRunE: preRun,
RunE: run,
}

func init() {
// Set custom error prefix for better branding
rootCmd.SetErrPrefix("MCPG Error:")

rootCmd.Flags().StringVarP(&configFile, "config", "c", defaultConfigFile, "Path to config file")
rootCmd.Flags().BoolVar(&configStdin, "config-stdin", defaultConfigStdin, "Read MCP server configuration from stdin (JSON format). When enabled, overrides --config")
rootCmd.Flags().StringVarP(&listenAddr, "listen", "l", defaultListenAddr, "HTTP server listen address")
Expand All @@ -70,10 +75,14 @@ func init() {
rootCmd.Flags().BoolVar(&enableDIFC, "enable-difc", defaultEnableDIFC, "Enable DIFC enforcement and session requirement (requires sys___init call before tool access)")
rootCmd.Flags().StringVar(&logDir, "log-dir", getDefaultLogDir(), "Directory for log files (falls back to stdout if directory cannot be created)")
rootCmd.Flags().BoolVar(&validateEnv, "validate-env", false, "Validate execution environment (Docker, env vars) before starting")
rootCmd.Flags().CountVarP(&verbosity, "verbose", "v", "Increase verbosity level (use -v for info, -vv for debug, -vvv for trace)")

// Mark mutually exclusive flags
rootCmd.MarkFlagsMutuallyExclusive("routed", "unified")

// Register custom flag completions
registerFlagCompletions(rootCmd)

// Add completion command
rootCmd.AddCommand(newCompletionCmd())
}
Expand All @@ -87,15 +96,66 @@ func getDefaultLogDir() string {
return defaultLogDir
}

func run(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
const (
// Debug log patterns for different verbosity levels
debugMainPackages = "cmd:*,server:*,launcher:*"
debugAllPackages = "*"
)

// registerFlagCompletions registers custom completion functions for flags
func registerFlagCompletions(cmd *cobra.Command) {
// Custom completion for --config flag (complete with .toml files)
cmd.RegisterFlagCompletionFunc("config", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
return []string{"toml"}, cobra.ShellCompDirectiveFilterFileExt
})

// Custom completion for --log-dir flag (complete with directories)
cmd.RegisterFlagCompletionFunc("log-dir", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
return nil, cobra.ShellCompDirectiveFilterDirs
})

// Custom completion for --env flag (complete with .env files)
cmd.RegisterFlagCompletionFunc("env", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
return []string{"env"}, cobra.ShellCompDirectiveFilterFileExt
})
}

// preRun performs validation before command execution
func preRun(cmd *cobra.Command, args []string) error {
// Validate that either --config or --config-stdin is provided
if !configStdin && configFile == "" {
return fmt.Errorf("configuration source required: specify either --config <file> or --config-stdin")
}

// Apply verbosity level to logging (if DEBUG is not already set)
// -v (1): info level, -vv (2): debug level, -vvv (3): trace level
if verbosity > 0 && os.Getenv("DEBUG") == "" {
// Set DEBUG env var based on verbosity level
// Level 1: basic info (no special DEBUG setting needed, handled by logger)
// Level 2: enable debug logs for cmd and server packages
// Level 3: enable all debug logs
switch verbosity {
case 1:
// Info level - no special DEBUG setting (standard log output)
debugLog.Printf("Verbosity level: info")
case 2:
// Debug level - enable debug logs for main packages
os.Setenv("DEBUG", debugMainPackages)
debugLog.Printf("Verbosity level: debug (DEBUG=%s)", debugMainPackages)
default:
// Trace level (3+) - enable all debug logs
os.Setenv("DEBUG", debugAllPackages)
debugLog.Printf("Verbosity level: trace (DEBUG=%s)", debugAllPackages)
}
}

return nil
}

func run(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Initialize file logger early
if err := logger.InitFileLogger(logDir, "mcp-gateway.log"); err != nil {
log.Printf("Warning: Failed to initialize file logger: %v", err)
Expand Down
152 changes: 130 additions & 22 deletions internal/cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,45 +79,153 @@ func TestRunRequiresConfigSource(t *testing.T) {
t.Run("no config source provided", func(t *testing.T) {
configFile = ""
configStdin = false
err := run(nil, nil)
err := preRun(nil, nil)
require.Error(t, err, "Expected error when neither --config nor --config-stdin is provided")
assert.Contains(t, err.Error(), "configuration source required", "Error should mention configuration source required")
})

t.Run("config file provided", func(t *testing.T) {
configFile = "test.toml"
configStdin = false
err := run(nil, nil)
// Should not be the "configuration source required" error
// (will fail later due to missing file, but should pass validation)
if err != nil {
assert.NotContains(t, err.Error(), "configuration source required",
"Should not require config source when --config is provided")
}
err := preRun(nil, nil)
// Should pass validation when --config is provided
assert.NoError(t, err, "Should not error when --config is provided")
})

t.Run("config stdin provided", func(t *testing.T) {
configFile = ""
configStdin = true
err := run(nil, nil)
// Should not be the "configuration source required" error
// (will fail later due to stdin not being readable, but should pass validation)
if err != nil {
assert.NotContains(t, err.Error(), "configuration source required",
"Should not require config source when --config-stdin is provided")
}
err := preRun(nil, nil)
// Should pass validation when --config-stdin is provided
assert.NoError(t, err, "Should not error when --config-stdin is provided")
})

t.Run("both config file and stdin provided", func(t *testing.T) {
configFile = "test.toml"
configStdin = true
err := run(nil, nil)
// When both are provided, stdin takes precedence per flag description
// Should not be the "configuration source required" error
if err != nil {
assert.NotContains(t, err.Error(), "configuration source required",
"Should not require config source when both are provided")
}
err := preRun(nil, nil)
// When both are provided, should pass validation
assert.NoError(t, err, "Should not error when both are provided")
})
}

// TestPreRunValidation tests the preRun validation function
func TestPreRunValidation(t *testing.T) {
// Save original values
origConfigFile := configFile
origConfigStdin := configStdin
origVerbosity := verbosity
t.Cleanup(func() {
configFile = origConfigFile
configStdin = origConfigStdin
verbosity = origVerbosity
})

t.Run("validation passes with config file", func(t *testing.T) {
configFile = "test.toml"
configStdin = false
verbosity = 0
err := preRun(nil, nil)
assert.NoError(t, err)
})

t.Run("validation passes with config stdin", func(t *testing.T) {
configFile = ""
configStdin = true
verbosity = 0
err := preRun(nil, nil)
assert.NoError(t, err)
})

t.Run("validation fails without config source", func(t *testing.T) {
configFile = ""
configStdin = false
verbosity = 0
err := preRun(nil, nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "configuration source required")
})

t.Run("verbosity level 1 does not set DEBUG", func(t *testing.T) {
// Save and clear DEBUG env var
origDebug, wasSet := os.LookupEnv("DEBUG")
t.Cleanup(func() {
if wasSet {
os.Setenv("DEBUG", origDebug)
} else {
os.Unsetenv("DEBUG")
}
})
os.Unsetenv("DEBUG")

configFile = "test.toml"
configStdin = false
verbosity = 1
err := preRun(nil, nil)
assert.NoError(t, err)
// Level 1 doesn't set DEBUG env var
assert.Empty(t, os.Getenv("DEBUG"))
})

t.Run("verbosity level 2 sets DEBUG for main packages", func(t *testing.T) {
// Save and clear DEBUG env var
origDebug, wasSet := os.LookupEnv("DEBUG")
t.Cleanup(func() {
if wasSet {
os.Setenv("DEBUG", origDebug)
} else {
os.Unsetenv("DEBUG")
}
})
os.Unsetenv("DEBUG")

configFile = "test.toml"
configStdin = false
verbosity = 2
err := preRun(nil, nil)
assert.NoError(t, err)
assert.Equal(t, "cmd:*,server:*,launcher:*", os.Getenv("DEBUG"))
})

t.Run("verbosity level 3 sets DEBUG to all", func(t *testing.T) {
// Save and clear DEBUG env var
origDebug, wasSet := os.LookupEnv("DEBUG")
t.Cleanup(func() {
if wasSet {
os.Setenv("DEBUG", origDebug)
} else {
os.Unsetenv("DEBUG")
}
})
os.Unsetenv("DEBUG")

configFile = "test.toml"
configStdin = false
verbosity = 3
err := preRun(nil, nil)
assert.NoError(t, err)
assert.Equal(t, "*", os.Getenv("DEBUG"))
})

t.Run("does not override existing DEBUG env var", func(t *testing.T) {
// Save DEBUG env var
origDebug, wasSet := os.LookupEnv("DEBUG")
t.Cleanup(func() {
if wasSet {
os.Setenv("DEBUG", origDebug)
} else {
os.Unsetenv("DEBUG")
}
})
os.Setenv("DEBUG", "custom:*")

configFile = "test.toml"
configStdin = false
verbosity = 2
err := preRun(nil, nil)
assert.NoError(t, err)
// Should not override existing DEBUG
assert.Equal(t, "custom:*", os.Getenv("DEBUG"))
})
}

Expand Down
64 changes: 62 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,68 @@
package main

import "github.com/githubnext/gh-aw-mcpg/internal/cmd"
import (
"fmt"
"runtime/debug"
"strings"

"github.com/githubnext/gh-aw-mcpg/internal/cmd"
)

func main() {
cmd.SetVersion(Version)
// Build version string with metadata
versionStr := buildVersionString()

// Set the version for the CLI
cmd.SetVersion(versionStr)

// Execute the root command
cmd.Execute()
}

const (
shortHashLength = 7 // Length for short git commit hash
)

// buildVersionString constructs a detailed version string with build metadata
func buildVersionString() string {
var parts []string

// Add main version
if Version != "" {
parts = append(parts, Version)
} else {
parts = append(parts, "dev")
}

// Add git commit if available
if GitCommit != "" {
parts = append(parts, fmt.Sprintf("commit: %s", GitCommit))
} else if buildInfo, ok := debug.ReadBuildInfo(); ok {
// Try to extract commit from build info if not set via ldflags
for _, setting := range buildInfo.Settings {
if setting.Key == "vcs.revision" {
commitHash := setting.Value
if len(commitHash) > shortHashLength {
commitHash = commitHash[:shortHashLength] // Short hash
}
parts = append(parts, fmt.Sprintf("commit: %s", commitHash))
break
}
}
}

// Add build date if available
if BuildDate != "" {
parts = append(parts, fmt.Sprintf("built: %s", BuildDate))
} else if buildInfo, ok := debug.ReadBuildInfo(); ok {
// Try to extract build time from build info if not set via ldflags
for _, setting := range buildInfo.Settings {
if setting.Key == "vcs.time" {
parts = append(parts, fmt.Sprintf("built: %s", setting.Value))
break
}
}
}

return strings.Join(parts, ", ")
}
Loading