diff --git a/.changeset/patch-aggregate-validation-errors.md b/.changeset/patch-aggregate-validation-errors.md new file mode 100644 index 0000000000..c33a6a0625 --- /dev/null +++ b/.changeset/patch-aggregate-validation-errors.md @@ -0,0 +1,5 @@ +--- +"gh-aw": patch +--- + +Aggregate validation errors so compilation reports all issues together and add the `--fail-fast` flag to preserve the legacy behavior when needed. diff --git a/cmd/gh-aw/main.go b/cmd/gh-aw/main.go index 434755dcad..d49d9058c9 100644 --- a/cmd/gh-aw/main.go +++ b/cmd/gh-aw/main.go @@ -221,6 +221,7 @@ Examples: jsonOutput, _ := cmd.Flags().GetBool("json") fix, _ := cmd.Flags().GetBool("fix") stats, _ := cmd.Flags().GetBool("stats") + failFast, _ := cmd.Flags().GetBool("fail-fast") noCheckUpdate, _ := cmd.Flags().GetBool("no-check-update") verbose, _ := cmd.Flags().GetBool("verbose") if err := validateEngine(engineOverride); err != nil { @@ -272,6 +273,7 @@ Examples: Actionlint: actionlint, JSONOutput: jsonOutput, Stats: stats, + FailFast: failFast, } if _, err := cli.CompileWorkflows(cmd.Context(), config); err != nil { errMsg := err.Error() @@ -502,6 +504,7 @@ Use "` + string(constants.CLIExtensionPrefix) + ` help all" to show help for all compileCmd.Flags().Bool("fix", false, "Apply automatic codemod fixes to workflows before compiling") compileCmd.Flags().BoolP("json", "j", false, "Output results in JSON format") compileCmd.Flags().Bool("stats", false, "Display statistics table sorted by file size (shows jobs, steps, scripts, and shells)") + compileCmd.Flags().Bool("fail-fast", false, "Stop at the first validation error instead of collecting all errors") compileCmd.Flags().Bool("no-check-update", false, "Skip checking for gh-aw updates") compileCmd.MarkFlagsMutuallyExclusive("dir", "workflows-dir") diff --git a/pkg/cli/compile_compiler_setup.go b/pkg/cli/compile_compiler_setup.go index cad36e297e..839060bbf7 100644 --- a/pkg/cli/compile_compiler_setup.go +++ b/pkg/cli/compile_compiler_setup.go @@ -93,6 +93,7 @@ func createAndConfigureCompiler(config CompileConfig) *workflow.Compiler { compiler := workflow.NewCompiler( workflow.WithVerbose(config.Verbose), workflow.WithEngineOverride(config.EngineOverride), + workflow.WithFailFast(config.FailFast), ) compileCompilerSetupLog.Print("Created compiler instance") diff --git a/pkg/cli/compile_config.go b/pkg/cli/compile_config.go index 0d1befe850..e689f4d589 100644 --- a/pkg/cli/compile_config.go +++ b/pkg/cli/compile_config.go @@ -32,6 +32,7 @@ type CompileConfig struct { ActionMode string // Action script inlining mode: inline, dev, or release ActionTag string // Override action SHA or tag for actions/setup (overrides action-mode to release) Stats bool // Display statistics table sorted by file size + FailFast bool // Stop at first error instead of collecting all errors } // WorkflowFailure represents a failed workflow with its error count diff --git a/pkg/workflow/compiler.go b/pkg/workflow/compiler.go index 0fbc9ad6c9..eef82a083f 100644 --- a/pkg/workflow/compiler.go +++ b/pkg/workflow/compiler.go @@ -185,13 +185,13 @@ func (c *Compiler) CompileWorkflowData(workflowData *WorkflowData, markdownPath // Validate safe-outputs allowed-domains configuration log.Printf("Validating safe-outputs allowed-domains") - if err := validateSafeOutputsAllowedDomains(workflowData.SafeOutputs); err != nil { + if err := c.validateSafeOutputsAllowedDomains(workflowData.SafeOutputs); err != nil { return formatCompilerError(markdownPath, "error", err.Error()) } // Validate network allowed domains configuration log.Printf("Validating network allowed domains") - if err := validateNetworkAllowedDomains(workflowData.NetworkPermissions); err != nil { + if err := c.validateNetworkAllowedDomains(workflowData.NetworkPermissions); err != nil { return formatCompilerError(markdownPath, "error", err.Error()) } diff --git a/pkg/workflow/compiler_types.go b/pkg/workflow/compiler_types.go index 1c10078fab..1be99b0d57 100644 --- a/pkg/workflow/compiler_types.go +++ b/pkg/workflow/compiler_types.go @@ -52,6 +52,11 @@ func WithStrictMode(strict bool) CompilerOption { return func(c *Compiler) { c.strictMode = strict } } +// WithFailFast configures whether to stop at first validation error +func WithFailFast(failFast bool) CompilerOption { + return func(c *Compiler) { c.failFast = failFast } +} + // WithForceRefreshActionPins configures whether to force refresh of action pins func WithForceRefreshActionPins(force bool) CompilerOption { return func(c *Compiler) { c.forceRefreshActionPins = force } @@ -101,6 +106,7 @@ type Compiler struct { trialLogicalRepoSlug string // If set in trial mode, the logical repository to checkout refreshStopTime bool // If true, regenerate stop-after times instead of preserving existing ones forceRefreshActionPins bool // If true, clear action cache and resolve all actions from GitHub API + failFast bool // If true, stop at first validation error instead of collecting all errors actionCacheCleared bool // Tracks if action cache has already been cleared (for forceRefreshActionPins) markdownPath string // Path to the markdown file being compiled (for context in dynamic tool generation) actionMode ActionMode // Mode for generating JavaScript steps (inline vs custom actions) diff --git a/pkg/workflow/error_aggregation.go b/pkg/workflow/error_aggregation.go new file mode 100644 index 0000000000..56e77965b9 --- /dev/null +++ b/pkg/workflow/error_aggregation.go @@ -0,0 +1,195 @@ +// This file provides error aggregation utilities for validation. +// +// # Error Aggregation +// +// This file implements error collection and aggregation for validation +// functions, allowing users to see all validation errors in a single run +// instead of discovering them one at a time. +// +// # Error Aggregation Functions +// +// - NewErrorCollector() - Creates a new error collector +// - ErrorCollector.Add() - Adds an error to the collection +// - ErrorCollector.HasErrors() - Checks if any errors were collected +// - ErrorCollector.Error() - Returns aggregated error using errors.Join +// - ErrorCollector.Count() - Returns the number of collected errors +// +// # Usage Pattern +// +// Use error collectors in validation functions to collect multiple errors: +// +// func validateMultipleThings(config Config, failFast bool) error { +// collector := NewErrorCollector(failFast) +// +// if err := validateThing1(config); err != nil { +// if returnErr := collector.Add(err); returnErr != nil { +// return returnErr // Fail-fast mode +// } +// } +// +// if err := validateThing2(config); err != nil { +// if returnErr := collector.Add(err); returnErr != nil { +// return returnErr // Fail-fast mode +// } +// } +// +// return collector.Error() +// } +// +// # Fail-Fast Mode +// +// When failFast is true, the collector returns immediately on the first error. +// When false, it collects all errors and returns them joined with errors.Join. + +package workflow + +import ( + "errors" + "fmt" + "strings" + + "github.com/githubnext/gh-aw/pkg/logger" +) + +var errorAggregationLog = logger.New("workflow:error_aggregation") + +// ErrorCollector collects multiple validation errors +type ErrorCollector struct { + errors []error + failFast bool +} + +// NewErrorCollector creates a new error collector +// If failFast is true, the collector will stop at the first error +func NewErrorCollector(failFast bool) *ErrorCollector { + errorAggregationLog.Printf("Creating error collector: fail_fast=%v", failFast) + return &ErrorCollector{ + errors: make([]error, 0), + failFast: failFast, + } +} + +// Add adds an error to the collector +// If failFast is enabled, returns the error immediately +// Otherwise, adds it to the collection and returns nil +func (c *ErrorCollector) Add(err error) error { + if err == nil { + return nil + } + + errorAggregationLog.Printf("Adding error to collector: %v", err) + + if c.failFast { + errorAggregationLog.Print("Fail-fast enabled, returning error immediately") + return err + } + + c.errors = append(c.errors, err) + return nil +} + +// HasErrors returns true if any errors have been collected +func (c *ErrorCollector) HasErrors() bool { + return len(c.errors) > 0 +} + +// Count returns the number of errors collected +func (c *ErrorCollector) Count() int { + return len(c.errors) +} + +// Error returns the aggregated error using errors.Join +// Returns nil if no errors were collected +func (c *ErrorCollector) Error() error { + if len(c.errors) == 0 { + return nil + } + + errorAggregationLog.Printf("Aggregating %d errors", len(c.errors)) + + if len(c.errors) == 1 { + return c.errors[0] + } + + return errors.Join(c.errors...) +} + +// FormattedError returns the aggregated error with a formatted header showing the count +// Returns nil if no errors were collected +// This method is preferred over Error() + FormatAggregatedError for better accuracy +func (c *ErrorCollector) FormattedError(category string) error { + if len(c.errors) == 0 { + return nil + } + + errorAggregationLog.Printf("Formatting %d errors for category: %s", len(c.errors), category) + + if len(c.errors) == 1 { + return c.errors[0] + } + + // Build formatted error with count header + var sb strings.Builder + fmt.Fprintf(&sb, "Found %d %s errors:", len(c.errors), category) + for _, err := range c.errors { + sb.WriteString("\n • ") + sb.WriteString(err.Error()) + } + + return fmt.Errorf("%s", sb.String()) +} + +// FormatAggregatedError formats aggregated errors with a summary header +// Returns a formatted error with count and categorization if multiple errors exist +func FormatAggregatedError(err error, category string) error { + if err == nil { + return nil + } + + // Check if this is a joined error by looking for newlines + errStr := err.Error() + lines := strings.Split(errStr, "\n") + + if len(lines) <= 1 { + return err + } + + // Format with count and category + header := fmt.Sprintf("Found %d %s errors:", len(lines), category) + + // Reconstruct with header + var sb strings.Builder + sb.WriteString(header) + for _, line := range lines { + if line != "" { + sb.WriteString("\n • ") + sb.WriteString(line) + } + } + + return fmt.Errorf("%s", sb.String()) +} + +// SplitJoinedErrors splits a joined error into individual error strings +func SplitJoinedErrors(err error) []error { + if err == nil { + return nil + } + + // errors.Join formats errors separated by newlines + errStr := err.Error() + lines := strings.Split(errStr, "\n") + + result := make([]error, 0, len(lines)) + for _, line := range lines { + if line != "" { + result = append(result, fmt.Errorf("%s", line)) + } + } + + if len(result) == 0 { + return []error{err} + } + + return result +} diff --git a/pkg/workflow/error_aggregation_test.go b/pkg/workflow/error_aggregation_test.go new file mode 100644 index 0000000000..9d895e2c63 --- /dev/null +++ b/pkg/workflow/error_aggregation_test.go @@ -0,0 +1,340 @@ +//go:build !integration + +package workflow + +import ( + "errors" + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewErrorCollector(t *testing.T) { + tests := []struct { + name string + failFast bool + }{ + { + name: "fail-fast enabled", + failFast: true, + }, + { + name: "fail-fast disabled", + failFast: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + collector := NewErrorCollector(tt.failFast) + require.NotNil(t, collector, "Collector should be created") + assert.Equal(t, tt.failFast, collector.failFast, "Fail-fast setting should match") + assert.False(t, collector.HasErrors(), "New collector should have no errors") + assert.Equal(t, 0, collector.Count(), "New collector should have zero count") + }) + } +} + +func TestErrorCollectorAdd_FailFast(t *testing.T) { + collector := NewErrorCollector(true) + err1 := fmt.Errorf("first error") + err2 := fmt.Errorf("second error") + + // First error should be returned immediately + result := collector.Add(err1) + require.Error(t, result, "Should return error immediately in fail-fast mode") + assert.Equal(t, err1, result, "Should return the exact error") + assert.False(t, collector.HasErrors(), "Should not collect errors in fail-fast mode") + + // Second error should also be returned immediately + result = collector.Add(err2) + require.Error(t, result, "Should return error immediately in fail-fast mode") + assert.Equal(t, err2, result, "Should return the exact error") +} + +func TestErrorCollectorAdd_Aggregate(t *testing.T) { + collector := NewErrorCollector(false) + err1 := fmt.Errorf("first error") + err2 := fmt.Errorf("second error") + err3 := fmt.Errorf("third error") + + // Add errors should not return them + result := collector.Add(err1) + require.NoError(t, result, "Should not return error in aggregate mode") + assert.True(t, collector.HasErrors(), "Should have errors") + assert.Equal(t, 1, collector.Count(), "Should have 1 error") + + result = collector.Add(err2) + require.NoError(t, result, "Should not return error in aggregate mode") + assert.Equal(t, 2, collector.Count(), "Should have 2 errors") + + result = collector.Add(err3) + require.NoError(t, result, "Should not return error in aggregate mode") + assert.Equal(t, 3, collector.Count(), "Should have 3 errors") +} + +func TestErrorCollectorAdd_NilError(t *testing.T) { + collector := NewErrorCollector(false) + + result := collector.Add(nil) + require.NoError(t, result, "Should handle nil error") + assert.False(t, collector.HasErrors(), "Should not have errors") + assert.Equal(t, 0, collector.Count(), "Should have zero count") +} + +func TestErrorCollectorError_NoErrors(t *testing.T) { + collector := NewErrorCollector(false) + + err := collector.Error() + assert.NoError(t, err, "Should return nil when no errors collected") +} + +func TestErrorCollectorError_SingleError(t *testing.T) { + collector := NewErrorCollector(false) + err1 := fmt.Errorf("single error") + + _ = collector.Add(err1) + result := collector.Error() + + require.Error(t, result, "Should return error") + assert.Equal(t, err1, result, "Should return the single error as-is") +} + +func TestErrorCollectorError_MultipleErrors(t *testing.T) { + collector := NewErrorCollector(false) + err1 := fmt.Errorf("first error") + err2 := fmt.Errorf("second error") + err3 := fmt.Errorf("third error") + + _ = collector.Add(err1) + _ = collector.Add(err2) + _ = collector.Add(err3) + + result := collector.Error() + require.Error(t, result, "Should return aggregated error") + + // Check that all errors are included + errStr := result.Error() + assert.Contains(t, errStr, "first error", "Should contain first error") + assert.Contains(t, errStr, "second error", "Should contain second error") + assert.Contains(t, errStr, "third error", "Should contain third error") +} + +func TestFormatAggregatedError_NoError(t *testing.T) { + result := FormatAggregatedError(nil, "validation") + require.NoError(t, result, "Should handle nil error") +} + +func TestFormatAggregatedError_SingleError(t *testing.T) { + err := fmt.Errorf("single error") + result := FormatAggregatedError(err, "validation") + + require.Error(t, result, "Should return error") + assert.Equal(t, err, result, "Should return single error unchanged") +} + +func TestFormatAggregatedError_MultipleErrors(t *testing.T) { + err1 := fmt.Errorf("first error") + err2 := fmt.Errorf("second error") + err3 := fmt.Errorf("third error") + + joined := errors.Join(err1, err2, err3) + result := FormatAggregatedError(joined, "validation") + + require.Error(t, result, "Should return formatted error") + errStr := result.Error() + + // Should contain header with count + assert.True(t, strings.Contains(errStr, "Found") && strings.Contains(errStr, "validation errors:"), + "Should contain header with count and category") + + // Should contain all individual errors + assert.Contains(t, errStr, "first error", "Should contain first error") + assert.Contains(t, errStr, "second error", "Should contain second error") + assert.Contains(t, errStr, "third error", "Should contain third error") +} + +func TestSplitJoinedErrors_NoError(t *testing.T) { + result := SplitJoinedErrors(nil) + assert.Nil(t, result, "Should return nil for nil error") +} + +func TestSplitJoinedErrors_SingleError(t *testing.T) { + err := fmt.Errorf("single error") + result := SplitJoinedErrors(err) + + require.Len(t, result, 1, "Should have 1 error") + assert.Equal(t, err, result[0], "Should contain the single error") +} + +func TestSplitJoinedErrors_MultipleErrors(t *testing.T) { + err1 := fmt.Errorf("first error") + err2 := fmt.Errorf("second error") + err3 := fmt.Errorf("third error") + + joined := errors.Join(err1, err2, err3) + result := SplitJoinedErrors(joined) + + require.NotNil(t, result, "Should return error slice") + assert.Greater(t, len(result), 1, "Should have multiple errors") + + // Check that all errors are present in the result + errStr := joined.Error() + assert.Contains(t, errStr, "first error", "Should contain first error") + assert.Contains(t, errStr, "second error", "Should contain second error") + assert.Contains(t, errStr, "third error", "Should contain third error") +} + +// TestErrorCollectorIntegration tests the full flow of error collection +func TestErrorCollectorIntegration(t *testing.T) { + tests := []struct { + name string + failFast bool + errors []error + expectError bool + expectCount int + shouldContain []string + }{ + { + name: "no errors collected", + failFast: false, + errors: []error{}, + expectError: false, + expectCount: 0, + }, + { + name: "single error aggregated", + failFast: false, + errors: []error{fmt.Errorf("error 1")}, + expectError: true, + expectCount: 1, + shouldContain: []string{"error 1"}, + }, + { + name: "multiple errors aggregated", + failFast: false, + errors: []error{fmt.Errorf("error 1"), fmt.Errorf("error 2"), fmt.Errorf("error 3")}, + expectError: true, + expectCount: 3, + shouldContain: []string{"error 1", "error 2", "error 3"}, + }, + { + name: "fail-fast stops at first error", + failFast: true, + errors: []error{fmt.Errorf("error 1"), fmt.Errorf("error 2")}, + expectError: true, + expectCount: 0, // No errors collected in fail-fast mode + shouldContain: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + collector := NewErrorCollector(tt.failFast) + + // Add all errors + for _, err := range tt.errors { + result := collector.Add(err) + if tt.failFast && err != nil { + // In fail-fast mode, Add should return error immediately + assert.Error(t, result, "Should return error in fail-fast mode") + return // Stop test here for fail-fast mode + } + } + + // Get the aggregated error + err := collector.Error() + + if tt.expectError { + require.Error(t, err, "Should have aggregated error") + errStr := err.Error() + + for _, expected := range tt.shouldContain { + assert.Contains(t, errStr, expected, "Should contain expected error message") + } + } else { + require.NoError(t, err, "Should not have error") + } + + assert.Equal(t, tt.expectCount, collector.Count(), "Error count should match") + }) + } +} + +// TestErrorCollectorFormattedError tests the FormattedError method +func TestErrorCollectorFormattedError(t *testing.T) { + tests := []struct { + name string + errors []error + category string + expectError bool + shouldContain []string + }{ + { + name: "no errors", + errors: []error{}, + category: "validation", + expectError: false, + }, + { + name: "single error (no formatting)", + errors: []error{fmt.Errorf("single error")}, + category: "validation", + expectError: true, + shouldContain: []string{"single error"}, + }, + { + name: "multiple errors with formatted header", + errors: []error{fmt.Errorf("error 1"), fmt.Errorf("error 2"), fmt.Errorf("error 3")}, + category: "validation", + expectError: true, + shouldContain: []string{ + "Found 3 validation errors:", + "error 1", + "error 2", + "error 3", + }, + }, + { + name: "errors with newlines preserved", + errors: []error{fmt.Errorf("error with\nmultiple\nlines"), fmt.Errorf("simple error")}, + category: "test", + expectError: true, + shouldContain: []string{ + "Found 2 test errors:", + "error with", + "multiple", + "lines", + "simple error", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + collector := NewErrorCollector(false) + + // Add all errors + for _, err := range tt.errors { + _ = collector.Add(err) + } + + // Get the formatted error + err := collector.FormattedError(tt.category) + + if tt.expectError { + require.Error(t, err, "Should have formatted error") + errStr := err.Error() + + for _, expected := range tt.shouldContain { + assert.Contains(t, errStr, expected, "Should contain expected text") + } + } else { + require.NoError(t, err, "Should not have error") + } + }) + } +} diff --git a/pkg/workflow/safe_outputs_domains_protocol_validation_test.go b/pkg/workflow/safe_outputs_domains_protocol_validation_test.go index a18b6bf0ec..ad6d71d38a 100644 --- a/pkg/workflow/safe_outputs_domains_protocol_validation_test.go +++ b/pkg/workflow/safe_outputs_domains_protocol_validation_test.go @@ -163,7 +163,8 @@ func TestValidateSafeOutputsAllowedDomainsWithProtocol(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := validateSafeOutputsAllowedDomains(tt.config) + c := NewCompiler() + err := c.validateSafeOutputsAllowedDomains(tt.config) if (err != nil) != tt.wantErr { t.Errorf("validateSafeOutputsAllowedDomains() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/pkg/workflow/safe_outputs_domains_validation.go b/pkg/workflow/safe_outputs_domains_validation.go index 2638d9a4fe..dfd0844990 100644 --- a/pkg/workflow/safe_outputs_domains_validation.go +++ b/pkg/workflow/safe_outputs_domains_validation.go @@ -11,13 +11,15 @@ import ( var safeOutputsDomainsValidationLog = logger.New("workflow:safe_outputs_domains_validation") // validateNetworkAllowedDomains validates the allowed domains in network configuration -func validateNetworkAllowedDomains(network *NetworkPermissions) error { +func (c *Compiler) validateNetworkAllowedDomains(network *NetworkPermissions) error { if network == nil || len(network.Allowed) == 0 { return nil } safeOutputsDomainsValidationLog.Printf("Validating %d network allowed domains", len(network.Allowed)) + collector := NewErrorCollector(c.failFast) + for i, domain := range network.Allowed { // Skip ecosystem identifiers - they don't need domain pattern validation if isEcosystemIdentifier(domain) { @@ -25,11 +27,14 @@ func validateNetworkAllowedDomains(network *NetworkPermissions) error { } if err := validateDomainPattern(domain); err != nil { - return fmt.Errorf("network.allowed[%d]: %w", i, err) + wrappedErr := fmt.Errorf("network.allowed[%d]: %w", i, err) + if returnErr := collector.Add(wrappedErr); returnErr != nil { + return returnErr // Fail-fast mode + } } } - return nil + return collector.Error() } // isEcosystemIdentifier checks if a domain string is actually an ecosystem identifier @@ -50,20 +55,25 @@ func isEcosystemIdentifier(domain string) bool { var domainPattern = regexp.MustCompile(`^(\*\.)?[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$`) // validateSafeOutputsAllowedDomains validates the allowed-domains configuration in safe-outputs -func validateSafeOutputsAllowedDomains(config *SafeOutputsConfig) error { +func (c *Compiler) validateSafeOutputsAllowedDomains(config *SafeOutputsConfig) error { if config == nil || len(config.AllowedDomains) == 0 { return nil } safeOutputsDomainsValidationLog.Printf("Validating %d allowed domains", len(config.AllowedDomains)) + collector := NewErrorCollector(c.failFast) + for i, domain := range config.AllowedDomains { if err := validateDomainPattern(domain); err != nil { - return fmt.Errorf("safe-outputs.allowed-domains[%d]: %w", i, err) + wrappedErr := fmt.Errorf("safe-outputs.allowed-domains[%d]: %w", i, err) + if returnErr := collector.Add(wrappedErr); returnErr != nil { + return returnErr // Fail-fast mode + } } } - return nil + return collector.Error() } // validateDomainPattern validates a single domain pattern diff --git a/pkg/workflow/safe_outputs_domains_validation_test.go b/pkg/workflow/safe_outputs_domains_validation_test.go index b8ed0d78f0..a211fc9ece 100644 --- a/pkg/workflow/safe_outputs_domains_validation_test.go +++ b/pkg/workflow/safe_outputs_domains_validation_test.go @@ -184,7 +184,8 @@ func TestValidateSafeOutputsAllowedDomains(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := validateSafeOutputsAllowedDomains(tt.config) + c := NewCompiler() + err := c.validateSafeOutputsAllowedDomains(tt.config) if tt.wantErr { require.Error(t, err, "Expected an error but got none") if tt.errMsg != "" { @@ -858,7 +859,8 @@ func TestValidateSafeOutputsAllowedDomainsIntegration(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := validateSafeOutputsAllowedDomains(tt.config) + c := NewCompiler() + err := c.validateSafeOutputsAllowedDomains(tt.config) if tt.wantErr { assert.Error(t, err) } else { diff --git a/pkg/workflow/strict_mode_validation.go b/pkg/workflow/strict_mode_validation.go index 4fe05e3f3e..384f37e6d1 100644 --- a/pkg/workflow/strict_mode_validation.go +++ b/pkg/workflow/strict_mode_validation.go @@ -234,33 +234,47 @@ func (c *Compiler) validateStrictMode(frontmatter map[string]any, networkPermiss strictModeValidationLog.Printf("Starting strict mode validation") + // Collect all strict mode validation errors + collector := NewErrorCollector(c.failFast) + // 1. Refuse write permissions if err := c.validateStrictPermissions(frontmatter); err != nil { - return err + if returnErr := collector.Add(err); returnErr != nil { + return returnErr // Fail-fast mode + } } // 2. Require network configuration and refuse "*" wildcard if err := c.validateStrictNetwork(networkPermissions); err != nil { - return err + if returnErr := collector.Add(err); returnErr != nil { + return returnErr // Fail-fast mode + } } // 3. Require network configuration on custom MCP servers if err := c.validateStrictMCPNetwork(frontmatter, networkPermissions); err != nil { - return err + if returnErr := collector.Add(err); returnErr != nil { + return returnErr // Fail-fast mode + } } // 4. Validate tools configuration if err := c.validateStrictTools(frontmatter); err != nil { - return err + if returnErr := collector.Add(err); returnErr != nil { + return returnErr // Fail-fast mode + } } // 5. Refuse deprecated fields if err := c.validateStrictDeprecatedFields(frontmatter); err != nil { - return err + if returnErr := collector.Add(err); returnErr != nil { + return returnErr // Fail-fast mode + } } - strictModeValidationLog.Printf("Strict mode validation completed successfully") - return nil + strictModeValidationLog.Printf("Strict mode validation completed: error_count=%d", collector.Count()) + + return collector.FormattedError("strict mode") } // validateStrictFirewall requires firewall to be enabled in strict mode for copilot and codex engines