Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pkg/workflow/compiler_safe_outputs_env.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ func (c *Compiler) addAllSafeOutputConfigEnvVars(steps *[]string, data *Workflow
if data.SafeOutputs.AddLabels != nil {
cfg := data.SafeOutputs.AddLabels
// Add staged flag if needed (but not if target-repo is specified or we're in trial mode)
if !c.trialMode && data.SafeOutputs.Staged && !stagedFlagAdded && cfg.TargetRepoSlug == "" {
// Check both global staged flag and per-handler staged flag
if !c.trialMode && (data.SafeOutputs.Staged || cfg.Staged) && !stagedFlagAdded && cfg.TargetRepoSlug == "" {
*steps = append(*steps, " GH_AW_SAFE_OUTPUTS_STAGED: \"true\"\n")
stagedFlagAdded = true
}
Expand All @@ -59,7 +60,8 @@ func (c *Compiler) addAllSafeOutputConfigEnvVars(steps *[]string, data *Workflow
if data.SafeOutputs.RemoveLabels != nil {
cfg := data.SafeOutputs.RemoveLabels
// Add staged flag if needed (but not if target-repo is specified or we're in trial mode)
if !c.trialMode && data.SafeOutputs.Staged && !stagedFlagAdded && cfg.TargetRepoSlug == "" {
// Check both global staged flag and per-handler staged flag
if !c.trialMode && (data.SafeOutputs.Staged || cfg.Staged) && !stagedFlagAdded && cfg.TargetRepoSlug == "" {
*steps = append(*steps, " GH_AW_SAFE_OUTPUTS_STAGED: \"true\"\n")
stagedFlagAdded = true
}
Expand Down
130 changes: 130 additions & 0 deletions pkg/workflow/compiler_safe_outputs_env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,3 +515,133 @@ func TestAddLabelsTargetRepoStagedBehavior(t *testing.T) {
// Should not add staged flag when target-repo is set
assert.NotContains(t, stepsContent, "GH_AW_SAFE_OUTPUTS_STAGED")
}

// TestAddLabelsPerHandlerStagedFlag tests per-handler staged flag for add_labels
func TestAddLabelsPerHandlerStagedFlag(t *testing.T) {
tests := []struct {
name string
globalStaged bool
perHandlerStaged bool
shouldIncludeStagedFlag bool
}{
{
name: "per-handler staged true, global false",
globalStaged: false,
perHandlerStaged: true,
shouldIncludeStagedFlag: true,
},
{
name: "per-handler staged false, global true",
globalStaged: true,
perHandlerStaged: false,
shouldIncludeStagedFlag: true,
},
{
name: "both per-handler and global staged true",
globalStaged: true,
perHandlerStaged: true,
shouldIncludeStagedFlag: true,
},
{
name: "both per-handler and global staged false",
globalStaged: false,
perHandlerStaged: false,
shouldIncludeStagedFlag: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compiler := NewCompiler()

workflowData := &WorkflowData{
Name: "Test Workflow",
SafeOutputs: &SafeOutputsConfig{
Staged: tt.globalStaged,
AddLabels: &AddLabelsConfig{
BaseSafeOutputConfig: BaseSafeOutputConfig{
Staged: tt.perHandlerStaged,
},
Allowed: []string{"bug"},
},
},
}

var steps []string
compiler.addAllSafeOutputConfigEnvVars(&steps, workflowData)

stepsContent := strings.Join(steps, "")

if tt.shouldIncludeStagedFlag {
assert.Contains(t, stepsContent, "GH_AW_SAFE_OUTPUTS_STAGED: \"true\"", "Expected staged flag to be set")
} else {
assert.NotContains(t, stepsContent, "GH_AW_SAFE_OUTPUTS_STAGED:", "Expected staged flag not to be set")
}
})
}
}

// TestRemoveLabelsPerHandlerStagedFlag tests per-handler staged flag for remove_labels
func TestRemoveLabelsPerHandlerStagedFlag(t *testing.T) {
tests := []struct {
name string
globalStaged bool
perHandlerStaged bool
shouldIncludeStagedFlag bool
}{
{
name: "per-handler staged true, global false",
globalStaged: false,
perHandlerStaged: true,
shouldIncludeStagedFlag: true,
},
{
name: "per-handler staged false, global true",
globalStaged: true,
perHandlerStaged: false,
shouldIncludeStagedFlag: true,
},
{
name: "both per-handler and global staged true",
globalStaged: true,
perHandlerStaged: true,
shouldIncludeStagedFlag: true,
},
{
name: "both per-handler and global staged false",
globalStaged: false,
perHandlerStaged: false,
shouldIncludeStagedFlag: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compiler := NewCompiler()

workflowData := &WorkflowData{
Name: "Test Workflow",
SafeOutputs: &SafeOutputsConfig{
Staged: tt.globalStaged,
RemoveLabels: &RemoveLabelsConfig{
BaseSafeOutputConfig: BaseSafeOutputConfig{
Staged: tt.perHandlerStaged,
},
Allowed: []string{"bug"},
},
},
}

var steps []string
compiler.addAllSafeOutputConfigEnvVars(&steps, workflowData)

stepsContent := strings.Join(steps, "")

if tt.shouldIncludeStagedFlag {
assert.Contains(t, stepsContent, "GH_AW_SAFE_OUTPUTS_STAGED: \"true\"", "Expected staged flag to be set")
} else {
assert.NotContains(t, stepsContent, "GH_AW_SAFE_OUTPUTS_STAGED:", "Expected staged flag not to be set")
}
})
}
}
4 changes: 3 additions & 1 deletion pkg/workflow/safe_output_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,9 @@ func (c *Compiler) BuildListSafeOutputJob(data *WorkflowData, mainJobName string
customEnvVars := BuildListJobEnvVars(builderConfig.EnvPrefix, listJobConfig, maxCount)

// Add standard environment variables (metadata + staged/target repo)
customEnvVars = append(customEnvVars, c.buildStandardSafeOutputEnvVars(data, listJobConfig.TargetRepoSlug)...)
// Check both global and per-handler staged flags
perHandlerStaged := baseSafeOutputConfig.Staged
customEnvVars = append(customEnvVars, c.buildStandardSafeOutputEnvVarsWithPerHandlerStaged(data, listJobConfig.TargetRepoSlug, perHandlerStaged)...)

// Create outputs for the job
outputs := map[string]string{
Expand Down
13 changes: 12 additions & 1 deletion pkg/workflow/safe_outputs_env.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ func buildSafeOutputJobEnvVars(trialMode bool, trialLogicalRepoSlug string, stag
// that all safe-output job builders need: metadata + staged/target repo handling
// This reduces duplication in safe-output job builders
func (c *Compiler) buildStandardSafeOutputEnvVars(data *WorkflowData, targetRepoSlug string) []string {
return c.buildStandardSafeOutputEnvVarsWithPerHandlerStaged(data, targetRepoSlug, false)
}

// buildStandardSafeOutputEnvVarsWithPerHandlerStaged builds the standard set of environment variables
// with support for per-handler staged flag. The staged flag is set if either the global flag
// or the per-handler flag is true.
func (c *Compiler) buildStandardSafeOutputEnvVarsWithPerHandlerStaged(data *WorkflowData, targetRepoSlug string, perHandlerStaged bool) []string {
var customEnvVars []string

// Add workflow metadata (name, source, and tracker-id)
Expand All @@ -128,11 +135,15 @@ func (c *Compiler) buildStandardSafeOutputEnvVars(data *WorkflowData, targetRepo
// Add engine metadata (id, version, model) for XML comment marker
customEnvVars = append(customEnvVars, buildEngineMetadataEnvVars(data.EngineConfig)...)

// Check both global and per-handler staged flags (OR operation)
globalStaged := data.SafeOutputs.Staged
effectiveStaged := globalStaged || perHandlerStaged

// Add common safe output job environment variables (staged/target repo)
customEnvVars = append(customEnvVars, buildSafeOutputJobEnvVars(
c.trialMode,
c.trialLogicalRepoSlug,
data.SafeOutputs.Staged,
effectiveStaged,
targetRepoSlug,
)...)

Expand Down
63 changes: 63 additions & 0 deletions pkg/workflow/staged_add_issue_labels_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,69 @@ func TestAddLabelsJobWithStagedFlag(t *testing.T) {

}

func TestAddLabelsJobWithPerHandlerStagedFlag(t *testing.T) {
// Create a compiler instance
c := NewCompiler()

// Test with per-handler staged: true (global staged: false)
workflowData := &WorkflowData{
Name: "test-workflow",
SafeOutputs: &SafeOutputsConfig{
AddLabels: &AddLabelsConfig{
BaseSafeOutputConfig: BaseSafeOutputConfig{
Staged: true,
},
},
Staged: false,
},
}

job, err := c.buildAddLabelsJob(workflowData, "main_job")
if err != nil {
t.Fatalf("Unexpected error building add labels job: %v", err)
}

// Convert steps to a single string for testing
stepsContent := strings.Join(job.Steps, "")

// Check that GH_AW_SAFE_OUTPUTS_STAGED is included when per-handler staged is true
if !strings.Contains(stepsContent, " GH_AW_SAFE_OUTPUTS_STAGED: \"true\"\n") {
t.Error("Expected GH_AW_SAFE_OUTPUTS_STAGED environment variable to be set to true when per-handler staged is true")
}

// Test with per-handler staged: false, global staged: true
workflowData.SafeOutputs.AddLabels.Staged = false
workflowData.SafeOutputs.Staged = true

job, err = c.buildAddLabelsJob(workflowData, "main_job")
if err != nil {
t.Fatalf("Unexpected error building add labels job: %v", err)
}

stepsContent = strings.Join(job.Steps, "")

// Check that GH_AW_SAFE_OUTPUTS_STAGED is included when global staged is true
if !strings.Contains(stepsContent, " GH_AW_SAFE_OUTPUTS_STAGED: \"true\"\n") {
t.Error("Expected GH_AW_SAFE_OUTPUTS_STAGED environment variable to be set to true when global staged is true")
}

// Test with both per-handler and global staged: false
workflowData.SafeOutputs.AddLabels.Staged = false
workflowData.SafeOutputs.Staged = false

job, err = c.buildAddLabelsJob(workflowData, "main_job")
if err != nil {
t.Fatalf("Unexpected error building add labels job: %v", err)
}

stepsContent = strings.Join(job.Steps, "")

// Check that GH_AW_SAFE_OUTPUTS_STAGED is not included when both are false
if strings.Contains(stepsContent, " GH_AW_SAFE_OUTPUTS_STAGED:") {
t.Error("Expected GH_AW_SAFE_OUTPUTS_STAGED environment variable not to be set when both staged flags are false")
}
}

func TestAddLabelsJobWithNilSafeOutputs(t *testing.T) {
// Create a compiler instance
c := NewCompiler()
Expand Down