Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move parallelism input to stage input #233

Merged
merged 2 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 64 additions & 44 deletions internal/step/foreach/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ var executeLifecycleStage = step.LifecycleStage{
RunningName: "executing",
FinishedName: "finished",
InputFields: map[string]struct{}{
"items": {},
"wait_for": {},
"items": {},
"parallelism": {},
"wait_for": {},
},
NextStages: map[string]dgraph.DependencyType{
string(StageIDOutputs): dgraph.AndDependency,
Expand Down Expand Up @@ -152,24 +153,6 @@ func (l *forEachProvider) ProviderSchema() map[string]*schema.PropertySchema {
nil,
[]string{"\"subworkflow.yaml\""},
),
"parallelism": schema.NewPropertySchema(
schema.NewIntSchema(
schema.PointerTo[int64](1),
nil,
nil,
),
schema.NewDisplayValue(
schema.PointerTo("Parallelism"),
schema.PointerTo("How many subworkflows to run in parallel."),
nil,
),
false,
nil,
nil,
nil,
schema.PointerTo("1"),
nil,
),
}
}

Expand Down Expand Up @@ -212,18 +195,35 @@ func (l *forEachProvider) LoadSchema(inputs map[string]any, workflowContext map[
}

return &runnableStep{
workflow: preparedWorkflow,
parallelism: inputs["parallelism"].(int64),
logger: l.logger,
workflow: preparedWorkflow,
logger: l.logger,
}, nil
}

type runnableStep struct {
workflow workflow.ExecutableWorkflow
parallelism int64
logger log.Logger
workflow workflow.ExecutableWorkflow
logger log.Logger
}

var parallelismSchema = schema.NewPropertySchema(
schema.NewIntSchema(
schema.PointerTo[int64](1),
nil,
nil,
),
jaredoconnell marked this conversation as resolved.
Show resolved Hide resolved
schema.NewDisplayValue(
schema.PointerTo("Parallelism"),
schema.PointerTo("How many subworkflows to run in parallel."),
nil,
),
false,
nil,
nil,
nil,
schema.PointerTo("1"),
nil,
)

func (r *runnableStep) Lifecycle(_ map[string]any) (step.Lifecycle[step.LifecycleStageWithSchema], error) {
workflowOutput := r.workflow.OutputSchema()

Expand Down Expand Up @@ -265,6 +265,7 @@ func (r *runnableStep) Lifecycle(_ map[string]any) (step.Lifecycle[step.Lifecycl
nil,
nil,
),
"parallelism": parallelismSchema,
},
},
{
Expand Down Expand Up @@ -442,33 +443,36 @@ func (r *runnableStep) Start(_ map[string]any, runID string, stageChangeHandler
lock: &sync.Mutex{},
currentStage: StageIDEnabling,
currentState: step.RunningStepStateStarting,
inputData: make(chan []any, 1),
executeInput: make(chan executeInput, 1),
enabledInput: make(chan bool, 1),
workflow: r.workflow,
stageChangeHandler: stageChangeHandler,
parallelism: r.parallelism,
logger: r.logger,
}
go rs.run()
return rs, nil
}

type executeInput struct {
data []any
parallelism int64
}

type runningStep struct {
runID string
workflow workflow.ExecutableWorkflow
currentStage StageID
lock *sync.Mutex
currentState step.RunningStepState
executionInputAvailable bool
inputData chan []any
executeInput chan executeInput
enabledInput chan bool
enabledInputAvailable bool
ctx context.Context
closed atomic.Bool
wg sync.WaitGroup
cancel context.CancelFunc
stageChangeHandler step.StageChangeHandler
parallelism int64
logger log.Logger
}

Expand All @@ -483,23 +487,39 @@ func (r *runningStep) ProvideStageInput(stage string, input map[string]any) erro
case string(StageIDExecute):
items := input["items"]
v := reflect.ValueOf(items)
input := make([]any, v.Len())
subworkflowInputs := make([]any, v.Len())
for i := 0; i < v.Len(); i++ {
item := v.Index(i).Interface()
_, err := r.workflow.Input().Unserialize(item)
if err != nil {
return fmt.Errorf("invalid input item %d for subworkflow (%w) for run/step %s", i, err, r.runID)
}
input[i] = item
subworkflowInputs[i] = item
}
if r.executionInputAvailable {
return fmt.Errorf("input for execute workflow provided twice for run/step %s", r.runID)
}
parallelismInput := input["parallelism"]
var parallelism int64
if parallelismInput != nil {
serializedParallelismInput, err := parallelismSchema.Unserialize(parallelismInput)
if err != nil {
return fmt.Errorf("failed to unserialized parallelism input for run/step %s: %w", r.runID, err)
}
parallelism = serializedParallelismInput.(int64)
} else {
parallelism = int64(1)
}
jaredoconnell marked this conversation as resolved.
Show resolved Hide resolved

if r.currentState == step.RunningStepStateWaitingForInput && r.currentStage == StageIDExecute {
r.currentState = step.RunningStepStateRunning
}
r.executionInputAvailable = true
r.inputData <- input // Send before unlock to ensure that it never gets closed before sending.
// Send before unlock to ensure that it never gets closed before sending.
r.executeInput <- executeInput{
data: subworkflowInputs,
parallelism: parallelism,
}
return nil
case string(StageIDOutputs):
return nil
Expand Down Expand Up @@ -548,7 +568,7 @@ func (r *runningStep) Close() error {
r.cancel()
r.wg.Wait()
r.logger.Debugf("Closing inputData channel in foreach step provider")
close(r.inputData)
close(r.executeInput)
return nil
}

Expand Down Expand Up @@ -759,7 +779,7 @@ func (r *runningStep) markStageFailures(firstStage StageID, err error) {

func (r *runningStep) runOnInput() {
select {
case loopData, ok := <-r.inputData:
case loopData, ok := <-r.executeInput:
if !ok {
r.logger.Debugf("aborted waiting for result in foreach")
return
Expand All @@ -771,9 +791,9 @@ func (r *runningStep) runOnInput() {
}
}

func (r *runningStep) processInput(inputData []any) {
func (r *runningStep) processInput(input executeInput) {
r.logger.Debugf("Executing subworkflow for step %s...", r.runID)
outputs, errors := r.executeSubWorkflows(inputData)
outputs, errors := r.executeSubWorkflows(input)

r.logger.Debugf("Subworkflow %s complete.", r.runID)
r.lock.Lock()
Expand All @@ -788,7 +808,7 @@ func (r *runningStep) processInput(inputData []any) {
unresolvableStage = StageIDOutputs
unresolvableError = fmt.Errorf("foreach subworkflow failed with errors (%v)", errors)
outputID = "error"
dataMap := make(map[int]any, len(inputData))
dataMap := make(map[int]any, len(input.data))
for i, entry := range outputs {
if entry != nil {
dataMap[i] = entry
Expand Down Expand Up @@ -832,13 +852,13 @@ func (r *runningStep) processInput(inputData []any) {
}

// returns true if there is an error.
func (r *runningStep) executeSubWorkflows(inputData []any) ([]any, map[int]string) {
itemOutputs := make([]any, len(inputData))
itemErrors := make(map[int]string, len(inputData))
func (r *runningStep) executeSubWorkflows(input executeInput) ([]any, map[int]string) {
itemOutputs := make([]any, len(input.data))
itemErrors := make(map[int]string, len(input.data))
wg := &sync.WaitGroup{}
wg.Add(len(inputData))
sem := make(chan struct{}, r.parallelism)
for i, input := range inputData {
wg.Add(len(input.data))
sem := make(chan struct{}, input.parallelism)
for i, input := range input.data {
i := i
input := input
go func() {
Expand Down
83 changes: 83 additions & 0 deletions workflow/workflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,89 @@ func TestWaitForSerial_Foreach(t *testing.T) {
}
}

var parallelismForeachWf = `
version: v0.2.0
input:
root: RootObject
objects:
RootObject:
id: RootObject
properties:
parallelism:
type:
type_id: integer
steps:
subwf_step:
kind: foreach
items:
- wait_time_ms: 0
- wait_time_ms: 0
- wait_time_ms: 0
workflow: subworkflow.yaml
parallelism: !expr $.input.parallelism
outputs:
success:
first_step_output: !expr $.steps.subwf_step.outputs
`

func TestForeachWithParallelismExpr(t *testing.T) {
// This test involves a workflow where the parallelism value is
// set by an expression as opposed to a literal.

logConfig := log.Config{
Level: log.LevelDebug,
Destination: log.DestinationStdout,
}
logger := log.New(
logConfig,
)
cfg := &config.Config{
Log: logConfig,
}
factories := workflowFactory{
config: cfg,
}
deployerRegistry := deployerregistry.New(
deployer.Any(testimpl.NewFactory()),
)

pluginProvider := assert.NoErrorR[step.Provider](t)(
plugin.New(logger, deployerRegistry, map[string]interface{}{
"builtin": map[string]any{
"deployer_name": "test-impl",
"deploy_time": "0",
},
}),
)
stepRegistry, err := stepregistry.New(
pluginProvider,
lang.Must2(foreach.New(logger, factories.createYAMLParser, factories.createWorkflow)),
)
assert.NoError(t, err)

factories.stepRegistry = stepRegistry
executor := lang.Must2(workflow.NewExecutor(
logger,
cfg,
stepRegistry,
builtinfunctions.GetFunctions(),
))
wf := lang.Must2(workflow.NewYAMLConverter(stepRegistry).FromYAML([]byte(parallelismForeachWf)))
preparedWorkflow := lang.Must2(executor.Prepare(wf, map[string][]byte{
"subworkflow.yaml": []byte(waitForSerialForeachSubwf),
}))
outputID, _, err := preparedWorkflow.Execute(context.Background(), map[string]any{
"parallelism": 1,
})
assert.NoError(t, err)
assert.Equals(t, outputID, "success")
outputID, _, err = preparedWorkflow.Execute(context.Background(), map[string]any{
"parallelism": 10,
})
assert.NoError(t, err)
assert.Equals(t, outputID, "success")
}

var waitForStartedForeachWf = `
version: v0.2.0
input:
Expand Down
Loading