Skip to content

Commit

Permalink
Add support of new command: ZRANDMEMBER (#2016)
Browse files Browse the repository at this point in the history
  • Loading branch information
JxLi0921 authored Jan 15, 2024
1 parent a1cbd1a commit dad7494
Show file tree
Hide file tree
Showing 5 changed files with 379 additions and 1 deletion.
65 changes: 64 additions & 1 deletion src/commands/cmd_zset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@
#include "commands/blocking_commander.h"
#include "commands/scan_base.h"
#include "error_constants.h"
#include "parse_util.h"
#include "rocksdb/env.h"
#include "server/redis_reply.h"
#include "server/server.h"
#include "string_util.h"
#include "types/redis_zset.h"

namespace redis {
Expand Down Expand Up @@ -1357,6 +1360,65 @@ class CommandZScan : public CommandSubkeyScanBase {
}
};

class CommandZRandMember : public Commander {
public:
CommandZRandMember() = default;

Status Parse(const std::vector<std::string> &args) override {
if (args.size() > 4) {
return {Status::RedisParseErr, errWrongNumOfArguments};
}

if (args.size() >= 3) {
no_parameters_ = false;
auto parse_result = ParseInt<int64_t>(args[2], 10);
if (!parse_result) {
return {Status::RedisParseErr, errValueNotInteger};
}
count_ = *parse_result;
}

if (args.size() == 4) {
if (util::ToLower(args[3]) == "withscores") {
with_scores_ = true;
} else {
return {Status::RedisParseErr, errInvalidSyntax};
}
}

return Commander::Parse(args);
}

Status Execute(Server *srv, Connection *conn, std::string *output) override {
redis::ZSet zset_db(srv->storage, conn->GetNamespace());
std::vector<MemberScore> member_scores;
auto s = zset_db.RandMember(args_[1], count_, &member_scores);

if (!s.ok() && !s.IsNotFound()) {
return {Status::RedisExecErr, s.ToString()};
}

std::vector<std::string> result_entries;
result_entries.reserve(member_scores.size());

for (const auto &[member, score] : member_scores) {
result_entries.emplace_back(member);
if (with_scores_) result_entries.emplace_back(util::Float2String(score));
}

if (no_parameters_)
*output = s.IsNotFound() ? redis::NilString() : redis::BulkString(result_entries[0]);
else
*output = redis::MultiBulkString(result_entries, false);
return Status::OK();
}

private:
int64_t count_ = 1;
bool with_scores_ = false;
bool no_parameters_ = true;
};

REDIS_REGISTER_COMMANDS(MakeCmdAttr<CommandZAdd>("zadd", -4, "write", 1, 1, 1),
MakeCmdAttr<CommandZCard>("zcard", 2, "read-only", 1, 1, 1),
MakeCmdAttr<CommandZCount>("zcount", 4, "read-only", 1, 1, 1),
Expand Down Expand Up @@ -1388,6 +1450,7 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr<CommandZAdd>("zadd", -4, "write", 1, 1, 1),
MakeCmdAttr<CommandZMScore>("zmscore", -3, "read-only", 1, 1, 1),
MakeCmdAttr<CommandZScan>("zscan", -3, "read-only", 1, 1, 1),
MakeCmdAttr<CommandZUnionStore>("zunionstore", -4, "write", CommandZUnionStore::Range),
MakeCmdAttr<CommandZUnion>("zunion", -3, "read-only", CommandZUnion::Range), )
MakeCmdAttr<CommandZUnion>("zunion", -3, "read-only", CommandZUnion::Range),
MakeCmdAttr<CommandZRandMember>("zrandmember", -2, "read-only", 1, 1, 1))

} // namespace redis
79 changes: 79 additions & 0 deletions src/types/redis_zset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <map>
#include <memory>
#include <optional>
#include <random>
#include <set>

#include "db_util.h"
Expand Down Expand Up @@ -851,4 +852,82 @@ rocksdb::Status ZSet::MGet(const Slice &user_key, const std::vector<Slice> &memb
return rocksdb::Status::OK();
}

rocksdb::Status ZSet::GetAllMemberScores(const Slice &user_key, std::vector<MemberScore> *member_scores) {
member_scores->clear();
std::string ns_key = AppendNamespacePrefix(user_key);
ZSetMetadata metadata(false);
rocksdb::Status s = GetMetadata(ns_key, &metadata);
if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s;

std::string prefix_key = InternalKey(ns_key, "", metadata.version, storage_->IsSlotIdEncoded()).Encode();
std::string next_version_prefix_key =
InternalKey(ns_key, "", metadata.version + 1, storage_->IsSlotIdEncoded()).Encode();

rocksdb::ReadOptions read_options = storage_->DefaultScanOptions();
LatestSnapShot ss(storage_);
read_options.snapshot = ss.GetSnapShot();

rocksdb::Slice upper_bound(next_version_prefix_key);
rocksdb::Slice lower_bound(prefix_key);
read_options.iterate_upper_bound = &upper_bound;
read_options.iterate_lower_bound = &lower_bound;

auto iter = util::UniqueIterator(storage_, read_options, score_cf_handle_);

for (iter->Seek(prefix_key); iter->Valid() && iter->key().starts_with(prefix_key); iter->Next()) {
InternalKey ikey(iter->key(), storage_->IsSlotIdEncoded());
Slice score_key = ikey.GetSubKey();
double score = NAN;
GetDouble(&score_key, &score);
member_scores->emplace_back(MemberScore{score_key.ToString(), score});
}

return rocksdb::Status::OK();
}

rocksdb::Status ZSet::RandMember(const Slice &user_key, int64_t command_count,
std::vector<MemberScore> *member_scores) {
if (command_count == 0) {
return rocksdb::Status::OK();
}

uint64_t count = command_count > 0 ? static_cast<uint64_t>(command_count) : static_cast<uint64_t>(-command_count);
bool unique = (command_count >= 0);

std::string ns_key = AppendNamespacePrefix(user_key);
ZSetMetadata metadata(false);
rocksdb::Status s = GetMetadata(ns_key, &metadata);
if (!s.ok() || metadata.size == 0) return s;

std::vector<MemberScore> samples;
s = GetAllMemberScores(user_key, &samples);
if (!s.ok() || samples.empty()) return s;

auto size = static_cast<uint64_t>(samples.size());
member_scores->reserve(std::min(size, count));

if (!unique || count == 1) {
std::mt19937 gen(std::random_device{}());
std::uniform_int_distribution<uint64_t> dist(0, size - 1);
for (uint64_t i = 0; i < count; i++) {
uint64_t index = dist(gen);
member_scores->emplace_back(samples[index]);
}
} else if (size <= count) {
for (auto &sample : samples) {
member_scores->push_back(sample);
}
} else {
// first shuffle the samples
std::shuffle(samples.begin(), samples.end(), std::random_device{});

// then pick the first `count` ones.
for (uint64_t i = 0; i < count; i++) {
member_scores->emplace_back(samples[i]);
}
}

return rocksdb::Status::OK();
}

} // namespace redis
2 changes: 2 additions & 0 deletions src/types/redis_zset.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ class ZSet : public SubKeyScanner {
uint64_t *removed_cnt);
rocksdb::Status RangeByLex(const Slice &user_key, const RangeLexSpec &spec, MemberScores *mscores,
uint64_t *removed_cnt);
rocksdb::Status GetAllMemberScores(const Slice &user_key, std::vector<MemberScore> *member_scores);
rocksdb::Status RandMember(const Slice &user_key, int64_t command_count, std::vector<MemberScore> *member_scores);

private:
rocksdb::ColumnFamilyHandle *score_cf_handle_;
Expand Down
102 changes: 102 additions & 0 deletions tests/cppunit/types/zset_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -433,3 +433,105 @@ TEST_F(RedisZSetTest, Rank) {
}
auto s = zset_->Del(key_);
}

TEST_F(RedisZSetTest, RandMember) {
uint64_t ret = 0;
{
std::vector<MemberScore> in_mscores;
in_mscores.reserve(fields_.size());
for (size_t i = 0; i < fields_.size(); i++) {
in_mscores.emplace_back(MemberScore{fields_[i].ToString(), scores_[i]});
}
zset_->Add(key_, ZAddFlags::Default(), &in_mscores, &ret);
EXPECT_EQ(static_cast<int>(fields_.size()), ret);
}

std::unordered_map<std::string, double> member_map;
for (size_t i = 0; i < fields_.size(); i++) {
member_map[fields_[i].ToString()] = scores_[i];
}

// count = 0
{
std::vector<MemberScore> mscores;
rocksdb::Status s = zset_->RandMember(key_, 0, &mscores);
EXPECT_EQ(0, mscores.size());
EXPECT_TRUE(s.ok());
}

// count = 1/-1
for (int64_t count : {1, -1}) {
std::vector<MemberScore> mscores;
rocksdb::Status s = zset_->RandMember(key_, count, &mscores);
EXPECT_EQ(1, mscores.size());
EXPECT_TRUE(s.ok());
EXPECT_NE(member_map.find(mscores[0].member), member_map.end());
}

auto no_duplicate_members = [](const std::vector<MemberScore> &mscores) {
std::unordered_set<std::string> member_set;
for (const auto &mscore : mscores) {
if (member_set.find(mscore.member) != member_set.end()) {
return false;
}
member_set.insert(mscore.member);
}
return true;
};

auto no_non_exist_members = [&member_map](const std::vector<MemberScore> &mscores) {
for (const auto &mscore : mscores) {
const auto find_res = member_map.find(mscore.member);
if (find_res == member_map.end() || find_res->second != mscore.score) {
return false;
}
}
return true;
};

// count > 1, but count <= fields_.size()
for (int64_t count : {
static_cast<int64_t>(fields_.size()),
static_cast<int64_t>(fields_.size() / 2),
}) {
std::vector<MemberScore> mscores;
rocksdb::Status s = zset_->RandMember(key_, count, &mscores);
EXPECT_EQ(static_cast<size_t>(count), mscores.size());
EXPECT_TRUE(s.ok());
ASSERT_TRUE(no_non_exist_members(mscores));
ASSERT_TRUE(no_duplicate_members(mscores));
}

// count < -1, but count >= -fields_.size()
for (int64_t count : {
-static_cast<int64_t>(fields_.size()),
-static_cast<int64_t>(fields_.size() / 2),
}) {
std::vector<MemberScore> mscores;
rocksdb::Status s = zset_->RandMember(key_, count, &mscores);
EXPECT_EQ(static_cast<size_t>(-count), mscores.size());
EXPECT_TRUE(s.ok());
ASSERT_TRUE(no_non_exist_members(mscores));
}

// cout < -fields_.size() or count > fields_.size()

for (int64_t count : {
static_cast<int64_t>(fields_.size() + 10),
-static_cast<int64_t>(fields_.size() + 10),
}) {
std::vector<MemberScore> mscores;
rocksdb::Status s = zset_->RandMember(key_, count, &mscores);
EXPECT_TRUE(s.ok());
ASSERT_TRUE(no_non_exist_members(mscores));
if (count > 0) {
EXPECT_EQ(fields_.size(), mscores.size());
ASSERT_TRUE(no_duplicate_members(mscores));
} else {
EXPECT_EQ(static_cast<size_t>(-count), mscores.size());
}
}

auto s = zset_->Del(key_);
EXPECT_TRUE(s.ok());
}
Loading

0 comments on commit dad7494

Please sign in to comment.