diff --git a/pkg/workflow/compiler_orchestrator.go b/pkg/workflow/compiler_orchestrator.go index 81b7c85185..c3db18943d 100644 --- a/pkg/workflow/compiler_orchestrator.go +++ b/pkg/workflow/compiler_orchestrator.go @@ -373,11 +373,16 @@ func (c *Compiler) ParseWorkflowFile(markdownPath string) (*WorkflowData, error) return nil, fmt.Errorf("failed to merge tools: %w", err) } - // Extract timeout setting from merged tools (defaults to 0 which means use engine defaults) - toolsTimeout := c.extractToolsTimeout(tools) + // Extract and validate tools timeout settings + toolsTimeout, err := c.extractToolsTimeout(tools) + if err != nil { + return nil, fmt.Errorf("invalid tools timeout configuration: %w", err) + } - // Extract startup-timeout setting from merged tools (defaults to 0 which means use engine defaults) - toolsStartupTimeout := c.extractToolsStartupTimeout(tools) + toolsStartupTimeout, err := c.extractToolsStartupTimeout(tools) + if err != nil { + return nil, fmt.Errorf("invalid tools startup timeout configuration: %w", err) + } // Remove meta fields (timeout, startup-timeout) from merged tools map // These are configuration fields, not actual tools diff --git a/pkg/workflow/frontmatter_extraction_metadata.go b/pkg/workflow/frontmatter_extraction_metadata.go index 617741019f..68befa9598 100644 --- a/pkg/workflow/frontmatter_extraction_metadata.go +++ b/pkg/workflow/frontmatter_extraction_metadata.go @@ -144,58 +144,80 @@ func safeUint64ToInt(u uint64) int { // extractToolsTimeout extracts the timeout setting from tools // Returns 0 if not set (engines will use their own defaults) -func (c *Compiler) extractToolsTimeout(tools map[string]any) int { +// Returns error if timeout is explicitly set but invalid (< 1) +func (c *Compiler) extractToolsTimeout(tools map[string]any) (int, error) { if tools == nil { - return 0 // Use engine defaults + return 0, nil // Use engine defaults } // Check if timeout is explicitly set in tools if timeoutValue, exists := tools["timeout"]; exists { + var timeout int // Handle different numeric types with safe conversions to prevent overflow switch v := timeoutValue.(type) { case int: - return v + timeout = v case int64: - return int(v) + timeout = int(v) case uint: - return safeUintToInt(v) // Safe conversion to prevent overflow (alert #418) + timeout = safeUintToInt(v) // Safe conversion to prevent overflow (alert #418) case uint64: - return safeUint64ToInt(v) // Safe conversion to prevent overflow (alert #416) + timeout = safeUint64ToInt(v) // Safe conversion to prevent overflow (alert #416) case float64: - return int(v) + timeout = int(v) + default: + return 0, fmt.Errorf("tools.timeout must be an integer, got %T", timeoutValue) } + + // Validate minimum value per schema constraint + if timeout < 1 { + return 0, fmt.Errorf("tools.timeout must be at least 1 second, got %d. Example:\ntools:\n timeout: 60", timeout) + } + + return timeout, nil } // Default to 0 (use engine defaults) - return 0 + return 0, nil } // extractToolsStartupTimeout extracts the startup-timeout setting from tools // Returns 0 if not set (engines will use their own defaults) -func (c *Compiler) extractToolsStartupTimeout(tools map[string]any) int { +// Returns error if startup-timeout is explicitly set but invalid (< 1) +func (c *Compiler) extractToolsStartupTimeout(tools map[string]any) (int, error) { if tools == nil { - return 0 // Use engine defaults + return 0, nil // Use engine defaults } // Check if startup-timeout is explicitly set in tools if timeoutValue, exists := tools["startup-timeout"]; exists { + var timeout int // Handle different numeric types with safe conversions to prevent overflow switch v := timeoutValue.(type) { case int: - return v + timeout = v case int64: - return int(v) + timeout = int(v) case uint: - return safeUintToInt(v) // Safe conversion to prevent overflow (alert #417) + timeout = safeUintToInt(v) // Safe conversion to prevent overflow (alert #417) case uint64: - return safeUint64ToInt(v) // Safe conversion to prevent overflow (alert #415) + timeout = safeUint64ToInt(v) // Safe conversion to prevent overflow (alert #415) case float64: - return int(v) + timeout = int(v) + default: + return 0, fmt.Errorf("tools.startup-timeout must be an integer, got %T", timeoutValue) } + + // Validate minimum value per schema constraint + if timeout < 1 { + return 0, fmt.Errorf("tools.startup-timeout must be at least 1 second, got %d. Example:\ntools:\n startup-timeout: 120", timeout) + } + + return timeout, nil } // Default to 0 (use engine defaults) - return 0 + return 0, nil } // extractMapFromFrontmatter is a generic helper to extract a map[string]any from frontmatter diff --git a/pkg/workflow/frontmatter_extraction_metadata_test.go b/pkg/workflow/frontmatter_extraction_metadata_test.go index ca9a471fe6..07b0989e99 100644 --- a/pkg/workflow/frontmatter_extraction_metadata_test.go +++ b/pkg/workflow/frontmatter_extraction_metadata_test.go @@ -99,9 +99,10 @@ func TestExtractToolsStartupTimeout(t *testing.T) { compiler := &Compiler{} tests := []struct { - name string - tools map[string]any - expected int + name string + tools map[string]any + expected int + shouldError bool }{ { name: "int timeout", @@ -156,33 +157,61 @@ func TestExtractToolsStartupTimeout(t *testing.T) { expected: 0, }, { - name: "invalid type (string)", + name: "invalid type (string) - should error", tools: map[string]any{ "startup-timeout": "not a number", }, - expected: 0, + expected: 0, + shouldError: true, }, { - name: "invalid type (array)", + name: "invalid type (array) - should error", tools: map[string]any{ "startup-timeout": []int{1, 2, 3}, }, - expected: 0, + expected: 0, + shouldError: true, }, { - name: "zero timeout", + name: "zero timeout - should fail validation", tools: map[string]any{ "startup-timeout": 0, }, - expected: 0, + expected: 0, + shouldError: true, + }, + { + name: "negative timeout - should fail validation", + tools: map[string]any{ + "startup-timeout": -5, + }, + expected: 0, + shouldError: true, + }, + { + name: "minimum valid timeout (1)", + tools: map[string]any{ + "startup-timeout": 1, + }, + expected: 1, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := compiler.extractToolsStartupTimeout(tt.tools) - if result != tt.expected { - t.Errorf("extractToolsStartupTimeout() = %d, want %d", result, tt.expected) + result, err := compiler.extractToolsStartupTimeout(tt.tools) + + if tt.shouldError { + if err == nil { + t.Errorf("Expected error but got none") + } + } else { + if err != nil { + t.Errorf("Expected no error but got: %v", err) + } + if result != tt.expected { + t.Errorf("extractToolsStartupTimeout() = %d, want %d", result, tt.expected) + } } }) } diff --git a/pkg/workflow/tools_timeout_test.go b/pkg/workflow/tools_timeout_test.go index 7965ac33ed..b7a49988e0 100644 --- a/pkg/workflow/tools_timeout_test.go +++ b/pkg/workflow/tools_timeout_test.go @@ -170,11 +170,13 @@ func TestExtractToolsTimeout(t *testing.T) { name string tools map[string]any expectedTimeout int + shouldError bool }{ { name: "no timeout specified", tools: map[string]any{}, expectedTimeout: 0, + shouldError: false, }, { name: "timeout as int", @@ -182,6 +184,7 @@ func TestExtractToolsTimeout(t *testing.T) { "timeout": 45, }, expectedTimeout: 45, + shouldError: false, }, { name: "timeout as int64", @@ -189,6 +192,7 @@ func TestExtractToolsTimeout(t *testing.T) { "timeout": int64(90), }, expectedTimeout: 90, + shouldError: false, }, { name: "timeout as uint", @@ -196,6 +200,7 @@ func TestExtractToolsTimeout(t *testing.T) { "timeout": uint(75), }, expectedTimeout: 75, + shouldError: false, }, { name: "timeout as uint64", @@ -203,6 +208,7 @@ func TestExtractToolsTimeout(t *testing.T) { "timeout": uint64(120), }, expectedTimeout: 120, + shouldError: false, }, { name: "timeout as float64", @@ -210,19 +216,55 @@ func TestExtractToolsTimeout(t *testing.T) { "timeout": 60.0, }, expectedTimeout: 60, + shouldError: false, }, { name: "nil tools", tools: nil, expectedTimeout: 0, + shouldError: false, + }, + { + name: "zero timeout - should fail", + tools: map[string]any{ + "timeout": 0, + }, + expectedTimeout: 0, + shouldError: true, + }, + { + name: "negative timeout - should fail", + tools: map[string]any{ + "timeout": -5, + }, + expectedTimeout: 0, + shouldError: true, + }, + { + name: "minimum valid timeout (1)", + tools: map[string]any{ + "timeout": 1, + }, + expectedTimeout: 1, + shouldError: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - timeout := compiler.extractToolsTimeout(tt.tools) - if timeout != tt.expectedTimeout { - t.Errorf("Expected timeout %d, got %d", tt.expectedTimeout, timeout) + timeout, err := compiler.extractToolsTimeout(tt.tools) + + if tt.shouldError { + if err == nil { + t.Errorf("Expected error but got none") + } + } else { + if err != nil { + t.Errorf("Expected no error but got: %v", err) + } + if timeout != tt.expectedTimeout { + t.Errorf("Expected timeout %d, got %d", tt.expectedTimeout, timeout) + } } }) } diff --git a/pkg/workflow/tools_timeout_validation_test.go b/pkg/workflow/tools_timeout_validation_test.go new file mode 100644 index 0000000000..43ac54da4e --- /dev/null +++ b/pkg/workflow/tools_timeout_validation_test.go @@ -0,0 +1,219 @@ +package workflow + +import ( + "os" + "strings" + "testing" + + "github.com/githubnext/gh-aw/pkg/stringutil" +) + +func TestToolsTimeoutValidation(t *testing.T) { + tests := []struct { + name string + workflowMd string + shouldCompile bool + errorContains string + }{ + { + name: "valid timeout - positive value", + workflowMd: `--- +on: workflow_dispatch +engine: claude +tools: + timeout: 60 + github: +--- + +# Test Timeout + +Test workflow. +`, + shouldCompile: true, + }, + { + name: "valid timeout - minimum value (1)", + workflowMd: `--- +on: workflow_dispatch +engine: claude +tools: + timeout: 1 + github: +--- + +# Test Timeout + +Test workflow. +`, + shouldCompile: true, + }, + { + name: "invalid timeout - zero", + workflowMd: `--- +on: workflow_dispatch +engine: claude +tools: + timeout: 0 + github: +--- + +# Test Timeout + +Test workflow. +`, + shouldCompile: false, + errorContains: "minimum: got 0, want 1", + }, + { + name: "invalid timeout - negative", + workflowMd: `--- +on: workflow_dispatch +engine: claude +tools: + timeout: -5 + github: +--- + +# Test Timeout + +Test workflow. +`, + shouldCompile: false, + errorContains: "minimum: got -5, want 1", + }, + { + name: "valid startup-timeout - positive value", + workflowMd: `--- +on: workflow_dispatch +engine: claude +tools: + startup-timeout: 120 + github: +--- + +# Test Startup Timeout + +Test workflow. +`, + shouldCompile: true, + }, + { + name: "valid startup-timeout - minimum value (1)", + workflowMd: `--- +on: workflow_dispatch +engine: claude +tools: + startup-timeout: 1 + github: +--- + +# Test Startup Timeout + +Test workflow. +`, + shouldCompile: true, + }, + { + name: "invalid startup-timeout - zero", + workflowMd: `--- +on: workflow_dispatch +engine: claude +tools: + startup-timeout: 0 + github: +--- + +# Test Startup Timeout + +Test workflow. +`, + shouldCompile: false, + errorContains: "minimum: got 0, want 1", + }, + { + name: "invalid startup-timeout - negative", + workflowMd: `--- +on: workflow_dispatch +engine: claude +tools: + startup-timeout: -10 + github: +--- + +# Test Startup Timeout + +Test workflow. +`, + shouldCompile: false, + errorContains: "minimum: got -10, want 1", + }, + { + name: "both timeouts valid", + workflowMd: `--- +on: workflow_dispatch +engine: claude +tools: + timeout: 60 + startup-timeout: 120 + github: +--- + +# Test Both Timeouts + +Test workflow. +`, + shouldCompile: true, + }, + { + name: "both timeouts invalid", + workflowMd: `--- +on: workflow_dispatch +engine: claude +tools: + timeout: 0 + startup-timeout: -5 + github: +--- + +# Test Both Timeouts Invalid + +Test workflow. +`, + shouldCompile: false, + errorContains: "minimum:", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Write to temporary file + tmpFile, err := os.CreateTemp("", "test-timeout-validation-*.md") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tmpFile.Name()) + defer os.Remove(stringutil.MarkdownToLockFile(tmpFile.Name())) + + if _, err := tmpFile.WriteString(tt.workflowMd); err != nil { + t.Fatalf("Failed to write to temp file: %v", err) + } + tmpFile.Close() + + // Compile the workflow + compiler := NewCompiler(false, "", "") + err = compiler.CompileWorkflow(tmpFile.Name()) + + if tt.shouldCompile { + if err != nil { + t.Errorf("Expected workflow to compile successfully, but got error: %v", err) + } + } else { + if err == nil { + t.Errorf("Expected workflow compilation to fail, but it succeeded") + } else if !strings.Contains(err.Error(), tt.errorContains) { + t.Errorf("Expected error to contain '%s', but got: %v", tt.errorContains, err) + } + } + }) + } +}