diff --git a/consumer.go b/consumer.go index a4d404f..c61b28b 100644 --- a/consumer.go +++ b/consumer.go @@ -48,8 +48,8 @@ func (s *Consumer) handle(m *job.Message) error { // run custom process function var err error - shouldRetry := true - for shouldRetry { + loop: + for { if m.Task != nil { err = m.Task(ctx) } else { @@ -62,7 +62,12 @@ func (s *Consumer) handle(m *job.Message) error { } m.RetryCount-- - <-time.After(m.RetryDelay) + select { + case <-time.After(m.RetryDelay): // retry delay time + case <-ctx.Done(): // timeout reached + err = ctx.Err() + break loop + } } done <- err diff --git a/consumer_test.go b/consumer_test.go index e8c37b4..78fe3a6 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -458,6 +458,7 @@ func TestRetryCountWithNewMessage(t *testing.T) { m.EXPECT().Bytes().Return([]byte("test")).AnyTimes() messages := make(chan string, 10) + keep := make(chan struct{}) count := 1 w := NewConsumer( @@ -466,6 +467,7 @@ func TestRetryCountWithNewMessage(t *testing.T) { count++ return errors.New("count not correct") } + close(keep) messages <- string(m.Bytes()) return nil }), @@ -485,6 +487,8 @@ func TestRetryCountWithNewMessage(t *testing.T) { )) assert.Len(t, messages, 0) q.Start() + // wait retry twice. + <-keep q.Release() assert.Len(t, messages, 1) } @@ -502,12 +506,15 @@ func TestRetryCountWithNewTask(t *testing.T) { ) assert.NoError(t, err) + keep := make(chan struct{}) + assert.NoError(t, q.QueueTask( func(ctx context.Context) error { if count%3 != 0 { count++ return errors.New("count not correct") } + close(keep) messages <- "foobar" return nil }, @@ -516,6 +523,83 @@ func TestRetryCountWithNewTask(t *testing.T) { )) assert.Len(t, messages, 0) q.Start() + // wait retry twice. + <-keep q.Release() assert.Len(t, messages, 1) } + +func TestCancelRetryCountWithNewTask(t *testing.T) { + messages := make(chan string, 10) + count := 1 + + w := NewConsumer() + + q, err := NewQueue( + WithLogger(NewLogger()), + WithWorker(w), + WithWorkerCount(1), + ) + assert.NoError(t, err) + + assert.NoError(t, q.QueueTask( + func(ctx context.Context) error { + if count%3 != 0 { + count++ + q.logger.Info("add count") + return errors.New("count not correct") + } + messages <- "foobar" + return nil + }, + job.WithRetryCount(3), + job.WithRetryDelay(100*time.Millisecond), + )) + assert.Len(t, messages, 0) + q.Start() + time.Sleep(50 * time.Millisecond) + q.Release() + assert.Len(t, messages, 0) + assert.Equal(t, 2, count) +} + +func TestCancelRetryCountWithNewMessage(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + m := mocks.NewMockQueuedMessage(controller) + m.EXPECT().Bytes().Return([]byte("test")).AnyTimes() + + messages := make(chan string, 10) + count := 1 + + w := NewConsumer( + WithFn(func(ctx context.Context, m core.QueuedMessage) error { + if count%3 != 0 { + count++ + return errors.New("count not correct") + } + messages <- string(m.Bytes()) + return nil + }), + ) + + q, err := NewQueue( + WithLogger(NewLogger()), + WithWorker(w), + WithWorkerCount(1), + ) + assert.NoError(t, err) + + assert.NoError(t, q.Queue( + m, + job.WithRetryCount(3), + job.WithRetryDelay(100*time.Millisecond), + )) + assert.Len(t, messages, 0) + q.Start() + time.Sleep(50 * time.Millisecond) + q.Release() + assert.Len(t, messages, 0) + assert.Equal(t, 2, count) +}