Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 153 additions & 72 deletions internal/config/validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package config

import (
"os"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -53,33 +52,54 @@ func TestExpandVariables(t *testing.T) {
envVars: map[string]string{"DEFINED": "value"},
shouldErr: true,
},
{
name: "nested variables in path",
input: "/path/${VAR1}/subdir/${VAR2}",
envVars: map[string]string{"VAR1": "foo", "VAR2": "bar"},
expected: "/path/foo/subdir/bar",
},
{
name: "empty variable value",
input: "prefix-${EMPTY_VAR}-suffix",
envVars: map[string]string{"EMPTY_VAR": ""},
expected: "prefix--suffix",
},
{
name: "variable at start",
input: "${VAR}/path/to/file",
envVars: map[string]string{"VAR": "/root"},
expected: "/root/path/to/file",
},
{
name: "variable at end",
input: "/path/to/${VAR}",
envVars: map[string]string{"VAR": "file.txt"},
expected: "/path/to/file.txt",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set up environment
for k, v := range tt.envVars {
os.Setenv(k, v)
defer os.Unsetenv(k)
t.Setenv(k, v)
}

result, err := expandVariables(tt.input, "test.path")

if tt.shouldErr {
assert.Error(t, err)
require.Error(t, err)
} else {
assert.NoError(t, err, "Unexpected error")
assert.Equal(t, tt.expected, result, "%q, got %q")
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}

func TestExpandEnvVariables(t *testing.T) {
os.Setenv("GITHUB_TOKEN", "ghp_test123")
os.Setenv("API_KEY", "secret")
defer os.Unsetenv("GITHUB_TOKEN")
defer os.Unsetenv("API_KEY")
t.Setenv("GITHUB_TOKEN", "ghp_test123")
t.Setenv("API_KEY", "secret")

tests := []struct {
name string
Expand Down Expand Up @@ -130,25 +150,36 @@ func TestExpandEnvVariables(t *testing.T) {
serverName: "test",
shouldErr: true,
},
{
name: "empty env map",
input: map[string]string{},
serverName: "test",
expected: map[string]string{},
},
{
name: "no variables to expand",
input: map[string]string{
"STATIC1": "value1",
"STATIC2": "value2",
},
serverName: "test",
expected: map[string]string{
"STATIC1": "value1",
"STATIC2": "value2",
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := expandEnvVariables(tt.input, tt.serverName)

if tt.shouldErr {
assert.Error(t, err)
// Check error message contains server name
if !strings.Contains(err.Error(), tt.serverName) {
t.Errorf("Error should mention server name %q", tt.serverName)
}
require.Error(t, err)
assert.ErrorContains(t, err, tt.serverName, "Error should mention server name")
} else {
assert.NoError(t, err, "Unexpected error")
for k, v := range tt.expected {
if result[k] != v {
t.Errorf("For key %q: expected %q, got %q", k, v, result[k])
}
}
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
Expand Down Expand Up @@ -342,12 +373,12 @@ func TestValidateStdioServer(t *testing.T) {
err := validateServerConfig("test-server", tt.server)

if tt.shouldErr {
assert.Error(t, err)
if tt.errorMsg != "" && err != nil && !strings.Contains(err.Error(), tt.errorMsg) {
t.Errorf("Expected error containing %q, got: %v", tt.errorMsg, err)
require.Error(t, err)
if tt.errorMsg != "" {
assert.ErrorContains(t, err, tt.errorMsg)
}
} else {
assert.NoError(t, err, "Unexpected error")
assert.NoError(t, err)
}
})
}
Expand Down Expand Up @@ -430,20 +461,43 @@ func TestValidateGatewayConfig(t *testing.T) {
err := validateGatewayConfig(tt.gateway)

if tt.shouldErr {
assert.Error(t, err)
if tt.errorMsg != "" && err != nil && !strings.Contains(err.Error(), tt.errorMsg) {
t.Errorf("Expected error containing %q, got: %v", tt.errorMsg, err)
require.Error(t, err)
if tt.errorMsg != "" {
assert.ErrorContains(t, err, tt.errorMsg)
}
} else {
assert.NoError(t, err, "Unexpected error")
assert.NoError(t, err)
}
})
}
}

// setupStdinTest is a helper that sets up stdin with the given JSON config
// Returns a cleanup function that should be deferred
func setupStdinTest(t *testing.T, jsonConfig string) func() {
t.Helper()
r, w, err := os.Pipe()
require.NoError(t, err, "Failed to create pipe")

oldStdin := os.Stdin
os.Stdin = r

go func() {
defer w.Close()
_, err := w.Write([]byte(jsonConfig))
if err != nil {
t.Logf("Failed to write to pipe: %v", err)
}
}()

return func() {
os.Stdin = oldStdin
r.Close()
}
}

func TestLoadFromStdin_WithVariableExpansion(t *testing.T) {
os.Setenv("GITHUB_TOKEN", "ghp_expanded")
defer os.Unsetenv("GITHUB_TOKEN")
t.Setenv("GITHUB_TOKEN", "ghp_expanded")

jsonConfig := `{
"mcpServers": {
Expand All @@ -463,24 +517,14 @@ func TestLoadFromStdin_WithVariableExpansion(t *testing.T) {
}
}`

r, w, _ := os.Pipe()
oldStdin := os.Stdin
os.Stdin = r
go func() {
w.Write([]byte(jsonConfig))
w.Close()
}()
cleanup := setupStdinTest(t, jsonConfig)
defer cleanup()

cfg, err := LoadFromStdin()
os.Stdin = oldStdin

require.NoError(t, err, "LoadFromStdin() failed")
require.NoError(t, err)

server := cfg.Servers["github"]
// Check docker command is set up correctly
if server.Command != "docker" {
t.Errorf("Expected Command to be 'docker', got %q", server.Command)
}
assert.Equal(t, "docker", server.Command, "Expected Command to be 'docker'")
}

func TestLoadFromStdin_UndefinedVariable(t *testing.T) {
Expand All @@ -501,25 +545,45 @@ func TestLoadFromStdin_UndefinedVariable(t *testing.T) {
}
}`

r, w, _ := os.Pipe()
oldStdin := os.Stdin
os.Stdin = r
go func() {
w.Write([]byte(jsonConfig))
w.Close()
}()
cleanup := setupStdinTest(t, jsonConfig)
defer cleanup()

_, err := LoadFromStdin()
os.Stdin = oldStdin

require.Error(t, err)
assert.ErrorContains(t, err, "UNDEFINED_GITHUB_TOKEN", "Error should mention the undefined variable")
assert.ErrorContains(t, err, "mcpServers.github.env", "Error should include JSON path")
}

if !strings.Contains(err.Error(), "UNDEFINED_GITHUB_TOKEN") {
t.Errorf("Error should mention the undefined variable, got: %v", err)
}
if !strings.Contains(err.Error(), "mcpServers.github.env") {
t.Errorf("Error should include JSON path, got: %v", err)
}
func TestLoadFromStdin_VariableExpansionInContainer(t *testing.T) {
t.Setenv("REGISTRY", "ghcr.io")
t.Setenv("IMAGE_NAME", "github/github-mcp-server")

jsonConfig := `{
"mcpServers": {
"github": {
"type": "stdio",
"container": "${REGISTRY}/${IMAGE_NAME}:latest",
"env": {
"TOKEN": "static-value"
}
}
},
"gateway": {
"port": 8080,
"domain": "localhost",
"apiKey": "test-key"
}
}`

cleanup := setupStdinTest(t, jsonConfig)
defer cleanup()

cfg, err := LoadFromStdin()
require.NoError(t, err)

server := cfg.Servers["github"]
// Container field should have variables expanded in docker args
assert.Contains(t, server.Args, "ghcr.io/github/github-mcp-server:latest")
}

func TestLoadFromStdin_ValidationErrors(t *testing.T) {
Expand Down Expand Up @@ -583,28 +647,45 @@ func TestLoadFromStdin_ValidationErrors(t *testing.T) {
shouldErr: true,
errorMsg: "validation error",
},
{
name: "malformed JSON",
config: `{
"mcpServers": {
"test": {
"type": "stdio",
"container": "test:latest"
}
// missing closing brace`,
shouldErr: true,
},
{
name: "empty mcpServers",
config: `{
"mcpServers": {},
"gateway": {
"port": 8080,
"domain": "localhost",
"apiKey": "test-key"
}
}`,
shouldErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, w, _ := os.Pipe()
oldStdin := os.Stdin
os.Stdin = r
go func() {
w.Write([]byte(tt.config))
w.Close()
}()
cleanup := setupStdinTest(t, tt.config)
defer cleanup()

_, err := LoadFromStdin()
os.Stdin = oldStdin

if tt.shouldErr {
assert.Error(t, err)
if err != nil && !strings.Contains(err.Error(), tt.errorMsg) {
t.Errorf("Expected error containing %q, got: %v", tt.errorMsg, err)
require.Error(t, err)
if tt.errorMsg != "" {
assert.ErrorContains(t, err, tt.errorMsg)
}
} else {
assert.NoError(t, err, "Unexpected error")
assert.NoError(t, err)
}
})
}
Expand Down