diff --git a/base.go b/base.go index 9b01114..d140eff 100644 --- a/base.go +++ b/base.go @@ -73,8 +73,14 @@ func (b *Base) SetRetry(retry int, maxRetry int) { } func (b *Base) Cancel() { - b.SetState(StateCanceling) - b.cancel() + switch b.State { + case StateSucceeded, StateCanceled, StateFailed: + return + } + if !isCanceled(b.ctx) { + b.SetState(StateCanceling) + b.cancel() + } } func (b *Base) Ctx() context.Context { diff --git a/manager.go b/manager.go index c53f716..16eb968 100644 --- a/manager.go +++ b/manager.go @@ -9,8 +9,8 @@ import ( "runtime" "sync/atomic" - nanoid "github.com/matoous/go-nanoid/v2" "github.com/OpenListTeam/gsync" + nanoid "github.com/matoous/go-nanoid/v2" ) // Manager is the manager of all tasks @@ -73,16 +73,15 @@ func (m *Manager[T]) Add(task T) { if _, maxRetry := task.GetRetry(); maxRetry == 0 { task.SetRetry(0, m.opts.MaxRetry) } - if sliceContains([]State{StateRunning}, task.GetState()) { + switch task.GetState() { + case StateRunning: task.SetState(StatePending) - } - if sliceContains([]State{StateCanceling}, task.GetState()) { + case StateFailing: + task.SetState(StateFailed) + case StateCanceling: task.SetState(StateCanceled) task.SetErr(context.Canceled) } - if task.GetState() == StateFailing { - task.SetState(StateFailed) - } m.tasks.Store(task.GetID(), task) if !sliceContains([]State{StateSucceeded, StateCanceled, StateErrored, StateFailed}, task.GetState()) { m.queue.Push(task) @@ -116,7 +115,7 @@ func (m *Manager[T]) next() { m.workers.Put(worker) m.next() }() - if task.GetState() == StateCanceling { + if isCanceled(task.Ctx()) { task.SetState(StateCanceled) task.SetErr(context.Canceled) return diff --git a/manager_test.go b/manager_test.go index d66fb96..a5cca7b 100644 --- a/manager_test.go +++ b/manager_test.go @@ -1,12 +1,15 @@ package tache_test import ( - "github.com/OpenListTeam/tache" + "context" + "errors" "log/slog" "os" "sync/atomic" "testing" "time" + + "github.com/OpenListTeam/tache" ) type TestTask struct { @@ -26,6 +29,48 @@ func TestManager_Add(t *testing.T) { t.Logf("%+v", task) } +func TestCancelRunningTask(t *testing.T) { + tm := tache.NewManager[*TestTask](tache.WithWorks(1)) + start := make(chan struct{}) + task := &TestTask{ + do: func(task *TestTask) error { + close(start) + <-task.Ctx().Done() + return task.Ctx().Err() + }, + } + tm.Add(task) + <-start + task.Cancel() + tm.Wait() + if task.GetState() != tache.StateCanceled { + t.Fatalf("task should be canceled, got %v", task.GetState()) + } + if !errors.Is(task.GetErr(), context.Canceled) { + t.Fatalf("task error should be context.Canceled, got %v", task.GetErr()) + } +} + +func TestCancelFinishedTask(t *testing.T) { + tm := tache.NewManager[*TestTask]() + task := &TestTask{ + do: func(task *TestTask) error { + return nil + }, + } + tm.Add(task) + tm.Wait() + + if task.GetState() != tache.StateSucceeded { + t.Fatalf("task should be succeeded before cancel, got %v", task.GetState()) + } + + task.Cancel() + if task.GetState() != tache.StateSucceeded { + t.Fatalf("cancel should not change finished task state, got %v", task.GetState()) + } +} + func TestWithRetry(t *testing.T) { tm := tache.NewManager[*TestTask](tache.WithMaxRetry(3), tache.WithWorks(1)) var num atomic.Int64 diff --git a/state.go b/state.go index 3b0e058..a7c6c7c 100644 --- a/state.go +++ b/state.go @@ -5,7 +5,7 @@ type State int const ( // StatePending is the state of a task when it is pending - StatePending = iota + StatePending State = iota // StateRunning is the state of a task when it is running StateRunning // StateSucceeded is the state of a task when it succeeded diff --git a/worker.go b/worker.go index 9da9682..f901930 100644 --- a/worker.go +++ b/worker.go @@ -3,7 +3,6 @@ package tache import ( "container/list" "context" - "errors" "fmt" "log" "sync" @@ -14,11 +13,22 @@ type Worker[T Task] struct { ID int } +func isCanceled(ctx context.Context) bool { + select { + case <-ctx.Done(): + return true + default: + return false + } +} + // Execute executes the task func (w Worker[T]) Execute(task T) { onError := func(err error) { task.SetErr(err) - if errors.Is(err, context.Canceled) { + finalState := StateFailed + if isCanceled(task.Ctx()) { + finalState = StateCanceled task.SetState(StateCanceled) } else { task.SetState(StateErrored) @@ -28,7 +38,7 @@ func (w Worker[T]) Execute(task T) { task.SetState(StateFailing) hook.OnFailed() } - task.SetState(StateFailed) + task.SetState(finalState) } } // Retry immediately in the same worker until success or max retry exhausted