Skip to content

Commit 057d2c7

Browse files
committed
Enable cancellation of pipelines
1 parent 198c1c4 commit 057d2c7

File tree

8 files changed

+125
-47
lines changed

8 files changed

+125
-47
lines changed

examples/context_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@ type Data struct {
1919
func TestExample_Context(t *testing.T) {
2020
// Create pipeline with defaults
2121
p := pipeline.NewPipeline()
22-
p.WithContext(context.WithValue(context.Background(), "data", &Data{}))
2322
p.WithSteps(
2423
pipeline.NewStep("define random number", defineNumber),
2524
pipeline.NewStepFromFunc("print number", printNumber),
2625
)
27-
result := p.Run()
26+
result := p.RunWithContext(context.WithValue(context.Background(), "data", &Data{}))
2827
if !result.IsSuccessful() {
2928
t.Fatal(result.Err)
3029
}

parallel/fanout.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ NewFanOutStep creates a pipeline step that runs nested pipelines in their own Go
1212
The function provided as PipelineSupplier is expected to close the given channel when no more pipelines should be executed, otherwise this step blocks forever.
1313
The step waits until all pipelines are finished.
1414
If the given ResultHandler is non-nil it will be called after all pipelines were run, otherwise the step is considered successful.
15-
The given pipelines have to define their own context.Context, it's not passed "down" from parent pipeline.
16-
However, The context.Context for the ResultHandler will be the one from parent pipeline.
1715
*/
1816
func NewFanOutStep(name string, pipelineSupplier PipelineSupplier, handler ResultHandler) pipeline.Step {
1917
step := pipeline.Step{Name: name}

parallel/pool.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ The step waits until all pipelines are finished.
1616
* The pipelines are executed in a pool of a number of Go routines indicated by size.
1717
* If size is 1, the pipelines are effectively run in sequence.
1818
* If size is 0 or less, the function panics.
19-
The given pipelines have to define their own context.Context, it's not passed "down" from parent pipeline.
20-
However, The context.Context for the ResultHandler will be the one from parent pipeline.
2119
*/
2220
func NewWorkerPoolStep(name string, size int, pipelineSupplier PipelineSupplier, handler ResultHandler) pipeline.Step {
2321
if size < 1 {
@@ -33,7 +31,7 @@ func NewWorkerPoolStep(name string, size int, pipelineSupplier PipelineSupplier,
3331
go pipelineSupplier(pipelineChan)
3432
for i := 0; i < size; i++ {
3533
wg.Add(1)
36-
go poolWork(pipelineChan, &wg, &count, &m)
34+
go poolWork(ctx, pipelineChan, &wg, &count, &m)
3735
}
3836

3937
wg.Wait()
@@ -42,10 +40,10 @@ func NewWorkerPoolStep(name string, size int, pipelineSupplier PipelineSupplier,
4240
return step
4341
}
4442

45-
func poolWork(pipelineChan chan *pipeline.Pipeline, wg *sync.WaitGroup, i *uint64, m *sync.Map) {
43+
func poolWork(ctx context.Context, pipelineChan chan *pipeline.Pipeline, wg *sync.WaitGroup, i *uint64, m *sync.Map) {
4644
defer wg.Done()
4745
for pipe := range pipelineChan {
4846
n := atomic.AddUint64(i, 1) - 1
49-
m.Store(n, pipe.Run())
47+
m.Store(n, pipe.RunWithContext(ctx))
5048
}
5149
}

parallel/pool_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ func TestNewWorkerPoolStep(t *testing.T) {
4646
assert.Error(t, results[0].Err)
4747
return pipeline.Result{Err: results[0].Err}
4848
})
49-
result := step.F(nil)
49+
result := step.F(context.Background())
5050
assert.Error(t, result.Err)
5151
})
5252
}

pipeline.go

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,6 @@ type Pipeline struct {
1414
options options
1515
}
1616

17-
// Result is the object that is returned after each step and after running a pipeline.
18-
type Result struct {
19-
// Err contains the step's returned error, nil otherwise.
20-
// In an aborted pipeline with ErrAbort it will still be nil.
21-
Err error
22-
// Name is an optional identifier for a result.
23-
// ActionFunc may set this property before returning to help a ResultHandler with further processing.
24-
Name string
25-
26-
aborted bool
27-
}
28-
2917
// Step is an intermediary action and part of a Pipeline.
3018
type Step struct {
3119
// Name describes the step's human-readable name.
@@ -84,24 +72,24 @@ func (p *Pipeline) WithSteps(steps ...Step) *Pipeline {
8472

8573
// WithNestedSteps is similar to AsNestedStep, but it accepts the steps given directly as parameters.
8674
func (p *Pipeline) WithNestedSteps(name string, steps ...Step) Step {
87-
return NewStep(name, func(_ context.Context) Result {
75+
return NewStep(name, func(ctx context.Context) Result {
8876
nested := &Pipeline{beforeHooks: p.beforeHooks, steps: steps, options: p.options}
89-
return nested.Run()
77+
return nested.RunWithContext(ctx)
9078
})
9179
}
9280

9381
// AsNestedStep converts the Pipeline instance into a Step that can be used in other pipelines.
9482
// The properties are passed to the nested pipeline.
9583
func (p *Pipeline) AsNestedStep(name string) Step {
96-
return NewStep(name, func(_ context.Context) Result {
84+
return NewStep(name, func(ctx context.Context) Result {
9785
nested := &Pipeline{beforeHooks: p.beforeHooks, steps: p.steps, options: p.options}
98-
return nested.Run()
86+
return nested.RunWithContext(ctx)
9987
})
10088
}
10189

10290
// WithFinalizer returns itself while setting the finalizer for the pipeline.
10391
// The finalizer is a handler that gets called after the last step is in the pipeline is completed.
104-
// If a pipeline aborts early then it is also called.
92+
// If a pipeline aborts early or gets canceled then it is also called.
10593
func (p *Pipeline) WithFinalizer(handler ResultHandler) *Pipeline {
10694
p.finalizer = handler
10795
return p
@@ -116,6 +104,9 @@ func (p *Pipeline) Run() Result {
116104
}
117105

118106
// RunWithContext is like Run but with a given context.Context.
107+
// Upon cancellation of the context, the pipeline does not terminate a currently running step, instead it skips the remaining steps in the execution order.
108+
// The context is passed to each Step.F and each Step may need to listen to the context cancellation event to truly cancel a long-running step.
109+
// If the pipeline gets canceled, Result.IsCanceled returns true and Result.Err contains the context's error.
119110
func (p *Pipeline) RunWithContext(ctx context.Context) Result {
120111
result := p.doRun(ctx)
121112
if p.finalizer != nil {
@@ -125,28 +116,45 @@ func (p *Pipeline) RunWithContext(ctx context.Context) Result {
125116
}
126117

127118
func (p *Pipeline) doRun(ctx context.Context) Result {
119+
name := ""
128120
for _, step := range p.steps {
129-
for _, hooks := range p.beforeHooks {
130-
hooks(step)
131-
}
121+
name = step.Name
122+
select {
123+
case <-ctx.Done():
124+
result := p.fail(ctx.Err(), step)
125+
return result
126+
default:
127+
for _, hooks := range p.beforeHooks {
128+
hooks(step)
129+
}
132130

133-
result := step.F(ctx)
134-
var err error
135-
if step.H != nil {
136-
err = step.H(ctx, result)
137-
} else {
138-
err = result.Err
139-
}
140-
if err != nil {
141-
if errors.Is(err, ErrAbort) {
142-
// Abort pipeline without error
143-
return Result{aborted: true}
131+
result := step.F(ctx)
132+
var err error
133+
if step.H != nil {
134+
result.name = step.Name
135+
err = step.H(ctx, result)
136+
} else {
137+
err = result.Err
144138
}
145-
if p.options.disableErrorWrapping {
146-
return Result{Err: err}
139+
if err != nil {
140+
if errors.Is(err, ErrAbort) {
141+
// Abort pipeline without error
142+
return Result{aborted: true, name: step.Name}
143+
}
144+
return p.fail(err, step)
147145
}
148-
return Result{Err: fmt.Errorf("step '%s' failed: %w", step.Name, err)}
149146
}
150147
}
151-
return Result{}
148+
return Result{name: name}
149+
}
150+
151+
func (p *Pipeline) fail(err error, step Step) Result {
152+
result := Result{name: step.Name}
153+
if p.options.disableErrorWrapping {
154+
result.Err = err
155+
} else {
156+
result.Err = fmt.Errorf("step %q failed: %w", step.Name, err)
157+
}
158+
result.canceled = errors.Is(err, ErrCanceled) || errors.Is(err, context.DeadlineExceeded)
159+
return result
152160
}

pipeline_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ package pipeline
33
import (
44
"context"
55
"errors"
6+
"fmt"
67
"testing"
8+
"time"
79

810
"github.com/stretchr/testify/assert"
911
"github.com/stretchr/testify/require"
@@ -165,6 +167,55 @@ func TestPipeline_Run(t *testing.T) {
165167
}
166168
}
167169

170+
func TestPipeline_RunWithContext_CancelLongRunningStep(t *testing.T) {
171+
p := NewPipeline().WithSteps(
172+
NewStepFromFunc("long running", func(ctx context.Context) error {
173+
for {
174+
select {
175+
case <-ctx.Done():
176+
return ErrCanceled
177+
default:
178+
// doing nothing
179+
}
180+
}
181+
return nil
182+
}),
183+
)
184+
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
185+
result := p.RunWithContext(ctx)
186+
cancel()
187+
assert.True(t, result.IsCanceled(), "IsCanceled()")
188+
assert.Equal(t, "long running", result.Name())
189+
assert.EqualError(t, result.Err, "step \"long running\" failed: canceled")
190+
}
191+
192+
func ExamplePipeline_RunWithContext() {
193+
// prepare pipeline
194+
p := NewPipeline().WithSteps(
195+
NewStepFromFunc("short step", func(ctx context.Context) error {
196+
fmt.Println("short step")
197+
return nil
198+
}),
199+
NewStepFromFunc("long running step", func(ctx context.Context) error {
200+
time.Sleep(100 * time.Millisecond)
201+
return nil
202+
}),
203+
NewStepFromFunc("canceled step", func(ctx context.Context) error {
204+
return errors.New("shouldn't execute")
205+
}),
206+
)
207+
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
208+
result := p.RunWithContext(ctx)
209+
// cancel the pipeline
210+
cancel()
211+
// inspect the result
212+
fmt.Println(result.IsCanceled())
213+
fmt.Println(result.Err)
214+
// Output: short step
215+
// true
216+
// step "canceled step" failed: context deadline exceeded
217+
}
218+
168219
func TestNewStepFromFunc(t *testing.T) {
169220
called := false
170221
step := NewStepFromFunc("name", func(ctx context.Context) error {

result.go

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,28 @@ package pipeline
22

33
import "errors"
44

5-
// ErrAbort indicates that the pipeline should be terminated immediately without returning an error.
5+
// ErrAbort indicates that the pipeline should be terminated immediately without being marked as failed (returning an error).
66
var ErrAbort = errors.New("abort")
77

8+
// ErrCanceled indicates that the pipeline has been canceled.
9+
var ErrCanceled = errors.New("canceled")
10+
11+
// Result is the object that is returned after each step and after running a pipeline.
12+
type Result struct {
13+
// Err contains the step's returned error, nil otherwise.
14+
// In an aborted pipeline with ErrAbort it will still be nil.
15+
Err error
16+
17+
name string
18+
aborted bool
19+
canceled bool
20+
}
21+
22+
// Name retrieves the name of the (last) step that has been executed.
23+
func (r Result) Name() string {
24+
return r.name
25+
}
26+
827
// IsSuccessful returns true if the contained error is nil.
928
// Aborted pipelines (with ErrAbort) are still reported as success.
1029
// To query if a pipeline is aborted early, use IsAborted.
@@ -21,3 +40,8 @@ func (r Result) IsFailed() bool {
2140
func (r Result) IsAborted() bool {
2241
return r.aborted
2342
}
43+
44+
// IsCanceled returns true if the pipeline's context has been canceled.
45+
func (r Result) IsCanceled() bool {
46+
return r.canceled
47+
}

step.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ func NewStep(name string, action ActionFunc) Step {
1414
func NewStepFromFunc(name string, fn func(ctx context.Context) error) Step {
1515
return NewStep(name, func(ctx context.Context) Result {
1616
err := fn(ctx)
17-
return Result{Err: err, Name: name}
17+
return Result{Err: err, name: name}
1818
})
1919
}
2020

0 commit comments

Comments
 (0)