diff --git a/README.md b/README.md index d94bb0c..6f4db61 100644 --- a/README.md +++ b/README.md @@ -11,35 +11,39 @@ Small Go utility that executes business actions in a pipeline. ```go import ( + "context" pipeline "github.com/ccremer/go-command-pipeline" "github.com/ccremer/go-command-pipeline/predicate" ) +type Data struct { + Number int +} + func main() { - number := 0 // define arbitrary data to pass around in the steps. + data := &Data // define arbitrary data to pass around in the steps. p := pipeline.NewPipeline() - p.WithContext(&number) p.WithSteps( pipeline.NewStep("define random number", defineNumber), pipeline.NewStepFromFunc("print number", printNumber), ) - result := p.Run() + result := p.RunWithContext(context.WithValue(context.Background, "data", data)) if !result.IsSuccessful() { log.Fatal(result.Err) } } -func defineNumber(ctx pipeline.Context) pipeline.Result { - ctx.(*int) = 10 +func defineNumber(ctx context.Context) pipeline.Result { + ctx.Value("data").(*Data).Number = 10 return pipeline.Result{} } // Let's assume this is a business function that can fail. // You can enable "automatic" fail-on-first-error pipelines by having more small functions that return errors. -func printNumber(ctx pipeline.Context) error { - number := ctx.(*int) - _, err := fmt.Println(*number) - return err +func printNumber(ctx context.Context) error { + number := ctx.Value("data").(*Data).Number + fmt.Println(number) + return nil } ``` @@ -70,18 +74,18 @@ We have tons of `if err != nil` that bloats the function with more error handlin It could be simplified to something like this: ```go -func Persist(data Data) error { - p := pipeline.NewPipeline().WithContext(data).WithSteps( +func Persist(data *Data) error { + p := pipeline.NewPipeline().WithSteps( pipeline.NewStep("prepareTransaction", prepareTransaction()), pipeline.NewStep("executeQuery", executeQuery()), pipeline.NewStep("commitTransaction", commit()), ) - return p.Run().Err + return p.RunWithContext(context.WithValue(context.Background(), myKey, data).Err } func executeQuery() pipeline.ActionFunc { - return func(ctx pipeline.Context) pipeline.Result { - data := ctx.(Data) + return func(ctx context.Context) pipeline.Result { + data := ctx.Value(myKey).(*Data) err := database.executeQuery("SOME QUERY", data) return pipeline.Result{Err: err} ) diff --git a/examples/abort_test.go b/examples/abort_test.go index 89ef693..a0547a7 100644 --- a/examples/abort_test.go +++ b/examples/abort_test.go @@ -4,6 +4,7 @@ package examples import ( + "context" "errors" "testing" @@ -22,11 +23,11 @@ func TestExample_Abort(t *testing.T) { assert.True(t, result.IsAborted()) } -func doNotExecute(_ pipeline.Context) error { +func doNotExecute(_ context.Context) error { return errors.New("should not execute") } -func abort(_ pipeline.Context) error { +func abort(_ context.Context) error { // some logic that can handle errors, but you don't want to bubble up the error. // terminate pipeline gracefully diff --git a/examples/context_test.go b/examples/context_test.go index 423c9d5..54a1350 100644 --- a/examples/context_test.go +++ b/examples/context_test.go @@ -4,6 +4,7 @@ package examples import ( + "context" "fmt" "math/rand" "testing" @@ -15,26 +16,27 @@ type Data struct { Number int } +var key = struct{}{} + func TestExample_Context(t *testing.T) { // Create pipeline with defaults p := pipeline.NewPipeline() - p.WithContext(&Data{}) p.WithSteps( pipeline.NewStep("define random number", defineNumber), pipeline.NewStepFromFunc("print number", printNumber), ) - result := p.Run() + result := p.RunWithContext(context.WithValue(context.Background(), key, &Data{})) if !result.IsSuccessful() { t.Fatal(result.Err) } } -func defineNumber(ctx pipeline.Context) pipeline.Result { - ctx.(*Data).Number = rand.Int() +func defineNumber(ctx context.Context) pipeline.Result { + ctx.Value(key).(*Data).Number = rand.Int() return pipeline.Result{} } -func printNumber(ctx pipeline.Context) error { - _, err := fmt.Println(ctx.(*Data).Number) +func printNumber(ctx context.Context) error { + _, err := fmt.Println(ctx.Value(key).(*Data).Number) return err } diff --git a/examples/git_test.go b/examples/git_test.go index 4f9350c..2efc77f 100644 --- a/examples/git_test.go +++ b/examples/git_test.go @@ -4,6 +4,7 @@ package examples import ( + "context" "log" "os" "os/exec" @@ -26,27 +27,27 @@ func TestExample_Git(t *testing.T) { } } -func logSuccess(ctx pipeline.Context, result pipeline.Result) error { +func logSuccess(_ context.Context, result pipeline.Result) error { log.Println("handler called") return result.Err } func CloneGitRepository() pipeline.ActionFunc { - return func(_ pipeline.Context) pipeline.Result { + return func(_ context.Context) pipeline.Result { err := execGitCommand("clone", "git@github.com/ccremer/go-command-pipeline") return pipeline.Result{Err: err} } } func Pull() pipeline.ActionFunc { - return func(_ pipeline.Context) pipeline.Result { + return func(_ context.Context) pipeline.Result { err := execGitCommand("pull") return pipeline.Result{Err: err} } } func CheckoutBranch() pipeline.ActionFunc { - return func(_ pipeline.Context) pipeline.Result { + return func(_ context.Context) pipeline.Result { err := execGitCommand("checkout", "master") return pipeline.Result{Err: err} } @@ -61,7 +62,7 @@ func execGitCommand(args ...string) error { } func DirExists(path string) predicate.Predicate { - return func(_ pipeline.Context) bool { + return func(_ context.Context) bool { if info, err := os.Stat(path); err != nil || !info.IsDir() { return false } diff --git a/examples/hooks_test.go b/examples/hooks_test.go index ce1c7da..52cc252 100644 --- a/examples/hooks_test.go +++ b/examples/hooks_test.go @@ -4,6 +4,7 @@ package examples import ( + "context" "fmt" "testing" @@ -16,7 +17,7 @@ func TestExample_Hooks(t *testing.T) { fmt.Println(fmt.Sprintf("Executing step: %s", step.Name)) }) p.WithSteps( - pipeline.NewStep("hook demo", AfterHookAction()), + pipeline.NewStepFromFunc("hook demo", AfterHookAction), ) result := p.Run() if !result.IsSuccessful() { @@ -24,9 +25,7 @@ func TestExample_Hooks(t *testing.T) { } } -func AfterHookAction() pipeline.ActionFunc { - return func(ctx pipeline.Context) pipeline.Result { - fmt.Println("I am called in an action after the hooks") - return pipeline.Result{} - } +func AfterHookAction(_ context.Context) error { + fmt.Println("I am called in an action after the hooks") + return nil } diff --git a/go.mod b/go.mod index c9f03a3..513eff8 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,10 @@ module github.com/ccremer/go-command-pipeline go 1.17 -require github.com/stretchr/testify v1.7.0 +require ( + github.com/stretchr/testify v1.7.0 + go.uber.org/goleak v1.1.12 +) require ( github.com/davecgh/go-spew v1.1.0 // indirect diff --git a/go.sum b/go.sum index acb88a4..c6dbbef 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,47 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= +go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.5 h1:ouewzE6p+/VEB31YYnTbEJdi8pFqKp4P4n85vwo3DHA= +golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/options_test.go b/options_test.go index 8ea59da..770b7ec 100644 --- a/options_test.go +++ b/options_test.go @@ -1,6 +1,7 @@ package pipeline import ( + "context" "errors" "testing" @@ -11,7 +12,7 @@ func TestPipeline_WithOptions(t *testing.T) { t.Run("DisableErrorWrapping", func(t *testing.T) { p := NewPipeline().WithOptions(DisableErrorWrapping) p.WithSteps( - NewStepFromFunc("disabled error wrapping", func(_ Context) error { + NewStepFromFunc("disabled error wrapping", func(_ context.Context) error { return errors.New("some error") }), ) diff --git a/parallel/fanout.go b/parallel/fanout.go index e0fb385..99ed28c 100644 --- a/parallel/fanout.go +++ b/parallel/fanout.go @@ -1,6 +1,7 @@ package parallel import ( + "context" "sync" pipeline "github.com/ccremer/go-command-pipeline" @@ -11,18 +12,20 @@ NewFanOutStep creates a pipeline step that runs nested pipelines in their own Go The function provided as PipelineSupplier is expected to close the given channel when no more pipelines should be executed, otherwise this step blocks forever. The step waits until all pipelines are finished. If the given ResultHandler is non-nil it will be called after all pipelines were run, otherwise the step is considered successful. -The given pipelines have to define their own pipeline.Context, it's not passed "down" from parent pipeline. -However, The pipeline.Context for the ResultHandler will be the one from parent pipeline. + +If the context is canceled, no new pipelines will be retrieved from the channel and the PipelineSupplier is expected to stop supplying new instances. +Also, once canceled, the step waits for the remaining children pipelines and collects their result via given ResultHandler. +However, the error returned from ResultHandler is wrapped in context.Canceled. */ func NewFanOutStep(name string, pipelineSupplier PipelineSupplier, handler ResultHandler) pipeline.Step { step := pipeline.Step{Name: name} - step.F = func(ctx pipeline.Context) pipeline.Result { + step.F = func(ctx context.Context) pipeline.Result { pipelineChan := make(chan *pipeline.Pipeline) m := sync.Map{} var wg sync.WaitGroup i := uint64(0) - go pipelineSupplier(pipelineChan) + go pipelineSupplier(ctx, pipelineChan) for pipe := range pipelineChan { p := pipe wg.Add(1) @@ -30,12 +33,12 @@ func NewFanOutStep(name string, pipelineSupplier PipelineSupplier, handler Resul i++ go func() { defer wg.Done() - m.Store(n, p.Run()) + m.Store(n, p.RunWithContext(ctx)) }() } - wg.Wait() - return collectResults(ctx, handler, &m) + res := collectResults(ctx, handler, &m) + return setResultErrorFromContext(ctx, res) } return step } diff --git a/parallel/fanout_test.go b/parallel/fanout_test.go index 246b36a..19dc126 100644 --- a/parallel/fanout_test.go +++ b/parallel/fanout_test.go @@ -1,6 +1,7 @@ package parallel import ( + "context" "fmt" "sync/atomic" "testing" @@ -8,6 +9,7 @@ import ( pipeline "github.com/ccremer/go-command-pipeline" "github.com/stretchr/testify/assert" + "go.uber.org/goleak" ) func TestNewFanOutStep(t *testing.T) { @@ -29,7 +31,7 @@ func TestNewFanOutStep(t *testing.T) { "GivenPipelineWith_WhenRunningStep_ThenReturnSuccessButRunErrorHandler": { jobs: 1, returnErr: fmt.Errorf("should be called"), - givenResultHandler: func(ctx pipeline.Context, _ map[uint64]pipeline.Result) pipeline.Result { + givenResultHandler: func(ctx context.Context, _ map[uint64]pipeline.Result) pipeline.Result { atomic.AddUint64(&counts, 1) return pipeline.Result{} }, @@ -39,43 +41,79 @@ func TestNewFanOutStep(t *testing.T) { for name, tt := range tests { counts = 0 t.Run(name, func(t *testing.T) { + goleak.VerifyNone(t) handler := tt.givenResultHandler if handler == nil { - handler = func(ctx pipeline.Context, results map[uint64]pipeline.Result) pipeline.Result { + handler = func(ctx context.Context, results map[uint64]pipeline.Result) pipeline.Result { assert.NoError(t, results[0].Err) return pipeline.Result{} } } - step := NewFanOutStep("fanout", func(funcs chan *pipeline.Pipeline) { + step := NewFanOutStep("fanout", func(_ context.Context, funcs chan *pipeline.Pipeline) { defer close(funcs) for i := 0; i < tt.jobs; i++ { - funcs <- pipeline.NewPipeline().WithSteps(pipeline.NewStep("step", func(_ pipeline.Context) pipeline.Result { + funcs <- pipeline.NewPipeline().WithSteps(pipeline.NewStep("step", func(_ context.Context) pipeline.Result { atomic.AddUint64(&counts, 1) return pipeline.Result{Err: tt.returnErr} })) } }, handler) - result := step.F(nil) + result := step.F(context.Background()) assert.NoError(t, result.Err) - assert.Equal(t, uint64(tt.expectedCounts), counts) + assert.Equal(t, tt.expectedCounts, int(counts)) }) } } +func TestNewFanOutStep_Cancel(t *testing.T) { + defer goleak.VerifyNone(t) + var counts uint64 + step := NewFanOutStep("fanout", func(ctx context.Context, pipelines chan *pipeline.Pipeline) { + defer close(pipelines) + for i := 0; i < 10000; i++ { + select { + case <-ctx.Done(): + return + default: + pipelines <- pipeline.NewPipeline().WithSteps(pipeline.NewStepFromFunc("increase", func(_ context.Context) error { + atomic.AddUint64(&counts, 1) + return nil + })) + time.Sleep(10 * time.Millisecond) + } + } + t.Fail() // should not reach this + }, func(ctx context.Context, results map[uint64]pipeline.Result) pipeline.Result { + assert.Len(t, results, 3) + return pipeline.Result{Err: fmt.Errorf("some error")} + }) + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel() + result := pipeline.NewPipeline().WithSteps(step).RunWithContext(ctx) + assert.Equal(t, 3, int(counts)) + assert.True(t, result.IsCanceled(), "canceled flag") + assert.EqualError(t, result.Err, `step "fanout" failed: context deadline exceeded, collection error: some error`) +} + func ExampleNewFanOutStep() { p := pipeline.NewPipeline() - fanout := NewFanOutStep("fanout", func(pipelines chan *pipeline.Pipeline) { + fanout := NewFanOutStep("fanout", func(ctx context.Context, pipelines chan *pipeline.Pipeline) { defer close(pipelines) // create some pipelines for i := 0; i < 3; i++ { n := i - pipelines <- pipeline.NewPipeline().AddStep(pipeline.NewStep(fmt.Sprintf("i = %d", n), func(_ pipeline.Context) pipeline.Result { - time.Sleep(time.Duration(n * 10000000)) // fake some load - fmt.Println(fmt.Sprintf("I am worker %d", n)) - return pipeline.Result{} - })) + select { + case <-ctx.Done(): + return // parent pipeline has been canceled, let's not create more pipelines. + default: + pipelines <- pipeline.NewPipeline().AddStep(pipeline.NewStep(fmt.Sprintf("i = %d", n), func(_ context.Context) pipeline.Result { + time.Sleep(time.Duration(n * 10000000)) // fake some load + fmt.Println(fmt.Sprintf("I am worker %d", n)) + return pipeline.Result{} + })) + } } - }, func(ctx pipeline.Context, results map[uint64]pipeline.Result) pipeline.Result { + }, func(ctx context.Context, results map[uint64]pipeline.Result) pipeline.Result { for worker, result := range results { if result.IsFailed() { fmt.Println(fmt.Sprintf("Worker %d failed: %v", worker, result.Err)) diff --git a/parallel/pool.go b/parallel/pool.go index 1b8dd22..32f067f 100644 --- a/parallel/pool.go +++ b/parallel/pool.go @@ -1,6 +1,7 @@ package parallel import ( + "context" "sync" "sync/atomic" @@ -15,36 +16,35 @@ The step waits until all pipelines are finished. * The pipelines are executed in a pool of a number of Go routines indicated by size. * If size is 1, the pipelines are effectively run in sequence. * If size is 0 or less, the function panics. -The given pipelines have to define their own pipeline.Context, it's not passed "down" from parent pipeline. -However, The pipeline.Context for the ResultHandler will be the one from parent pipeline. */ func NewWorkerPoolStep(name string, size int, pipelineSupplier PipelineSupplier, handler ResultHandler) pipeline.Step { if size < 1 { panic("pool size cannot be lower than 1") } step := pipeline.Step{Name: name} - step.F = func(ctx pipeline.Context) pipeline.Result { + step.F = func(ctx context.Context) pipeline.Result { pipelineChan := make(chan *pipeline.Pipeline, size) m := sync.Map{} var wg sync.WaitGroup count := uint64(0) - go pipelineSupplier(pipelineChan) + go pipelineSupplier(ctx, pipelineChan) for i := 0; i < size; i++ { wg.Add(1) - go poolWork(pipelineChan, &wg, &count, &m) + go poolWork(ctx, pipelineChan, &wg, &count, &m) } wg.Wait() - return collectResults(ctx, handler, &m) + res := collectResults(ctx, handler, &m) + return setResultErrorFromContext(ctx, res) } return step } -func poolWork(pipelineChan chan *pipeline.Pipeline, wg *sync.WaitGroup, i *uint64, m *sync.Map) { +func poolWork(ctx context.Context, pipelineChan chan *pipeline.Pipeline, wg *sync.WaitGroup, i *uint64, m *sync.Map) { defer wg.Done() for pipe := range pipelineChan { n := atomic.AddUint64(i, 1) - 1 - m.Store(n, pipe.Run()) + m.Store(n, pipe.RunWithContext(ctx)) } } diff --git a/parallel/pool_test.go b/parallel/pool_test.go index 1fc9e15..60b8eca 100644 --- a/parallel/pool_test.go +++ b/parallel/pool_test.go @@ -1,6 +1,7 @@ package parallel import ( + "context" "errors" "fmt" "sync/atomic" @@ -9,6 +10,8 @@ import ( pipeline "github.com/ccremer/go-command-pipeline" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" ) func TestNewWorkerPoolStep(t *testing.T) { @@ -28,6 +31,7 @@ func TestNewWorkerPoolStep(t *testing.T) { } for name, tt := range tests { t.Run(name, func(t *testing.T) { + goleak.VerifyNone(t) counts = 0 if tt.expectPanic { assert.Panics(t, func() { @@ -35,36 +39,86 @@ func TestNewWorkerPoolStep(t *testing.T) { }) return } - step := NewWorkerPoolStep("pool", 1, func(pipelines chan *pipeline.Pipeline) { + step := NewWorkerPoolStep("pool", 1, func(ctx context.Context, pipelines chan *pipeline.Pipeline) { defer close(pipelines) - pipelines <- pipeline.NewPipeline().AddStep(pipeline.NewStep("step", func(_ pipeline.Context) pipeline.Result { + pipelines <- pipeline.NewPipeline().AddStep(pipeline.NewStep("step", func(_ context.Context) pipeline.Result { atomic.AddUint64(&counts, 1) return pipeline.Result{Err: tt.expectedError} })) - }, func(ctx pipeline.Context, results map[uint64]pipeline.Result) pipeline.Result { + }, func(ctx context.Context, results map[uint64]pipeline.Result) pipeline.Result { assert.Error(t, results[0].Err) return pipeline.Result{Err: results[0].Err} }) - result := step.F(nil) + result := step.F(context.Background()) assert.Error(t, result.Err) }) } } +func TestNewWorkerPoolStep_Cancel(t *testing.T) { + defer goleak.VerifyNone(t) + var counts uint64 + step := NewWorkerPoolStep("workerpool", 2, func(ctx context.Context, pipelines chan *pipeline.Pipeline) { + defer close(pipelines) + for i := 0; i < 10000; i++ { + select { + case <-ctx.Done(): + return + default: + pipelines <- pipeline.NewPipeline().WithSteps( + pipeline.NewStepFromFunc("noop", func(_ context.Context) error { return nil }), + pipeline.NewStepFromFunc("increase", func(_ context.Context) error { + atomic.AddUint64(&counts, 1) + time.Sleep(10 * time.Millisecond) + return nil + })) + } + } + t.Fail() // should not reach this + }, func(ctx context.Context, results map[uint64]pipeline.Result) pipeline.Result { + require.Len(t, results, 9) + for r := uint64(0); r < 6; r++ { + // The first 6 jobs are successful + assert.Equal(t, "increase", results[r].Name()) + assert.False(t, results[r].IsCanceled()) + assert.NoError(t, results[r].Err) + } + for r := uint64(6); r < 9; r++ { + // remaining jobs were cancelled + assert.Equal(t, "noop", results[r].Name()) + assert.True(t, results[r].IsCanceled()) + assert.EqualError(t, results[r].Err, `step "noop" failed: context deadline exceeded`) + } + return pipeline.Result{} + }) + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel() + result := pipeline.NewPipeline().WithSteps(step).RunWithContext(ctx) + assert.Equal(t, 6, int(counts), "successful increments") + assert.True(t, result.IsCanceled(), "overall canceled flag") + assert.False(t, result.IsSuccessful(), "overall success flag") + assert.EqualError(t, result.Err, `step "workerpool" failed: context deadline exceeded`) +} + func ExampleNewWorkerPoolStep() { p := pipeline.NewPipeline() - pool := NewWorkerPoolStep("pool", 2, func(pipelines chan *pipeline.Pipeline) { + pool := NewWorkerPoolStep("pool", 2, func(ctx context.Context, pipelines chan *pipeline.Pipeline) { defer close(pipelines) // create some pipelines for i := 0; i < 3; i++ { n := i - pipelines <- pipeline.NewPipeline().AddStep(pipeline.NewStep(fmt.Sprintf("i = %d", n), func(_ pipeline.Context) pipeline.Result { - time.Sleep(time.Duration(n * 100000000)) // fake some load - fmt.Println(fmt.Sprintf("This is job item %d", n)) - return pipeline.Result{} - })) + select { + case <-ctx.Done(): + return // parent pipeline has been canceled, let's not create more pipelines. + default: + pipelines <- pipeline.NewPipeline().AddStep(pipeline.NewStep(fmt.Sprintf("i = %d", n), func(_ context.Context) pipeline.Result { + time.Sleep(time.Duration(n * 100000000)) // fake some load + fmt.Println(fmt.Sprintf("This is job item %d", n)) + return pipeline.Result{} + })) + } } - }, func(ctx pipeline.Context, results map[uint64]pipeline.Result) pipeline.Result { + }, func(ctx context.Context, results map[uint64]pipeline.Result) pipeline.Result { for jobIndex, result := range results { if result.IsFailed() { fmt.Println(fmt.Sprintf("Job %d failed: %v", jobIndex, result.Err)) diff --git a/parallel/resulthandler.go b/parallel/resulthandler.go new file mode 100644 index 0000000..c589948 --- /dev/null +++ b/parallel/resulthandler.go @@ -0,0 +1,40 @@ +package parallel + +import ( + "context" + "fmt" + "sync" + + pipeline "github.com/ccremer/go-command-pipeline" +) + +// ResultHandler is a callback that provides a result map and expect a single, combined pipeline.Result object. +// The map key is a zero-based index of n-th pipeline.Pipeline spawned, e.g. pipeline number 3 will have index 2. +// Return an empty pipeline.Result if you want to ignore errors, or reduce multiple errors into a single one to make the parent pipeline fail. +type ResultHandler func(ctx context.Context, results map[uint64]pipeline.Result) pipeline.Result + +func collectResults(ctx context.Context, handler ResultHandler, m *sync.Map) pipeline.Result { + collectiveResult := pipeline.Result{} + if handler != nil { + // convert sync.Map to conventional map for easier access + resultMap := make(map[uint64]pipeline.Result) + m.Range(func(key, value interface{}) bool { + resultMap[key.(uint64)] = value.(pipeline.Result) + return true + }) + collectiveResult = handler(ctx, resultMap) + } + return collectiveResult +} + +func setResultErrorFromContext(ctx context.Context, result pipeline.Result) pipeline.Result { + if ctx.Err() != nil { + if result.Err != nil { + result.Err = fmt.Errorf("%w, collection error: %v", ctx.Err(), result.Err) + return pipeline.Canceled(result) + } + result.Err = ctx.Err() + return pipeline.Canceled(result) + } + return result +} diff --git a/parallel/supplier.go b/parallel/supplier.go new file mode 100644 index 0000000..128c762 --- /dev/null +++ b/parallel/supplier.go @@ -0,0 +1,20 @@ +/* +Package parallel extends the command-pipeline core with concurrency steps. +*/ +package parallel + +import ( + "context" + + pipeline "github.com/ccremer/go-command-pipeline" +) + +// PipelineSupplier is a function that spawns pipeline.Pipeline for consumption. +// Supply new pipelines by putting new pipeline.Pipeline instances into the given channel. +// The function must close the channel once all pipelines are spawned (`defer close()` recommended). +// +// The parent pipeline may get canceled, thus the given context is provided to stop putting more pipeline.Pipeline instances into the channel. +// Use +// select { case <-ctx.Done(): return, default: pipelinesChan <- ... } +// to cancel the supply, otherwise you may leak an orphaned goroutine. +type PipelineSupplier func(ctx context.Context, pipelinesChan chan *pipeline.Pipeline) diff --git a/parallel/types.go b/parallel/types.go deleted file mode 100644 index 3f792e7..0000000 --- a/parallel/types.go +++ /dev/null @@ -1,35 +0,0 @@ -/* -Package parallel extends the command-pipeline core with concurrency steps. -*/ -package parallel - -import ( - "sync" - - pipeline "github.com/ccremer/go-command-pipeline" -) - -type ( - // ResultHandler is a callback that provides a result map and expect a single, combined pipeline.Result object. - // The map key is a zero-based index of n-th pipeline.Pipeline spawned, e.g. pipeline number 3 will have index 2. - // Context may be nil. - // Return an empty pipeline.Result if you want to ignore errors, or reduce multiple errors into a single one to make the parent pipeline fail. - ResultHandler func(ctx pipeline.Context, results map[uint64]pipeline.Result) pipeline.Result - // PipelineSupplier is a function that spawns pipeline.Pipeline for consumption. - // The function must close the channel once all pipelines are spawned (`defer close()` recommended). - PipelineSupplier func(chan *pipeline.Pipeline) -) - -func collectResults(ctx pipeline.Context, handler ResultHandler, m *sync.Map) pipeline.Result { - collectiveResult := pipeline.Result{} - if handler != nil { - // convert sync.Map to conventional map for easier access - resultMap := make(map[uint64]pipeline.Result) - m.Range(func(key, value interface{}) bool { - resultMap[key.(uint64)] = value.(pipeline.Result) - return true - }) - collectiveResult = handler(ctx, resultMap) - } - return collectiveResult -} diff --git a/pipeline.go b/pipeline.go index 564d41f..cf072b5 100644 --- a/pipeline.go +++ b/pipeline.go @@ -1,65 +1,47 @@ package pipeline import ( + "context" "errors" "fmt" ) -type ( - // Pipeline holds and runs intermediate actions, called "steps". - Pipeline struct { - steps []Step - context Context - beforeHooks []Listener - finalizer ResultHandler - options options - } - // Result is the object that is returned after each step and after running a pipeline. - Result struct { - // Err contains the step's returned error, nil otherwise. - // In an aborted pipeline with ErrAbort it will still be nil. - Err error - // Name is an optional identifier for a result. - // ActionFunc may set this property before returning to help a ResultHandler with further processing. - Name string - - aborted bool - } - // Step is an intermediary action and part of a Pipeline. - Step struct { - // Name describes the step's human-readable name. - // It has no other uses other than easily identifying a step for debugging or logging. - Name string - // F is the ActionFunc assigned to a pipeline Step. - // This is required. - F ActionFunc - // H is the ResultHandler assigned to a pipeline Step. - // This is optional, and it will be called in any case if it is set after F completed. - // Use cases could be logging, updating a GUI or handle errors while continuing the pipeline. - // The function may return nil even if the Result contains an error, in which case the pipeline will continue. - // This function is called before the next step's F is invoked. - H ResultHandler - } - // Context contains arbitrary data relevant for the pipeline execution. - Context interface{} - // Listener is a simple func that listens to Pipeline events. - Listener func(step Step) - // ActionFunc is the func that contains your business logic. - // The context is a user-defined arbitrary data of type interface{} that gets provided in every Step, but may be nil if not set. - ActionFunc func(ctx Context) Result - // ResultHandler is a func that gets called when a step's ActionFunc has finished with any Result. - // Context may be nil. - ResultHandler func(ctx Context, result Result) error -) +// Pipeline holds and runs intermediate actions, called "steps". +type Pipeline struct { + steps []Step + beforeHooks []Listener + finalizer ResultHandler + options options +} -// NewPipeline returns a new quiet Pipeline instance with KeyValueContext. -func NewPipeline() *Pipeline { - return &Pipeline{} +// Step is an intermediary action and part of a Pipeline. +type Step struct { + // Name describes the step's human-readable name. + // It has no other uses other than easily identifying a step for debugging or logging. + Name string + // F is the ActionFunc assigned to a pipeline Step. + // This is required. + F ActionFunc + // H is the ResultHandler assigned to a pipeline Step. + // This is optional, and it will be called in any case if it is set after F completed. + // Use cases could be logging, updating a GUI or handle errors while continuing the pipeline. + // The function may return nil even if the Result contains an error, in which case the pipeline will continue. + // This function is called before the next step's F is invoked. + H ResultHandler } -// NewPipelineWithContext returns a new Pipeline instance with the given context. -func NewPipelineWithContext(ctx Context) *Pipeline { - return &Pipeline{context: ctx} +// Listener is a simple func that listens to Pipeline events. +type Listener func(step Step) + +// ActionFunc is the func that contains your business logic. +type ActionFunc func(ctx context.Context) Result + +// ResultHandler is a func that gets called when a step's ActionFunc has finished with any Result. +type ResultHandler func(ctx context.Context, result Result) error + +// NewPipeline returns a new Pipeline instance. +func NewPipeline() *Pipeline { + return &Pipeline{} } // WithBeforeHooks takes a list of listeners. @@ -82,7 +64,7 @@ func (p *Pipeline) AddStep(step Step) *Pipeline { return p } -// WithSteps appends the given arrway of steps to the Pipeline at the end and returns itself. +// WithSteps appends the given array of steps to the Pipeline at the end and returns itself. func (p *Pipeline) WithSteps(steps ...Step) *Pipeline { p.steps = steps return p @@ -90,70 +72,89 @@ func (p *Pipeline) WithSteps(steps ...Step) *Pipeline { // WithNestedSteps is similar to AsNestedStep, but it accepts the steps given directly as parameters. func (p *Pipeline) WithNestedSteps(name string, steps ...Step) Step { - return NewStep(name, func(_ Context) Result { - nested := &Pipeline{beforeHooks: p.beforeHooks, steps: steps, context: p.context, options: p.options} - return nested.Run() + return NewStep(name, func(ctx context.Context) Result { + nested := &Pipeline{beforeHooks: p.beforeHooks, steps: steps, options: p.options} + return nested.RunWithContext(ctx) }) } // AsNestedStep converts the Pipeline instance into a Step that can be used in other pipelines. // The properties are passed to the nested pipeline. func (p *Pipeline) AsNestedStep(name string) Step { - return NewStep(name, func(_ Context) Result { - nested := &Pipeline{beforeHooks: p.beforeHooks, steps: p.steps, context: p.context, options: p.options} - return nested.Run() + return NewStep(name, func(ctx context.Context) Result { + nested := &Pipeline{beforeHooks: p.beforeHooks, steps: p.steps, options: p.options} + return nested.RunWithContext(ctx) }) } -// WithContext returns itself while setting the context for the pipeline steps. -func (p *Pipeline) WithContext(ctx Context) *Pipeline { - p.context = ctx - return p -} - // WithFinalizer returns itself while setting the finalizer for the pipeline. // The finalizer is a handler that gets called after the last step is in the pipeline is completed. -// If a pipeline aborts early then it is also called. +// If a pipeline aborts early or gets canceled then it is also called. func (p *Pipeline) WithFinalizer(handler ResultHandler) *Pipeline { p.finalizer = handler return p } -// Run executes the pipeline and returns the result. +// Run executes the pipeline with context.Background and returns the result. // Steps are executed sequentially as they were added to the Pipeline. // If a Step returns a Result with a non-nil error, the Pipeline is aborted and its Result contains the affected step's error. // However, if Result.Err is wrapped in ErrAbort, then the pipeline is aborted, but the final Result.Err will be nil. func (p *Pipeline) Run() Result { - result := p.doRun() + return p.RunWithContext(context.Background()) +} + +// RunWithContext is like Run but with a given context.Context. +// Upon cancellation of the context, the pipeline does not terminate a currently running step, instead it skips the remaining steps in the execution order. +// The context is passed to each Step.F and each Step may need to listen to the context cancellation event to truly cancel a long-running step. +// If the pipeline gets canceled, Result.IsCanceled returns true and Result.Err contains the context's error. +func (p *Pipeline) RunWithContext(ctx context.Context) Result { + result := p.doRun(ctx) if p.finalizer != nil { - result.Err = p.finalizer(p.context, result) + result.Err = p.finalizer(ctx, result) } return result } -func (p *Pipeline) doRun() Result { +func (p *Pipeline) doRun(ctx context.Context) Result { + name := "" for _, step := range p.steps { - for _, hooks := range p.beforeHooks { - hooks(step) - } + name = step.Name + select { + case <-ctx.Done(): + result := p.fail(ctx.Err(), step) + return result + default: + for _, hooks := range p.beforeHooks { + hooks(step) + } - result := step.F(p.context) - var err error - if step.H != nil { - err = step.H(p.context, result) - } else { - err = result.Err - } - if err != nil { - if errors.Is(err, ErrAbort) { - // Abort pipeline without error - return Result{aborted: true} + result := step.F(ctx) + var err error + if step.H != nil { + result.name = step.Name + err = step.H(ctx, result) + } else { + err = result.Err } - if p.options.disableErrorWrapping { - return Result{Err: err} + if err != nil { + if errors.Is(err, ErrAbort) { + // Abort pipeline without error + return Result{aborted: true, name: step.Name} + } + return p.fail(err, step) } - return Result{Err: fmt.Errorf("step '%s' failed: %w", step.Name, err)} } } - return Result{} + return Result{name: name} +} + +func (p *Pipeline) fail(err error, step Step) Result { + result := Result{name: step.Name} + if p.options.disableErrorWrapping { + result.Err = err + } else { + result.Err = fmt.Errorf("step %q failed: %w", step.Name, err) + } + result.canceled = errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) + return result } diff --git a/pipeline_test.go b/pipeline_test.go index 0c1db81..398c2c5 100644 --- a/pipeline_test.go +++ b/pipeline_test.go @@ -1,8 +1,11 @@ package pipeline import ( + "context" "errors" + "fmt" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -29,7 +32,7 @@ func TestPipeline_Run(t *testing.T) { }{ "GivenSingleStep_WhenRunning_ThenCallStep": { givenSteps: []Step{ - NewStep("test-step", func(_ Context) Result { + NewStep("test-step", func(_ context.Context) Result { callCount += 1 return Result{} }), @@ -38,7 +41,7 @@ func TestPipeline_Run(t *testing.T) { }, "GivenSingleStep_WhenBeforeHookGiven_ThenCallBeforeHook": { givenSteps: []Step{ - NewStep("test-step", func(_ Context) Result { + NewStep("test-step", func(_ context.Context) Result { callCount += hook.calls + 1 return Result{} }), @@ -47,7 +50,7 @@ func TestPipeline_Run(t *testing.T) { expectedCalls: 2, }, "GivenPipelineWithFinalizer_WhenRunning_ThenCallHandler": { - givenFinalizer: func(_ Context, result Result) error { + givenFinalizer: func(_ context.Context, result Result) error { callCount += 1 return nil }, @@ -55,7 +58,7 @@ func TestPipeline_Run(t *testing.T) { }, "GivenSingleStepWithoutHandler_WhenRunningWithError_ThenReturnError": { givenSteps: []Step{ - NewStep("test-step", func(_ Context) Result { + NewStep("test-step", func(_ context.Context) Result { callCount += 1 return Result{Err: errors.New("step failed")} }), @@ -65,11 +68,11 @@ func TestPipeline_Run(t *testing.T) { }, "GivenStepWithErrAbort_WhenRunningWithErrAbort_ThenDoNotExecuteNextSteps": { givenSteps: []Step{ - NewStepFromFunc("test-step", func(_ Context) error { + NewStepFromFunc("test-step", func(_ context.Context) error { callCount += 1 return ErrAbort }), - NewStepFromFunc("step-should-not-execute", func(_ Context) error { + NewStepFromFunc("step-should-not-execute", func(_ context.Context) error { callCount += 1 return errors.New("should not execute") }), @@ -81,14 +84,14 @@ func TestPipeline_Run(t *testing.T) { }, "GivenSingleStepWithHandler_WhenRunningWithError_ThenAbortWithError": { givenSteps: []Step{ - NewStep("test-step", func(_ Context) Result { + NewStep("test-step", func(_ context.Context) Result { callCount += 1 return Result{} - }).WithResultHandler(func(_ Context, result Result) error { + }).WithResultHandler(func(_ context.Context, result Result) error { callCount += 1 return errors.New("handler") }), - NewStep("don't run this step", func(_ Context) Result { + NewStep("don't run this step", func(_ context.Context) Result { callCount += 1 return Result{} }), @@ -98,14 +101,14 @@ func TestPipeline_Run(t *testing.T) { }, "GivenSingleStepWithHandler_WhenNullifyingError_ThenContinuePipeline": { givenSteps: []Step{ - NewStep("test-step", func(_ Context) Result { + NewStep("test-step", func(_ context.Context) Result { callCount += 1 return Result{Err: errors.New("failed step")} - }).WithResultHandler(func(_ Context, result Result) error { + }).WithResultHandler(func(_ context.Context, result Result) error { callCount += 1 return nil }), - NewStep("continue", func(_ Context) Result { + NewStep("continue", func(_ context.Context) Result { callCount += 1 return Result{} }), @@ -114,12 +117,12 @@ func TestPipeline_Run(t *testing.T) { }, "GivenNestedPipeline_WhenParentPipelineRuns_ThenRunNestedAsWell": { givenSteps: []Step{ - NewStep("test-step", func(_ Context) Result { + NewStep("test-step", func(_ context.Context) Result { callCount += 1 return Result{} }), NewPipeline(). - AddStep(NewStep("nested-step", func(_ Context) Result { + AddStep(NewStep("nested-step", func(_ context.Context) Result { callCount += 1 return Result{} })).AsNestedStep("nested-pipeline"), @@ -130,7 +133,7 @@ func TestPipeline_Run(t *testing.T) { givenSteps: []Step{ NewPipeline(). WithNestedSteps("nested-pipeline", - NewStep("nested-step", func(_ Context) Result { + NewStep("nested-step", func(_ context.Context) Result { callCount += 1 return Result{} })), @@ -164,31 +167,59 @@ func TestPipeline_Run(t *testing.T) { } } -func TestPipeline_RunWithContext(t *testing.T) { - t.Run("custom type", func(t *testing.T) { - context := "some type" - p := NewPipelineWithContext(context) - p.AddStep(NewStep("context", func(ctx Context) Result { - assert.Equal(t, context, ctx) - return Result{} - })) - result := p.Run() - require.NoError(t, result.Err) - }) - t.Run("nil context", func(t *testing.T) { - p := NewPipeline().WithContext(nil) - p.AddStep(NewStep("context", func(ctx Context) Result { - assert.Nil(t, ctx) - return Result{} - })) - result := p.Run() - require.NoError(t, result.Err) - }) +func TestPipeline_RunWithContext_CancelLongRunningStep(t *testing.T) { + p := NewPipeline().WithSteps( + NewStepFromFunc("long running", func(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + // doing nothing + } + } + }), + ) + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + go func() { + time.Sleep(5 * time.Millisecond) + cancel() + }() + result := p.RunWithContext(ctx) + assert.True(t, result.IsCanceled(), "IsCanceled()") + assert.Equal(t, "long running", result.Name()) + assert.EqualError(t, result.Err, "step \"long running\" failed: context canceled") +} + +func ExamplePipeline_RunWithContext() { + // prepare pipeline + p := NewPipeline().WithSteps( + NewStepFromFunc("short step", func(ctx context.Context) error { + fmt.Println("short step") + return nil + }), + NewStepFromFunc("long running step", func(ctx context.Context) error { + time.Sleep(100 * time.Millisecond) + return nil + }), + NewStepFromFunc("canceled step", func(ctx context.Context) error { + return errors.New("shouldn't execute") + }), + ) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + result := p.RunWithContext(ctx) + // inspect the result + fmt.Println(result.IsCanceled()) + fmt.Println(result.Err) + // Output: short step + // true + // step "canceled step" failed: context deadline exceeded } func TestNewStepFromFunc(t *testing.T) { called := false - step := NewStepFromFunc("name", func(ctx Context) error { + step := NewStepFromFunc("name", func(ctx context.Context) error { called = true return nil }) diff --git a/predicate/predicate.go b/predicate/predicate.go index 8a9dccd..c6dd843 100644 --- a/predicate/predicate.go +++ b/predicate/predicate.go @@ -1,22 +1,24 @@ package predicate import ( + "context" + pipeline "github.com/ccremer/go-command-pipeline" ) type ( // Predicate is a function that expects 'true' if a pipeline.ActionFunc should run. // It is evaluated lazily resp. only when needed. - Predicate func(ctx pipeline.Context) bool + Predicate func(ctx context.Context) bool ) // ToStep wraps the given action func in its own step. // When the step's function is called, the given Predicate will evaluate whether the action should actually run. // It returns the action's pipeline.Result, otherwise an empty (successful) pipeline.Result struct. -// The pipeline.Context from the pipeline is passed through the given action. +// The context.Context from the pipeline is passed through the given action. func ToStep(name string, action pipeline.ActionFunc, predicate Predicate) pipeline.Step { step := pipeline.Step{Name: name} - step.F = func(ctx pipeline.Context) pipeline.Result { + step.F = func(ctx context.Context) pipeline.Result { if predicate(ctx) { return action(ctx) } @@ -28,10 +30,10 @@ func ToStep(name string, action pipeline.ActionFunc, predicate Predicate) pipeli // ToNestedStep wraps the given pipeline in its own step. // When the step's function is called, the given Predicate will evaluate whether the nested pipeline.Pipeline should actually run. // It returns the pipeline's pipeline.Result, otherwise an empty (successful) pipeline.Result struct. -// The given pipeline has to define its own pipeline.Context, it's not passed "down". +// The given pipeline has to define its own context.Context, it's not passed "down". func ToNestedStep(name string, predicate Predicate, p *pipeline.Pipeline) pipeline.Step { step := pipeline.Step{Name: name} - step.F = func(ctx pipeline.Context) pipeline.Result { + step.F = func(ctx context.Context) pipeline.Result { if predicate(ctx) { return p.Run() } @@ -41,10 +43,10 @@ func ToNestedStep(name string, predicate Predicate, p *pipeline.Pipeline) pipeli } // If returns a new step that wraps the given step and executes its action only if the given Predicate evaluates true. -// The pipeline.Context from the pipeline is passed through the given action. +// The context.Context from the pipeline is passed through the given action. func If(predicate Predicate, originalStep pipeline.Step) pipeline.Step { wrappedStep := pipeline.Step{Name: originalStep.Name} - wrappedStep.F = func(ctx pipeline.Context) pipeline.Result { + wrappedStep.F = func(ctx context.Context) pipeline.Result { if predicate(ctx) { return originalStep.F(ctx) } @@ -56,7 +58,7 @@ func If(predicate Predicate, originalStep pipeline.Step) pipeline.Step { // Bool returns a Predicate that simply returns v when evaluated. // Use BoolPtr() over Bool() if the value can change between setting up the pipeline and evaluating the predicate. func Bool(v bool) Predicate { - return func(_ pipeline.Context) bool { + return func(_ context.Context) bool { return v } } @@ -64,14 +66,14 @@ func Bool(v bool) Predicate { // BoolPtr returns a Predicate that returns *v when evaluated. // Use BoolPtr() over Bool() if the value can change between setting up the pipeline and evaluating the predicate. func BoolPtr(v *bool) Predicate { - return func(_ pipeline.Context) bool { + return func(_ context.Context) bool { return *v } } // Not returns a Predicate that evaluates, but then negates the given Predicate. func Not(predicate Predicate) Predicate { - return func(ctx pipeline.Context) bool { + return func(ctx context.Context) bool { return !predicate(ctx) } } @@ -79,7 +81,7 @@ func Not(predicate Predicate) Predicate { // And returns a Predicate that does logical AND of the given predicates. // p2 is not evaluated if p1 evaluates already to false. func And(p1, p2 Predicate) Predicate { - return func(ctx pipeline.Context) bool { + return func(ctx context.Context) bool { return p1(ctx) && p2(ctx) } } @@ -87,7 +89,7 @@ func And(p1, p2 Predicate) Predicate { // Or returns a Predicate that does logical OR of the given predicates. // p2 is not evaluated if p1 evaluates already to true. func Or(p1, p2 Predicate) Predicate { - return func(ctx pipeline.Context) bool { + return func(ctx context.Context) bool { return p1(ctx) || p2(ctx) } } diff --git a/predicate/predicate_test.go b/predicate/predicate_test.go index 8e1e7a7..46939c6 100644 --- a/predicate/predicate_test.go +++ b/predicate/predicate_test.go @@ -1,6 +1,7 @@ package predicate import ( + "context" "testing" pipeline "github.com/ccremer/go-command-pipeline" @@ -58,7 +59,7 @@ func Test_Predicates(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { counter = 0 - step := ToStep("name", func(_ pipeline.Context) pipeline.Result { + step := ToStep("name", func(_ context.Context) pipeline.Result { counter += 1 return pipeline.Result{} }, tt.givenPredicate) @@ -88,7 +89,7 @@ func TestToNestedStep(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { counter = 0 - p := pipeline.NewPipeline().AddStep(pipeline.NewStep("nested step", func(_ pipeline.Context) pipeline.Result { + p := pipeline.NewPipeline().AddStep(pipeline.NewStep("nested step", func(_ context.Context) pipeline.Result { counter++ return pipeline.Result{} })) @@ -118,7 +119,7 @@ func TestIf(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { counter = 0 - step := pipeline.NewStep("step", func(_ pipeline.Context) pipeline.Result { + step := pipeline.NewStep("step", func(_ context.Context) pipeline.Result { counter++ return pipeline.Result{} }) @@ -135,7 +136,7 @@ func TestBoolPtr(t *testing.T) { called := false b := false p := pipeline.NewPipeline().WithSteps( - If(BoolPtr(&b), pipeline.NewStepFromFunc("boolptr", func(_ pipeline.Context) error { + If(BoolPtr(&b), pipeline.NewStepFromFunc("boolptr", func(_ context.Context) error { called = true return nil })), @@ -146,14 +147,14 @@ func TestBoolPtr(t *testing.T) { } func truePredicate(counter *int) Predicate { - return func(_ pipeline.Context) bool { + return func(_ context.Context) bool { *counter++ return true } } func falsePredicate(counter *int) Predicate { - return func(_ pipeline.Context) bool { + return func(_ context.Context) bool { *counter-- return false } diff --git a/result.go b/result.go index 2c20d26..0e05660 100644 --- a/result.go +++ b/result.go @@ -2,9 +2,25 @@ package pipeline import "errors" -// ErrAbort indicates that the pipeline should be terminated immediately without returning an error. +// ErrAbort indicates that the pipeline should be terminated immediately without being marked as failed (returning an error). var ErrAbort = errors.New("abort") +// Result is the object that is returned after each step and after running a pipeline. +type Result struct { + // Err contains the step's returned error, nil otherwise. + // In an aborted pipeline with ErrAbort it will still be nil. + Err error + + name string + aborted bool + canceled bool +} + +// Name retrieves the name of the (last) step that has been executed. +func (r Result) Name() string { + return r.name +} + // IsSuccessful returns true if the contained error is nil. // Aborted pipelines (with ErrAbort) are still reported as success. // To query if a pipeline is aborted early, use IsAborted. @@ -21,3 +37,14 @@ func (r Result) IsFailed() bool { func (r Result) IsAborted() bool { return r.aborted } + +// IsCanceled returns true if the pipeline's context has been canceled. +func (r Result) IsCanceled() bool { + return r.canceled +} + +// Canceled sets Result.IsCanceled to true. +func Canceled(result Result) Result { + result.canceled = true + return result +} diff --git a/step.go b/step.go index 1ebff06..b2dcddb 100644 --- a/step.go +++ b/step.go @@ -1,5 +1,7 @@ package pipeline +import "context" + // NewStep returns a new Step with given name and action. func NewStep(name string, action ActionFunc) Step { return Step{ @@ -9,10 +11,10 @@ func NewStep(name string, action ActionFunc) Step { } // NewStepFromFunc returns a new Step with given name using a function that expects an error. -func NewStepFromFunc(name string, fn func(ctx Context) error) Step { - return NewStep(name, func(ctx Context) Result { +func NewStepFromFunc(name string, fn func(ctx context.Context) error) Step { + return NewStep(name, func(ctx context.Context) Result { err := fn(ctx) - return Result{Err: err, Name: name} + return Result{Err: err, name: name} }) } @@ -24,8 +26,8 @@ func (s Step) WithResultHandler(handler ResultHandler) Step { // WithErrorHandler wraps given errorHandler and sets the ResultHandler of this specific step and returns the step itself. // The difference to WithResultHandler is that errorHandler only gets called if Result.Err is non-nil. -func (s Step) WithErrorHandler(errorHandler func(ctx Context, err error) error) Step { - s.H = func(ctx Context, result Result) error { +func (s Step) WithErrorHandler(errorHandler func(ctx context.Context, err error) error) Step { + s.H = func(ctx context.Context, result Result) error { if result.IsFailed() { return errorHandler(ctx, result.Err) } diff --git a/step_test.go b/step_test.go index da0d44d..f62b78c 100644 --- a/step_test.go +++ b/step_test.go @@ -1,6 +1,7 @@ package pipeline import ( + "context" "errors" "testing" @@ -24,9 +25,9 @@ func TestStep_WithErrorHandler(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { executed := false - s := NewStepFromFunc("test", func(_ Context) error { + s := NewStepFromFunc("test", func(_ context.Context) error { return nil - }).WithErrorHandler(func(_ Context, err error) error { + }).WithErrorHandler(func(_ context.Context, err error) error { executed = true return err })