From 0b876cc36ecd1800fcf42f58386a40835c93cad9 Mon Sep 17 00:00:00 2001 From: MaheshMadushan Date: Sat, 13 Jan 2024 14:22:38 +0000 Subject: [PATCH 1/6] ZDIFF initial implementation --- src/commands/cmd_zset.cc | 62 ++++++++++++++++++++++++++++++-- src/types/redis_zset.cc | 34 ++++++++++++++++++ src/types/redis_zset.h | 1 + tests/cppunit/types/zset_test.cc | 51 ++++++++++++++++++++++++++ 4 files changed, 146 insertions(+), 2 deletions(-) diff --git a/src/commands/cmd_zset.cc b/src/commands/cmd_zset.cc index 1fa51ab227c..e3cd072705d 100644 --- a/src/commands/cmd_zset.cc +++ b/src/commands/cmd_zset.cc @@ -1357,6 +1357,63 @@ class CommandZScan : public CommandSubkeyScanBase { } }; +class CommandZDiff : public Commander { + public: + Status Parse(const std::vector &args) override { + auto parse_result = ParseInt(args[1], 10); + if (!parse_result) { + return {Status::RedisParseErr, errValueNotInteger}; + } + + numkeys_ = *parse_result; + if (numkeys_ > args.size() - 2) { + return {Status::RedisParseErr, errInvalidSyntax}; + } + + size_t j = 0; + while (j < numkeys_) { + keys_.emplace_back(args[j + 2]); + j++; + } + + if (auto i = 2 + numkeys_; i < args.size()) { + if (util::ToLower(args[i]) == "withscores") { + with_scores_ = true; + } + } + + return Commander::Parse(args); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::ZSet zset_db(srv->storage, conn->GetNamespace()); + + std::vector members_with_scores; + auto s = zset_db.Diff(keys_, &members_with_scores); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + + output->append(redis::MultiLen(members_with_scores.size() * (with_scores_ ? 2 : 1))); + for (const auto &ms : members_with_scores) { + output->append(redis::BulkString(ms.member)); + if (with_scores_) output->append(redis::BulkString(util::Float2String(ms.score))); + } + + return Status::OK(); + } + + static CommandKeyRange Range(const std::vector &args) { + int num_key = *ParseInt(args[1], 10); + return {2, 2 + num_key, 1}; + } + + protected: + size_t numkeys_ {0}; + std::vector keys_; + bool with_scores_ {false}; +}; + REDIS_REGISTER_COMMANDS(MakeCmdAttr("zadd", -4, "write", 1, 1, 1), MakeCmdAttr("zcard", 2, "read-only", 1, 1, 1), MakeCmdAttr("zcount", 4, "read-only", 1, 1, 1), @@ -1388,6 +1445,7 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr("zadd", -4, "write", 1, 1, 1), MakeCmdAttr("zmscore", -3, "read-only", 1, 1, 1), MakeCmdAttr("zscan", -3, "read-only", 1, 1, 1), MakeCmdAttr("zunionstore", -4, "write", CommandZUnionStore::Range), - MakeCmdAttr("zunion", -3, "read-only", CommandZUnion::Range), ) - + MakeCmdAttr("zunion", -3, "read-only", CommandZUnion::Range), + MakeCmdAttr("zunion", -3, "read-only", CommandZUnion::Range), + MakeCmdAttr("zdiff", -3, "read-only", CommandZDiff::Range), ) } // namespace redis diff --git a/src/types/redis_zset.cc b/src/types/redis_zset.cc index f29765443c4..9abc5ae4edd 100644 --- a/src/types/redis_zset.cc +++ b/src/types/redis_zset.cc @@ -851,4 +851,38 @@ rocksdb::Status ZSet::MGet(const Slice &user_key, const std::vector &memb return rocksdb::Status::OK(); } +rocksdb::Status ZSet::Diff(const std::vector &keys, MemberScores *members) { + std::vector lock_keys; + lock_keys.reserve(keys.size()); + for (const auto key : keys) { + std::string ns_key = AppendNamespacePrefix(key); + lock_keys.emplace_back(std::move(ns_key)); + } + MultiLockGuard guard(storage_->GetLockManager(), lock_keys); + + members->clear(); + MemberScores source_member_scores; + RangeScoreSpec spec; + uint64_t size {0}; + auto s = RangeByScore(keys[0], spec, &source_member_scores, &size); + if (!s.ok()) return s; + + std::map exclude_members {}; + MemberScores target_member_scores {}; + for (size_t i = 1; i < keys.size(); i++) { + uint64_t size {0}; + s = RangeByScore(keys[i], spec, &target_member_scores, &size); + if (!s.ok()) return s; + for (const auto &member_score : target_member_scores) { + exclude_members[member_score.member] = true; + } + } + for (const auto &member_score : source_member_scores) { + if (exclude_members.find(member_score.member) == exclude_members.end()) { + members->push_back(member_score); + } + } + return rocksdb::Status::OK(); +} + } // namespace redis diff --git a/src/types/redis_zset.h b/src/types/redis_zset.h index 3cd81622ece..005e3e68b73 100644 --- a/src/types/redis_zset.h +++ b/src/types/redis_zset.h @@ -116,6 +116,7 @@ class ZSet : public SubKeyScanner { AggregateMethod aggregate_method, uint64_t *saved_cnt); rocksdb::Status Union(const std::vector &keys_weights, AggregateMethod aggregate_method, std::vector *members); + rocksdb::Status Diff(const std::vector &keys, std::vector *members); rocksdb::Status MGet(const Slice &user_key, const std::vector &members, std::map *scores); rocksdb::Status GetMetadata(const Slice &ns_key, ZSetMetadata *metadata); diff --git a/tests/cppunit/types/zset_test.cc b/tests/cppunit/types/zset_test.cc index 230aa4009ef..5987174a8d7 100644 --- a/tests/cppunit/types/zset_test.cc +++ b/tests/cppunit/types/zset_test.cc @@ -433,3 +433,54 @@ TEST_F(RedisZSetTest, Rank) { } auto s = zset_->Del(key_); } + +TEST_F(RedisZSetTest, Diff) { + uint64_t ret = 0; + + std::string k1 = "key1"; + std::vector k1_fields_ = {"a", "b", "c", "d"}; + std::vector k1_scores_ = {-100.1, -100.1, 0, 1.234}; + std::vector k1_mscores; + for (size_t i = 0; i < k1_fields_.size(); i++) { + k1_mscores.emplace_back(MemberScore{k1_fields_[i].ToString(), k1_scores_[i]}); + } + + std::string k2 = "key2"; + std::vector k2_fields_ = {"c"}; + std::vector k2_scores_ = {-150.1}; + std::vector k2_mscores; + for (size_t i = 0; i < k2_fields_.size(); i++) { + k2_mscores.emplace_back(MemberScore{k2_fields_[i].ToString(), k2_scores_[i]}); + } + + std::string k3 = "key3"; + std::vector k3_fields_ = {"a", "c", "e"}; + std::vector k3_scores_ = {-1000.1, -100.1, 8000.9}; + std::vector k3_mscores; + for (size_t i = 0; i < k3_fields_.size(); i++) { + k3_mscores.emplace_back(MemberScore{k3_fields_[i].ToString(), k3_scores_[i]}); + } + + auto s = zset_->Add(k1, ZAddFlags::Default(), &k1_mscores, &ret); + EXPECT_EQ(ret, 4); + zset_->Add(k2, ZAddFlags::Default(), &k2_mscores, &ret); + EXPECT_EQ(ret, 1); + zset_->Add(k3, ZAddFlags::Default(), &k3_mscores, &ret); + EXPECT_EQ(ret, 3); + + std::vector mscores; + zset_->Diff({k1, k2, k3}, &mscores); + + EXPECT_EQ(2, mscores.size()); + std::vector expected_mscores = {{"b", -100.1}, {"d", 1.234}}; + int index = 0; + for (auto mscore : expected_mscores) { + EXPECT_EQ(mscore.member, mscores[index].member); + EXPECT_EQ(mscore.score, mscores[index].score); + index++; + } + + s = zset_->Del(k1); + s = zset_->Del(k2); + s = zset_->Del(k3); +} From 9e4d293556cfc11dae67575d4fe32024517190dc Mon Sep 17 00:00:00 2001 From: MaheshMadushan Date: Sun, 14 Jan 2024 07:01:27 +0000 Subject: [PATCH 2/6] Added ZDIFF go basic test cases --- tests/gocase/unit/type/zset/zset_test.go | 39 ++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/gocase/unit/type/zset/zset_test.go b/tests/gocase/unit/type/zset/zset_test.go index 860316b2035..bd425716d96 100644 --- a/tests/gocase/unit/type/zset/zset_test.go +++ b/tests/gocase/unit/type/zset/zset_test.go @@ -1330,6 +1330,45 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding s Weights: []float64{math.NaN(), math.NaN()}}, ).Err(), ".*weight.*not.*double.*") }) + + t.Run(fmt.Sprintf("ZDIFF with two sets - %s", encoding), func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{ + {Score: 1, Member: "a"}, {Score: 2, Member: "b"}, {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, {Score: 4, Member: "e"}, + }) + createZset(rdb, ctx, "zsetb", []redis.Z{ + {Score: 1, Member: "b"}, {Score: 2, Member: "c"}, {Score: 4, Member: "f"}, + }) + cmd := rdb.ZDiff(ctx, "zseta", "zsetb") + require.NoError(t, cmd.Err()) + sort.Strings(cmd.Val()) + require.EqualValues(t, []string{"a", "d", "e"}, cmd.Val()) + }) + + t.Run(fmt.Sprintf("ZDIFF with three sets - %s", encoding), func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{ + {Score: 1, Member: "a"}, {Score: 2, Member: "b"}, {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, {Score: 4, Member: "e"}, + }) + createZset(rdb, ctx, "zsetb", []redis.Z{ + {Score: 1, Member: "b"}, {Score: 2, Member: "c"}, {Score: 4, Member: "f"}, + }) + createZset(rdb, ctx, "zsetc", []redis.Z{ + {Score: 3, Member: "c"}, {Score: 3, Member: "d"}, {Score: 4, Member: "e"}, + }) + cmd := rdb.ZDiff(ctx, "zseta", "zsetb", "zsetc") + require.NoError(t, cmd.Err()) + sort.Strings(cmd.Val()) + require.EqualValues(t, []string{"a"}, cmd.Val()) + }) + + // t.Run("ZDIFFSTORE with three sets - ", func(t *testing.T) { + // require.NoError(t, rdb.ZDiffStore(ctx, "setres", "set1", "set4", "set5").Err()) + // cmd := rdb.SMembers(ctx, "setres") + // require.NoError(t, cmd.Err()) + // sort.Strings(cmd.Val()) + // require.EqualValues(t, []string{"1", "2", "3", "4"}, cmd.Val()) + // }) } } From ec7e1e4aee694d715efdd460415160ad656c0c14 Mon Sep 17 00:00:00 2001 From: raffertyyu Date: Fri, 12 Jan 2024 13:27:31 +0800 Subject: [PATCH 3/6] Add support of new command: ssubscribe and sunsubscribe (#2003) --- src/commands/cmd_pubsub.cc | 60 ++++++- src/commands/cmd_server.cc | 5 +- src/server/redis_connection.cc | 39 +++++ src/server/redis_connection.h | 5 + src/server/server.cc | 61 +++++++ src/server/server.h | 7 + tests/gocase/unit/pubsub/pubsubshard_test.go | 164 +++++++++++++++++++ 7 files changed, 333 insertions(+), 8 deletions(-) create mode 100644 tests/gocase/unit/pubsub/pubsubshard_test.go diff --git a/src/commands/cmd_pubsub.cc b/src/commands/cmd_pubsub.cc index 45272eef2ca..6ec61eea5d3 100644 --- a/src/commands/cmd_pubsub.cc +++ b/src/commands/cmd_pubsub.cc @@ -138,6 +138,44 @@ class CommandPUnSubscribe : public Commander { } }; +class CommandSSubscribe : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + uint16_t slot = 0; + if (srv->GetConfig()->cluster_enabled) { + slot = GetSlotIdFromKey(args_[1]); + for (unsigned int i = 2; i < args_.size(); i++) { + if (GetSlotIdFromKey(args_[i]) != slot) { + return {Status::RedisExecErr, "CROSSSLOT Keys in request don't hash to the same slot"}; + } + } + } + + for (unsigned int i = 1; i < args_.size(); i++) { + conn->SSubscribeChannel(args_[i], slot); + SubscribeCommandReply(output, "ssubscribe", args_[i], conn->SSubscriptionsCount()); + } + return Status::OK(); + } +}; + +class CommandSUnSubscribe : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + if (args_.size() == 1) { + conn->SUnsubscribeAll([output](const std::string &sub_name, int num) { + SubscribeCommandReply(output, "sunsubscribe", sub_name, num); + }); + } else { + for (size_t i = 1; i < args_.size(); i++) { + conn->SUnsubscribeChannel(args_[i], srv->GetConfig()->cluster_enabled ? GetSlotIdFromKey(args_[i]) : 0); + SubscribeCommandReply(output, "sunsubscribe", args_[i], conn->SSubscriptionsCount()); + } + } + return Status::OK(); + } +}; + class CommandPubSub : public Commander { public: Status Parse(const std::vector &args) override { @@ -146,14 +184,14 @@ class CommandPubSub : public Commander { return Status::OK(); } - if ((subcommand_ == "numsub") && args.size() >= 2) { + if ((subcommand_ == "numsub" || subcommand_ == "shardnumsub") && args.size() >= 2) { if (args.size() > 2) { channels_ = std::vector(args.begin() + 2, args.end()); } return Status::OK(); } - if ((subcommand_ == "channels") && args.size() <= 3) { + if ((subcommand_ == "channels" || subcommand_ == "shardchannels") && args.size() <= 3) { if (args.size() == 3) { pattern_ = args[2]; } @@ -169,9 +207,13 @@ class CommandPubSub : public Commander { return Status::OK(); } - if (subcommand_ == "numsub") { + if (subcommand_ == "numsub" || subcommand_ == "shardnumsub") { std::vector channel_subscribe_nums; - srv->ListChannelSubscribeNum(channels_, &channel_subscribe_nums); + if (subcommand_ == "numsub") { + srv->ListChannelSubscribeNum(channels_, &channel_subscribe_nums); + } else { + srv->ListSChannelSubscribeNum(channels_, &channel_subscribe_nums); + } output->append(redis::MultiLen(channel_subscribe_nums.size() * 2)); for (const auto &chan_subscribe_num : channel_subscribe_nums) { @@ -182,9 +224,13 @@ class CommandPubSub : public Commander { return Status::OK(); } - if (subcommand_ == "channels") { + if (subcommand_ == "channels" || subcommand_ == "shardchannels") { std::vector channels; - srv->GetChannelsByPattern(pattern_, &channels); + if (subcommand_ == "channels") { + srv->GetChannelsByPattern(pattern_, &channels); + } else { + srv->GetSChannelsByPattern(pattern_, &channels); + } *output = redis::MultiBulkString(channels); return Status::OK(); } @@ -205,6 +251,8 @@ REDIS_REGISTER_COMMANDS( MakeCmdAttr("unsubscribe", -1, "read-only pub-sub no-multi no-script", 0, 0, 0), MakeCmdAttr("psubscribe", -2, "read-only pub-sub no-multi no-script", 0, 0, 0), MakeCmdAttr("punsubscribe", -1, "read-only pub-sub no-multi no-script", 0, 0, 0), + MakeCmdAttr("ssubscribe", -2, "read-only pub-sub no-multi no-script", 0, 0, 0), + MakeCmdAttr("sunsubscribe", -1, "read-only pub-sub no-multi no-script", 0, 0, 0), MakeCmdAttr("pubsub", -2, "read-only pub-sub no-script", 0, 0, 0), ) } // namespace redis diff --git a/src/commands/cmd_server.cc b/src/commands/cmd_server.cc index a41094d1bd8..d94d81e73ec 100644 --- a/src/commands/cmd_server.cc +++ b/src/commands/cmd_server.cc @@ -1160,7 +1160,7 @@ class CommandAnalyze : public Commander { public: Status Parse(const std::vector &args) override { if (args.size() <= 1) return {Status::RedisExecErr, errInvalidSyntax}; - for (int i = 1; i < args.size(); ++i) { + for (unsigned int i = 1; i < args.size(); ++i) { command_args_.push_back(args[i]); } return Status::OK(); @@ -1178,7 +1178,8 @@ class CommandAnalyze : public Commander { cmd->SetArgs(command_args_); int arity = cmd->GetAttributes()->arity; - if ((arity > 0 && command_args_.size() != arity) || (arity < 0 && command_args_.size() < -arity)) { + if ((arity > 0 && static_cast(command_args_.size()) != arity) || + (arity < 0 && static_cast(command_args_.size()) < -arity)) { *output = redis::Error("ERR wrong number of arguments"); return {Status::RedisExecErr, errWrongNumOfArguments}; } diff --git a/src/server/redis_connection.cc b/src/server/redis_connection.cc index ae80e950434..d6e0b5f6749 100644 --- a/src/server/redis_connection.cc +++ b/src/server/redis_connection.cc @@ -261,6 +261,45 @@ void Connection::PUnsubscribeAll(const UnsubscribeCallback &reply) { int Connection::PSubscriptionsCount() { return static_cast(subscribe_patterns_.size()); } +void Connection::SSubscribeChannel(const std::string &channel, uint16_t slot) { + for (const auto &chan : subscribe_shard_channels_) { + if (channel == chan) return; + } + + subscribe_shard_channels_.emplace_back(channel); + owner_->srv->SSubscribeChannel(channel, this, slot); +} + +void Connection::SUnsubscribeChannel(const std::string &channel, uint16_t slot) { + for (auto iter = subscribe_shard_channels_.begin(); iter != subscribe_shard_channels_.end(); iter++) { + if (*iter == channel) { + subscribe_shard_channels_.erase(iter); + owner_->srv->SUnsubscribeChannel(channel, this, slot); + return; + } + } +} + +void Connection::SUnsubscribeAll(const UnsubscribeCallback &reply) { + if (subscribe_shard_channels_.empty()) { + if (reply) reply("", 0); + return; + } + + int removed = 0; + for (const auto &chan : subscribe_shard_channels_) { + owner_->srv->SUnsubscribeChannel(chan, this, + owner_->srv->GetConfig()->cluster_enabled ? GetSlotIdFromKey(chan) : 0); + removed++; + if (reply) { + reply(chan, static_cast(subscribe_shard_channels_.size() - removed)); + } + } + subscribe_shard_channels_.clear(); +} + +int Connection::SSubscriptionsCount() { return static_cast(subscribe_shard_channels_.size()); } + bool Connection::IsProfilingEnabled(const std::string &cmd) { auto config = srv_->GetConfig(); if (config->profiling_sample_ratio == 0) return false; diff --git a/src/server/redis_connection.h b/src/server/redis_connection.h index 25b522d848a..34fbcbae9fa 100644 --- a/src/server/redis_connection.h +++ b/src/server/redis_connection.h @@ -74,6 +74,10 @@ class Connection : public EvbufCallbackBase { void PUnsubscribeChannel(const std::string &pattern); void PUnsubscribeAll(const UnsubscribeCallback &reply = nullptr); int PSubscriptionsCount(); + void SSubscribeChannel(const std::string &channel, uint16_t slot); + void SUnsubscribeChannel(const std::string &channel, uint16_t slot); + void SUnsubscribeAll(const UnsubscribeCallback &reply = nullptr); + int SSubscriptionsCount(); uint64_t GetAge() const; uint64_t GetIdleTime() const; @@ -159,6 +163,7 @@ class Connection : public EvbufCallbackBase { std::vector subscribe_channels_; std::vector subscribe_patterns_; + std::vector subscribe_shard_channels_; Server *srv_; bool in_exec_ = false; diff --git a/src/server/server.cc b/src/server/server.cc index f8f2fb94c22..efe721b27ba 100644 --- a/src/server/server.cc +++ b/src/server/server.cc @@ -78,6 +78,9 @@ Server::Server(engine::Storage *storage, Config *config) // Init cluster cluster = std::make_unique(this, config_->binds, config_->port); + // init shard pub/sub channels + pubsub_shard_channels_.resize(config->cluster_enabled ? HASH_SLOTS_SIZE : 1); + for (int i = 0; i < config->workers; i++) { auto worker = std::make_unique(this, config); // multiple workers can't listen to the same unix socket, so @@ -497,6 +500,64 @@ void Server::PUnsubscribeChannel(const std::string &pattern, redis::Connection * } } +void Server::SSubscribeChannel(const std::string &channel, redis::Connection *conn, uint16_t slot) { + assert((config_->cluster_enabled && slot < HASH_SLOTS_SIZE) || slot == 0); + std::lock_guard guard(pubsub_shard_channels_mu_); + + auto conn_ctx = ConnContext(conn->Owner(), conn->GetFD()); + if (auto iter = pubsub_shard_channels_[slot].find(channel); iter == pubsub_shard_channels_[slot].end()) { + pubsub_shard_channels_[slot].emplace(channel, std::list{conn_ctx}); + } else { + iter->second.emplace_back(conn_ctx); + } +} + +void Server::SUnsubscribeChannel(const std::string &channel, redis::Connection *conn, uint16_t slot) { + assert((config_->cluster_enabled && slot < HASH_SLOTS_SIZE) || slot == 0); + std::lock_guard guard(pubsub_shard_channels_mu_); + + auto iter = pubsub_shard_channels_[slot].find(channel); + if (iter == pubsub_shard_channels_[slot].end()) { + return; + } + + for (const auto &conn_ctx : iter->second) { + if (conn->GetFD() == conn_ctx.fd && conn->Owner() == conn_ctx.owner) { + iter->second.remove(conn_ctx); + if (iter->second.empty()) { + pubsub_shard_channels_[slot].erase(iter); + } + break; + } + } +} + +void Server::GetSChannelsByPattern(const std::string &pattern, std::vector *channels) { + std::lock_guard guard(pubsub_shard_channels_mu_); + + for (const auto &shard_channels : pubsub_shard_channels_) { + for (const auto &iter : shard_channels) { + if (pattern.empty() || util::StringMatch(pattern, iter.first, 0)) { + channels->emplace_back(iter.first); + } + } + } +} + +void Server::ListSChannelSubscribeNum(const std::vector &channels, + std::vector *channel_subscribe_nums) { + std::lock_guard guard(pubsub_shard_channels_mu_); + + for (const auto &chan : channels) { + uint16_t slot = config_->cluster_enabled ? GetSlotIdFromKey(chan) : 0; + if (auto iter = pubsub_shard_channels_[slot].find(chan); iter != pubsub_shard_channels_[slot].end()) { + channel_subscribe_nums->emplace_back(ChannelSubscribeNum{iter->first, iter->second.size()}); + } else { + channel_subscribe_nums->emplace_back(ChannelSubscribeNum{chan, 0}); + } + } +} + void Server::BlockOnKey(const std::string &key, redis::Connection *conn) { std::lock_guard guard(blocking_keys_mu_); diff --git a/src/server/server.h b/src/server/server.h index 2acd0f5dbf1..a86eedf1cd8 100644 --- a/src/server/server.h +++ b/src/server/server.h @@ -201,6 +201,11 @@ class Server { void PSubscribeChannel(const std::string &pattern, redis::Connection *conn); void PUnsubscribeChannel(const std::string &pattern, redis::Connection *conn); size_t GetPubSubPatternSize() const { return pubsub_patterns_.size(); } + void SSubscribeChannel(const std::string &channel, redis::Connection *conn, uint16_t slot); + void SUnsubscribeChannel(const std::string &channel, redis::Connection *conn, uint16_t slot); + void GetSChannelsByPattern(const std::string &pattern, std::vector *channels); + void ListSChannelSubscribeNum(const std::vector &channels, + std::vector *channel_subscribe_nums); void BlockOnKey(const std::string &key, redis::Connection *conn); void UnblockOnKey(const std::string &key, redis::Connection *conn); @@ -351,6 +356,8 @@ class Server { std::map> pubsub_channels_; std::map> pubsub_patterns_; std::mutex pubsub_channels_mu_; + std::vector>> pubsub_shard_channels_; + std::mutex pubsub_shard_channels_mu_; std::map> blocking_keys_; std::mutex blocking_keys_mu_; diff --git a/tests/gocase/unit/pubsub/pubsubshard_test.go b/tests/gocase/unit/pubsub/pubsubshard_test.go new file mode 100644 index 00000000000..9e8b04cf79d --- /dev/null +++ b/tests/gocase/unit/pubsub/pubsubshard_test.go @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package pubsub + +import ( + "context" + "fmt" + "testing" + + "github.com/apache/kvrocks/tests/gocase/util" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func TestPubSubShard(t *testing.T) { + ctx := context.Background() + + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + csrv := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer csrv.Close() + crdb := csrv.NewClient() + defer func() { require.NoError(t, crdb.Close()) }() + + nodeID := "YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY" + require.NoError(t, crdb.Do(ctx, "clusterx", "SETNODEID", nodeID).Err()) + clusterNodes := fmt.Sprintf("%s %s %d master - 0-16383", nodeID, csrv.Host(), csrv.Port()) + require.NoError(t, crdb.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + + rdbs := []*redis.Client{rdb, crdb} + + t.Run("SSUBSCRIBE PING", func(t *testing.T) { + pubsub := rdb.SSubscribe(ctx, "somechannel") + receiveType(t, pubsub, &redis.Subscription{}) + require.NoError(t, pubsub.Ping(ctx)) + require.NoError(t, pubsub.Ping(ctx)) + require.NoError(t, pubsub.SUnsubscribe(ctx, "somechannel")) + require.Equal(t, "PONG", rdb.Ping(ctx).Val()) + receiveType(t, pubsub, &redis.Pong{}) + receiveType(t, pubsub, &redis.Pong{}) + }) + + t.Run("SSUBSCRIBE/SUNSUBSCRIBE basic", func(t *testing.T) { + for _, c := range rdbs { + pubsub := c.SSubscribe(ctx, "singlechannel") + defer pubsub.Close() + + msg := receiveType(t, pubsub, &redis.Subscription{}) + require.EqualValues(t, 1, msg.Count) + require.EqualValues(t, "singlechannel", msg.Channel) + require.EqualValues(t, "ssubscribe", msg.Kind) + + err := pubsub.SSubscribe(ctx, "multichannel1{tag1}", "multichannel2{tag1}", "multichannel1{tag1}") + require.Nil(t, err) + require.EqualValues(t, 2, receiveType(t, pubsub, &redis.Subscription{}).Count) + require.EqualValues(t, 3, receiveType(t, pubsub, &redis.Subscription{}).Count) + require.EqualValues(t, 3, receiveType(t, pubsub, &redis.Subscription{}).Count) + + err = pubsub.SSubscribe(ctx, "multichannel3{tag1}", "multichannel4{tag2}") + require.Nil(t, err) + if c == rdb { + require.EqualValues(t, 4, receiveType(t, pubsub, &redis.Subscription{}).Count) + require.EqualValues(t, 5, receiveType(t, pubsub, &redis.Subscription{}).Count) + } else { + // note: when cluster enabled, shard channels in single command must belong to the same slot + // reference: https://redis.io/commands/ssubscribe + _, err = pubsub.Receive(ctx) + require.EqualError(t, err, "ERR CROSSSLOT Keys in request don't hash to the same slot") + } + + err = pubsub.SUnsubscribe(ctx, "multichannel3{tag1}", "multichannel4{tag2}", "multichannel5{tag2}") + require.Nil(t, err) + if c == rdb { + require.EqualValues(t, 4, receiveType(t, pubsub, &redis.Subscription{}).Count) + require.EqualValues(t, 3, receiveType(t, pubsub, &redis.Subscription{}).Count) + require.EqualValues(t, 3, receiveType(t, pubsub, &redis.Subscription{}).Count) + } else { + require.EqualValues(t, 3, receiveType(t, pubsub, &redis.Subscription{}).Count) + require.EqualValues(t, 3, receiveType(t, pubsub, &redis.Subscription{}).Count) + require.EqualValues(t, 3, receiveType(t, pubsub, &redis.Subscription{}).Count) + } + + err = pubsub.SUnsubscribe(ctx) + require.Nil(t, err) + msg = receiveType(t, pubsub, &redis.Subscription{}) + require.EqualValues(t, 2, msg.Count) + require.EqualValues(t, "sunsubscribe", msg.Kind) + require.EqualValues(t, 1, receiveType(t, pubsub, &redis.Subscription{}).Count) + require.EqualValues(t, 0, receiveType(t, pubsub, &redis.Subscription{}).Count) + } + }) + + t.Run("SSUBSCRIBE/SUNSUBSCRIBE with empty channel", func(t *testing.T) { + for _, c := range rdbs { + pubsub := c.SSubscribe(ctx) + defer pubsub.Close() + + err := pubsub.SUnsubscribe(ctx, "foo", "bar") + require.Nil(t, err) + require.EqualValues(t, 0, receiveType(t, pubsub, &redis.Subscription{}).Count) + require.EqualValues(t, 0, receiveType(t, pubsub, &redis.Subscription{}).Count) + } + }) + + t.Run("SHARDNUMSUB returns numbers, not strings", func(t *testing.T) { + require.EqualValues(t, map[string]int64{ + "abc": 0, + "def": 0, + }, rdb.PubSubShardNumSub(ctx, "abc", "def").Val()) + }) + + t.Run("PUBSUB SHARDNUMSUB/SHARDCHANNELS", func(t *testing.T) { + for _, c := range rdbs { + pubsub := c.SSubscribe(ctx, "singlechannel") + defer pubsub.Close() + receiveType(t, pubsub, &redis.Subscription{}) + + err := pubsub.SSubscribe(ctx, "multichannel1{tag1}", "multichannel2{tag1}", "multichannel3{tag1}") + require.Nil(t, err) + receiveType(t, pubsub, &redis.Subscription{}) + receiveType(t, pubsub, &redis.Subscription{}) + receiveType(t, pubsub, &redis.Subscription{}) + + pubsub1 := c.SSubscribe(ctx, "multichannel1{tag1}") + defer pubsub1.Close() + + sc := c.PubSubShardChannels(ctx, "") + require.EqualValues(t, len(sc.Val()), 4) + sc = c.PubSubShardChannels(ctx, "multi*") + require.EqualValues(t, len(sc.Val()), 3) + + sn := c.PubSubShardNumSub(ctx) + require.EqualValues(t, len(sn.Val()), 0) + sn = c.PubSubShardNumSub(ctx, "singlechannel", "multichannel1{tag1}", "multichannel2{tag1}", "multichannel3{tag1}") + for i, k := range sn.Val() { + if i == "multichannel1{tag1}" { + require.EqualValues(t, k, 2) + } else { + require.EqualValues(t, k, 1) + } + } + } + }) +} From 814043468b8325ff80dab4fa0d616b6b0402517b Mon Sep 17 00:00:00 2001 From: hulk Date: Fri, 12 Jan 2024 18:16:51 +0800 Subject: [PATCH 4/6] Implement an unify key-value iterator for Kvrocks (#2004) Currently, we need to iterate all keys in the database in different places like the cluster migration and kvrocks2redis, but don't have an iterator for this purpose. It's very error-prone to implement this in different places since Kvrocks may add a new column family in the future, and we must be careful to iterate all keys in all column families. This would be a burden for maintenance, So we want to implement an iterator for iterating keys. ```C++ DBIter iter(storage, read_option); for (iter.Seek(); iter.Valid(); iter.Next()) { if (iter.Type() == kRedisString || iter.Type() == kRedisJSON) { // the string/json type didn't have subkeys continue; } auto subkey_iter = iter.GetSubKeyIterator(); for (subkey_iter.Seek(); subkey_iter.Valid(); subkey_iter.Next()) { // handle its subkey and value here } } ``` When using this iterator, it will iterate the metadata column family first and check its type, if it's not a string or JSON, then it will iterate the corresponding column family to get subkeys. That said, if we have a key foo with type hash, then the iterator will iterate foo and foo:field1, foo:field2, and so on. This solution can bring those benefits: - The codes look more intuitive - Can reuse this iterator if we want to iterate keys only This closes #1989 --- src/storage/iterator.cc | 166 +++++++++++++++ src/storage/iterator.h | 82 ++++++++ tests/cppunit/iterator_test.cc | 366 +++++++++++++++++++++++++++++++++ 3 files changed, 614 insertions(+) create mode 100644 src/storage/iterator.cc create mode 100644 src/storage/iterator.h create mode 100644 tests/cppunit/iterator_test.cc diff --git a/src/storage/iterator.cc b/src/storage/iterator.cc new file mode 100644 index 00000000000..12238ceafc2 --- /dev/null +++ b/src/storage/iterator.cc @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "iterator.h" + +#include + +#include "db_util.h" + +namespace engine { +DBIterator::DBIterator(Storage* storage, rocksdb::ReadOptions read_options, int slot) + : storage_(storage), read_options_(std::move(read_options)), slot_(slot) { + metadata_cf_handle_ = storage_->GetCFHandle(kMetadataColumnFamilyName); + metadata_iter_ = util::UniqueIterator(storage_->NewIterator(read_options_, metadata_cf_handle_)); +} + +void DBIterator::Next() { + if (!Valid()) return; + + metadata_iter_->Next(); + nextUntilValid(); +} + +void DBIterator::nextUntilValid() { + // slot_ != -1 means we would like to iterate all keys in the slot + // so we can skip the afterwards keys if the slot id doesn't match + if (slot_ != -1 && metadata_iter_->Valid()) { + auto [_, user_key] = ExtractNamespaceKey(metadata_iter_->key(), storage_->IsSlotIdEncoded()); + // Release the iterator if the slot id doesn't match + if (GetSlotIdFromKey(user_key.ToString()) != slot_) { + Reset(); + return; + } + } + + while (metadata_iter_->Valid()) { + Metadata metadata(kRedisNone, false); + // Skip the metadata if it's expired + if (metadata.Decode(metadata_iter_->value()).ok() && !metadata.Expired()) { + metadata_ = metadata; + break; + } + metadata_iter_->Next(); + } +} + +bool DBIterator::Valid() const { return metadata_iter_ && metadata_iter_->Valid(); } + +Slice DBIterator::Key() const { return Valid() ? metadata_iter_->key() : Slice(); } + +std::tuple DBIterator::UserKey() const { + if (!Valid()) { + return {}; + } + return ExtractNamespaceKey(metadata_iter_->key(), slot_ != -1); +} + +Slice DBIterator::Value() const { return Valid() ? metadata_iter_->value() : Slice(); } + +RedisType DBIterator::Type() const { return Valid() ? metadata_.Type() : kRedisNone; } + +void DBIterator::Reset() { + if (metadata_iter_) metadata_iter_.reset(); +} + +void DBIterator::Seek(const std::string& target) { + if (!metadata_iter_) return; + + // Iterate with the slot id but storage didn't enable slot id encoding + if (slot_ != -1 && !storage_->IsSlotIdEncoded()) { + Reset(); + return; + } + std::string prefix = target; + if (slot_ != -1) { + // Use the slot id as the prefix if it's specified + prefix = ComposeSlotKeyPrefix(kDefaultNamespace, slot_) + target; + } + + metadata_iter_->Seek(prefix); + nextUntilValid(); +} + +std::unique_ptr DBIterator::GetSubKeyIterator() const { + if (!Valid()) { + return nullptr; + } + + // The string/json type doesn't have sub keys + RedisType type = metadata_.Type(); + if (type == kRedisNone || type == kRedisString || type == kRedisJson) { + return nullptr; + } + + auto prefix = InternalKey(Key(), "", metadata_.version, storage_->IsSlotIdEncoded()).Encode(); + return std::make_unique(storage_, read_options_, type, std::move(prefix)); +} + +SubKeyIterator::SubKeyIterator(Storage* storage, rocksdb::ReadOptions read_options, RedisType type, std::string prefix) + : storage_(storage), read_options_(std::move(read_options)), type_(type), prefix_(std::move(prefix)) { + if (type_ == kRedisStream) { + cf_handle_ = storage_->GetCFHandle(kStreamColumnFamilyName); + } else { + cf_handle_ = storage_->GetCFHandle(kSubkeyColumnFamilyName); + } + iter_ = util::UniqueIterator(storage_->NewIterator(read_options_, cf_handle_)); +} + +void SubKeyIterator::Next() { + if (!Valid()) return; + + iter_->Next(); + + if (!Valid()) return; + + if (!iter_->key().starts_with(prefix_)) { + Reset(); + } +} + +bool SubKeyIterator::Valid() const { return iter_ && iter_->Valid(); } + +Slice SubKeyIterator::Key() const { return Valid() ? iter_->key() : Slice(); } + +Slice SubKeyIterator::UserKey() const { + if (!Valid()) return {}; + + const InternalKey internal_key(iter_->key(), storage_->IsSlotIdEncoded()); + return internal_key.GetSubKey(); +} + +Slice SubKeyIterator::Value() const { return Valid() ? iter_->value() : Slice(); } + +void SubKeyIterator::Seek() { + if (!iter_) return; + + iter_->Seek(prefix_); + if (!iter_->Valid()) return; + // For the subkey iterator, it MUST contain the prefix key itself + if (!iter_->key().starts_with(prefix_)) { + Reset(); + } +} + +void SubKeyIterator::Reset() { + if (iter_) iter_.reset(); +} + +} // namespace engine diff --git a/src/storage/iterator.h b/src/storage/iterator.h new file mode 100644 index 00000000000..40b93bc3799 --- /dev/null +++ b/src/storage/iterator.h @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ +#pragma once + +#include +#include + +#include "storage.h" + +namespace engine { + +class SubKeyIterator { + public: + explicit SubKeyIterator(Storage *storage, rocksdb::ReadOptions read_options, RedisType type, std::string prefix); + ~SubKeyIterator() = default; + bool Valid() const; + void Seek(); + void Next(); + // return the raw key in rocksdb + Slice Key() const; + // return the user key without prefix + Slice UserKey() const; + Slice Value() const; + void Reset(); + + private: + Storage *storage_; + rocksdb::ReadOptions read_options_; + RedisType type_; + std::string prefix_; + std::unique_ptr iter_; + rocksdb::ColumnFamilyHandle *cf_handle_ = nullptr; +}; + +class DBIterator { + public: + explicit DBIterator(Storage *storage, rocksdb::ReadOptions read_options, int slot = -1); + ~DBIterator() = default; + + bool Valid() const; + void Seek(const std::string &target = ""); + void Next(); + // return the raw key in rocksdb + Slice Key() const; + // return the namespace and user key without prefix + std::tuple UserKey() const; + Slice Value() const; + RedisType Type() const; + void Reset(); + std::unique_ptr GetSubKeyIterator() const; + + private: + void nextUntilValid(); + + Storage *storage_; + rocksdb::ReadOptions read_options_; + int slot_ = -1; + Metadata metadata_ = Metadata(kRedisNone, false); + + rocksdb::ColumnFamilyHandle *metadata_cf_handle_ = nullptr; + std::unique_ptr metadata_iter_; + std::unique_ptr subkey_iter_; +}; + +} // namespace engine diff --git a/tests/cppunit/iterator_test.cc b/tests/cppunit/iterator_test.cc new file mode 100644 index 00000000000..4bbd24089ea --- /dev/null +++ b/tests/cppunit/iterator_test.cc @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "test_base.h" +#include "types/redis_string.h" + +class IteratorTest : public TestBase { + protected: + explicit IteratorTest() = default; + ~IteratorTest() override = default; + + void SetUp() override { + { // string + redis::String string(storage_, "test_ns0"); + string.Set("a", "1"); + string.Set("b", "2"); + string.Set("c", "3"); + // Make sure the key "c" is expired + auto s = string.Expire("c", 1); + ASSERT_TRUE(s.ok()); + string.Set("d", "4"); + } + + { // hash + uint64_t ret = 0; + redis::Hash hash(storage_, "test_ns1"); + hash.MSet("hash-1", {{"f0", "v0"}, {"f1", "v1"}, {"f2", "v2"}, {"f3", "v3"}}, false, &ret); + } + + { // set + uint64_t ret = 0; + redis::Set set(storage_, "test_ns2"); + set.Add("set-1", {"e0", "e1", "e2"}, &ret); + } + + { // sorted set + uint64_t ret = 0; + redis::ZSet zset(storage_, "test_ns3"); + auto mscores = std::vector{{"z0", 0}, {"z1", 1}, {"z2", 2}}; + zset.Add("zset-1", ZAddFlags(), &mscores, &ret); + } + + { // list + uint64_t ret = 0; + redis::List list(storage_, "test_ns4"); + list.Push("list-1", {"l0", "l1", "l2"}, false, &ret); + } + + { // stream + redis::Stream stream(storage_, "test_ns5"); + redis::StreamEntryID ret; + redis::StreamAddOptions options; + options.next_id_strategy = std::make_unique(); + stream.Add("stream-1", options, {"x0"}, &ret); + stream.Add("stream-1", options, {"x1"}, &ret); + stream.Add("stream-1", options, {"x2"}, &ret); + // TODO(@git-hulk): add stream group after it's finished + } + + { // bitmap + redis::Bitmap bitmap(storage_, "test_ns6"); + bool ret = false; + bitmap.SetBit("bitmap-1", 0, true, &ret); + bitmap.SetBit("bitmap-1", 8 * 1024, true, &ret); + bitmap.SetBit("bitmap-1", 2 * 8 * 1024, true, &ret); + } + + { // json + redis::Json json(storage_, "test_ns7"); + json.Set("json-1", "$", "{\"a\": 1, \"b\": 2}"); + json.Set("json-2", "$", "{\"a\": 1, \"b\": 2}"); + json.Set("json-3", "$", "{\"a\": 1, \"b\": 2}"); + json.Set("json-4", "$", "{\"a\": 1, \"b\": 2}"); + auto s = json.Expire("json-4", 1); + ASSERT_TRUE(s.ok()); + } + + { + // sorted integer + redis::Sortedint sortedint(storage_, "test_ns8"); + uint64_t ret = 0; + sortedint.Add("sortedint-1", {1, 2, 3}, &ret); + } + } +}; + +TEST_F(IteratorTest, AllKeys) { + engine::DBIterator iter(storage_, rocksdb::ReadOptions()); + std::vector live_keys = {"a", "b", "d", "hash-1", "set-1", "zset-1", "list-1", + "stream-1", "bitmap-1", "json-1", "json-2", "json-3", "sortedint-1"}; + std::reverse(live_keys.begin(), live_keys.end()); + for (iter.Seek(); iter.Valid(); iter.Next()) { + ASSERT_TRUE(!live_keys.empty()); + auto [_, user_key] = iter.UserKey(); + ASSERT_EQ(live_keys.back(), user_key.ToString()); + live_keys.pop_back(); + } + ASSERT_TRUE(live_keys.empty()); +} + +TEST_F(IteratorTest, BasicString) { + engine::DBIterator iter(storage_, rocksdb::ReadOptions()); + + std::vector expected_keys = {"a", "b", "d"}; + std::reverse(expected_keys.begin(), expected_keys.end()); + auto prefix = ComposeNamespaceKey("test_ns0", "", storage_->IsSlotIdEncoded()); + for (iter.Seek(prefix); iter.Valid() && iter.Key().starts_with(prefix); iter.Next()) { + if (expected_keys.empty()) { + FAIL() << "Unexpected key: " << iter.Key().ToString(); + } + ASSERT_EQ(kRedisString, iter.Type()); + auto [ns, key] = iter.UserKey(); + ASSERT_EQ("test_ns0", ns.ToString()); + ASSERT_EQ(expected_keys.back(), key.ToString()); + expected_keys.pop_back(); + // Make sure there is no subkey iterator + ASSERT_TRUE(!iter.GetSubKeyIterator()); + } + // Make sure all keys are iterated except the expired one: "c" + ASSERT_TRUE(expected_keys.empty()); +} + +TEST_F(IteratorTest, BasicHash) { + engine::DBIterator iter(storage_, rocksdb::ReadOptions()); + auto prefix = ComposeNamespaceKey("test_ns1", "", storage_->IsSlotIdEncoded()); + for (iter.Seek(prefix); iter.Valid() && iter.Key().starts_with(prefix); iter.Next()) { + ASSERT_EQ(kRedisHash, iter.Type()); + auto [ns, key] = iter.UserKey(); + ASSERT_EQ("test_ns1", ns.ToString()); + + auto subkey_iter = iter.GetSubKeyIterator(); + ASSERT_TRUE(subkey_iter); + std::vector expected_keys = {"f0", "f1", "f2", "f3"}; + std::reverse(expected_keys.begin(), expected_keys.end()); + for (subkey_iter->Seek(); subkey_iter->Valid(); subkey_iter->Next()) { + if (expected_keys.empty()) { + FAIL() << "Unexpected key: " << subkey_iter->UserKey().ToString(); + } + ASSERT_EQ(expected_keys.back(), subkey_iter->UserKey().ToString()); + expected_keys.pop_back(); + } + ASSERT_TRUE(expected_keys.empty()); + } +} + +TEST_F(IteratorTest, BasicSet) { + engine::DBIterator iter(storage_, rocksdb::ReadOptions()); + auto prefix = ComposeNamespaceKey("test_ns2", "", storage_->IsSlotIdEncoded()); + for (iter.Seek(prefix); iter.Valid() && iter.Key().starts_with(prefix); iter.Next()) { + ASSERT_EQ(kRedisSet, iter.Type()); + auto [ns, key] = iter.UserKey(); + ASSERT_EQ("test_ns2", ns.ToString()); + + auto subkey_iter = iter.GetSubKeyIterator(); + ASSERT_TRUE(subkey_iter); + std::vector expected_keys = {"e0", "e1", "e2"}; + std::reverse(expected_keys.begin(), expected_keys.end()); + for (subkey_iter->Seek(); subkey_iter->Valid(); subkey_iter->Next()) { + if (expected_keys.empty()) { + FAIL() << "Unexpected key: " << subkey_iter->UserKey().ToString(); + } + ASSERT_EQ(expected_keys.back(), subkey_iter->UserKey().ToString()); + expected_keys.pop_back(); + } + ASSERT_TRUE(expected_keys.empty()); + } +} + +TEST_F(IteratorTest, BasicZSet) { + engine::DBIterator iter(storage_, rocksdb::ReadOptions()); + auto prefix = ComposeNamespaceKey("test_ns3", "", storage_->IsSlotIdEncoded()); + for (iter.Seek(prefix); iter.Valid() && iter.Key().starts_with(prefix); iter.Next()) { + ASSERT_EQ(kRedisZSet, iter.Type()); + auto [ns, key] = iter.UserKey(); + ASSERT_EQ("test_ns3", ns.ToString()); + + auto subkey_iter = iter.GetSubKeyIterator(); + ASSERT_TRUE(subkey_iter); + std::vector expected_members = {"z0", "z1", "z2"}; + std::reverse(expected_members.begin(), expected_members.end()); + for (subkey_iter->Seek(); subkey_iter->Valid(); subkey_iter->Next()) { + if (expected_members.empty()) { + FAIL() << "Unexpected key: " << subkey_iter->UserKey().ToString(); + } + ASSERT_EQ(expected_members.back(), subkey_iter->UserKey().ToString()); + expected_members.pop_back(); + } + ASSERT_TRUE(expected_members.empty()); + } +} + +TEST_F(IteratorTest, BasicList) { + engine::DBIterator iter(storage_, rocksdb::ReadOptions()); + auto prefix = ComposeNamespaceKey("test_ns4", "", storage_->IsSlotIdEncoded()); + for (iter.Seek(prefix); iter.Valid() && iter.Key().starts_with(prefix); iter.Next()) { + ASSERT_EQ(kRedisList, iter.Type()); + auto [ns, key] = iter.UserKey(); + ASSERT_EQ("test_ns4", ns.ToString()); + + auto subkey_iter = iter.GetSubKeyIterator(); + ASSERT_TRUE(subkey_iter); + std::vector expected_values = {"l0", "l1", "l2"}; + std::reverse(expected_values.begin(), expected_values.end()); + for (subkey_iter->Seek(); subkey_iter->Valid(); subkey_iter->Next()) { + if (expected_values.empty()) { + FAIL() << "Unexpected value: " << subkey_iter->Value().ToString(); + } + ASSERT_EQ(expected_values.back(), subkey_iter->Value().ToString()); + expected_values.pop_back(); + } + ASSERT_TRUE(expected_values.empty()); + } +} + +TEST_F(IteratorTest, BasicStream) { + engine::DBIterator iter(storage_, rocksdb::ReadOptions()); + auto prefix = ComposeNamespaceKey("test_ns5", "", storage_->IsSlotIdEncoded()); + for (iter.Seek(prefix); iter.Valid() && iter.Key().starts_with(prefix); iter.Next()) { + ASSERT_EQ(kRedisStream, iter.Type()); + auto [ns, key] = iter.UserKey(); + ASSERT_EQ("test_ns5", ns.ToString()); + + auto subkey_iter = iter.GetSubKeyIterator(); + ASSERT_TRUE(subkey_iter); + std::vector expected_values = {"x0", "x1", "x2"}; + std::reverse(expected_values.begin(), expected_values.end()); + for (subkey_iter->Seek(); subkey_iter->Valid(); subkey_iter->Next()) { + if (expected_values.empty()) { + FAIL() << "Unexpected value: " << subkey_iter->Value().ToString(); + } + std::vector elems; + auto s = redis::DecodeRawStreamEntryValue(subkey_iter->Value().ToString(), &elems); + ASSERT_TRUE(s.IsOK() && !elems.empty()); + ASSERT_EQ(expected_values.back(), elems[0]); + expected_values.pop_back(); + } + ASSERT_TRUE(expected_values.empty()); + } +} + +TEST_F(IteratorTest, BasicBitmap) { + engine::DBIterator iter(storage_, rocksdb::ReadOptions()); + auto prefix = ComposeNamespaceKey("test_ns6", "", storage_->IsSlotIdEncoded()); + for (iter.Seek(prefix); iter.Valid() && iter.Key().starts_with(prefix); iter.Next()) { + ASSERT_EQ(kRedisBitmap, iter.Type()); + auto [ns, key] = iter.UserKey(); + ASSERT_EQ("test_ns6", ns.ToString()); + + auto subkey_iter = iter.GetSubKeyIterator(); + ASSERT_TRUE(subkey_iter); + std::vector expected_values = {"\x1", "\x1", "\x1"}; + for (subkey_iter->Seek(); subkey_iter->Valid(); subkey_iter->Next()) { + if (expected_values.empty()) { + FAIL() << "Unexpected value: " << subkey_iter->Value().ToString(); + } + ASSERT_EQ(expected_values.back(), subkey_iter->Value().ToString()); + expected_values.pop_back(); + } + ASSERT_TRUE(expected_values.empty()); + } +} + +TEST_F(IteratorTest, BasicJSON) { + engine::DBIterator iter(storage_, rocksdb::ReadOptions()); + + std::vector expected_keys = {"json-1", "json-2", "json-3"}; + std::reverse(expected_keys.begin(), expected_keys.end()); + auto prefix = ComposeNamespaceKey("test_ns7", "", storage_->IsSlotIdEncoded()); + for (iter.Seek(prefix); iter.Valid() && iter.Key().starts_with(prefix); iter.Next()) { + if (expected_keys.empty()) { + FAIL() << "Unexpected key: " << iter.Key().ToString(); + } + ASSERT_EQ(kRedisJson, iter.Type()); + auto [ns, key] = iter.UserKey(); + ASSERT_EQ("test_ns7", ns.ToString()); + ASSERT_EQ(expected_keys.back(), key.ToString()); + expected_keys.pop_back(); + // Make sure there is no subkey iterator + ASSERT_TRUE(!iter.GetSubKeyIterator()); + } + // Make sure all keys are iterated except the expired one: "json-4" + ASSERT_TRUE(expected_keys.empty()); +} + +TEST_F(IteratorTest, BasicSortedInt) { + engine::DBIterator iter(storage_, rocksdb::ReadOptions()); + + auto prefix = ComposeNamespaceKey("test_ns8", "", storage_->IsSlotIdEncoded()); + for (iter.Seek(prefix); iter.Valid() && iter.Key().starts_with(prefix); iter.Next()) { + ASSERT_EQ(kRedisSortedint, iter.Type()); + auto [ns, key] = iter.UserKey(); + ASSERT_EQ("test_ns8", ns.ToString()); + + auto subkey_iter = iter.GetSubKeyIterator(); + ASSERT_TRUE(subkey_iter); + std::vector expected_keys = {1, 2, 3}; + std::reverse(expected_keys.begin(), expected_keys.end()); + for (subkey_iter->Seek(); subkey_iter->Valid(); subkey_iter->Next()) { + auto value = DecodeFixed64(subkey_iter->UserKey().data()); + if (expected_keys.empty()) { + FAIL() << "Unexpected value: " << value; + } + ASSERT_EQ(expected_keys.back(), value); + expected_keys.pop_back(); + } + } +} + +class SlotIteratorTest : public TestBase { + protected: + explicit SlotIteratorTest() = default; + ~SlotIteratorTest() override = default; + void SetUp() override { storage_->GetConfig()->slot_id_encoded = true; } +}; + +TEST_F(SlotIteratorTest, LiveKeys) { + redis::String string(storage_, kDefaultNamespace); + std::vector keys = {"{x}a", "{x}b", "{y}c", "{y}d", "{x}e"}; + for (const auto &key : keys) { + string.Set(key, "1"); + } + + std::set same_slot_keys; + auto slot_id = GetSlotIdFromKey(keys[0]); + for (const auto &key : keys) { + if (GetSlotIdFromKey(key) == slot_id) { + same_slot_keys.insert(key); + } + } + engine::DBIterator iter(storage_, rocksdb::ReadOptions(), slot_id); + int count = 0; + for (iter.Seek(); iter.Valid(); iter.Next()) { + auto [_, user_key] = iter.UserKey(); + ASSERT_EQ(slot_id, GetSlotIdFromKey(user_key.ToString())); + count++; + } + ASSERT_EQ(count, same_slot_keys.size()); +} From c02c0db6637304f14cd8000bc1a96e57e3c253b8 Mon Sep 17 00:00:00 2001 From: Aleks Lozovyuk Date: Sat, 13 Jan 2024 09:28:11 +0200 Subject: [PATCH 5/6] Bump rocksdb to 8.10.0 (#2005) --- cmake/rocksdb.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/rocksdb.cmake b/cmake/rocksdb.cmake index 8890627d2a4..bd05eb35323 100644 --- a/cmake/rocksdb.cmake +++ b/cmake/rocksdb.cmake @@ -26,8 +26,8 @@ endif() include(cmake/utils.cmake) FetchContent_DeclareGitHubWithMirror(rocksdb - facebook/rocksdb v8.9.1 - MD5=88f8d12c3d9ba10ddade370e0f12a010 + facebook/rocksdb v8.10.0 + MD5=ed06e98fae30c29cceacbfd45a316f06 ) FetchContent_GetProperties(jemalloc) From 329a9081035c80dd46ffae66709f59ee35e7c291 Mon Sep 17 00:00:00 2001 From: MaheshMadushan Date: Sun, 14 Jan 2024 16:38:53 +0000 Subject: [PATCH 6/6] Add ZDIFFSTORE implementation --- src/commands/cmd_zset.cc | 60 ++++++++++++++++++++---- src/types/redis_zset.cc | 8 ++++ src/types/redis_zset.h | 3 +- tests/gocase/unit/type/zset/zset_test.go | 48 ++++++++++++++++--- 4 files changed, 103 insertions(+), 16 deletions(-) diff --git a/src/commands/cmd_zset.cc b/src/commands/cmd_zset.cc index e3cd072705d..dd536f42c45 100644 --- a/src/commands/cmd_zset.cc +++ b/src/commands/cmd_zset.cc @@ -1361,14 +1361,10 @@ class CommandZDiff : public Commander { public: Status Parse(const std::vector &args) override { auto parse_result = ParseInt(args[1], 10); - if (!parse_result) { - return {Status::RedisParseErr, errValueNotInteger}; - } + if (!parse_result) return {Status::RedisParseErr, errValueNotInteger}; numkeys_ = *parse_result; - if (numkeys_ > args.size() - 2) { - return {Status::RedisParseErr, errInvalidSyntax}; - } + if (numkeys_ > args.size() - 2) return {Status::RedisParseErr, errInvalidSyntax}; size_t j = 0; while (j < numkeys_) { @@ -1414,6 +1410,54 @@ class CommandZDiff : public Commander { bool with_scores_ {false}; }; +class CommandZDiffStore : public Commander { + public: + Status Parse(const std::vector &args) override { + auto parse_result = ParseInt(args[2], 10); + if (!parse_result) return {Status::RedisParseErr, errValueNotInteger}; + + numkeys_ = *parse_result; + if (numkeys_ > args.size() - 3) return {Status::RedisParseErr, errInvalidSyntax}; + + size_t j = 0; + while (j < numkeys_) { + keys_.emplace_back(args[j + 3]); + j++; + } + + if (auto i = 2 + numkeys_; i < args.size()) { + if (util::ToLower(args[i]) == "withscores") { + with_scores_ = true; + } + } + + return Commander::Parse(args); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::ZSet zset_db(srv->storage, conn->GetNamespace()); + + uint64_t stored_count; + LOG(INFO) << args_[1]; + auto s = zset_db.DiffStore(args_[1], keys_, &stored_count); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + *output = redis::Integer(stored_count); + return Status::OK(); + } + + static CommandKeyRange Range(const std::vector &args) { + int num_key = *ParseInt(args[1], 10); + return {3, 2 + num_key, 1}; + } + + protected: + size_t numkeys_ {0}; + std::vector keys_; + bool with_scores_ {false}; +}; + REDIS_REGISTER_COMMANDS(MakeCmdAttr("zadd", -4, "write", 1, 1, 1), MakeCmdAttr("zcard", 2, "read-only", 1, 1, 1), MakeCmdAttr("zcount", 4, "read-only", 1, 1, 1), @@ -1446,6 +1490,6 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr("zadd", -4, "write", 1, 1, 1), MakeCmdAttr("zscan", -3, "read-only", 1, 1, 1), MakeCmdAttr("zunionstore", -4, "write", CommandZUnionStore::Range), MakeCmdAttr("zunion", -3, "read-only", CommandZUnion::Range), - MakeCmdAttr("zunion", -3, "read-only", CommandZUnion::Range), - MakeCmdAttr("zdiff", -3, "read-only", CommandZDiff::Range), ) + MakeCmdAttr("zdiff", -3, "read-only", CommandZDiff::Range), + MakeCmdAttr("zdiffstore", -3, "read-only", CommandZDiffStore::Range), ) } // namespace redis diff --git a/src/types/redis_zset.cc b/src/types/redis_zset.cc index 9abc5ae4edd..bba66364079 100644 --- a/src/types/redis_zset.cc +++ b/src/types/redis_zset.cc @@ -885,4 +885,12 @@ rocksdb::Status ZSet::Diff(const std::vector &keys, MemberScores *members return rocksdb::Status::OK(); } +rocksdb::Status ZSet::DiffStore(const Slice &dst, const std::vector &keys, uint64_t *stored_count) { + MemberScores mscores; + auto s = Diff(keys, &mscores); + if (!s.ok()) return s; + *stored_count = mscores.size(); + return Overwrite(dst, mscores); +} + } // namespace redis diff --git a/src/types/redis_zset.h b/src/types/redis_zset.h index 005e3e68b73..92fbb4e6598 100644 --- a/src/types/redis_zset.h +++ b/src/types/redis_zset.h @@ -116,7 +116,8 @@ class ZSet : public SubKeyScanner { AggregateMethod aggregate_method, uint64_t *saved_cnt); rocksdb::Status Union(const std::vector &keys_weights, AggregateMethod aggregate_method, std::vector *members); - rocksdb::Status Diff(const std::vector &keys, std::vector *members); + rocksdb::Status Diff(const std::vector &keys, MemberScores *members); + rocksdb::Status DiffStore(const Slice &dst, const std::vector &keys, uint64_t *stored_count); rocksdb::Status MGet(const Slice &user_key, const std::vector &members, std::map *scores); rocksdb::Status GetMetadata(const Slice &ns_key, ZSetMetadata *metadata); diff --git a/tests/gocase/unit/type/zset/zset_test.go b/tests/gocase/unit/type/zset/zset_test.go index bd425716d96..ed71c9f1a6b 100644 --- a/tests/gocase/unit/type/zset/zset_test.go +++ b/tests/gocase/unit/type/zset/zset_test.go @@ -1333,11 +1333,16 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding s t.Run(fmt.Sprintf("ZDIFF with two sets - %s", encoding), func(t *testing.T) { createZset(rdb, ctx, "zseta", []redis.Z{ - {Score: 1, Member: "a"}, {Score: 2, Member: "b"}, {Score: 3, Member: "c"}, - {Score: 3, Member: "d"}, {Score: 4, Member: "e"}, + {Score: 1, Member: "a"}, + {Score: 2, Member: "b"}, + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 4, Member: "e"}, }) createZset(rdb, ctx, "zsetb", []redis.Z{ - {Score: 1, Member: "b"}, {Score: 2, Member: "c"}, {Score: 4, Member: "f"}, + {Score: 1, Member: "b"}, + {Score: 2, Member: "c"}, + {Score: 4, Member: "f"}, }) cmd := rdb.ZDiff(ctx, "zseta", "zsetb") require.NoError(t, cmd.Err()) @@ -1347,14 +1352,21 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding s t.Run(fmt.Sprintf("ZDIFF with three sets - %s", encoding), func(t *testing.T) { createZset(rdb, ctx, "zseta", []redis.Z{ - {Score: 1, Member: "a"}, {Score: 2, Member: "b"}, {Score: 3, Member: "c"}, - {Score: 3, Member: "d"}, {Score: 4, Member: "e"}, + {Score: 1, Member: "a"}, + {Score: 2, Member: "b"}, + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 4, Member: "e"}, }) createZset(rdb, ctx, "zsetb", []redis.Z{ - {Score: 1, Member: "b"}, {Score: 2, Member: "c"}, {Score: 4, Member: "f"}, + {Score: 1, Member: "b"}, + {Score: 2, Member: "c"}, + {Score: 4, Member: "f"}, }) createZset(rdb, ctx, "zsetc", []redis.Z{ - {Score: 3, Member: "c"}, {Score: 3, Member: "d"}, {Score: 4, Member: "e"}, + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 5, Member: "e"}, }) cmd := rdb.ZDiff(ctx, "zseta", "zsetb", "zsetc") require.NoError(t, cmd.Err()) @@ -1362,6 +1374,28 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding s require.EqualValues(t, []string{"a"}, cmd.Val()) }) + t.Run(fmt.Sprintf("ZDIFF with three sets with scores - %s", encoding), func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{ + {Score: 1, Member: "a"}, + {Score: 2, Member: "b"}, + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 4, Member: "e"}, + }) + createZset(rdb, ctx, "zsetb", []redis.Z{ + {Score: 1, Member: "b"}, + {Score: 2, Member: "c"}, + {Score: 4, Member: "f"}, + }) + createZset(rdb, ctx, "zsetc", []redis.Z{ + {Score: 4, Member: "c"}, + {Score: 5, Member: "e"}, + }) + cmd := rdb.ZDiffWithScores(ctx, "zseta", "zsetb", "zsetc") + require.NoError(t, cmd.Err()) + require.EqualValues(t, []redis.Z([]redis.Z{{Score: 1, Member: "a"}, {Score: 3, Member: "d"}}), cmd.Val()) + }) + // t.Run("ZDIFFSTORE with three sets - ", func(t *testing.T) { // require.NoError(t, rdb.ZDiffStore(ctx, "setres", "set1", "set4", "set5").Err()) // cmd := rdb.SMembers(ctx, "setres")