diff --git a/.github/workflows/container.yml b/.github/workflows/container.yml new file mode 100644 index 00000000..e3a5912a --- /dev/null +++ b/.github/workflows/container.yml @@ -0,0 +1,44 @@ +name: build-and-push-container + +on: + push: + branches: [ "main" ] + tags: [ "v*" ] + workflow_dispatch: {} + +permissions: + contents: read + packages: write + +jobs: + docker: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + # Enables emulation so the amd64 runner can build arm64 too + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 # Buildx builder [oai_citation:1‡GitHub](https://github.com/docker/setup-buildx-action?utm_source=chatgpt.com) + + - name: Log in to GHCR + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} # packages:write [oai_citation:2‡GitHub](https://github.com/docker/login-action?utm_source=chatgpt.com) + + - name: Build and push (multi-arch) + uses: docker/build-push-action@v6 + with: + context: . + push: true + platforms: linux/amd64,linux/arm64 + tags: | + ghcr.io/${{ github.repository }}:latest + ghcr.io/${{ github.repository }}:${{ github.sha }} + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..26128a22 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,35 @@ +# Build stage +FROM golang:1.23-alpine AS builder + +WORKDIR /app + +# Copy go mod files +COPY go.mod go.sum ./ +RUN go mod download + +# Copy source code +COPY . . + +# Build the binary +RUN CGO_ENABLED=0 GOOS=linux go build -o flowguard-go . + +# Runtime stage +FROM alpine:latest + +# Install Docker CLI and bash for launching backend MCP servers +RUN apk add --no-cache docker-cli bash + +WORKDIR /app + +# Copy binary from builder +COPY --from=builder /app/flowguard-go . + +# Copy run.sh script +COPY run.sh . +RUN chmod +x run.sh + +# Expose default HTTP port +EXPOSE 8000 + +# Use run.sh as entrypoint +ENTRYPOINT ["/app/run.sh"] diff --git a/README.md b/README.md index 53e3a91c..c3411339 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,293 @@ -# gh-aw-mcpg -Github Agentic Workflows MCP Gateway +# FlowGuard (Go Port) + +A simplified Go port of FlowGuard - a proxy server for Model Context Protocol (MCP) servers. + +## Features + +- **Configuration Modes**: Supports both TOML files and JSON stdin configuration +- **Routing Modes**: + - **Routed**: Each backend server accessible at `/mcp/{serverID}` + - **Unified**: Single endpoint `/mcp` that routes to configured servers +- **Docker Support**: Launch backend MCP servers as Docker containers +- **Stdio Transport**: JSON-RPC 2.0 over stdin/stdout for MCP communication + +## Quick Start + +### Prerequisites + +1. **Docker** installed and running +2. **Go 1.23+** for building from source + +### Setup Steps + +1. **Build the binary** + ```bash + go build -o flowguard-go + ``` + +2. **Create your environment file** + ```bash + cp example.env .env + ``` + +3. **Create a GitHub Personal Access Token** + - Go to https://github.com/settings/tokens + - Click "Generate new token (classic)" + - Select scopes as needed (e.g., `repo` for repository access) + - Copy the generated token + +4. **Add your token to `.env`** + + Replace the placeholder value with your actual token: + ```bash + sed -i '' 's/GITHUB_PERSONAL_ACCESS_TOKEN=.*/GITHUB_PERSONAL_ACCESS_TOKEN=your_token_here/' .env + ``` + + Or edit `.env` manually and replace the value of `GITHUB_PERSONAL_ACCESS_TOKEN`. + +5. **Pull required Docker images** + ```bash + docker pull ghcr.io/github/github-mcp-server:latest + docker pull mcp/fetch + docker pull mcp/memory + ``` + +6. **Start FlowGuard** + + In one terminal, run: + ```bash + ./run.sh + ``` + + This will start FlowGuard in routed mode on `http://127.0.0.1:8000`. + +7. **Run Codex (in another terminal)** + ```bash + cp ~/.codex/config.toml ~/.codex/config.toml.bak && cp agent-configs/codex.config.toml ~/.codex/config.toml + AGENT_ID=demo-agent codex + ``` + + You can use '/mcp' in codex to list the available tools. + + That's it! FlowGuard is now proxying MCP requests to your configured backend servers. + + When you're done you can restore your old codex config file: + + ```bash + cp ~/.codex/config.toml.bak ~/.codex/config.toml + ``` + +## Testing with curl + +You can test the MCP server directly using curl commands: + +### 1. Initialize a session and extract session ID + +```bash +MCP_URL="http://127.0.0.1:8000/mcp/github" + +SESSION_ID=$( + curl -isS -X POST $MCP_URL \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -H 'Authorization: Bearer demo-agent' \ + -d '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"1.0.0","capabilities":{},"clientInfo":{"name":"curl","version":"0.1"}}}' \ + | awk 'BEGIN{IGNORECASE=1} /^mcp-session-id:/{print $2}' | tr -d '\r' +) + +echo "Session ID: $SESSION_ID" +``` + +### 2. List available tools + +```bash +curl -s \ + -H "Content-Type: application/json" \ + -H "Mcp-Session-Id: $SESSION_ID" \ + -H 'Authorization: Bearer demo-agent' \ + -X POST \ + $MCP_URL \ + -d '{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/list", + "params": {} + }' +``` + +### Manual Build & Run + +If you prefer to run manually without the `run.sh` script: + +```bash +# Run with TOML config +./flowguard-go --config config.toml + +# Run with JSON stdin config +echo '{"mcpServers": {...}}' | ./flowguard-go --config-stdin +``` + +## Configuration + +### TOML Format (`config.toml`) + +```toml +[servers] + +[servers.github] +command = "docker" +args = ["run", "--rm", "-e", "GITHUB_PERSONAL_ACCESS_TOKEN", "-i", "ghcr.io/github/github-mcp-server:latest"] + +[servers.filesystem] +command = "node" +args = ["/path/to/filesystem-server.js"] +``` + +### JSON Stdin Format + +```json +{ + "mcpServers": { + "github": { + "type": "local", + "container": "ghcr.io/github/github-mcp-server:latest", + "env": { + "GITHUB_PERSONAL_ACCESS_TOKEN": "" + }, + } + } +} +``` + +**Environment Variable Passthrough**: Set the value to an empty string (`""`) to pass through the variable from the host environment. + +## Usage + +``` +FlowGuard is a proxy server for Model Context Protocol (MCP) servers. +It provides routing, aggregation, and management of multiple MCP backend servers. + +Usage: + flowguard-go [flags] + +Flags: + -c, --config string Path to config file (default "config.toml") + --config-stdin Read MCP server configuration from stdin (JSON format). When enabled, overrides --config + --env string Path to .env file to load environment variables + -h, --help help for flowguard-go + -l, --listen string HTTP server listen address (default "127.0.0.1:3000") + --routed Run in routed mode (each backend at /mcp/) + --unified Run in unified mode (all backends at /mcp) +``` + +## Docker + +### Build Image + +```bash +docker build -t flowguard-go . +``` + +### Run Container + +```bash +docker run --rm -v $(pwd)/.env:/app/.env \ + -v /var/run/docker.sock:/var/run/docker.sock \ + -p 8000:8000 \ + flowguard-go +``` + +The container uses `run.sh` as the entrypoint, which automatically: +- Detects architecture and sets DOCKER_API_VERSION (1.43 for arm64, 1.44 for amd64) +- Loads environment variables from `.env` +- Starts FlowGuard in routed mode on port 8000 +- Reads configuration from stdin (via heredoc in run.sh) + +### Override with custom configuration + +To use a custom config file, set environment variables that `run.sh` reads: + +```bash +docker run --rm -v $(pwd)/config.toml:/app/config.toml \ + -v $(pwd)/.env:/app/.env \ + -v /var/run/docker.sock:/var/run/docker.sock \ + -e CONFIG=/app/config.toml \ + -e ENV_FILE=/app/.env \ + -e PORT=8000 \ + -e HOST=127.0.0.1 \ + -p 8000:8000 \ + flowguard-go +``` + +Available environment variables for `run.sh`: +- `CONFIG` - Path to config file (overrides stdin config) +- `ENV_FILE` - Path to .env file (default: `.env`) +- `PORT` - Server port (default: `8000`) +- `HOST` - Server host (default: `127.0.0.1`) +- `MODE` - Server mode flag (default: `--routed`, can be `--unified`) + +**Note:** Set `DOCKER_API_VERSION=1.43` for arm64 (Mac) or `1.44` for amd64 (Linux). + + +## API Endpoints + +### Routed Mode (default) + +- `POST /mcp/{serverID}` - Send JSON-RPC request to specific server + - Example: `POST /mcp/github` with body `{"jsonrpc": "2.0", "method": "tools/list", "id": 1}` + +### Unified Mode + +- `POST /mcp` - Send JSON-RPC request (routed to first configured server) + +### Health Check + +- `GET /health` - Returns `OK` + +## MCP Methods + +Supported JSON-RPC 2.0 methods: + +- `tools/list` - List available tools +- `tools/call` - Call a tool with parameters +- Any other MCP method (forwarded as-is) + +## Architecture Simplifications + +This Go port focuses on core MCP proxy functionality: + +- ✅ TOML and JSON stdin configuration +- ✅ Stdio transport for backend servers +- ✅ Docker container launching +- ✅ Routed and unified modes +- ✅ Basic request/response proxying +- ❌ DIFC enforcement (removed) +- ❌ Sub-agents (removed) +- ❌ Guards (removed) + +## Development + +### Project Structure + +``` +flowguard-go/ +├── main.go # Entry point +├── go.mod # Dependencies +├── Dockerfile # Container image +└── internal/ + ├── cmd/ # CLI commands (cobra) + ├── config/ # Configuration loading + ├── launcher/ # Backend server management + ├── mcp/ # MCP protocol types & connection + └── server/ # HTTP server +``` + +### Dependencies + +- `github.com/spf13/cobra` - CLI framework +- `github.com/BurntSushi/toml` - TOML parser +- Standard library for JSON, HTTP, exec + +## License + +Same as original FlowGuard project. diff --git a/agent-configs/codex.config.toml b/agent-configs/codex.config.toml new file mode 100644 index 00000000..43eb2556 --- /dev/null +++ b/agent-configs/codex.config.toml @@ -0,0 +1,34 @@ +model = "gpt-5.1-codex-max" +sandbox_mode = "workspace-write" +model_reasoning_effort = "high" + +[mcp_servers.flowguard] +url = "http://127.0.0.1:8000/mcp/sys" +transport = "streamablehttp" +tool_timeout_sec = 120000 +startup_timeout_ms = 180000 +bearer_token_env_var = "AGENT_ID" + +[mcp_servers.github] +url = "http://127.0.0.1:8000/mcp/github" +transport = "streamablehttp" +tool_timeout_sec = 120000 +startup_timeout_ms = 180000 +bearer_token_env_var = "AGENT_ID" + +[mcp_servers.fetch] +url = "http://127.0.0.1:8000/mcp/fetch" +transport = "streamablehttp" +tool_timeout_sec = 120000 +startup_timeout_ms = 180000 +bearer_token_env_var = "AGENT_ID" + +[mcp_servers.memory] +url = "http://127.0.0.1:8000/mcp/memory" +transport = "streamablehttp" +tool_timeout_sec = 120000 +startup_timeout_ms = 180000 +bearer_token_env_var = "AGENT_ID" + +[projects."/workspace/"] +trust_level="trusted" diff --git a/config.json b/config.json new file mode 100644 index 00000000..0289e31b --- /dev/null +++ b/config.json @@ -0,0 +1,19 @@ +{ + "mcpServers": { + "github": { + "type": "local", + "container": "ghcr.io/github/github-mcp-server:latest", + "env": { + "GITHUB_PERSONAL_ACCESS_TOKEN": "" + } + }, + "fetch": { + "type": "local", + "container": "mcp/fetch" + }, + "memory": { + "type": "local", + "container": "mcp/memory" + } + } +} diff --git a/config.toml b/config.toml new file mode 100644 index 00000000..f55dc923 --- /dev/null +++ b/config.toml @@ -0,0 +1,24 @@ +[servers] + +[servers.github] +command = "docker" +args = ["run", "--rm", "-i", + "--name", "flowguard-github-mcp", + "-e", "GITHUB_PERSONAL_ACCESS_TOKEN", + "-e", "NO_COLOR=1", "-e", "TERM=dumb", + "ghcr.io/github/github-mcp-server:latest"] + +[servers.fetch] +command = "docker" +args = ["run", "--rm", "-i", + "-e", "NO_COLOR=1", "-e", "TERM=dumb", + "-e", "PYTHONUNBUFFERED=1", + "mcp/fetch"] + +[servers.memory] +command = "docker" +args = ["run", "--rm", "-i", "-e", "NO_COLOR=1", "-e", "TERM=dumb", "-e", "PYTHONUNBUFFERED=1", "mcp/memory"] + +# Note: DOCKER_API_VERSION is automatically set based on architecture +# - ARM64 (M1/M2/M3 Macs): 1.43 +# - x86_64 (Intel, GitHub Actions): 1.44 diff --git a/example.env b/example.env new file mode 100644 index 00000000..d5323d6f --- /dev/null +++ b/example.env @@ -0,0 +1,3 @@ +GH_TOKEN= +GITHUB_PERSONAL_ACCESS_TOKEN=$GH_TOKEN +GITHUB_TOKEN=$GH_TOKEN diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..7abeb185 --- /dev/null +++ b/go.mod @@ -0,0 +1,17 @@ +module github.com/githubnext/gh-aw-mcpg + +go 1.23.0 + +require ( + github.com/BurntSushi/toml v1.5.0 + github.com/modelcontextprotocol/go-sdk v1.1.0 + github.com/spf13/cobra v1.8.0 +) + +require ( + github.com/google/jsonschema-go v0.3.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/oauth2 v0.30.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..13b3d49b --- /dev/null +++ b/go.sum @@ -0,0 +1,24 @@ +github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= +github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= +github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/modelcontextprotocol/go-sdk v1.1.0 h1:Qjayg53dnKC4UZ+792W21e4BpwEZBzwgRW6LrjLWSwA= +github.com/modelcontextprotocol/go-sdk v1.1.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= +github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= +golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/cmd/root.go b/internal/cmd/root.go new file mode 100644 index 00000000..f3c19fa6 --- /dev/null +++ b/internal/cmd/root.go @@ -0,0 +1,177 @@ +package cmd + +import ( + "bufio" + "context" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "strings" + "syscall" + + "github.com/githubnext/gh-aw-mcpg/internal/config" + "github.com/githubnext/gh-aw-mcpg/internal/server" + "github.com/spf13/cobra" +) + +var ( + configFile string + configStdin bool + listenAddr string + routedMode bool + unifiedMode bool + envFile string +) + +var rootCmd = &cobra.Command{ + Use: "flowguard", + Short: "FlowGuard MCP proxy server", + Long: `FlowGuard is a proxy server for Model Context Protocol (MCP) servers. +It provides routing, aggregation, and management of multiple MCP backend servers.`, + RunE: run, +} + +func init() { + rootCmd.Flags().StringVarP(&configFile, "config", "c", "config.toml", "Path to config file") + rootCmd.Flags().BoolVar(&configStdin, "config-stdin", false, "Read MCP server configuration from stdin (JSON format). When enabled, overrides --config") + rootCmd.Flags().StringVarP(&listenAddr, "listen", "l", "127.0.0.1:3000", "HTTP server listen address") + rootCmd.Flags().BoolVar(&routedMode, "routed", false, "Run in routed mode (each backend at /mcp/)") + rootCmd.Flags().BoolVar(&unifiedMode, "unified", false, "Run in unified mode (all backends at /mcp)") + rootCmd.Flags().StringVar(&envFile, "env", "", "Path to .env file to load environment variables") +} + +func run(cmd *cobra.Command, args []string) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Load .env file if specified + if envFile != "" { + if err := loadEnvFile(envFile); err != nil { + return fmt.Errorf("failed to load .env file: %w", err) + } + } + + // Load configuration + var cfg *config.Config + var err error + + if configStdin { + log.Println("Reading configuration from stdin...") + cfg, err = config.LoadFromStdin() + } else { + log.Printf("Reading configuration from %s...", configFile) + cfg, err = config.LoadFromFile(configFile) + } + + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + + log.Printf("Loaded %d MCP server(s)", len(cfg.Servers)) + + // Determine mode (default to unified if neither flag is set) + mode := "unified" + if routedMode && unifiedMode { + return fmt.Errorf("cannot specify both --routed and --unified") + } + if routedMode { + mode = "routed" + } + + // Create unified MCP server (backend for both modes) + unifiedServer, err := server.NewUnified(ctx, cfg) + if err != nil { + return fmt.Errorf("failed to create unified server: %w", err) + } + defer unifiedServer.Close() + + // Handle graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + go func() { + <-sigChan + log.Println("Shutting down...") + cancel() + unifiedServer.Close() + os.Exit(0) + }() + + // Create HTTP server based on mode + var httpServer *http.Server + if mode == "routed" { + log.Printf("Starting FlowGuard in ROUTED mode on %s", listenAddr) + log.Printf("Routes: /mcp/ where is one of: %v", unifiedServer.GetServerIDs()) + httpServer = server.CreateHTTPServerForRoutedMode(listenAddr, unifiedServer) + } else { + log.Printf("Starting FlowGuard in UNIFIED mode on %s", listenAddr) + log.Printf("Endpoint: /mcp") + httpServer = server.CreateHTTPServerForMCP(listenAddr, unifiedServer) + } + + // Start HTTP server + if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + return fmt.Errorf("server error: %w", err) + } + + return nil +} + +// loadEnvFile reads a .env file and sets environment variables +func loadEnvFile(path string) error { + file, err := os.Open(path) + if err != nil { + return err + } + defer file.Close() + + log.Printf("Loading environment from %s...", path) + scanner := bufio.NewScanner(file) + loadedVars := 0 + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + // Skip empty lines and comments + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + // Parse KEY=VALUE + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + continue + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + // Expand $VAR references in value + value = os.ExpandEnv(value) + + if err := os.Setenv(key, value); err != nil { + return fmt.Errorf("failed to set %s: %w", key, err) + } + + // Log loaded variable (hide sensitive values) + displayValue := value + if len(value) > 0 { + displayValue = value[:min(10, len(value))] + "..." + } + log.Printf(" Loaded: %s=%s", key, displayValue) + loadedVars++ + } + + log.Printf("Loaded %d environment variables from %s", loadedVars, path) + + return scanner.Err() +} + +// Execute runs the root command +func Execute() { + if err := rootCmd.Execute(); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 00000000..0df3bc5f --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,136 @@ +package config + +import ( + "encoding/json" + "fmt" + "io" + "log" + "os" + + "github.com/BurntSushi/toml" +) + +// Config represents the FlowGuard configuration +type Config struct { + Servers map[string]*ServerConfig `toml:"servers"` +} + +// ServerConfig represents a single MCP server configuration +type ServerConfig struct { + Command string `toml:"command"` + Args []string `toml:"args"` + Env map[string]string `toml:"env"` + WorkingDirectory string `toml:"working_directory"` +} + +// StdinConfig represents JSON configuration from stdin +type StdinConfig struct { + MCPServers map[string]*StdinServerConfig `json:"mcpServers"` + Gateway *StdinGatewayConfig `json:"gateway,omitempty"` +} + +// StdinServerConfig represents a single server from stdin JSON +type StdinServerConfig struct { + Type string `json:"type"` + Command string `json:"command,omitempty"` + Args []string `json:"args,omitempty"` + Env map[string]string `json:"env,omitempty"` + Container string `json:"container,omitempty"` + EntrypointArgs []string `json:"entrypointArgs,omitempty"` +} + +// StdinGatewayConfig represents gateway configuration from stdin JSON +type StdinGatewayConfig struct { + Port *int `json:"port,omitempty"` + APIKey string `json:"apiKey,omitempty"` +} + +// LoadFromFile loads configuration from a TOML file +func LoadFromFile(path string) (*Config, error) { + var cfg Config + if _, err := toml.DecodeFile(path, &cfg); err != nil { + return nil, fmt.Errorf("failed to decode TOML: %w", err) + } + return &cfg, nil +} + +// LoadFromStdin loads configuration from stdin JSON +func LoadFromStdin() (*Config, error) { + data, err := io.ReadAll(os.Stdin) + if err != nil { + return nil, fmt.Errorf("failed to read stdin: %w", err) + } + + var stdinCfg StdinConfig + if err := json.Unmarshal(data, &stdinCfg); err != nil { + return nil, fmt.Errorf("failed to parse JSON: %w", err) + } + + // Log gateway configuration if present (reserved for future use) + if stdinCfg.Gateway != nil { + if stdinCfg.Gateway.Port != nil || stdinCfg.Gateway.APIKey != "" { + log.Println("Gateway configuration present but not yet implemented (reserved for future use)") + } + } + + // Convert stdin config to internal format + cfg := &Config{ + Servers: make(map[string]*ServerConfig), + } + + for name, server := range stdinCfg.MCPServers { + // Only support "local" type for now + if server.Type != "local" { + log.Printf("Warning: skipping server '%s' with unsupported type '%s'", name, server.Type) + continue + } + + // For Docker containers + if server.Container != "" { + args := []string{ + "run", + "--rm", + "-i", + // Standard environment variables for better Docker compatibility + "-e", "NO_COLOR=1", + "-e", "TERM=dumb", + "-e", "PYTHONUNBUFFERED=1", + } + + // Add user-specified environment variables + // Empty string "" means passthrough from host (just -e KEY) + // Non-empty string means explicit value (-e KEY=value) + for k, v := range server.Env { + args = append(args, "-e") + if v == "" { + // Passthrough from host environment + args = append(args, k) + } else { + // Explicit value + args = append(args, fmt.Sprintf("%s=%s", k, v)) + } + } + + // Add container name + args = append(args, server.Container) + + // Add entrypoint args + args = append(args, server.EntrypointArgs...) + + cfg.Servers[name] = &ServerConfig{ + Command: "docker", + Args: args, + Env: make(map[string]string), + } + } else { + // Direct command execution + cfg.Servers[name] = &ServerConfig{ + Command: server.Command, + Args: server.Args, + Env: server.Env, + } + } + } + + return cfg, nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 00000000..cd8f851b --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,276 @@ +package config + +import ( + "encoding/json" + "os" + "strings" + "testing" +) + +func TestLoadFromStdin_ValidJSON(t *testing.T) { + jsonConfig := `{ + "mcpServers": { + "test": { + "type": "local", + "container": "test/container:latest", + "entrypointArgs": ["arg1", "arg2"], + "env": { + "TEST_VAR": "value", + "PASSTHROUGH_VAR": "" + } + } + } + }` + + // Mock stdin + r, w, _ := os.Pipe() + oldStdin := os.Stdin + os.Stdin = r + go func() { + w.Write([]byte(jsonConfig)) + w.Close() + }() + + cfg, err := LoadFromStdin() + os.Stdin = oldStdin + + if err != nil { + t.Fatalf("LoadFromStdin() failed: %v", err) + } + + if cfg == nil { + t.Fatal("LoadFromStdin() returned nil config") + } + + if len(cfg.Servers) != 1 { + t.Errorf("Expected 1 server, got %d", len(cfg.Servers)) + } + + server, ok := cfg.Servers["test"] + if !ok { + t.Fatal("Server 'test' not found in config") + } + + if server.Command != "docker" { + t.Errorf("Expected command 'docker', got '%s'", server.Command) + } + + // Check that standard Docker env vars are included + hasNoColor := false + hasTerm := false + hasPythonUnbuffered := false + hasTestVar := false + hasPassthrough := false + + for i := 0; i < len(server.Args); i++ { + arg := server.Args[i] + if arg == "-e" && i+1 < len(server.Args) { + nextArg := server.Args[i+1] + if nextArg == "NO_COLOR=1" { + hasNoColor = true + } else if nextArg == "TERM=dumb" { + hasTerm = true + } else if nextArg == "PYTHONUNBUFFERED=1" { + hasPythonUnbuffered = true + } else if nextArg == "TEST_VAR=value" { + hasTestVar = true + } else if nextArg == "PASSTHROUGH_VAR" { + hasPassthrough = true + } + } + } + + if !hasNoColor { + t.Error("Standard env var NO_COLOR=1 not found") + } + if !hasTerm { + t.Error("Standard env var TERM=dumb not found") + } + if !hasPythonUnbuffered { + t.Error("Standard env var PYTHONUNBUFFERED=1 not found") + } + if !hasTestVar { + t.Error("Custom env var TEST_VAR=value not found") + } + if !hasPassthrough { + t.Error("Passthrough env var PASSTHROUGH_VAR not found") + } + + // Check that container name is in args + if !contains(server.Args, "test/container:latest") { + t.Error("Container name not found in args") + } + + // Check that entrypoint args are included + if !contains(server.Args, "arg1") || !contains(server.Args, "arg2") { + t.Error("Entrypoint args not found") + } +} + +func TestLoadFromStdin_WithGateway(t *testing.T) { + port := 8080 + jsonConfig := `{ + "mcpServers": { + "test": { + "type": "local", + "container": "test/container:latest" + } + }, + "gateway": { + "port": 8080, + "apiKey": "test-key" + } + }` + + r, w, _ := os.Pipe() + oldStdin := os.Stdin + os.Stdin = r + go func() { + w.Write([]byte(jsonConfig)) + w.Close() + }() + + _, err := LoadFromStdin() + os.Stdin = oldStdin + + if err != nil { + t.Fatalf("LoadFromStdin() failed: %v", err) + } + + // Gateway should be parsed but not affect server config + var stdinCfg StdinConfig + json.Unmarshal([]byte(jsonConfig), &stdinCfg) + + if stdinCfg.Gateway == nil { + t.Error("Gateway not parsed") + } + if stdinCfg.Gateway.Port == nil || *stdinCfg.Gateway.Port != port { + t.Error("Gateway port not correct") + } + if stdinCfg.Gateway.APIKey != "test-key" { + t.Error("Gateway API key not correct") + } +} + +func TestLoadFromStdin_UnsupportedType(t *testing.T) { + jsonConfig := `{ + "mcpServers": { + "unsupported": { + "type": "remote", + "container": "test/container:latest" + }, + "supported": { + "type": "local", + "container": "test/container:latest" + } + } + }` + + r, w, _ := os.Pipe() + oldStdin := os.Stdin + os.Stdin = r + go func() { + w.Write([]byte(jsonConfig)) + w.Close() + }() + + cfg, err := LoadFromStdin() + os.Stdin = oldStdin + + if err != nil { + t.Fatalf("LoadFromStdin() failed: %v", err) + } + + // Only 'local' type should be loaded + if len(cfg.Servers) != 1 { + t.Errorf("Expected 1 server (local type only), got %d", len(cfg.Servers)) + } + + if _, ok := cfg.Servers["unsupported"]; ok { + t.Error("Unsupported server type was loaded") + } + + if _, ok := cfg.Servers["supported"]; !ok { + t.Error("Supported server type was not loaded") + } +} + +func TestLoadFromStdin_DirectCommand(t *testing.T) { + jsonConfig := `{ + "mcpServers": { + "direct": { + "type": "local", + "command": "node", + "args": ["index.js"], + "env": { + "NODE_ENV": "production" + } + } + } + }` + + r, w, _ := os.Pipe() + oldStdin := os.Stdin + os.Stdin = r + go func() { + w.Write([]byte(jsonConfig)) + w.Close() + }() + + cfg, err := LoadFromStdin() + os.Stdin = oldStdin + + if err != nil { + t.Fatalf("LoadFromStdin() failed: %v", err) + } + + server, ok := cfg.Servers["direct"] + if !ok { + t.Fatal("Server 'direct' not found") + } + + if server.Command != "node" { + t.Errorf("Expected command 'node', got '%s'", server.Command) + } + + if !contains(server.Args, "index.js") { + t.Error("Args not preserved for direct command") + } + + if server.Env["NODE_ENV"] != "production" { + t.Error("Env vars not preserved for direct command") + } +} + +func TestLoadFromStdin_InvalidJSON(t *testing.T) { + jsonConfig := `{invalid json}` + + r, w, _ := os.Pipe() + oldStdin := os.Stdin + os.Stdin = r + go func() { + w.Write([]byte(jsonConfig)) + w.Close() + }() + + _, err := LoadFromStdin() + os.Stdin = oldStdin + + if err == nil { + t.Error("Expected error for invalid JSON, got nil") + } + + if !strings.Contains(err.Error(), "parse JSON") { + t.Errorf("Expected 'parse JSON' error, got: %v", err) + } +} + +// Helper function to check if slice contains item +func contains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} diff --git a/internal/launcher/launcher.go b/internal/launcher/launcher.go new file mode 100644 index 00000000..92664ea4 --- /dev/null +++ b/internal/launcher/launcher.go @@ -0,0 +1,122 @@ +package launcher + +import ( + "context" + "fmt" + "log" + "os" + "strings" + "sync" + + "github.com/githubnext/gh-aw-mcpg/internal/config" + "github.com/githubnext/gh-aw-mcpg/internal/mcp" +) + +// Launcher manages backend MCP server connections +type Launcher struct { + ctx context.Context + config *config.Config + connections map[string]*mcp.Connection + mu sync.RWMutex +} + +// New creates a new Launcher +func New(ctx context.Context, cfg *config.Config) *Launcher { + return &Launcher{ + ctx: ctx, + config: cfg, + connections: make(map[string]*mcp.Connection), + } +} + +// GetOrLaunch returns an existing connection or launches a new one +func GetOrLaunch(l *Launcher, serverID string) (*mcp.Connection, error) { + // Check if already exists + l.mu.RLock() + if conn, ok := l.connections[serverID]; ok { + l.mu.RUnlock() + return conn, nil + } + l.mu.RUnlock() + + // Launch new connection + l.mu.Lock() + defer l.mu.Unlock() + + // Double-check after acquiring write lock + if conn, ok := l.connections[serverID]; ok { + return conn, nil + } + + // Get server config + serverCfg, ok := l.config.Servers[serverID] + if !ok { + return nil, fmt.Errorf("server '%s' not found in config", serverID) + } + + // Log the command being executed + log.Printf("[LAUNCHER] Starting MCP server: %s", serverID) + log.Printf("[LAUNCHER] Command: %s", serverCfg.Command) + log.Printf("[LAUNCHER] Args: %v", serverCfg.Args) + + // Check for environment variable passthrough (only check args after -e flags) + for i := 0; i < len(serverCfg.Args); i++ { + arg := serverCfg.Args[i] + // If this arg is "-e", check the next argument + if arg == "-e" && i+1 < len(serverCfg.Args) { + nextArg := serverCfg.Args[i+1] + // Check if it's a passthrough (no = sign) vs explicit value (has = sign) + if !strings.Contains(nextArg, "=") { + // This is a passthrough variable, check if it exists in our environment + if val := os.Getenv(nextArg); val != "" { + displayVal := val + if len(val) > 10 { + displayVal = val[:10] + "..." + } + log.Printf("[LAUNCHER] ✓ Env passthrough: %s=%s (from FlowGuard process)", nextArg, displayVal) + } else { + log.Printf("[LAUNCHER] ✗ WARNING: Env passthrough for %s requested but NOT FOUND in FlowGuard process", nextArg) + } + } + i++ // Skip the next arg since we just processed it + } + } + + if len(serverCfg.Env) > 0 { + log.Printf("[LAUNCHER] Additional env vars: %v", serverCfg.Env) + } + + // Create connection + conn, err := mcp.NewConnection(l.ctx, serverCfg.Command, serverCfg.Args, serverCfg.Env) + if err != nil { + return nil, fmt.Errorf("failed to create connection: %w", err) + } + + log.Printf("[LAUNCHER] Successfully launched: %s", serverID) + + l.connections[serverID] = conn + return conn, nil +} + +// ServerIDs returns all configured server IDs +func (l *Launcher) ServerIDs() []string { + l.mu.RLock() + defer l.mu.RUnlock() + + ids := make([]string, 0, len(l.config.Servers)) + for id := range l.config.Servers { + ids = append(ids, id) + } + return ids +} + +// Close closes all connections +func (l *Launcher) Close() { + l.mu.Lock() + defer l.mu.Unlock() + + for _, conn := range l.connections { + conn.Close() + } + l.connections = make(map[string]*mcp.Connection) +} diff --git a/internal/mcp/connection.go b/internal/mcp/connection.go new file mode 100644 index 00000000..6662c59f --- /dev/null +++ b/internal/mcp/connection.go @@ -0,0 +1,272 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "log" + "os" + "os/exec" + + sdk "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// Connection represents a connection to an MCP server using the official SDK +type Connection struct { + client *sdk.Client + session *sdk.ClientSession + ctx context.Context + cancel context.CancelFunc +} + +// NewConnection creates a new MCP connection using the official SDK +func NewConnection(ctx context.Context, command string, args []string, env map[string]string) (*Connection, error) { + ctx, cancel := context.WithCancel(ctx) + + // Create MCP client + client := sdk.NewClient(&sdk.Implementation{ + Name: "flowguard", + Version: "1.0.0", + }, nil) + + // Expand Docker -e flags that reference environment variables + // Docker's `-e VAR_NAME` expects VAR_NAME to be in the environment + expandedArgs := expandDockerEnvArgs(args) + expandedArgs = args // --- IGNORE --- + // Create command transport + cmd := exec.CommandContext(ctx, command, expandedArgs...) + + // Start with parent's environment to inherit shell variables + cmd.Env = append([]string{}, cmd.Environ()...) + + // Add/override with config-specified environment variables + if len(env) > 0 { + for k, v := range env { + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v)) + } + } + + log.Printf("Starting MCP server command: %s %v", command, expandedArgs) + transport := &sdk.CommandTransport{Command: cmd} + + // Connect to the server (this handles the initialization handshake automatically) + log.Printf("Connecting to MCP server...") + session, err := client.Connect(ctx, transport, nil) + if err != nil { + cancel() + return nil, fmt.Errorf("failed to connect: %w", err) + } + + conn := &Connection{ + client: client, + session: session, + ctx: ctx, + cancel: cancel, + } + + log.Printf("Started MCP server: %s %v", command, args) + return conn, nil +} + +// SendRequest sends a JSON-RPC request and waits for the response +func (c *Connection) SendRequest(method string, params interface{}) (*Response, error) { + switch method { + case "tools/list": + return c.listTools() + case "tools/call": + return c.callTool(params) + case "resources/list": + return c.listResources() + case "resources/read": + return c.readResource(params) + case "prompts/list": + return c.listPrompts() + case "prompts/get": + return c.getPrompt(params) + default: + return nil, fmt.Errorf("unsupported method: %s", method) + } +} + +func (c *Connection) listTools() (*Response, error) { + result, err := c.session.ListTools(c.ctx, &sdk.ListToolsParams{}) + if err != nil { + return nil, err + } + + resultJSON, err := json.Marshal(result) + if err != nil { + return nil, err + } + + return &Response{ + JSONRPC: "2.0", + ID: 1, // Placeholder ID + Result: resultJSON, + }, nil +} + +func (c *Connection) callTool(params interface{}) (*Response, error) { + var callParams CallToolParams + paramsJSON, _ := json.Marshal(params) + if err := json.Unmarshal(paramsJSON, &callParams); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + + result, err := c.session.CallTool(c.ctx, &sdk.CallToolParams{ + Name: callParams.Name, + Arguments: callParams.Arguments, + }) + if err != nil { + return nil, err + } + + resultJSON, err := json.Marshal(result) + if err != nil { + return nil, err + } + + return &Response{ + JSONRPC: "2.0", + ID: 1, + Result: resultJSON, + }, nil +} + +func (c *Connection) listResources() (*Response, error) { + result, err := c.session.ListResources(c.ctx, &sdk.ListResourcesParams{}) + if err != nil { + return nil, err + } + + resultJSON, err := json.Marshal(result) + if err != nil { + return nil, err + } + + return &Response{ + JSONRPC: "2.0", + ID: 1, + Result: resultJSON, + }, nil +} + +func (c *Connection) readResource(params interface{}) (*Response, error) { + var readParams struct { + URI string `json:"uri"` + } + paramsJSON, _ := json.Marshal(params) + if err := json.Unmarshal(paramsJSON, &readParams); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + + result, err := c.session.ReadResource(c.ctx, &sdk.ReadResourceParams{ + URI: readParams.URI, + }) + if err != nil { + return nil, err + } + + resultJSON, err := json.Marshal(result) + if err != nil { + return nil, err + } + + return &Response{ + JSONRPC: "2.0", + ID: 1, + Result: resultJSON, + }, nil +} + +func (c *Connection) listPrompts() (*Response, error) { + result, err := c.session.ListPrompts(c.ctx, &sdk.ListPromptsParams{}) + if err != nil { + return nil, err + } + + resultJSON, err := json.Marshal(result) + if err != nil { + return nil, err + } + + return &Response{ + JSONRPC: "2.0", + ID: 1, + Result: resultJSON, + }, nil +} + +func (c *Connection) getPrompt(params interface{}) (*Response, error) { + var getParams struct { + Name string `json:"name"` + Arguments map[string]string `json:"arguments"` + } + paramsJSON, _ := json.Marshal(params) + if err := json.Unmarshal(paramsJSON, &getParams); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + + result, err := c.session.GetPrompt(c.ctx, &sdk.GetPromptParams{ + Name: getParams.Name, + Arguments: getParams.Arguments, + }) + if err != nil { + return nil, err + } + + resultJSON, err := json.Marshal(result) + if err != nil { + return nil, err + } + + return &Response{ + JSONRPC: "2.0", + ID: 1, + Result: resultJSON, + }, nil +} + +// expandDockerEnvArgs expands Docker -e flags that reference environment variables +// Converts "-e VAR_NAME" to "-e VAR_NAME=value" by reading from the process environment +func expandDockerEnvArgs(args []string) []string { + result := make([]string, 0, len(args)) + for i := 0; i < len(args); i++ { + arg := args[i] + + // Check if this is a -e flag + if arg == "-e" && i+1 < len(args) { + nextArg := args[i+1] + // If next arg doesn't contain '=', it's a variable reference + if len(nextArg) > 0 && !containsEqual(nextArg) { + // Look up the variable in the environment + if value, exists := os.LookupEnv(nextArg); exists { + result = append(result, "-e") + result = append(result, fmt.Sprintf("%s=%s", nextArg, value)) + i++ // Skip the next arg since we processed it + continue + } + } + } + result = append(result, arg) + } + return result +} + +func containsEqual(s string) bool { + for _, c := range s { + if c == '=' { + return true + } + } + return false +} + +// Close closes the connection +func (c *Connection) Close() error { + c.cancel() + if c.session != nil { + return c.session.Close() + } + return nil +} diff --git a/internal/mcp/types.go b/internal/mcp/types.go new file mode 100644 index 00000000..b283e0c9 --- /dev/null +++ b/internal/mcp/types.go @@ -0,0 +1,45 @@ +package mcp + +import "encoding/json" + +// Request represents a JSON-RPC 2.0 request +type Request struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id,omitempty"` + Method string `json:"method"` + Params json.RawMessage `json:"params,omitempty"` +} + +// Response represents a JSON-RPC 2.0 response +type Response struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *ResponseError `json:"error,omitempty"` +} + +// ResponseError represents a JSON-RPC 2.0 error +type ResponseError struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data,omitempty"` +} + +// Tool represents an MCP tool definition +type Tool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema map[string]interface{} `json:"inputSchema"` +} + +// CallToolParams represents parameters for calling a tool +type CallToolParams struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments,omitempty"` +} + +// ContentItem represents a content item in tool responses +type ContentItem struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` +} diff --git a/internal/server/routed.go b/internal/server/routed.go new file mode 100644 index 00000000..eecfa293 --- /dev/null +++ b/internal/server/routed.go @@ -0,0 +1,137 @@ +package server + +import ( + "bytes" + "context" + "fmt" + "io" + "log" + "net/http" + "strings" + + sdk "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// CreateHTTPServerForRoutedMode creates an HTTP server for routed mode +// In routed mode, each backend is accessible at /mcp/ +// Multiple routes from the same Bearer token share a session +func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer) *http.Server { + mux := http.NewServeMux() + + // OAuth discovery endpoint - return 404 since we don't use OAuth + mux.HandleFunc("/mcp/.well-known/oauth-authorization-server", func(w http.ResponseWriter, r *http.Request) { + log.Printf("[%s] %s %s - OAuth discovery (not supported)", r.RemoteAddr, r.Method, r.URL.Path) + http.NotFound(w, r) + }) + + // Create routes for all backends plus sys + allBackends := append([]string{"sys"}, unifiedServer.GetServerIDs()...) + + // Create a proxy for each backend server (including sys) + for _, serverID := range allBackends { + // Capture serverID for the closure + backendID := serverID + route := fmt.Sprintf("/mcp/%s", backendID) + + // Create StreamableHTTP handler for this route + routeHandler := sdk.NewStreamableHTTPHandler(func(r *http.Request) *sdk.Server { + // Extract Bearer token from Authorization header + authHeader := r.Header.Get("Authorization") + var sessionID string + + if strings.HasPrefix(authHeader, "Bearer ") { + sessionID = strings.TrimPrefix(authHeader, "Bearer ") + sessionID = strings.TrimSpace(sessionID) + } + + // Reject requests without valid Bearer token + if sessionID == "" { + log.Printf("[%s] %s %s - REJECTED: No Bearer token", r.RemoteAddr, r.Method, r.URL.Path) + return nil + } + + log.Printf("=== NEW SSE CONNECTION (ROUTED) ===") + log.Printf("[%s] %s %s", r.RemoteAddr, r.Method, r.URL.Path) + log.Printf("Backend: %s", backendID) + log.Printf("Bearer Token (Session ID): %s", sessionID) + + // Log request body for debugging + if r.Method == "POST" && r.Body != nil { + bodyBytes, err := io.ReadAll(r.Body) + if err == nil && len(bodyBytes) > 0 { + log.Printf("Request body: %s", string(bodyBytes)) + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + } + } + + // Store session ID and backend ID in request context + ctx := context.WithValue(r.Context(), SessionIDContextKey, sessionID) + ctx = context.WithValue(ctx, ContextKey("backend-id"), backendID) + *r = *r.WithContext(ctx) + log.Printf("✓ Injected session ID and backend ID into context") + log.Printf("===================================\n") + + // Return a filtered proxy server that only exposes this backend's tools + return createFilteredServer(unifiedServer, backendID) + }, &sdk.StreamableHTTPOptions{ + Stateless: false, + }) + + // Mount the handler at both /mcp/ and /mcp// + mux.Handle(route+"/", routeHandler) + mux.Handle(route, routeHandler) + log.Printf("Registered route: %s", route) + } + + // Health check + mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "OK\n") + }) + + return &http.Server{ + Addr: addr, + Handler: mux, + } +} + +// createFilteredServer creates an MCP server that only exposes tools for a specific backend +// This reuses the unified server's tool handlers, ensuring all calls go through the same session +func createFilteredServer(unifiedServer *UnifiedServer, backendID string) *sdk.Server { + // Create a new SDK server for this route + server := sdk.NewServer(&sdk.Implementation{ + Name: fmt.Sprintf("flowguard-%s", backendID), + Version: "1.0.0", + }, nil) + + // Get tools for this backend from the unified server + tools := unifiedServer.GetToolsForBackend(backendID) + + log.Printf("Creating filtered server for %s with %d tools", backendID, len(tools)) + + // Register each tool (without prefix) using the unified server's handlers + for _, toolInfo := range tools { + // Capture for closure + toolNameCopy := toolInfo.Name + + // Get the unified server's handler for this tool + handler := unifiedServer.GetToolHandler(backendID, toolInfo.Name) + if handler == nil { + log.Printf("WARNING: No handler found for %s___%s", backendID, toolInfo.Name) + continue + } + + sdk.AddTool(server, &sdk.Tool{ + Name: toolInfo.Name, // Without prefix for the client + Description: toolInfo.Description, + InputSchema: toolInfo.InputSchema, + }, func(ctx context.Context, req *sdk.CallToolRequest, args interface{}) (*sdk.CallToolResult, interface{}, error) { + // Call the unified server's handler directly + // This ensures we go through the same session and connection pool + log.Printf("[ROUTED] Calling unified handler for: %s", toolNameCopy) + return handler(ctx, req, args) + }) + } + + return server +} diff --git a/internal/server/routed_test.go b/internal/server/routed_test.go new file mode 100644 index 00000000..82da0878 --- /dev/null +++ b/internal/server/routed_test.go @@ -0,0 +1,195 @@ +package server + +import ( + "context" + "testing" + + sdk "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/githubnext/gh-aw-mcpg/internal/config" +) + +func TestCreateFilteredServer_ToolFiltering(t *testing.T) { + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{}, + } + + ctx := context.Background() + us, err := NewUnified(ctx, cfg) + if err != nil { + t.Fatalf("NewUnified() failed: %v", err) + } + defer us.Close() + + // Add test tools - Handler is not tested directly, just use nil + us.toolsMu.Lock() + us.tools["github___issue_read"] = &ToolInfo{ + Name: "github___issue_read", + Description: "Read an issue", + BackendID: "github", + Handler: nil, + } + us.tools["github___repo_list"] = &ToolInfo{ + Name: "github___repo_list", + Description: "List repos", + BackendID: "github", + Handler: nil, + } + us.tools["fetch___get"] = &ToolInfo{ + Name: "fetch___get", + Description: "Fetch URL", + BackendID: "fetch", + Handler: nil, + } + us.toolsMu.Unlock() + + // Create filtered server for github backend + filteredServer := createFilteredServer(us, "github") + + // We can't easily inspect the filtered server's tools without SDK internals, + // but we can verify GetToolsForBackend returns correct filtered list + tools := us.GetToolsForBackend("github") + if len(tools) != 2 { + t.Errorf("Expected 2 tools for github backend, got %d", len(tools)) + } + + // Verify tool names have prefix stripped + toolNames := make(map[string]bool) + for _, tool := range tools { + toolNames[tool.Name] = true + } + + if !toolNames["issue_read"] { + t.Error("Expected tool 'issue_read' not found") + } + if !toolNames["repo_list"] { + t.Error("Expected tool 'repo_list' not found") + } + if toolNames["get"] { + t.Error("Tool 'get' from fetch backend should not be in github filtered server") + } + + _ = filteredServer // Use variable to avoid unused error +} + +func TestGetToolHandler(t *testing.T) { + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{}, + } + + ctx := context.Background() + us, err := NewUnified(ctx, cfg) + if err != nil { + t.Fatalf("NewUnified() failed: %v", err) + } + defer us.Close() + + // Create a mock handler with correct signature + mockHandler := func(ctx context.Context, req *sdk.CallToolRequest, state interface{}) (*sdk.CallToolResult, interface{}, error) { + return &sdk.CallToolResult{IsError: false}, state, nil + } + + // Add test tool with handler + us.toolsMu.Lock() + us.tools["github___test_tool"] = &ToolInfo{ + Name: "github___test_tool", + Description: "Test tool", + BackendID: "github", + Handler: mockHandler, + } + us.toolsMu.Unlock() + + // Test retrieval with non-prefixed name (routed mode format) + handler := us.GetToolHandler("github", "test_tool") + if handler == nil { + t.Fatal("GetToolHandler() returned nil for non-prefixed tool name") + } + + // Test non-existent tool + handler = us.GetToolHandler("github", "nonexistent_tool") + if handler != nil { + t.Error("GetToolHandler() should return nil for non-existent tool") + } + + // Test wrong backend (test_tool belongs to github, not fetch) + handler = us.GetToolHandler("fetch", "test_tool") + if handler != nil { + t.Error("GetToolHandler() should return nil when backend doesn't match") + } +} + +func TestCreateHTTPServerForRoutedMode_ServerIDs(t *testing.T) { + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{ + "github": {Command: "docker", Args: []string{}}, + "fetch": {Command: "docker", Args: []string{}}, + }, + } + + ctx := context.Background() + us, err := NewUnified(ctx, cfg) + if err != nil { + t.Fatalf("NewUnified() failed: %v", err) + } + defer us.Close() + + // Create routed mode server + httpServer := CreateHTTPServerForRoutedMode("127.0.0.1:8000", us) + if httpServer == nil { + t.Fatal("CreateHTTPServerForRoutedMode() returned nil") + } + + // Verify server IDs are correctly set up + serverIDs := us.GetServerIDs() + if len(serverIDs) != 2 { + t.Errorf("Expected 2 server IDs, got %d", len(serverIDs)) + } + + expectedIDs := map[string]bool{"github": true, "fetch": true} + for _, id := range serverIDs { + if !expectedIDs[id] { + t.Errorf("Unexpected server ID: %s", id) + } + } +} + +func TestRoutedMode_SysToolsBackend(t *testing.T) { + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{ + "github": {Command: "docker", Args: []string{}}, + }, + } + + ctx := context.Background() + us, err := NewUnified(ctx, cfg) + if err != nil { + t.Fatalf("NewUnified() failed: %v", err) + } + defer us.Close() + + // Verify sys tools exist + sysTools := us.GetToolsForBackend("sys") + if len(sysTools) == 0 { + t.Error("Expected sys tools to be registered, got none") + } + + // Check for expected sys tools + toolNames := make(map[string]bool) + for _, tool := range sysTools { + toolNames[tool.Name] = true + } + + expectedSysTools := []string{"init", "list_servers"} + for _, expectedTool := range expectedSysTools { + if !toolNames[expectedTool] { + t.Errorf("Expected sys tool '%s' not found", expectedTool) + } + } + + // Verify sys tools have correct backend ID + for _, tool := range sysTools { + if tool.BackendID != "sys" { + t.Errorf("Expected BackendID 'sys', got '%s'", tool.BackendID) + } + } +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 00000000..0089603f --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,262 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/http" + "strings" + + "github.com/githubnext/gh-aw-mcpg/internal/launcher" + "github.com/githubnext/gh-aw-mcpg/internal/mcp" + "github.com/githubnext/gh-aw-mcpg/internal/sys" +) + +// Server represents the FlowGuard HTTP server +type Server struct { + launcher *launcher.Launcher + sysServer *sys.SysServer + mux *http.ServeMux + mode string // "unified" or "routed" +} + +// New creates a new Server +func New(ctx context.Context, l *launcher.Launcher, mode string) *Server { + s := &Server{ + launcher: l, + sysServer: sys.NewSysServer(l.ServerIDs()), + mux: http.NewServeMux(), + mode: mode, + } + + s.setupRoutes() + return s +} + +func (s *Server) setupRoutes() { + if s.mode == "routed" { + // Routed mode: /mcp/{server}/{method} + s.mux.HandleFunc("/mcp/", s.handleRoutedMCP) + } else { + // Unified mode: /mcp (single endpoint for all servers) + s.mux.HandleFunc("/mcp", s.handleUnifiedMCP) + } + + // Health check + s.mux.HandleFunc("/health", s.handleHealth) +} + +func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { + log.Printf("[%s] %s %s", r.RemoteAddr, r.Method, r.URL.Path) + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "OK\n") +} + +func (s *Server) handleUnifiedMCP(w http.ResponseWriter, r *http.Request) { + log.Printf("[%s] %s %s", r.RemoteAddr, r.Method, r.URL.Path) + if r.Method != http.MethodPost { + log.Printf("Method not allowed: %s", r.Method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Read request + var req mcp.Request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + log.Printf("Failed to decode request body: %v", err) + s.sendError(w, -32700, "Parse error", nil) + return + } + log.Printf("Unified MCP request - method: %s", req.Method) + + // In unified mode, we need to determine which server to route to + // For now, default to the first configured server + // TODO: Implement proper routing logic based on tool name or other criteria + serverIDs := s.launcher.ServerIDs() + if len(serverIDs) == 0 { + s.sendError(w, -32603, "No MCP servers configured", nil) + return + } + + serverID := serverIDs[0] // Simple: use first server + s.proxyToServer(w, r, serverID, &req) +} + +func (s *Server) handleRoutedMCP(w http.ResponseWriter, r *http.Request) { + log.Printf("[%s] %s %s", r.RemoteAddr, r.Method, r.URL.Path) + if r.Method != http.MethodPost { + log.Printf("Method not allowed: %s", r.Method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse path: /mcp/{serverID} + path := strings.TrimPrefix(r.URL.Path, "/mcp/") + serverID := strings.Split(path, "/")[0] + + if serverID == "" { + log.Printf("No server ID in path") + http.Error(w, "Server ID required in path", http.StatusBadRequest) + return + } + + // Read request + var req mcp.Request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + log.Printf("Failed to decode request body: %v", err) + s.sendError(w, -32700, "Parse error", nil) + return + } + log.Printf("Routed MCP request - server: %s, method: %s", serverID, req.Method) + + // Handle initialize requests directly (MCP handshake) + if req.Method == "initialize" { + s.handleInitialize(w, &req, serverID) + return + } + + // Handle notifications/initialized (sent after initialize response) + if req.Method == "notifications/initialized" { + log.Printf("Received initialized notification for server: %s", serverID) + // No response needed for notifications + w.WriteHeader(http.StatusOK) + return + } + + s.proxyToServer(w, r, serverID, &req) +} + +func (s *Server) handleInitialize(w http.ResponseWriter, req *mcp.Request, serverID string) { + log.Printf("Handling initialize request for server: %s", serverID) + + // Return a proper MCP initialize response + result := map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{ + "tools": map[string]interface{}{}, + "resources": map[string]interface{}{}, + "prompts": map[string]interface{}{}, + }, + "serverInfo": map[string]interface{}{ + "name": "flowguard-" + serverID, + "version": "1.0.0", + }, + } + + resultJSON, _ := json.Marshal(result) + resp := &mcp.Response{ + JSONRPC: "2.0", + ID: req.ID, + Result: resultJSON, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} + +func (s *Server) proxyToServer(w http.ResponseWriter, r *http.Request, serverID string, req *mcp.Request) { + // Handle built-in sys server + if serverID == "sys" { + s.handleSysRequest(w, req) + return + } + + // Get or launch connection + conn, err := launcher.GetOrLaunch(s.launcher, serverID) + if err != nil { + log.Printf("Failed to get connection to '%s': %v", serverID, err) + s.sendError(w, -32603, fmt.Sprintf("Failed to connect to server '%s'", serverID), nil) + return + } + + // Forward request based on method + var resp *mcp.Response + + switch req.Method { + case "tools/list": + resp, err = conn.SendRequest("tools/list", nil) + case "tools/call": + var params mcp.CallToolParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + s.sendError(w, -32602, "Invalid params", nil) + return + } + resp, err = conn.SendRequest("tools/call", params) + default: + // Forward as-is + var params interface{} + if len(req.Params) > 0 { + json.Unmarshal(req.Params, ¶ms) + } + resp, err = conn.SendRequest(req.Method, params) + } + + if err != nil { + log.Printf("Error proxying request to '%s': %v", serverID, err) + s.sendError(w, -32603, "Internal error", nil) + return + } + + // Send response + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} + +func (s *Server) handleSysRequest(w http.ResponseWriter, req *mcp.Request) { + // Handle sys server requests locally + result, err := s.sysServer.HandleRequest(req.Method, req.Params) + if err != nil { + log.Printf("Sys server error: %v", err) + s.sendError(w, -32603, err.Error(), nil) + return + } + + // Marshal result + resultJSON, err := json.Marshal(result) + if err != nil { + s.sendError(w, -32603, "Failed to marshal result", nil) + return + } + + // Create response + resp := &mcp.Response{ + JSONRPC: "2.0", + ID: req.ID, + Result: resultJSON, + } + + // Send response + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} + +func (s *Server) sendError(w http.ResponseWriter, code int, message string, data interface{}) { + resp := &mcp.Response{ + JSONRPC: "2.0", + Error: &mcp.ResponseError{ + Code: code, + Message: message, + }, + } + + if data != nil { + dataBytes, _ := json.Marshal(data) + resp.Error.Data = dataBytes + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) // MCP errors still return 200 + json.NewEncoder(w).Encode(resp) +} + +// ServeHTTP implements http.Handler +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.mux.ServeHTTP(w, r) +} + +// ListenAndServe starts the HTTP server +func (s *Server) ListenAndServe(addr string) error { + log.Printf("Starting FlowGuard HTTP server on %s (mode: %s)", addr, s.mode) + return http.ListenAndServe(addr, s) +} diff --git a/internal/server/transport.go b/internal/server/transport.go new file mode 100644 index 00000000..dbcdb29f --- /dev/null +++ b/internal/server/transport.go @@ -0,0 +1,133 @@ +package server + +import ( + "bytes" + "context" + "fmt" + "io" + "log" + "net/http" + "strings" + + sdk "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// HTTPTransport wraps the SDK's HTTP transport +type HTTPTransport struct { + Addr string +} + +// Start implements sdk.Transport interface +func (t *HTTPTransport) Start(ctx context.Context) error { + // The SDK will handle the actual HTTP server setup + // We just need to provide the address + log.Printf("HTTP transport ready on %s", t.Addr) + return nil +} + +// Send implements sdk.Transport interface +func (t *HTTPTransport) Send(msg interface{}) error { + // Messages are sent via HTTP responses, handled by SDK + return nil +} + +// Recv implements sdk.Transport interface +func (t *HTTPTransport) Recv() (interface{}, error) { + // Messages are received via HTTP requests, handled by SDK + return nil, nil +} + +// Close implements sdk.Transport interface +func (t *HTTPTransport) Close() error { + return nil +} + +// loggingResponseWriter wraps http.ResponseWriter to capture response body +type loggingResponseWriter struct { + http.ResponseWriter + body []byte +} + +func (w *loggingResponseWriter) Write(b []byte) (int, error) { + w.body = append(w.body, b...) + return w.ResponseWriter.Write(b) +} + +// CreateHTTPServerForMCP creates an HTTP server that handles MCP over SSE +func CreateHTTPServerForMCP(addr string, unifiedServer *UnifiedServer) *http.Server { + mux := http.NewServeMux() + + // OAuth discovery endpoint - return 404 since we don't use OAuth + mux.HandleFunc("/mcp/.well-known/oauth-authorization-server", func(w http.ResponseWriter, r *http.Request) { + log.Printf("[%s] %s %s - OAuth discovery (not supported)", r.RemoteAddr, r.Method, r.URL.Path) + http.NotFound(w, r) + }) + + // Create StreamableHTTP handler for MCP protocol (supports POST requests) + // This is what Codex uses with transport = "streamablehttp" + streamableHandler := sdk.NewStreamableHTTPHandler(func(r *http.Request) *sdk.Server { + // With SSE, this callback fires ONCE per HTTP connection establishment + // All subsequent JSON-RPC messages come over the same persistent connection + // We use the Bearer token from Authorization header as the session ID + // This groups all routes from the same agent (same token) into one session + + // Extract Bearer token from Authorization header + authHeader := r.Header.Get("Authorization") + var sessionID string + + if strings.HasPrefix(authHeader, "Bearer ") { + sessionID = strings.TrimPrefix(authHeader, "Bearer ") + sessionID = strings.TrimSpace(sessionID) + } + + // Reject requests without valid Bearer token + if sessionID == "" { + log.Printf("[%s] %s %s - REJECTED: No Bearer token", r.RemoteAddr, r.Method, r.URL.Path) + // Return nil to reject the connection + // The SDK will handle sending an appropriate error response + return nil + } + + log.Printf("=== NEW SSE CONNECTION ===") + log.Printf("[%s] %s %s", r.RemoteAddr, r.Method, r.URL.Path) + log.Printf("Bearer Token (Session ID): %s", sessionID) + + log.Printf("DEBUG: About to check request body, Method=%s, Body!=nil: %v", r.Method, r.Body != nil) + + // Log request body for debugging (typically the 'initialize' request) + if r.Method == "POST" && r.Body != nil { + bodyBytes, err := io.ReadAll(r.Body) + if err == nil && len(bodyBytes) > 0 { + log.Printf("Request body: %s", string(bodyBytes)) + // Restore body + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + } + } + + // Store session ID in request context + // This context will be passed to all tool handlers for this connection + ctx := context.WithValue(r.Context(), SessionIDContextKey, sessionID) + *r = *r.WithContext(ctx) + log.Printf("✓ Injected session ID into context") + log.Printf("==========================\n") + + return unifiedServer.server + }, &sdk.StreamableHTTPOptions{ + Stateless: false, // Support stateful sessions + }) + + // Mount streamableHandler directly at /mcp endpoint (logging is done in the callback above) + mux.Handle("/mcp/", streamableHandler) + mux.Handle("/mcp", streamableHandler) + + // Health check + mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "OK\n") + }) + + return &http.Server{ + Addr: addr, + Handler: mux, + } +} diff --git a/internal/server/unified.go b/internal/server/unified.go new file mode 100644 index 00000000..d8907721 --- /dev/null +++ b/internal/server/unified.go @@ -0,0 +1,419 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "log" + "sync" + + "github.com/githubnext/gh-aw-mcpg/internal/config" + "github.com/githubnext/gh-aw-mcpg/internal/launcher" + "github.com/githubnext/gh-aw-mcpg/internal/sys" + sdk "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// Session represents a FlowGuard session +type Session struct { + Token string + SessionID string +} + +// ContextKey for session ID (exported so transport can use it) +type ContextKey string + +// SessionIDContextKey is used to store MCP session ID in context +const SessionIDContextKey ContextKey = "flowguard-session-id" + +// ToolInfo stores metadata about a registered tool +type ToolInfo struct { + Name string + Description string + InputSchema map[string]interface{} + BackendID string // Which backend this tool belongs to + Handler func(context.Context, *sdk.CallToolRequest, interface{}) (*sdk.CallToolResult, interface{}, error) +} + +// UnifiedServer implements a unified MCP server that aggregates multiple backend servers +type UnifiedServer struct { + launcher *launcher.Launcher + sysServer *sys.SysServer + ctx context.Context + server *sdk.Server + sessions map[string]*Session // mcp-session-id -> Session + sessionMu sync.RWMutex + tools map[string]*ToolInfo // prefixed tool name -> tool info + toolsMu sync.RWMutex +} + +// NewUnified creates a new unified MCP server +func NewUnified(ctx context.Context, cfg *config.Config) (*UnifiedServer, error) { + l := launcher.New(ctx, cfg) + + us := &UnifiedServer{ + launcher: l, + sysServer: sys.NewSysServer(l.ServerIDs()), + ctx: ctx, + sessions: make(map[string]*Session), + tools: make(map[string]*ToolInfo), + } + + // Create MCP server + server := sdk.NewServer(&sdk.Implementation{ + Name: "flowguard-unified", + Version: "1.0.0", + }, nil) + + us.server = server + + // Register aggregated tools from all backends + if err := us.registerAllTools(); err != nil { + return nil, fmt.Errorf("failed to register tools: %w", err) + } + + return us, nil +} + +// registerAllTools fetches and registers tools from all backend servers +func (us *UnifiedServer) registerAllTools() error { + log.Println("Registering tools from all backends...") + + // Register sys tools first + log.Println("Registering sys tools...") + if err := us.registerSysTools(); err != nil { + log.Printf("Warning: failed to register sys tools: %v", err) + } + + // Register tools from each backend server + for _, serverID := range us.launcher.ServerIDs() { + if err := us.registerToolsFromBackend(serverID); err != nil { + log.Printf("Warning: failed to register tools from %s: %v", serverID, err) + // Continue with other backends + } + } + + return nil +} + +// registerToolsFromBackend registers tools from a specific backend with ___ naming +func (us *UnifiedServer) registerToolsFromBackend(serverID string) error { + log.Printf("Registering tools from backend: %s", serverID) + + // Get connection to backend + conn, err := launcher.GetOrLaunch(us.launcher, serverID) + if err != nil { + return fmt.Errorf("failed to connect: %w", err) + } + + // List tools from backend + result, err := conn.SendRequest("tools/list", nil) + if err != nil { + return fmt.Errorf("failed to list tools: %w", err) + } + + // Parse the result + var listResult struct { + Tools []struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema map[string]interface{} `json:"inputSchema"` + } `json:"tools"` + } + + if err := json.Unmarshal(result.Result, &listResult); err != nil { + return fmt.Errorf("failed to parse tools: %w", err) + } + + // Register each tool with prefixed name + toolNames := []string{} + for _, tool := range listResult.Tools { + prefixedName := fmt.Sprintf("%s___%s", serverID, tool.Name) + toolDesc := fmt.Sprintf("[%s] %s", serverID, tool.Description) + toolNames = append(toolNames, prefixedName) + + // Store tool info for routed mode + us.toolsMu.Lock() + us.tools[prefixedName] = &ToolInfo{ + Name: prefixedName, + Description: toolDesc, + InputSchema: tool.InputSchema, + BackendID: serverID, + } + us.toolsMu.Unlock() + + // Create a closure to capture serverID and toolName + serverIDCopy := serverID + toolNameCopy := tool.Name + + // Create the handler function + handler := func(ctx context.Context, req *sdk.CallToolRequest, args interface{}) (*sdk.CallToolResult, interface{}, error) { + // Check session is initialized + if err := us.requireSession(ctx); err != nil { + return &sdk.CallToolResult{IsError: true}, nil, err + } + return us.callBackendTool(ctx, serverIDCopy, toolNameCopy, args) + } + + // Store handler for routed mode to reuse + us.toolsMu.Lock() + us.tools[prefixedName].Handler = handler + us.toolsMu.Unlock() + + // Register the tool with the SDK + sdk.AddTool(us.server, &sdk.Tool{ + Name: prefixedName, + Description: toolDesc, + InputSchema: tool.InputSchema, + }, handler) + + log.Printf("Registered tool: %s", prefixedName) + } + + log.Printf("Registered %d tools from %s: %v", len(listResult.Tools), serverID, toolNames) + return nil +} + +// registerSysTools registers built-in sys tools +func (us *UnifiedServer) registerSysTools() error { + // Create sys_init handler + sysInitHandler := func(ctx context.Context, req *sdk.CallToolRequest, args interface{}) (*sdk.CallToolResult, interface{}, error) { + // Extract token from args + token := "" + if argsMap, ok := args.(map[string]interface{}); ok { + if t, ok := argsMap["token"].(string); ok { + token = t + } + } + + // TODO: Security check on token will be implemented later + + // Get session ID from context + sessionID := us.getSessionID(ctx) + if sessionID == "" { + return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("no session ID provided") + } + + // Create session + us.sessionMu.Lock() + us.sessions[sessionID] = &Session{ + Token: token, + SessionID: sessionID, + } + us.sessionMu.Unlock() + + log.Printf("Initialized session: %s", sessionID) + + // Call sys_init + params, _ := json.Marshal(map[string]interface{}{ + "name": "sys_init", + "arguments": map[string]interface{}{}, + }) + result, err := us.sysServer.HandleRequest("tools/call", json.RawMessage(params)) + if err != nil { + return &sdk.CallToolResult{IsError: true}, nil, err + } + return nil, result, nil + } + + // Store sys_init tool info + us.toolsMu.Lock() + us.tools["sys___init"] = &ToolInfo{ + Name: "sys___init", + Description: "Initialize the FlowGuard system and get available MCP servers", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "token": map[string]interface{}{ + "type": "string", + "description": "Authentication token for session initialization (can be empty for first call)", + }, + }, + }, + BackendID: "sys", + Handler: sysInitHandler, + } + us.toolsMu.Unlock() + + // Register with SDK + sdk.AddTool(us.server, &sdk.Tool{ + Name: "sys___init", + Description: "Initialize the FlowGuard system and get available MCP servers", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "token": map[string]interface{}{ + "type": "string", + "description": "Authentication token for session initialization (can be empty for first call)", + }, + }, + }, + }, sysInitHandler) + + // Create sys_list_servers handler + sysListHandler := func(ctx context.Context, req *sdk.CallToolRequest, args interface{}) (*sdk.CallToolResult, interface{}, error) { + // Check session is initialized + if err := us.requireSession(ctx); err != nil { + return &sdk.CallToolResult{IsError: true}, nil, err + } + + params, _ := json.Marshal(map[string]interface{}{ + "name": "sys_list_servers", + "arguments": map[string]interface{}{}, + }) + result, err := us.sysServer.HandleRequest("tools/call", json.RawMessage(params)) + if err != nil { + return &sdk.CallToolResult{IsError: true}, nil, err + } + return nil, result, nil + } + + // Store sys_list_servers tool info + us.toolsMu.Lock() + us.tools["sys___list_servers"] = &ToolInfo{ + Name: "sys___list_servers", + Description: "List all configured MCP backend servers", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + }, + BackendID: "sys", + Handler: sysListHandler, + } + us.toolsMu.Unlock() + + // Register with SDK + sdk.AddTool(us.server, &sdk.Tool{ + Name: "sys___list_servers", + Description: "List all configured MCP backend servers", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + }, + }, sysListHandler) + + log.Println("Registered 2 sys tools") + return nil +} + +// callBackendTool calls a tool on a backend server +func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName string, args interface{}) (*sdk.CallToolResult, interface{}, error) { + // Note: Session validation happens at the tool registration level via closures + // The closure captures the request and validates before calling this method + log.Printf("Calling tool on %s: %s", serverID, toolName) + + // Get connection to backend + conn, err := launcher.GetOrLaunch(us.launcher, serverID) + if err != nil { + return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("failed to connect: %w", err) + } + + // Call the tool + response, err := conn.SendRequest("tools/call", map[string]interface{}{ + "name": toolName, + "arguments": args, + }) + if err != nil { + return &sdk.CallToolResult{IsError: true}, nil, err + } + + // Parse the result + var result interface{} + if err := json.Unmarshal(response.Result, &result); err != nil { + return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("failed to parse result: %w", err) + } + + return nil, result, nil +} + +// Run starts the unified MCP server on the specified transport +func (us *UnifiedServer) Run(transport sdk.Transport) error { + log.Println("Starting unified MCP server...") + return us.server.Run(us.ctx, transport) +} + +// getSessionID extracts the MCP session ID from the context +func (us *UnifiedServer) getSessionID(ctx context.Context) string { + if sessionID, ok := ctx.Value(SessionIDContextKey).(string); ok && sessionID != "" { + log.Printf("Extracted session ID from context: %s", sessionID) + return sessionID + } + // No session ID in context - this happens before the SDK assigns one + // For now, use "default" as a placeholder for single-client scenarios + // In production multi-agent scenarios, the SDK will provide session IDs after initialize + log.Printf("No session ID in context, using 'default' (this is normal before SDK session is established)") + return "default" +} + +// requireSession checks that a session has been initialized for this request +func (us *UnifiedServer) requireSession(ctx context.Context) error { + sessionID := us.getSessionID(ctx) + log.Printf("Checking session for ID: %s", sessionID) + + us.sessionMu.RLock() + session := us.sessions[sessionID] + us.sessionMu.RUnlock() + + if session == nil { + log.Printf("Session not found for ID: %s. Available sessions: %v", sessionID, us.getSessionKeys()) + return fmt.Errorf("sys___init must be called before any other tool calls") + } + + log.Printf("Session validated for ID: %s", sessionID) + return nil +} + +// getSessionKeys returns a list of active session IDs for debugging +func (us *UnifiedServer) getSessionKeys() []string { + us.sessionMu.RLock() + defer us.sessionMu.RUnlock() + + keys := make([]string, 0, len(us.sessions)) + for k := range us.sessions { + keys = append(keys, k) + } + return keys +} + +// GetServerIDs returns the list of backend server IDs +func (us *UnifiedServer) GetServerIDs() []string { + return us.launcher.ServerIDs() +} + +// GetToolsForBackend returns tools for a specific backend with prefix stripped +func (us *UnifiedServer) GetToolsForBackend(backendID string) []ToolInfo { + us.toolsMu.RLock() + defer us.toolsMu.RUnlock() + + prefix := backendID + "___" + filtered := make([]ToolInfo, 0) + + for _, tool := range us.tools { + if tool.BackendID == backendID { + // Create a copy with the prefix stripped from the name + filteredTool := *tool + filteredTool.Name = tool.Name[len(prefix):] // Strip prefix + filtered = append(filtered, filteredTool) + } + } + + return filtered +} + +// GetToolHandler returns the handler for a specific backend tool +// This allows routed mode to reuse the unified server's tool handlers +func (us *UnifiedServer) GetToolHandler(backendID string, toolName string) func(context.Context, *sdk.CallToolRequest, interface{}) (*sdk.CallToolResult, interface{}, error) { + us.toolsMu.RLock() + defer us.toolsMu.RUnlock() + + prefixedName := backendID + "___" + toolName + if toolInfo, ok := us.tools[prefixedName]; ok { + return toolInfo.Handler + } + return nil +} + +// Close cleans up resources +func (us *UnifiedServer) Close() error { + us.launcher.Close() + return nil +} diff --git a/internal/server/unified_test.go b/internal/server/unified_test.go new file mode 100644 index 00000000..3719b4f0 --- /dev/null +++ b/internal/server/unified_test.go @@ -0,0 +1,242 @@ +package server + +import ( + "context" + "testing" + + "github.com/githubnext/gh-aw-mcpg/internal/config" +) + +func TestUnifiedServer_GetServerIDs(t *testing.T) { + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{ + "github": {Command: "docker", Args: []string{}}, + "fetch": {Command: "docker", Args: []string{}}, + }, + } + + ctx := context.Background() + us, err := NewUnified(ctx, cfg) + if err != nil { + t.Fatalf("NewUnified() failed: %v", err) + } + defer us.Close() + + serverIDs := us.GetServerIDs() + if len(serverIDs) != 2 { + t.Errorf("Expected 2 server IDs, got %d", len(serverIDs)) + } + + expectedIDs := map[string]bool{"github": true, "fetch": true} + for _, id := range serverIDs { + if !expectedIDs[id] { + t.Errorf("Unexpected server ID: %s", id) + } + } +} + +func TestUnifiedServer_SessionManagement(t *testing.T) { + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{}, + } + + ctx := context.Background() + us, err := NewUnified(ctx, cfg) + if err != nil { + t.Fatalf("NewUnified() failed: %v", err) + } + defer us.Close() + + // Test session creation + sessionID := "test-session-123" + token := "test-token" + + us.sessionMu.Lock() + us.sessions[sessionID] = &Session{ + Token: token, + SessionID: sessionID, + } + us.sessionMu.Unlock() + + // Test session retrieval + us.sessionMu.RLock() + session, exists := us.sessions[sessionID] + us.sessionMu.RUnlock() + + if !exists { + t.Error("Session not found after creation") + } + + if session.Token != token { + t.Errorf("Expected token '%s', got '%s'", token, session.Token) + } + + if session.SessionID != sessionID { + t.Errorf("Expected session ID '%s', got '%s'", sessionID, session.SessionID) + } +} + +func TestUnifiedServer_GetSessionKeys(t *testing.T) { + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{}, + } + + ctx := context.Background() + us, err := NewUnified(ctx, cfg) + if err != nil { + t.Fatalf("NewUnified() failed: %v", err) + } + defer us.Close() + + // Add multiple sessions + sessions := []string{"session-1", "session-2", "session-3"} + for _, sid := range sessions { + us.sessionMu.Lock() + us.sessions[sid] = &Session{SessionID: sid, Token: "token"} + us.sessionMu.Unlock() + } + + keys := us.getSessionKeys() + if len(keys) != len(sessions) { + t.Errorf("Expected %d session keys, got %d", len(sessions), len(keys)) + } + + keyMap := make(map[string]bool) + for _, key := range keys { + keyMap[key] = true + } + + for _, expected := range sessions { + if !keyMap[expected] { + t.Errorf("Session key '%s' not found", expected) + } + } +} + +func TestUnifiedServer_GetToolsForBackend(t *testing.T) { + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{}, + } + + ctx := context.Background() + us, err := NewUnified(ctx, cfg) + if err != nil { + t.Fatalf("NewUnified() failed: %v", err) + } + defer us.Close() + + // Manually add some tool info + us.toolsMu.Lock() + us.tools["github___issue_read"] = &ToolInfo{ + Name: "github___issue_read", + Description: "Read an issue", + BackendID: "github", + } + us.tools["github___repo_list"] = &ToolInfo{ + Name: "github___repo_list", + Description: "List repositories", + BackendID: "github", + } + us.tools["fetch___get"] = &ToolInfo{ + Name: "fetch___get", + Description: "Fetch a URL", + BackendID: "fetch", + } + us.toolsMu.Unlock() + + // Test filtering for github backend + githubTools := us.GetToolsForBackend("github") + if len(githubTools) != 2 { + t.Errorf("Expected 2 GitHub tools, got %d", len(githubTools)) + } + + for _, tool := range githubTools { + if tool.BackendID != "github" { + t.Errorf("Expected BackendID 'github', got '%s'", tool.BackendID) + } + // Check that prefix is stripped + if tool.Name == "github___issue_read" || tool.Name == "github___repo_list" { + t.Errorf("Tool name '%s' still has prefix", tool.Name) + } + if tool.Name != "issue_read" && tool.Name != "repo_list" { + t.Errorf("Unexpected tool name after prefix strip: '%s'", tool.Name) + } + } + + // Test filtering for fetch backend + fetchTools := us.GetToolsForBackend("fetch") + if len(fetchTools) != 1 { + t.Errorf("Expected 1 fetch tool, got %d", len(fetchTools)) + } + + if fetchTools[0].Name != "get" { + t.Errorf("Expected tool name 'get', got '%s'", fetchTools[0].Name) + } + + // Test filtering for non-existent backend + noTools := us.GetToolsForBackend("nonexistent") + if len(noTools) != 0 { + t.Errorf("Expected 0 tools for nonexistent backend, got %d", len(noTools)) + } +} + +func TestGetSessionID_FromContext(t *testing.T) { + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{}, + } + + ctx := context.Background() + us, err := NewUnified(ctx, cfg) + if err != nil { + t.Fatalf("NewUnified() failed: %v", err) + } + defer us.Close() + + // Test with session ID in context + sessionID := "test-bearer-token-123" + ctxWithSession := context.WithValue(ctx, SessionIDContextKey, sessionID) + + extractedID := us.getSessionID(ctxWithSession) + if extractedID != sessionID { + t.Errorf("Expected session ID '%s', got '%s'", sessionID, extractedID) + } + + // Test without session ID in context + extractedID = us.getSessionID(ctx) + if extractedID != "default" { + t.Errorf("Expected default session ID, got '%s'", extractedID) + } +} + +func TestRequireSession(t *testing.T) { + cfg := &config.Config{ + Servers: map[string]*config.ServerConfig{}, + } + + ctx := context.Background() + us, err := NewUnified(ctx, cfg) + if err != nil { + t.Fatalf("NewUnified() failed: %v", err) + } + defer us.Close() + + // Create a session + sessionID := "valid-session" + us.sessionMu.Lock() + us.sessions[sessionID] = &Session{SessionID: sessionID, Token: "token"} + us.sessionMu.Unlock() + + // Test with valid session + ctxWithSession := context.WithValue(ctx, SessionIDContextKey, sessionID) + err = us.requireSession(ctxWithSession) + if err != nil { + t.Errorf("requireSession() failed for valid session: %v", err) + } + + // Test with invalid session + ctxWithInvalidSession := context.WithValue(ctx, SessionIDContextKey, "invalid-session") + err = us.requireSession(ctxWithInvalidSession) + if err == nil { + t.Error("requireSession() should fail for invalid session") + } +} diff --git a/internal/sys/sys.go b/internal/sys/sys.go new file mode 100644 index 00000000..b9315974 --- /dev/null +++ b/internal/sys/sys.go @@ -0,0 +1,98 @@ +package sys + +import ( + "encoding/json" + "fmt" +) + +// SysServer implements the FlowGuard system tools +type SysServer struct { + serverIDs []string +} + +// NewSysServer creates a new system server +func NewSysServer(serverIDs []string) *SysServer { + return &SysServer{ + serverIDs: serverIDs, + } +} + +// HandleRequest processes MCP requests for system tools +func (s *SysServer) HandleRequest(method string, params json.RawMessage) (interface{}, error) { + switch method { + case "tools/list": + return s.listTools() + case "tools/call": + var callParams struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments"` + } + if err := json.Unmarshal(params, &callParams); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + return s.callTool(callParams.Name, callParams.Arguments) + default: + return nil, fmt.Errorf("unsupported method: %s", method) + } +} + +func (s *SysServer) listTools() (interface{}, error) { + return map[string]interface{}{ + "tools": []map[string]interface{}{ + { + "name": "sys_init", + "description": "Initialize the FlowGuard system and get available MCP servers", + "inputSchema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + }, + }, + { + "name": "sys_list_servers", + "description": "List all configured MCP backend servers", + "inputSchema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + }, + }, + }, + }, nil +} + +func (s *SysServer) callTool(name string, args map[string]interface{}) (interface{}, error) { + switch name { + case "sys_init": + return s.sysInit() + case "sys_list_servers": + return s.listServers() + default: + return nil, fmt.Errorf("unknown tool: %s", name) + } +} + +func (s *SysServer) sysInit() (interface{}, error) { + return map[string]interface{}{ + "content": []map[string]interface{}{ + { + "type": "text", + "text": fmt.Sprintf("FlowGuard initialized. Available servers: %v", s.serverIDs), + }, + }, + }, nil +} + +func (s *SysServer) listServers() (interface{}, error) { + serverList := "" + for i, id := range s.serverIDs { + serverList += fmt.Sprintf("%d. %s\n", i+1, id) + } + + return map[string]interface{}{ + "content": []map[string]interface{}{ + { + "type": "text", + "text": fmt.Sprintf("Configured MCP Servers:\n%s", serverList), + }, + }, + }, nil +} diff --git a/main.go b/main.go new file mode 100644 index 00000000..7bb6f6a3 --- /dev/null +++ b/main.go @@ -0,0 +1,7 @@ +package main + +import "github.com/githubnext/gh-aw-mcpg/internal/cmd" + +func main() { + cmd.Execute() +} diff --git a/run.sh b/run.sh new file mode 100755 index 00000000..448a453b --- /dev/null +++ b/run.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +# Set DOCKER_API_VERSION based on architecture +ARCH=$(uname -m) +if [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then + export DOCKER_API_VERSION=1.43 +else + export DOCKER_API_VERSION=1.44 +fi + +# Default values +HOST="${HOST:- 0.0.0.0}" +PORT="${PORT:-8000}" +CONFIG="${CONFIG}" +ENV_FILE="${ENV_FILE:-.env}" +MODE="${MODE:---routed}" + +# Build the command +CMD="./flowguard-go" +FLAGS="$MODE --listen ${HOST}:${PORT}" + +if [ -n "$ENV_FILE" ]; then + FLAGS="$FLAGS --env $ENV_FILE" +fi + +if [ -n "$CONFIG" ]; then + FLAGS="$FLAGS --config $CONFIG" +else + FLAGS="$FLAGS --config-stdin" + CONFIG_JSON=$(cat <