diff --git a/src/commands/cmd_zset.cc b/src/commands/cmd_zset.cc index 6ba6f1d91ab..42c9553d873 100644 --- a/src/commands/cmd_zset.cc +++ b/src/commands/cmd_zset.cc @@ -73,6 +73,8 @@ class CommandZAdd : public Commander { return {Status::RedisExecErr, s.ToString()}; } + svr->WakeupBlockingConns(args_[1], member_scores_.size()); + if (flags_.HasIncr()) { auto new_score = member_scores_[0].score; if ((flags_.HasNX() || flags_.HasXX() || flags_.HasLT() || flags_.HasGT()) && old_score == new_score && @@ -273,6 +275,187 @@ class CommandZPopMax : public CommandZPop { CommandZPopMax() : CommandZPop(false) {} }; +static rocksdb::Status PopFromMultipleZsets(redis::ZSet *zset_db, const std::vector &keys, bool min, + int count, std::string *user_key, std::vector *member_scores) { + rocksdb::Status s; + for (auto &key : keys) { + s = zset_db->Pop(key, count, min, member_scores); + if (!s.ok()) { + return s; + } + + if (!member_scores->empty()) { + *user_key = key; + break; + } + } + + return rocksdb::Status::OK(); +} + +class CommandBZPop : public Commander, + private EvbufCallbackBase, + private EventCallbackBase { + public: + explicit CommandBZPop(bool min) : min_(min) {} + + Status Parse(const std::vector &args) override { + auto parse_result = ParseInt(args[args.size() - 1], 10); + if (!parse_result) { + return {Status::RedisParseErr, "timeout is not an integer or out of range"}; + } + if (*parse_result < 0) { + return {Status::RedisParseErr, errTimeoutIsNegative}; + } + timeout_ = *parse_result; + + keys_ = std::vector(args.begin() + 1, args.end() - 1); + return Commander::Parse(args); + } + + Status Execute(Server *svr, Connection *conn, std::string *output) override { + svr_ = svr; + conn_ = conn; + + std::string user_key; + std::vector member_scores; + + redis::ZSet zset_db(svr->storage, conn->GetNamespace()); + auto s = PopFromMultipleZsets(&zset_db, keys_, min_, 1, &user_key, &member_scores); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + + if (!member_scores.empty()) { + SendMembersWithScores(member_scores, user_key); + return Status::OK(); + } + + // all sorted sets are empty + if (conn->IsInExec()) { + *output = redis::MultiLen(-1); + return Status::OK(); // no blocking in multi-exec + } + + for (const auto &key : keys_) { + svr_->BlockOnKey(key, conn_); + } + + auto bev = conn->GetBufferEvent(); + SetCB(bev); + + if (timeout_) { + timer_.reset(NewTimer(bufferevent_get_base(bev))); + timeval tm = {timeout_, 0}; + evtimer_add(timer_.get(), &tm); + } + + return {Status::BlockingCmd}; + } + + void SendMembersWithScores(const std::vector &member_scores, const std::string &user_key) { + std::string output; + output.append(redis::MultiLen(member_scores.size() * 2 + 1)); + output.append(redis::BulkString(user_key)); + for (const auto &ms : member_scores) { + output.append(redis::BulkString(ms.member)); + output.append(redis::BulkString(util::Float2String(ms.score))); + } + conn_->Reply(output); + } + + void OnWrite(bufferevent *bev) { + std::string user_key; + std::vector member_scores; + + redis::ZSet zset_db(svr_->storage, conn_->GetNamespace()); + auto s = PopFromMultipleZsets(&zset_db, keys_, min_, 1, &user_key, &member_scores); + if (!s.ok()) { + conn_->Reply(redis::Error("ERR " + s.ToString())); + return; + } + + if (member_scores.empty()) { + // The connection may be waked up but can't pop from a zset. For example, connection A is blocked on zset and + // connection B added a new element; then connection A was unblocked, but this element may be taken by + // another connection C. So we need to block connection A again and wait for the element being added + // by disabling the WRITE event. + bufferevent_disable(bev, EV_WRITE); + return; + } + + SendMembersWithScores(member_scores, user_key); + + if (timer_) { + timer_.reset(); + } + + unblockOnAllKeys(); + conn_->SetCB(bev); + bufferevent_enable(bev, EV_READ); + // We need to manually trigger the read event since we will stop processing commands + // in connection after the blocking command, so there may have some commands to be processed. + // Related issue: https://github.com/apache/incubator-kvrocks/issues/831 + bufferevent_trigger(bev, EV_READ, BEV_TRIG_IGNORE_WATERMARKS); + } + + void OnEvent(bufferevent *bev, int16_t events) { + if (events & (BEV_EVENT_EOF | BEV_EVENT_ERROR)) { + if (timer_ != nullptr) { + timer_.reset(); + } + unblockOnAllKeys(); + } + conn_->OnEvent(bev, events); + } + + void TimerCB(int, int16_t) { + conn_->Reply(redis::MultiLen(-1)); + timer_.reset(); + unblockOnAllKeys(); + auto bev = conn_->GetBufferEvent(); + conn_->SetCB(bev); + bufferevent_enable(bev, EV_READ); + } + + private: + bool min_; + int timeout_; + std::vector keys_; + Server *svr_ = nullptr; + Connection *conn_ = nullptr; + UniqueEvent timer_; + + void unblockOnAllKeys() { + for (const auto &key : keys_) { + svr_->UnblockOnKey(key, conn_); + } + } +}; + +class CommandBZPopMin : public CommandBZPop { + public: + CommandBZPopMin() : CommandBZPop(true) {} +}; + +class CommandBZPopMax : public CommandBZPop { + public: + CommandBZPopMax() : CommandBZPop(false) {} +}; + +static void SendMembersWithScoresForZMpop(Connection *conn, const std::string &user_key, + const std::vector &member_scores) { + std::string output; + output.append(redis::MultiLen(2)); + output.append(redis::BulkString(user_key)); + output.append(redis::MultiLen(member_scores.size() * 2)); + for (const auto &ms : member_scores) { + output.append(redis::BulkString(ms.member)); + output.append(redis::BulkString(util::Float2String(ms.score))); + } + conn->Reply(output); +} + class CommandZMPop : public Commander { public: CommandZMPop() = default; @@ -313,16 +496,10 @@ class CommandZMPop : public Commander { continue; } - output->append(redis::MultiLen(2)); - output->append(redis::BulkString(user_key)); - output->append(redis::MultiLen(member_scores.size() * 2)); - for (const auto &ms : member_scores) { - output->append(redis::BulkString(ms.member)); - output->append(redis::BulkString(util::Float2String(ms.score))); - } + SendMembersWithScoresForZMpop(conn, user_key, member_scores); return Status::OK(); } - *output = redis::NilString(); + *output = redis::MultiLen(-1); return Status::OK(); } @@ -338,6 +515,158 @@ class CommandZMPop : public Commander { int count_ = 1; }; +class CommandBZMPop : public Commander, + private EvbufCallbackBase, + private EventCallbackBase { + public: + Status Parse(const std::vector &args) override { + CommandParser parser(args, 1); + + timeout_ = GET_OR_RET(parser.TakeInt(NumericRange{0, std::numeric_limits::max()})); + if (timeout_ < 0) { + return {Status::RedisParseErr, errTimeoutIsNegative}; + } + + num_keys_ = GET_OR_RET(parser.TakeInt(NumericRange{1, std::numeric_limits::max()})); + for (int i = 0; i < num_keys_; ++i) { + keys_.emplace_back(GET_OR_RET(parser.TakeStr())); + } + + while (parser.Good()) { + if (parser.EatEqICase("min")) { + flag_ = ZSET_MIN; + } else if (parser.EatEqICase("max")) { + flag_ = ZSET_MAX; + } else if (parser.EatEqICase("count")) { + count_ = GET_OR_RET(parser.TakeInt(NumericRange{1, std::numeric_limits::max()})); + } else { + return parser.InvalidSyntax(); + } + } + + if (flag_ == ZSET_NONE) { + return parser.InvalidSyntax(); + } + + return Commander::Parse(args); + } + + Status Execute(Server *svr, Connection *conn, std::string *output) override { + svr_ = svr; + conn_ = conn; + + std::string user_key; + std::vector member_scores; + + redis::ZSet zset_db(svr->storage, conn->GetNamespace()); + auto s = PopFromMultipleZsets(&zset_db, keys_, flag_ == ZSET_MIN, count_, &user_key, &member_scores); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + + if (!member_scores.empty()) { + SendMembersWithScoresForZMpop(conn_, user_key, member_scores); + return Status::OK(); + } + + // all sorted sets are empty + if (conn->IsInExec()) { + *output = redis::MultiLen(-1); + return Status::OK(); // no blocking in multi-exec + } + + for (const auto &key : keys_) { + svr_->BlockOnKey(key, conn_); + } + + auto bev = conn->GetBufferEvent(); + SetCB(bev); + + if (timeout_) { + timer_.reset(NewTimer(bufferevent_get_base(bev))); + timeval tm = {timeout_, 0}; + evtimer_add(timer_.get(), &tm); + } + + return {Status::BlockingCmd}; + } + + void OnWrite(bufferevent *bev) { + std::string user_key; + std::vector member_scores; + + redis::ZSet zset_db(svr_->storage, conn_->GetNamespace()); + auto s = PopFromMultipleZsets(&zset_db, keys_, flag_ == ZSET_MIN, count_, &user_key, &member_scores); + if (!s.ok()) { + conn_->Reply(redis::Error("ERR " + s.ToString())); + return; + } + + if (member_scores.empty()) { + // The connection may be waked up but can't pop from a zset. For example, connection A is blocked on zset and + // connection B added a new element; then connection A was unblocked, but this element may be taken by + // another connection C. So we need to block connection A again and wait for the element being added + // by disabling the WRITE event. + bufferevent_disable(bev, EV_WRITE); + return; + } + + SendMembersWithScoresForZMpop(conn_, user_key, member_scores); + + if (timer_) { + timer_.reset(); + } + + unblockOnAllKeys(); + conn_->SetCB(bev); + bufferevent_enable(bev, EV_READ); + // We need to manually trigger the read event since we will stop processing commands + // in connection after the blocking command, so there may have some commands to be processed. + // Related issue: https://github.com/apache/incubator-kvrocks/issues/831 + bufferevent_trigger(bev, EV_READ, BEV_TRIG_IGNORE_WATERMARKS); + } + + void OnEvent(bufferevent *bev, int16_t events) { + if (events & (BEV_EVENT_EOF | BEV_EVENT_ERROR)) { + if (timer_ != nullptr) { + timer_.reset(); + } + unblockOnAllKeys(); + } + conn_->OnEvent(bev, events); + } + + void TimerCB(int, int16_t events) { + conn_->Reply(redis::NilString()); + timer_.reset(); + unblockOnAllKeys(); + auto bev = conn_->GetBufferEvent(); + conn_->SetCB(bev); + bufferevent_enable(bev, EV_READ); + } + + static CommandKeyRange Range(const std::vector &args) { + int num_key = *ParseInt(args[2], 10); + return {3, 1 + num_key, 1}; + } + + private: + int timeout_ = 0; // seconds + int num_keys_; + std::vector keys_; + enum { ZSET_MIN, ZSET_MAX, ZSET_NONE } flag_ = ZSET_NONE; + int count_ = 1; + Server *svr_ = nullptr; + Connection *conn_ = nullptr; + UniqueEvent timer_; + + void unblockOnAllKeys() { + for (const auto &key : keys_) { + svr_->UnblockOnKey(key, conn_); + } + } +}; + class CommandZRangeStore : public Commander { public: explicit CommandZRangeStore() : range_type_(kZRangeRank), direction_(kZRangeDirectionForward) {} @@ -936,7 +1265,10 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr("zadd", -4, "write", 1, 1, 1), MakeCmdAttr("zlexcount", 4, "read-only", 1, 1, 1), MakeCmdAttr("zpopmax", -2, "write", 1, 1, 1), MakeCmdAttr("zpopmin", -2, "write", 1, 1, 1), + MakeCmdAttr("bzpopmax", -3, "write", 1, -2, 1), + MakeCmdAttr("bzpopmin", -3, "write", 1, -2, 1), MakeCmdAttr("zmpop", -4, "write", CommandZMPop::Range), + MakeCmdAttr("bzmpop", -5, "write", CommandBZMPop::Range), MakeCmdAttr("zrangestore", -4, "write", 1, 1, 1), MakeCmdAttr("zrange", -4, "read-only", 1, 1, 1), MakeCmdAttr("zrevrange", -4, "read-only", 1, 1, 1), diff --git a/tests/gocase/unit/type/zset/zset_test.go b/tests/gocase/unit/type/zset/zset_test.go index 94a16b61901..ec215271441 100644 --- a/tests/gocase/unit/type/zset/zset_test.go +++ b/tests/gocase/unit/type/zset/zset_test.go @@ -67,7 +67,7 @@ func createDefaultLexZset(rdb *redis.Client, ctx context.Context) { {0, "omega"}}) } -func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding string) { +func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding string, srv *util.KvrocksServer) { t.Run(fmt.Sprintf("Check encoding - %s", encoding), func(t *testing.T) { rdb.Del(ctx, "ztmp") rdb.ZAdd(ctx, "ztmp", redis.Z{Score: 10, Member: "x"}) @@ -318,6 +318,60 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding s require.Equal(t, []redis.Z{{Score: 10, Member: "a"}}, rdb.ZPopMax(ctx, "ztmp", 3).Val()) }) + t.Run(fmt.Sprintf("BZPOPMIN basics - %s", encoding), func(t *testing.T) { + rdb.Del(ctx, "zseta") + rdb.Del(ctx, "zsetb") + rdb.ZAdd(ctx, "zseta", redis.Z{Score: 1, Member: "a"}, redis.Z{Score: 2, Member: "b"}, redis.Z{Score: 3, Member: "c"}) + rdb.ZAdd(ctx, "zsetb", redis.Z{Score: 1, Member: "d"}, redis.Z{Score: 2, Member: "e"}) + require.EqualValues(t, 3, rdb.ZCard(ctx, "zseta").Val()) + require.EqualValues(t, 2, rdb.ZCard(ctx, "zsetb").Val()) + resultz := rdb.BZPopMin(ctx, 0, "zseta", "zsetb").Val().Z + require.Equal(t, redis.Z{Score: 1, Member: "a"}, resultz) + resultz = rdb.BZPopMin(ctx, 0, "zseta", "zsetb").Val().Z + require.Equal(t, redis.Z{Score: 2, Member: "b"}, resultz) + resultz = rdb.BZPopMin(ctx, 0, "zsetb", "zseta").Val().Z + require.Equal(t, redis.Z{Score: 1, Member: "d"}, resultz) + resultz = rdb.BZPopMin(ctx, 0, "zsetb", "zseta").Val().Z + require.Equal(t, redis.Z{Score: 2, Member: "e"}, resultz) + resultz = rdb.BZPopMin(ctx, 0, "zseta", "zsetb").Val().Z + require.Equal(t, redis.Z{Score: 3, Member: "c"}, resultz) + var err = rdb.BZPopMin(ctx, time.Millisecond*1000, "zseta", "zsetb").Err() + require.Equal(t, redis.Nil, err) + + rd := srv.NewTCPClient() + defer func() { require.NoError(t, rd.Close()) }() + require.NoError(t, rd.WriteArgs("bzpopmin", "zseta", "0")) + rdb.ZAdd(ctx, "zseta", redis.Z{Score: 1, Member: "a"}) + rd.MustReadStrings(t, []string{"zseta", "a", "1"}) + }) + + t.Run(fmt.Sprintf("BZPOPMAX basics - %s", encoding), func(t *testing.T) { + rdb.Del(ctx, "zseta") + rdb.Del(ctx, "zsetb") + rdb.ZAdd(ctx, "zseta", redis.Z{Score: 1, Member: "a"}, redis.Z{Score: 2, Member: "b"}, redis.Z{Score: 3, Member: "c"}) + rdb.ZAdd(ctx, "zsetb", redis.Z{Score: 1, Member: "d"}, redis.Z{Score: 2, Member: "e"}) + require.EqualValues(t, 3, rdb.ZCard(ctx, "zseta").Val()) + require.EqualValues(t, 2, rdb.ZCard(ctx, "zsetb").Val()) + resultz := rdb.BZPopMax(ctx, 0, "zseta", "zsetb").Val().Z + require.Equal(t, redis.Z{Score: 3, Member: "c"}, resultz) + resultz = rdb.BZPopMax(ctx, 0, "zseta", "zsetb").Val().Z + require.Equal(t, redis.Z{Score: 2, Member: "b"}, resultz) + resultz = rdb.BZPopMax(ctx, 0, "zsetb", "zseta").Val().Z + require.Equal(t, redis.Z{Score: 2, Member: "e"}, resultz) + resultz = rdb.BZPopMax(ctx, 0, "zsetb", "zseta").Val().Z + require.Equal(t, redis.Z{Score: 1, Member: "d"}, resultz) + resultz = rdb.BZPopMax(ctx, 0, "zseta", "zsetb").Val().Z + require.Equal(t, redis.Z{Score: 1, Member: "a"}, resultz) + var err = rdb.BZPopMin(ctx, time.Millisecond*1000, "zseta", "zsetb").Err() + require.Equal(t, redis.Nil, err) + + rd := srv.NewTCPClient() + defer func() { require.NoError(t, rd.Close()) }() + require.NoError(t, rd.WriteArgs("bzpopmax", "zseta", "0")) + rdb.ZAdd(ctx, "zseta", redis.Z{Score: 1, Member: "a"}) + rd.MustReadStrings(t, []string{"zseta", "a", "1"}) + }) + t.Run(fmt.Sprintf("ZMPOP basics - %s", encoding), func(t *testing.T) { rdb.Del(ctx, "zseta") rdb.Del(ctx, "zsetb") @@ -338,6 +392,41 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding s require.EqualValues(t, 0, rdb.Exists(ctx, "zseta", "zsetb").Val()) }) + t.Run(fmt.Sprintf("BZMPOP basics - %s", encoding), func(t *testing.T) { + rdb.Del(ctx, "zseta") + rdb.Del(ctx, "zsetb") + rdb.ZAdd(ctx, "zseta", redis.Z{Score: 1, Member: "a"}, redis.Z{Score: 2, Member: "b"}, redis.Z{Score: 3, Member: "c"}) + rdb.ZAdd(ctx, "zsetb", redis.Z{Score: 1, Member: "d"}, redis.Z{Score: 2, Member: "e"}) + require.EqualValues(t, 3, rdb.ZCard(ctx, "zseta").Val()) + require.EqualValues(t, 2, rdb.ZCard(ctx, "zsetb").Val()) + var key, zset = rdb.BZMPop(ctx, 0, "min", 1, "zseta").Val() + require.Equal(t, "zseta", key) + require.Equal(t, []redis.Z{{Score: 1, Member: "a"}}, zset) + key, zset = rdb.BZMPop(ctx, 0, "max", 2, "zsetb").Val() + require.Equal(t, "zsetb", key) + require.Equal(t, []redis.Z{{Score: 2, Member: "e"}, {Score: 1, Member: "d"}}, zset) + key, zset = rdb.BZMPop(ctx, 0, "min", 3, "zseta").Val() + require.Equal(t, "zseta", key) + require.Equal(t, []redis.Z{{Score: 2, Member: "b"}, {Score: 3, Member: "c"}}, zset) + require.Equal(t, redis.Nil, rdb.BZMPop(ctx, time.Millisecond*1000, "max", 10, "zseta", "zsetb").Err()) + + rd := srv.NewClient() + defer func() { require.NoError(t, rd.Close()) }() + ch := make(chan *redis.ZSliceWithKeyCmd) + go func() { + ch <- rd.BZMPop(ctx, 0, "min", 10, "zseta") + }() + require.Eventually(t, func() bool { + cnt, _ := strconv.Atoi(util.FindInfoEntry(rdb, "blocked_clients")) + return cnt == 1 + }, 5*time.Second, 100*time.Millisecond) + rdb.ZAdd(ctx, "zseta", redis.Z{Score: 1, Member: "a"}, redis.Z{Score: 2, Member: "b"}) + r := <-ch + key, zset = r.Val() + require.Equal(t, "zseta", key) + require.Equal(t, []redis.Z{{Score: 1, Member: "a"}, {Score: 2, Member: "b"}}, zset) + }) + t.Run(fmt.Sprintf("ZRANGESTORE basics - %s", encoding), func(t *testing.T) { rdb.Del(ctx, "zsrc") rdb.Del(ctx, "zdst") @@ -1295,7 +1384,7 @@ func TestZset(t *testing.T) { rdb := srv.NewClient() defer func() { require.NoError(t, rdb.Close()) }() - basicTests(t, rdb, ctx, "skiplist") + basicTests(t, rdb, ctx, "skiplist", srv) t.Run("ZUNIONSTORE regression, should not create NaN in scores", func(t *testing.T) { rdb.ZAdd(ctx, "z", redis.Z{Score: math.Inf(-1), Member: "neginf"})