Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions base.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,14 @@ func (b *Base) SetRetry(retry int, maxRetry int) {
}

func (b *Base) Cancel() {
b.SetState(StateCanceling)
b.cancel()
switch b.State {
case StateSucceeded, StateCanceled, StateFailed:
return
}
if !isCanceled(b.ctx) {
b.SetState(StateCanceling)
b.cancel()
}
}

func (b *Base) Ctx() context.Context {
Expand Down
15 changes: 7 additions & 8 deletions manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"runtime"
"sync/atomic"

nanoid "github.com/matoous/go-nanoid/v2"
"github.com/OpenListTeam/gsync"
nanoid "github.com/matoous/go-nanoid/v2"
)

// Manager is the manager of all tasks
Expand Down Expand Up @@ -73,16 +73,15 @@ func (m *Manager[T]) Add(task T) {
if _, maxRetry := task.GetRetry(); maxRetry == 0 {
task.SetRetry(0, m.opts.MaxRetry)
}
if sliceContains([]State{StateRunning}, task.GetState()) {
switch task.GetState() {
case StateRunning:
task.SetState(StatePending)
}
if sliceContains([]State{StateCanceling}, task.GetState()) {
case StateFailing:
task.SetState(StateFailed)
case StateCanceling:
task.SetState(StateCanceled)
task.SetErr(context.Canceled)
}
if task.GetState() == StateFailing {
task.SetState(StateFailed)
}
m.tasks.Store(task.GetID(), task)
if !sliceContains([]State{StateSucceeded, StateCanceled, StateErrored, StateFailed}, task.GetState()) {
m.queue.Push(task)
Expand Down Expand Up @@ -116,7 +115,7 @@ func (m *Manager[T]) next() {
m.workers.Put(worker)
m.next()
}()
if task.GetState() == StateCanceling {
if isCanceled(task.Ctx()) {
task.SetState(StateCanceled)
task.SetErr(context.Canceled)
return
Expand Down
47 changes: 46 additions & 1 deletion manager_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package tache_test

import (
"github.com/OpenListTeam/tache"
"context"
"errors"
"log/slog"
"os"
"sync/atomic"
"testing"
"time"

"github.com/OpenListTeam/tache"
)

type TestTask struct {
Expand All @@ -26,6 +29,48 @@ func TestManager_Add(t *testing.T) {
t.Logf("%+v", task)
}

func TestCancelRunningTask(t *testing.T) {
tm := tache.NewManager[*TestTask](tache.WithWorks(1))
start := make(chan struct{})
task := &TestTask{
do: func(task *TestTask) error {
close(start)
<-task.Ctx().Done()
return task.Ctx().Err()
},
}
tm.Add(task)
<-start
task.Cancel()
tm.Wait()
if task.GetState() != tache.StateCanceled {
t.Fatalf("task should be canceled, got %v", task.GetState())
}
if !errors.Is(task.GetErr(), context.Canceled) {
t.Fatalf("task error should be context.Canceled, got %v", task.GetErr())
}
}

func TestCancelFinishedTask(t *testing.T) {
tm := tache.NewManager[*TestTask]()
task := &TestTask{
do: func(task *TestTask) error {
return nil
},
}
tm.Add(task)
tm.Wait()

if task.GetState() != tache.StateSucceeded {
t.Fatalf("task should be succeeded before cancel, got %v", task.GetState())
}

task.Cancel()
if task.GetState() != tache.StateSucceeded {
t.Fatalf("cancel should not change finished task state, got %v", task.GetState())
}
}

func TestWithRetry(t *testing.T) {
tm := tache.NewManager[*TestTask](tache.WithMaxRetry(3), tache.WithWorks(1))
var num atomic.Int64
Expand Down
2 changes: 1 addition & 1 deletion state.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ type State int

const (
// StatePending is the state of a task when it is pending
StatePending = iota
StatePending State = iota
// StateRunning is the state of a task when it is running
StateRunning
// StateSucceeded is the state of a task when it succeeded
Expand Down
16 changes: 13 additions & 3 deletions worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package tache
import (
"container/list"
"context"
"errors"
"fmt"
"log"
"sync"
Expand All @@ -14,11 +13,22 @@ type Worker[T Task] struct {
ID int
}

func isCanceled(ctx context.Context) bool {
select {
case <-ctx.Done():
return true
default:
return false
}
}

// Execute executes the task
func (w Worker[T]) Execute(task T) {
onError := func(err error) {
task.SetErr(err)
if errors.Is(err, context.Canceled) {
finalState := StateFailed
if isCanceled(task.Ctx()) {
finalState = StateCanceled
task.SetState(StateCanceled)
} else {
task.SetState(StateErrored)
Expand All @@ -28,7 +38,7 @@ func (w Worker[T]) Execute(task T) {
task.SetState(StateFailing)
hook.OnFailed()
}
task.SetState(StateFailed)
task.SetState(finalState)
}
}
// Retry immediately in the same worker until success or max retry exhausted
Expand Down