diff --git a/.changeset/patch-refactor-duplicate-engine-code.md b/.changeset/patch-refactor-duplicate-engine-code.md new file mode 100644 index 0000000000..ed496c21c3 --- /dev/null +++ b/.changeset/patch-refactor-duplicate-engine-code.md @@ -0,0 +1,5 @@ +--- +"gh-aw": patch +--- + +Extract duplicate custom engine step handling into shared helper functions diff --git a/pkg/workflow/claude_engine.go b/pkg/workflow/claude_engine.go index 316342c89c..54e2805020 100644 --- a/pkg/workflow/claude_engine.go +++ b/pkg/workflow/claude_engine.go @@ -86,19 +86,8 @@ func (e *ClaudeEngine) GetVersionCommand() string { // GetExecutionSteps returns the GitHub Actions steps for executing Claude func (e *ClaudeEngine) GetExecutionSteps(workflowData *WorkflowData, logFile string) []GitHubActionStep { - var steps []GitHubActionStep - // Handle custom steps if they exist in engine config - if workflowData.EngineConfig != nil && len(workflowData.EngineConfig.Steps) > 0 { - for _, step := range workflowData.EngineConfig.Steps { - stepYAML, err := e.convertStepToYAML(step) - if err != nil { - // Log error but continue with other steps - continue - } - steps = append(steps, GitHubActionStep{stepYAML}) - } - } + steps := InjectCustomEngineSteps(workflowData, e.convertStepToYAML) // Build claude CLI arguments based on configuration var claudeArgs []string @@ -616,14 +605,8 @@ func (e *ClaudeEngine) RenderMCPConfig(yaml *strings.Builder, tools map[string]a case "web-fetch": renderMCPFetchServerConfig(yaml, "json", " ", isLast, false) default: - // Handle custom MCP tools (those with MCP-compatible type) - if toolConfig, ok := tools[toolName].(map[string]any); ok { - if hasMcp, _ := hasMCPConfig(toolConfig); hasMcp { - if err := e.renderClaudeMCPConfig(yaml, toolName, toolConfig, isLast); err != nil { - fmt.Printf("Error generating custom MCP configuration for %s: %v\n", toolName, err) - } - } - } + // Handle custom MCP tools using shared helper + HandleCustomMCPToolInSwitch(yaml, toolName, tools, isLast, e.renderClaudeMCPConfig) } } diff --git a/pkg/workflow/codex_engine.go b/pkg/workflow/codex_engine.go index 19e3ba7017..2ec7f2bb78 100644 --- a/pkg/workflow/codex_engine.go +++ b/pkg/workflow/codex_engine.go @@ -95,19 +95,8 @@ func (e *CodexEngine) GetDeclaredOutputFiles() []string { // GetExecutionSteps returns the GitHub Actions steps for executing Codex func (e *CodexEngine) GetExecutionSteps(workflowData *WorkflowData, logFile string) []GitHubActionStep { - var steps []GitHubActionStep - // Handle custom steps if they exist in engine config - if workflowData.EngineConfig != nil && len(workflowData.EngineConfig.Steps) > 0 { - for _, step := range workflowData.EngineConfig.Steps { - stepYAML, err := e.convertStepToYAML(step) - if err != nil { - // Log error but continue with other steps - continue - } - steps = append(steps, GitHubActionStep{stepYAML}) - } - } + steps := InjectCustomEngineSteps(workflowData, e.convertStepToYAML) // Build model parameter only if specified in engineConfig var modelParam string @@ -253,14 +242,10 @@ func (e *CodexEngine) RenderMCPConfig(yaml *strings.Builder, tools map[string]an case "web-fetch": renderMCPFetchServerConfig(yaml, "toml", " ", false, false) default: - // Handle custom MCP tools (those with MCP-compatible type) - if toolConfig, ok := expandedTools[toolName].(map[string]any); ok { - if hasMcp, _ := hasMCPConfig(toolConfig); hasMcp { - if err := e.renderCodexMCPConfig(yaml, toolName, toolConfig); err != nil { - fmt.Printf("Error generating custom MCP configuration for %s: %v\n", toolName, err) - } - } - } + // Handle custom MCP tools using shared helper (with adapter for isLast parameter) + HandleCustomMCPToolInSwitch(yaml, toolName, expandedTools, false, func(yaml *strings.Builder, toolName string, toolConfig map[string]any, isLast bool) error { + return e.renderCodexMCPConfig(yaml, toolName, toolConfig) + }) } } diff --git a/pkg/workflow/copilot_engine.go b/pkg/workflow/copilot_engine.go index 8a5ef263b6..f1a7d22d82 100644 --- a/pkg/workflow/copilot_engine.go +++ b/pkg/workflow/copilot_engine.go @@ -65,19 +65,8 @@ func (e *CopilotEngine) GetVersionCommand() string { // GetExecutionSteps returns the GitHub Actions steps for executing GitHub Copilot CLI func (e *CopilotEngine) GetExecutionSteps(workflowData *WorkflowData, logFile string) []GitHubActionStep { - var steps []GitHubActionStep - // Handle custom steps if they exist in engine config - if workflowData.EngineConfig != nil && len(workflowData.EngineConfig.Steps) > 0 { - for _, step := range workflowData.EngineConfig.Steps { - stepYAML, err := e.convertStepToYAML(step) - if err != nil { - // Log error but continue with other steps - continue - } - steps = append(steps, GitHubActionStep{stepYAML}) - } - } + steps := InjectCustomEngineSteps(workflowData, e.convertStepToYAML) // Build copilot CLI arguments based on configuration var copilotArgs = []string{"--add-dir", "/tmp/gh-aw/", "--log-level", "all", "--log-dir", logsFolder} @@ -252,14 +241,8 @@ func (e *CopilotEngine) RenderMCPConfig(yaml *strings.Builder, tools map[string] case "web-fetch": renderMCPFetchServerConfig(yaml, "json", " ", isLast, true) default: - // Handle custom MCP tools (those with MCP-compatible type) - if toolConfig, ok := tools[toolName].(map[string]any); ok { - if hasMcp, _ := hasMCPConfig(toolConfig); hasMcp { - if err := e.renderCopilotMCPConfig(yaml, toolName, toolConfig, isLast); err != nil { - fmt.Printf("Error generating custom MCP configuration for %s: %v\n", toolName, err) - } - } - } + // Handle custom MCP tools using shared helper + HandleCustomMCPToolInSwitch(yaml, toolName, tools, isLast, e.renderCopilotMCPConfig) } } diff --git a/pkg/workflow/custom_engine.go b/pkg/workflow/custom_engine.go index d714065c9f..c81654c60a 100644 --- a/pkg/workflow/custom_engine.go +++ b/pkg/workflow/custom_engine.go @@ -172,14 +172,8 @@ func (e *CustomEngine) RenderMCPConfig(yaml *strings.Builder, tools map[string]a case "web-fetch": renderMCPFetchServerConfig(yaml, "json", " ", isLast, false) default: - // Handle custom MCP tools (those with MCP-compatible type) - if toolConfig, ok := tools[toolName].(map[string]any); ok { - if hasMcp, _ := hasMCPConfig(toolConfig); hasMcp { - if err := e.renderCustomMCPConfig(yaml, toolName, toolConfig, isLast); err != nil { - fmt.Printf("Error generating custom MCP configuration for %s: %v\n", toolName, err) - } - } - } + // Handle custom MCP tools using shared helper + HandleCustomMCPToolInSwitch(yaml, toolName, tools, isLast, e.renderCustomMCPConfig) } } diff --git a/pkg/workflow/engine_shared_helpers.go b/pkg/workflow/engine_shared_helpers.go new file mode 100644 index 0000000000..9f415528b9 --- /dev/null +++ b/pkg/workflow/engine_shared_helpers.go @@ -0,0 +1,70 @@ +package workflow + +import ( + "fmt" + "strings" +) + +// InjectCustomEngineSteps processes custom steps from engine config and converts them to GitHubActionSteps. +// This shared function extracts the common pattern used by Copilot, Codex, and Claude engines. +// +// Parameters: +// - workflowData: The workflow data containing engine configuration +// - convertStepFunc: A function that converts a step map to YAML string (engine-specific) +// +// Returns: +// - []GitHubActionStep: Array of custom steps ready to be included in the execution pipeline +func InjectCustomEngineSteps( + workflowData *WorkflowData, + convertStepFunc func(map[string]any) (string, error), +) []GitHubActionStep { + var steps []GitHubActionStep + + // Handle custom steps if they exist in engine config + if workflowData.EngineConfig != nil && len(workflowData.EngineConfig.Steps) > 0 { + for _, step := range workflowData.EngineConfig.Steps { + stepYAML, err := convertStepFunc(step) + if err != nil { + // Log error but continue with other steps + continue + } + steps = append(steps, GitHubActionStep{stepYAML}) + } + } + + return steps +} + +// RenderCustomMCPToolConfigHandler is a function type that engines must provide to render their specific MCP config +type RenderCustomMCPToolConfigHandler func(yaml *strings.Builder, toolName string, toolConfig map[string]any, isLast bool) error + +// HandleCustomMCPToolInSwitch processes custom MCP tools in the default case of a switch statement. +// This shared function extracts the common pattern used across all workflow engines. +// +// Parameters: +// - yaml: The string builder for YAML output +// - toolName: The name of the tool being processed +// - tools: The tools map containing tool configurations (supports both expanded and non-expanded tools) +// - isLast: Whether this is the last tool in the list +// - renderFunc: Engine-specific function to render the MCP configuration +// +// Returns: +// - bool: true if a custom MCP tool was handled, false otherwise +func HandleCustomMCPToolInSwitch( + yaml *strings.Builder, + toolName string, + tools map[string]any, + isLast bool, + renderFunc RenderCustomMCPToolConfigHandler, +) bool { + // Handle custom MCP tools (those with MCP-compatible type) + if toolConfig, ok := tools[toolName].(map[string]any); ok { + if hasMcp, _ := hasMCPConfig(toolConfig); hasMcp { + if err := renderFunc(yaml, toolName, toolConfig, isLast); err != nil { + fmt.Printf("Error generating custom MCP configuration for %s: %v\n", toolName, err) + } + return true + } + } + return false +} diff --git a/pkg/workflow/engine_shared_helpers_test.go b/pkg/workflow/engine_shared_helpers_test.go new file mode 100644 index 0000000000..7d8cbc8fa0 --- /dev/null +++ b/pkg/workflow/engine_shared_helpers_test.go @@ -0,0 +1,283 @@ +package workflow + +import ( + "fmt" + "strings" + "testing" +) + +// TestInjectCustomEngineSteps verifies that custom steps are properly injected +func TestInjectCustomEngineSteps(t *testing.T) { + tests := []struct { + name string + workflowData *WorkflowData + expectedSteps int + expectedErr bool + convertErrStep int // Which step should fail conversion (0 = none) + }{ + { + name: "No custom steps", + workflowData: &WorkflowData{ + EngineConfig: nil, + }, + expectedSteps: 0, + }, + { + name: "Empty custom steps", + workflowData: &WorkflowData{ + EngineConfig: &EngineConfig{ + Steps: []map[string]any{}, + }, + }, + expectedSteps: 0, + }, + { + name: "Single custom step", + workflowData: &WorkflowData{ + EngineConfig: &EngineConfig{ + Steps: []map[string]any{ + { + "name": "Test Step", + "run": "echo 'test'", + }, + }, + }, + }, + expectedSteps: 1, + }, + { + name: "Multiple custom steps", + workflowData: &WorkflowData{ + EngineConfig: &EngineConfig{ + Steps: []map[string]any{ + { + "name": "Step 1", + "run": "echo 'step1'", + }, + { + "name": "Step 2", + "run": "echo 'step2'", + }, + { + "name": "Step 3", + "run": "echo 'step3'", + }, + }, + }, + }, + expectedSteps: 3, + }, + { + name: "Step conversion error - should continue", + workflowData: &WorkflowData{ + EngineConfig: &EngineConfig{ + Steps: []map[string]any{ + { + "name": "Step 1", + "run": "echo 'step1'", + }, + { + "name": "Step 2 - will fail", + "run": "echo 'step2'", + }, + { + "name": "Step 3", + "run": "echo 'step3'", + }, + }, + }, + }, + expectedSteps: 2, // Only 2 steps should succeed + convertErrStep: 2, // Second step fails + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a mock convert function + stepCounter := 0 + convertStepFunc := func(stepMap map[string]any) (string, error) { + stepCounter++ + // Simulate conversion error for specific step + if tt.convertErrStep > 0 && stepCounter == tt.convertErrStep { + return "", fmt.Errorf("conversion error for step %d", stepCounter) + } + // Return a simple YAML representation + name := stepMap["name"] + return fmt.Sprintf(" - name: %v\n run: test\n", name), nil + } + + steps := InjectCustomEngineSteps(tt.workflowData, convertStepFunc) + + if len(steps) != tt.expectedSteps { + t.Errorf("Expected %d steps, got %d", tt.expectedSteps, len(steps)) + } + + // Verify each step contains valid YAML + for i, step := range steps { + if len(step) == 0 { + t.Errorf("Step %d is empty", i) + } + } + }) + } +} + +// TestHandleCustomMCPToolInSwitch verifies custom MCP tool handling in switch statements +func TestHandleCustomMCPToolInSwitch(t *testing.T) { + tests := []struct { + name string + toolName string + tools map[string]any + isLast bool + shouldHandle bool + renderCalled bool + simulateError bool + }{ + { + name: "Valid custom MCP tool", + toolName: "custom-tool", + tools: map[string]any{ + "custom-tool": map[string]any{ + "type": "stdio", + "command": "node", + "args": []string{"server.js"}, + }, + }, + isLast: false, + shouldHandle: true, + renderCalled: true, + }, + { + name: "Valid custom MCP tool - last in list", + toolName: "custom-tool", + tools: map[string]any{ + "custom-tool": map[string]any{ + "type": "http", + "url": "https://example.com", + "headers": map[string]string{"key": "value"}, + }, + }, + isLast: true, + shouldHandle: true, + renderCalled: true, + }, + { + name: "Tool config is not a map", + toolName: "invalid-tool", + tools: map[string]any{ + "invalid-tool": "just a string", + }, + isLast: false, + shouldHandle: false, + renderCalled: false, + }, + { + name: "Tool has no MCP config", + toolName: "non-mcp-tool", + tools: map[string]any{ + "non-mcp-tool": map[string]any{ + "some-key": "some-value", + }, + }, + isLast: false, + shouldHandle: false, + renderCalled: false, + }, + { + name: "Render function returns error", + toolName: "error-tool", + tools: map[string]any{ + "error-tool": map[string]any{ + "type": "stdio", + "command": "node", + }, + }, + isLast: false, + shouldHandle: true, + renderCalled: true, + simulateError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var yaml strings.Builder + renderCalled := false + + // Create a mock render function + renderFunc := func(yaml *strings.Builder, toolName string, toolConfig map[string]any, isLast bool) error { + renderCalled = true + if tt.simulateError { + return fmt.Errorf("simulated render error") + } + // Write some output to verify it was called + yaml.WriteString(fmt.Sprintf("rendered: %s, isLast: %v\n", toolName, isLast)) + return nil + } + + handled := HandleCustomMCPToolInSwitch(&yaml, tt.toolName, tt.tools, tt.isLast, renderFunc) + + if handled != tt.shouldHandle { + t.Errorf("Expected handled=%v, got %v", tt.shouldHandle, handled) + } + + if renderCalled != tt.renderCalled { + t.Errorf("Expected renderCalled=%v, got %v", tt.renderCalled, renderCalled) + } + + // If render was called and no error, verify output + if tt.renderCalled && !tt.simulateError { + output := yaml.String() + if !strings.Contains(output, tt.toolName) { + t.Errorf("Expected output to contain tool name %q, got: %q", tt.toolName, output) + } + if !strings.Contains(output, fmt.Sprintf("isLast: %v", tt.isLast)) { + t.Errorf("Expected output to contain isLast=%v, got: %q", tt.isLast, output) + } + } + }) + } +} + +// TestInjectCustomEngineStepsWithRealConversion tests with actual ConvertStepToYAML function +func TestInjectCustomEngineStepsWithRealConversion(t *testing.T) { + workflowData := &WorkflowData{ + EngineConfig: &EngineConfig{ + Steps: []map[string]any{ + { + "name": "Install dependencies", + "run": "npm install", + }, + { + "name": "Run tests", + "run": "npm test", + }, + }, + }, + } + + steps := InjectCustomEngineSteps(workflowData, ConvertStepToYAML) + + if len(steps) != 2 { + t.Fatalf("Expected 2 steps, got %d", len(steps)) + } + + // Verify the YAML content of the first step + firstStepYAML := steps[0][0] + if !strings.Contains(firstStepYAML, "Install dependencies") { + t.Errorf("First step should contain 'Install dependencies', got: %s", firstStepYAML) + } + if !strings.Contains(firstStepYAML, "npm install") { + t.Errorf("First step should contain 'npm install', got: %s", firstStepYAML) + } + + // Verify the YAML content of the second step + secondStepYAML := steps[1][0] + if !strings.Contains(secondStepYAML, "Run tests") { + t.Errorf("Second step should contain 'Run tests', got: %s", secondStepYAML) + } + if !strings.Contains(secondStepYAML, "npm test") { + t.Errorf("Second step should contain 'npm test', got: %s", secondStepYAML) + } +}