diff --git a/internal/step/foreach/provider.go b/internal/step/foreach/provider.go index d7ddb6a..083688d 100644 --- a/internal/step/foreach/provider.go +++ b/internal/step/foreach/provider.go @@ -53,8 +53,9 @@ var executeLifecycleStage = step.LifecycleStage{ RunningName: "executing", FinishedName: "finished", InputFields: map[string]struct{}{ - "items": {}, - "wait_for": {}, + "items": {}, + "parallelism": {}, + "wait_for": {}, }, NextStages: map[string]dgraph.DependencyType{ string(StageIDOutputs): dgraph.AndDependency, @@ -152,24 +153,6 @@ func (l *forEachProvider) ProviderSchema() map[string]*schema.PropertySchema { nil, []string{"\"subworkflow.yaml\""}, ), - "parallelism": schema.NewPropertySchema( - schema.NewIntSchema( - schema.PointerTo[int64](1), - nil, - nil, - ), - schema.NewDisplayValue( - schema.PointerTo("Parallelism"), - schema.PointerTo("How many subworkflows to run in parallel."), - nil, - ), - false, - nil, - nil, - nil, - schema.PointerTo("1"), - nil, - ), } } @@ -212,18 +195,35 @@ func (l *forEachProvider) LoadSchema(inputs map[string]any, workflowContext map[ } return &runnableStep{ - workflow: preparedWorkflow, - parallelism: inputs["parallelism"].(int64), - logger: l.logger, + workflow: preparedWorkflow, + logger: l.logger, }, nil } type runnableStep struct { - workflow workflow.ExecutableWorkflow - parallelism int64 - logger log.Logger + workflow workflow.ExecutableWorkflow + logger log.Logger } +var parallelismSchema = schema.NewPropertySchema( + schema.NewIntSchema( + schema.PointerTo[int64](1), + nil, + nil, + ), + schema.NewDisplayValue( + schema.PointerTo("Parallelism"), + schema.PointerTo("How many subworkflows to run in parallel."), + nil, + ), + false, + nil, + nil, + nil, + schema.PointerTo("1"), + nil, +) + func (r *runnableStep) Lifecycle(_ map[string]any) (step.Lifecycle[step.LifecycleStageWithSchema], error) { workflowOutput := r.workflow.OutputSchema() @@ -265,6 +265,7 @@ func (r *runnableStep) Lifecycle(_ map[string]any) (step.Lifecycle[step.Lifecycl nil, nil, ), + "parallelism": parallelismSchema, }, }, { @@ -442,17 +443,21 @@ func (r *runnableStep) Start(_ map[string]any, runID string, stageChangeHandler lock: &sync.Mutex{}, currentStage: StageIDEnabling, currentState: step.RunningStepStateStarting, - inputData: make(chan []any, 1), + executeInput: make(chan executeInput, 1), enabledInput: make(chan bool, 1), workflow: r.workflow, stageChangeHandler: stageChangeHandler, - parallelism: r.parallelism, logger: r.logger, } go rs.run() return rs, nil } +type executeInput struct { + data []any + parallelism int64 +} + type runningStep struct { runID string workflow workflow.ExecutableWorkflow @@ -460,7 +465,7 @@ type runningStep struct { lock *sync.Mutex currentState step.RunningStepState executionInputAvailable bool - inputData chan []any + executeInput chan executeInput enabledInput chan bool enabledInputAvailable bool ctx context.Context @@ -468,7 +473,6 @@ type runningStep struct { wg sync.WaitGroup cancel context.CancelFunc stageChangeHandler step.StageChangeHandler - parallelism int64 logger log.Logger } @@ -483,23 +487,39 @@ func (r *runningStep) ProvideStageInput(stage string, input map[string]any) erro case string(StageIDExecute): items := input["items"] v := reflect.ValueOf(items) - input := make([]any, v.Len()) + subworkflowInputs := make([]any, v.Len()) for i := 0; i < v.Len(); i++ { item := v.Index(i).Interface() _, err := r.workflow.Input().Unserialize(item) if err != nil { return fmt.Errorf("invalid input item %d for subworkflow (%w) for run/step %s", i, err, r.runID) } - input[i] = item + subworkflowInputs[i] = item } if r.executionInputAvailable { return fmt.Errorf("input for execute workflow provided twice for run/step %s", r.runID) } + parallelismInput := input["parallelism"] + var parallelism int64 + if parallelismInput != nil { + serializedParallelismInput, err := parallelismSchema.Unserialize(parallelismInput) + if err != nil { + return fmt.Errorf("failed to unserialized parallelism input for run/step %s: %w", r.runID, err) + } + parallelism = serializedParallelismInput.(int64) + } else { + parallelism = int64(1) + } + if r.currentState == step.RunningStepStateWaitingForInput && r.currentStage == StageIDExecute { r.currentState = step.RunningStepStateRunning } r.executionInputAvailable = true - r.inputData <- input // Send before unlock to ensure that it never gets closed before sending. + // Send before unlock to ensure that it never gets closed before sending. + r.executeInput <- executeInput{ + data: subworkflowInputs, + parallelism: parallelism, + } return nil case string(StageIDOutputs): return nil @@ -548,7 +568,7 @@ func (r *runningStep) Close() error { r.cancel() r.wg.Wait() r.logger.Debugf("Closing inputData channel in foreach step provider") - close(r.inputData) + close(r.executeInput) return nil } @@ -759,7 +779,7 @@ func (r *runningStep) markStageFailures(firstStage StageID, err error) { func (r *runningStep) runOnInput() { select { - case loopData, ok := <-r.inputData: + case loopData, ok := <-r.executeInput: if !ok { r.logger.Debugf("aborted waiting for result in foreach") return @@ -771,9 +791,9 @@ func (r *runningStep) runOnInput() { } } -func (r *runningStep) processInput(inputData []any) { +func (r *runningStep) processInput(input executeInput) { r.logger.Debugf("Executing subworkflow for step %s...", r.runID) - outputs, errors := r.executeSubWorkflows(inputData) + outputs, errors := r.executeSubWorkflows(input) r.logger.Debugf("Subworkflow %s complete.", r.runID) r.lock.Lock() @@ -788,7 +808,7 @@ func (r *runningStep) processInput(inputData []any) { unresolvableStage = StageIDOutputs unresolvableError = fmt.Errorf("foreach subworkflow failed with errors (%v)", errors) outputID = "error" - dataMap := make(map[int]any, len(inputData)) + dataMap := make(map[int]any, len(input.data)) for i, entry := range outputs { if entry != nil { dataMap[i] = entry @@ -832,13 +852,13 @@ func (r *runningStep) processInput(inputData []any) { } // returns true if there is an error. -func (r *runningStep) executeSubWorkflows(inputData []any) ([]any, map[int]string) { - itemOutputs := make([]any, len(inputData)) - itemErrors := make(map[int]string, len(inputData)) +func (r *runningStep) executeSubWorkflows(input executeInput) ([]any, map[int]string) { + itemOutputs := make([]any, len(input.data)) + itemErrors := make(map[int]string, len(input.data)) wg := &sync.WaitGroup{} - wg.Add(len(inputData)) - sem := make(chan struct{}, r.parallelism) - for i, input := range inputData { + wg.Add(len(input.data)) + sem := make(chan struct{}, input.parallelism) + for i, input := range input.data { i := i input := input go func() { diff --git a/workflow/workflow_test.go b/workflow/workflow_test.go index 0bba298..ffef9c3 100644 --- a/workflow/workflow_test.go +++ b/workflow/workflow_test.go @@ -602,6 +602,89 @@ func TestWaitForSerial_Foreach(t *testing.T) { } } +var parallelismForeachWf = ` +version: v0.2.0 +input: + root: RootObject + objects: + RootObject: + id: RootObject + properties: + parallelism: + type: + type_id: integer +steps: + subwf_step: + kind: foreach + items: + - wait_time_ms: 0 + - wait_time_ms: 0 + - wait_time_ms: 0 + workflow: subworkflow.yaml + parallelism: !expr $.input.parallelism +outputs: + success: + first_step_output: !expr $.steps.subwf_step.outputs +` + +func TestForeachWithParallelismExpr(t *testing.T) { + // This test involves a workflow where the parallelism value is + // set by an expression as opposed to a literal. + + logConfig := log.Config{ + Level: log.LevelDebug, + Destination: log.DestinationStdout, + } + logger := log.New( + logConfig, + ) + cfg := &config.Config{ + Log: logConfig, + } + factories := workflowFactory{ + config: cfg, + } + deployerRegistry := deployerregistry.New( + deployer.Any(testimpl.NewFactory()), + ) + + pluginProvider := assert.NoErrorR[step.Provider](t)( + plugin.New(logger, deployerRegistry, map[string]interface{}{ + "builtin": map[string]any{ + "deployer_name": "test-impl", + "deploy_time": "0", + }, + }), + ) + stepRegistry, err := stepregistry.New( + pluginProvider, + lang.Must2(foreach.New(logger, factories.createYAMLParser, factories.createWorkflow)), + ) + assert.NoError(t, err) + + factories.stepRegistry = stepRegistry + executor := lang.Must2(workflow.NewExecutor( + logger, + cfg, + stepRegistry, + builtinfunctions.GetFunctions(), + )) + wf := lang.Must2(workflow.NewYAMLConverter(stepRegistry).FromYAML([]byte(parallelismForeachWf))) + preparedWorkflow := lang.Must2(executor.Prepare(wf, map[string][]byte{ + "subworkflow.yaml": []byte(waitForSerialForeachSubwf), + })) + outputID, _, err := preparedWorkflow.Execute(context.Background(), map[string]any{ + "parallelism": 1, + }) + assert.NoError(t, err) + assert.Equals(t, outputID, "success") + outputID, _, err = preparedWorkflow.Execute(context.Background(), map[string]any{ + "parallelism": 10, + }) + assert.NoError(t, err) + assert.Equals(t, outputID, "success") +} + var waitForStartedForeachWf = ` version: v0.2.0 input: