Skip to content

Commit

Permalink
feat: concurrency control queue
Browse files Browse the repository at this point in the history
  • Loading branch information
kanzihuang committed May 2, 2024
1 parent 2b632b9 commit d3d176b
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 15 deletions.
50 changes: 37 additions & 13 deletions internal/rdb/rdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,31 @@ const statsTTL = 90 * 24 * time.Hour // 90 days
// LeaseDuration is the duration used to initially create a lease and to extend it thereafter.
const LeaseDuration = 30 * time.Second

type Option func(r *RDB)

func WithQueueConcurrency(queueConcurrency map[string]int) Option {
return func(r *RDB) {
r.queueConcurrency = queueConcurrency
}
}

// RDB is a client interface to query and mutate task queues.
type RDB struct {
client redis.UniversalClient
clock timeutil.Clock
client redis.UniversalClient
clock timeutil.Clock
queueConcurrency map[string]int
}

// NewRDB returns a new instance of RDB.
func NewRDB(client redis.UniversalClient) *RDB {
return &RDB{
func NewRDB(client redis.UniversalClient, opts ...Option) *RDB {
r := &RDB{
client: client,
clock: timeutil.NewRealClock(),
}
for _, opt := range opts {
opt(r)
}
return r
}

// Close closes the connection with redis server.
Expand Down Expand Up @@ -209,6 +222,7 @@ func (r *RDB) EnqueueUnique(ctx context.Context, msg *base.TaskMessage, ttl time
// --
// ARGV[1] -> initial lease expiration Unix time
// ARGV[2] -> task key prefix
// ARGV[3] -> queue concurrency
//
// Output:
// Returns nil if no processable task is found in the given queue.
Expand All @@ -217,15 +231,20 @@ func (r *RDB) EnqueueUnique(ctx context.Context, msg *base.TaskMessage, ttl time
// Note: dequeueCmd checks whether a queue is paused first, before
// calling RPOPLPUSH to pop a task from the queue.
var dequeueCmd = redis.NewScript(`
if redis.call("EXISTS", KEYS[2]) == 0 then
local id = redis.call("RPOPLPUSH", KEYS[1], KEYS[3])
if id then
local key = ARGV[2] .. id
redis.call("HSET", key, "state", "active")
redis.call("HDEL", key, "pending_since")
redis.call("ZADD", KEYS[4], ARGV[1], id)
return redis.call("HGET", key, "msg")
end
if redis.call("EXISTS", KEYS[2]) > 0 then
return nil
end
local count = redis.call("ZCARD", KEYS[4])
if (count >= tonumber(ARGV[3])) then
return nil
end
local id = redis.call("RPOPLPUSH", KEYS[1], KEYS[3])
if id then
local key = ARGV[2] .. id
redis.call("HSET", key, "state", "active")
redis.call("HDEL", key, "pending_since")
redis.call("ZADD", KEYS[4], ARGV[1], id)
return redis.call("HGET", key, "msg")
end
return nil`)

Expand All @@ -243,9 +262,14 @@ func (r *RDB) Dequeue(qnames ...string) (msg *base.TaskMessage, leaseExpirationT
base.LeaseKey(qname),
}
leaseExpirationTime = r.clock.Now().Add(LeaseDuration)
queueConcurrency, ok := r.queueConcurrency[qname]
if !ok || queueConcurrency <= 0 {
queueConcurrency = math.MaxInt
}
argv := []interface{}{
leaseExpirationTime.Unix(),
base.TaskKeyPrefix(qname),
queueConcurrency,
}
res, err := dequeueCmd.Run(context.Background(), r.client, keys, argv...).Result()
if err == redis.Nil {
Expand Down
81 changes: 81 additions & 0 deletions internal/rdb/rdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ func TestDequeue(t *testing.T) {
wantPending map[string][]*base.TaskMessage
wantActive map[string][]*base.TaskMessage
wantLease map[string][]base.Z
queueConcurrency map[string]int
}{
{
pending: map[string][]*base.TaskMessage{
Expand Down Expand Up @@ -441,6 +442,86 @@ func TestDequeue(t *testing.T) {
}
}

func TestDequeueWithQueueConcurrency(t *testing.T) {
r := setup(t)
defer r.Close()
now := time.Now()
r.SetClock(timeutil.NewSimulatedClock(now))
const taskNum = 3
msgs := make([]*base.TaskMessage, 0, taskNum)
for i := 0; i < taskNum; i++ {
msg := &base.TaskMessage{
ID: uuid.NewString(),
Type: "send_email",
Payload: h.JSON(map[string]interface{}{"subject": "hello!"}),
Queue: "default",
Timeout: 1800,
Deadline: 0,
}
msgs = append(msgs, msg)
}

tests := []struct {
name string
pending map[string][]*base.TaskMessage
qnames []string // list of queues to query
queueConcurrency map[string]int
wantMsgs []*base.TaskMessage
}{
{
name: "without queue concurrency control",
pending: map[string][]*base.TaskMessage{
"default": msgs,
},
qnames: []string{"default"},
wantMsgs: msgs,
},
{
name: "with queue concurrency control",
pending: map[string][]*base.TaskMessage{
"default": msgs,
},
qnames: []string{"default"},
queueConcurrency: map[string]int{"default": 2},
wantMsgs: msgs[:2],
},
{
name: "with queue concurrency zero",
pending: map[string][]*base.TaskMessage{
"default": msgs,
},
qnames: []string{"default"},
queueConcurrency: map[string]int{"default": 0},
wantMsgs: msgs,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
h.FlushDB(t, r.client) // clean up db before each test case
h.SeedAllPendingQueues(t, r.client, tc.pending)

r.queueConcurrency = tc.queueConcurrency
gotMsgs := make([]*base.TaskMessage, 0, len(msgs))
for i := 0; i < len(msgs); i++ {
msg, _, err := r.Dequeue(tc.qnames...)
if errors.Is(err, errors.ErrNoProcessableTask) {
break
}
if err != nil {
t.Errorf("(*RDB).Dequeue(%v) returned error %v", tc.qnames, err)
continue
}
gotMsgs = append(gotMsgs, msg)
}
if diff := cmp.Diff(tc.wantMsgs, gotMsgs, h.SortZSetEntryOpt); diff != "" {
t.Errorf("(*RDB).Dequeue(%v) returned message %v; want %v",
tc.qnames, gotMsgs, tc.wantMsgs)
}
})
}
}

func TestDequeueError(t *testing.T) {
r := setup(t)
defer r.Close()
Expand Down
9 changes: 7 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ type Config struct {
// If BaseContext is nil, the default is context.Background().
// If this is defined, then it MUST return a non-nil context
BaseContext func() context.Context

// TaskCheckInterval specifies the interval between checks for new tasks to process when all queues are empty.
//
// If unset, zero or a negative value, the interval is set to 1 second.
Expand Down Expand Up @@ -239,6 +239,11 @@ type Config struct {
//
// If unset or nil, the group aggregation feature will be disabled on the server.
GroupAggregator GroupAggregator

// Maximum number of concurrent tasks of a queue.
//
// If set to a zero or not set, NewServer will not limit concurrency of the queue.
QueueConcurrency map[string]int
}

// GroupAggregator aggregates a group of tasks into one before the tasks are passed to the Handler.
Expand Down Expand Up @@ -478,7 +483,7 @@ func NewServer(r RedisConnOpt, cfg Config) *Server {
}
logger.SetLevel(toInternalLogLevel(loglevel))

rdb := rdb.NewRDB(c)
rdb := rdb.NewRDB(c, rdb.WithQueueConcurrency(cfg.QueueConcurrency))
starting := make(chan *workerInfo)
finished := make(chan *base.TaskMessage)
syncCh := make(chan *syncRequest)
Expand Down
94 changes: 94 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package asynq
import (
"context"
"fmt"
"github.com/redis/go-redis/v9"
"syscall"
"testing"
"time"
Expand Down Expand Up @@ -53,6 +54,99 @@ func TestServer(t *testing.T) {
srv.Shutdown()
}

func TestServerWithQueueConcurrency(t *testing.T) {
// https://github.com/go-redis/redis/issues/1029
ignoreOpt := goleak.IgnoreTopFunction("github.com/redis/go-redis/v9/internal/pool.(*ConnPool).reaper")
defer goleak.VerifyNone(t, ignoreOpt)

redisConnOpt := getRedisConnOpt(t)
r, ok := redisConnOpt.MakeRedisClient().(redis.UniversalClient)
if !ok {
t.Fatalf("asynq: unsupported RedisConnOpt type %T", r)
}

c := NewClient(redisConnOpt)
defer c.Close()

const taskNum = 8
const serverNum = 2
tests := []struct {
name string
concurrency int
queueConcurrency int
wantActiveNum int
}{
{
name: "based on client concurrency control",
concurrency: 2,
queueConcurrency: 6,
wantActiveNum: 2 * serverNum,
},
{
name: "no queue concurrency control",
concurrency: 2,
queueConcurrency: 0,
wantActiveNum: 2 * serverNum,
},
{
name: "based on queue concurrency control",
concurrency: 6,
queueConcurrency: 2,
wantActiveNum: 2,
},
}

// no-op handler
handle := func(ctx context.Context, task *Task) error {
time.Sleep(time.Second * 2)
return nil
}

const queueName = "default"
var servers [serverNum]*Server
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var err error
testutil.FlushDB(t, r)
for i := 0; i < taskNum; i++ {
_, err = c.Enqueue(NewTask("send_email",
testutil.JSON(map[string]interface{}{"recipient_id": i + 123})))
if err != nil {
t.Fatalf("could not enqueue a task: %v", err)
}
}

for i := 0; i < serverNum; i++ {
srv := NewServer(redisConnOpt, Config{
Concurrency: tc.concurrency,
LogLevel: testLogLevel,
QueueConcurrency: map[string]int{queueName: tc.queueConcurrency},
})
err = srv.Start(HandlerFunc(handle))
if err != nil {
t.Fatal(err)
}
servers[i] = srv
}
defer func() {
for _, srv := range servers {
srv.Shutdown()
}
}()

time.Sleep(time.Second)
inspector := NewInspector(redisConnOpt)
tasks, err := inspector.ListActiveTasks(queueName)
if err != nil {
t.Fatalf("could not list active tasks: %v", err)
}
if len(tasks) != tc.wantActiveNum {
t.Errorf("default queue has %d active tasks, want %d", len(tasks), tc.wantActiveNum)
}
})
}
}

func TestServerRun(t *testing.T) {
// https://github.com/go-redis/redis/issues/1029
ignoreOpt := goleak.IgnoreTopFunction("github.com/redis/go-redis/v9/internal/pool.(*ConnPool).reaper")
Expand Down

0 comments on commit d3d176b

Please sign in to comment.