diff --git a/taskqueue/taskqueue.go b/taskqueue/taskqueue.go index 67810ca3..86ba5667 100644 --- a/taskqueue/taskqueue.go +++ b/taskqueue/taskqueue.go @@ -2,6 +2,7 @@ package taskqueue import ( "context" + "sync/atomic" "time" "github.com/ipfs/go-peertaskqueue" @@ -32,7 +33,9 @@ type WorkerTaskQueue struct { cancelFn func() peerTaskQueue *peertaskqueue.PeerTaskQueue workSignal chan struct{} + noTaskSignal chan struct{} ticker *time.Ticker + activeTasks int32 } // NewTaskQueue initializes a new queue @@ -43,6 +46,7 @@ func NewTaskQueue(ctx context.Context) *WorkerTaskQueue { cancelFn: cancelFn, peerTaskQueue: peertaskqueue.New(), workSignal: make(chan struct{}, 1), + noTaskSignal: make(chan struct{}, 1), ticker: time.NewTicker(thawSpeed), } } @@ -88,6 +92,16 @@ func (tq *WorkerTaskQueue) Shutdown() { tq.cancelFn() } +func (tq *WorkerTaskQueue) WaitForNoActiveTasks() { + for atomic.LoadInt32(&tq.activeTasks) > 0 { + select { + case <-tq.ctx.Done(): + return + case <-tq.noTaskSignal: + } + } +} + func (tq *WorkerTaskQueue) worker(executor Executor) { targetWork := 1 for { @@ -104,7 +118,14 @@ func (tq *WorkerTaskQueue) worker(executor Executor) { } } for _, task := range tasks { + atomic.AddInt32(&tq.activeTasks, 1) terminate := executor.ExecuteTask(tq.ctx, pid, task) + if atomic.AddInt32(&tq.activeTasks, -1) == 0 { + select { + case tq.noTaskSignal <- struct{}{}: + default: + } + } if terminate { return }