Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract a rand sample helper and support negative sample count in "Set" #2113

Merged
merged 6 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 21 additions & 33 deletions src/types/redis_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include "db_util.h"
#include "parse_util.h"
#include "sample_helper.h"

namespace redis {

Expand Down Expand Up @@ -389,43 +390,30 @@ rocksdb::Status Hash::RandField(const Slice &user_key, int64_t command_count, st
rocksdb::Status s = GetMetadata(ns_key, &metadata);
if (!s.ok()) return s;

uint64_t size = metadata.size;
std::vector<FieldValue> samples;
// TODO: Getting all values in Hash might be heavy, consider lazy-loading these values later
if (count == 0) return rocksdb::Status::OK();
s = GetAll(user_key, &samples, type);
if (!s.ok()) return s;
auto append_field_with_index = [field_values, &samples, type](uint64_t index) {
if (type == HashFetchType::kAll) {
field_values->emplace_back(samples[index].field, samples[index].value);
} else {
field_values->emplace_back(samples[index].field, "");
}
};
field_values->reserve(std::min(size, count));
if (!unique || count == 1) {
// Case 1: Negative count, randomly select elements or without parameter
std::mt19937 gen(std::random_device{}());
std::uniform_int_distribution<uint64_t> dis(0, size - 1);
for (uint64_t i = 0; i < count; i++) {
uint64_t index = dis(gen);
append_field_with_index(index);
}
} else if (size <= count) {
// Case 2: Requested count is greater than or equal to the number of elements inside the hash
for (uint64_t i = 0; i < size; i++) {
append_field_with_index(i);
}
} else {
// Case 3: Requested count is less than the number of elements inside the hash
std::vector<uint64_t> indices(size);
std::iota(indices.begin(), indices.end(), 0);
std::mt19937 gen(std::random_device{}());
std::shuffle(indices.begin(), indices.end(), gen); // use Fisher-Yates shuffle algorithm to randomize the order
for (uint64_t i = 0; i < count; i++) {
uint64_t index = indices[i];
append_field_with_index(index);
s = ExtractRandMemberFromSet<FieldValue>(
unique, count,
[this, user_key, type](std::vector<FieldValue> *elements) { return this->GetAll(user_key, elements, type); },
field_values);
if (!s.ok()) {
return s;
}
switch (type) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a dcheck. Previously, code called with append_field_with_index. However, GetAll is already called with "type", so it's not neccessary to clean up or set it again

case HashFetchType::kAll:
break;
case HashFetchType::kOnlyKey: {
// GetAll should only fetching the key, checking all the values is empty
for (const FieldValue &value : *field_values) {
DCHECK(value.value.empty());
}
break;
}
case HashFetchType::kOnlyValue:
// Unreachable.
DCHECK(false);
break;
}
return rocksdb::Status::OK();
}
Expand Down
72 changes: 29 additions & 43 deletions src/types/redis_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
#include <map>
#include <memory>
#include <optional>
#include <random>

#include "db_util.h"
#include "sample_helper.h"

namespace redis {

Expand Down Expand Up @@ -197,9 +197,14 @@ rocksdb::Status Set::MIsMember(const Slice &user_key, const std::vector<Slice> &
}

rocksdb::Status Set::Take(const Slice &user_key, std::vector<std::string> *members, int count, bool pop) {
int n = 0;
members->clear();
if (count <= 0) return rocksdb::Status::OK();
bool unique = true;
if (count == 0) return rocksdb::Status::OK();
if (count < 0) {
DCHECK(!pop);
count = -count;
unique = false;
}

std::string ns_key = AppendNamespacePrefix(user_key);

Expand All @@ -210,49 +215,30 @@ rocksdb::Status Set::Take(const Slice &user_key, std::vector<std::string> *membe
rocksdb::Status s = GetMetadata(ns_key, &metadata);
if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s;

auto batch = storage_->GetWriteBatchBase();
WriteBatchLogData log_data(kRedisSet);
batch->PutLogData(log_data.Encode());

std::string prefix = InternalKey(ns_key, "", metadata.version, storage_->IsSlotIdEncoded()).Encode();
std::string next_version_prefix = InternalKey(ns_key, "", metadata.version + 1, storage_->IsSlotIdEncoded()).Encode();

rocksdb::ReadOptions read_options = storage_->DefaultScanOptions();
LatestSnapShot ss(storage_);
read_options.snapshot = ss.GetSnapShot();
rocksdb::Slice upper_bound(next_version_prefix);
read_options.iterate_upper_bound = &upper_bound;

std::vector<std::string> iter_keys;
iter_keys.reserve(count);
std::random_device rd;
std::mt19937 gen(rd());
auto iter = util::UniqueIterator(storage_, read_options);
for (iter->Seek(prefix); iter->Valid() && iter->key().starts_with(prefix); iter->Next()) {
++n;
if (n <= count) {
iter_keys.push_back(iter->key().ToString());
} else { // n > count
std::uniform_int_distribution<> distrib(0, n - 1);
int random = distrib(gen); // [0,n-1]
if (random < count) {
iter_keys[random] = iter->key().ToString();
}
}
ObserverOrUniquePtr<rocksdb::WriteBatchBase> batch = storage_->GetWriteBatchBase();
if (pop) {
WriteBatchLogData log_data(kRedisSet);
batch->PutLogData(log_data.Encode());
}
for (Slice key : iter_keys) {
InternalKey ikey(key, storage_->IsSlotIdEncoded());
members->emplace_back(ikey.GetSubKey().ToString());
if (pop) {
batch->Delete(key);
}
members->clear();
s = ExtractRandMemberFromSet<std::string>(
unique, count, [this, user_key](std::vector<std::string> *samples) { return this->Members(user_key, samples); },
members);
if (!s.ok()) {
return s;
}
if (pop && !iter_keys.empty()) {
metadata.size -= iter_keys.size();
std::string bytes;
metadata.Encode(&bytes);
batch->Put(metadata_cf_handle_, ns_key, bytes);
// Avoid to write an empty op-log if just random select some members.
if (!pop) return rocksdb::Status::OK();
// Avoid to write an empty op-log if the set is empty.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't know whether we can do this optimization, would it affect replication?

if (members->empty()) return rocksdb::Status::OK();
for (std::string &user_sub_key : *members) {
std::string sub_key = InternalKey(ns_key, user_sub_key, metadata.version, storage_->IsSlotIdEncoded()).Encode();
batch->Delete(sub_key);
}
metadata.size -= members->size();
std::string bytes;
metadata.Encode(&bytes);
batch->Put(metadata_cf_handle_, ns_key, bytes);
return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
}

Expand Down
37 changes: 7 additions & 30 deletions src/types/redis_zset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
#include <map>
#include <memory>
#include <optional>
#include <random>
#include <set>

#include "db_util.h"
#include "sample_helper.h"

namespace redis {

Expand Down Expand Up @@ -900,35 +900,12 @@ rocksdb::Status ZSet::RandMember(const Slice &user_key, int64_t command_count,
if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s;
if (metadata.size == 0) return rocksdb::Status::OK();

std::vector<MemberScore> samples;
s = GetAllMemberScores(user_key, &samples);
if (!s.ok() || samples.empty()) return s;

uint64_t size = samples.size();
member_scores->reserve(std::min(size, count));

if (!unique || count == 1) {
std::mt19937 gen(std::random_device{}());
std::uniform_int_distribution<uint64_t> dist(0, size - 1);
for (uint64_t i = 0; i < count; i++) {
uint64_t index = dist(gen);
member_scores->emplace_back(samples[index]);
}
} else if (size <= count) {
for (auto &sample : samples) {
member_scores->push_back(std::move(sample));
}
} else {
// first shuffle the samples
std::mt19937 gen(std::random_device{}());
std::shuffle(samples.begin(), samples.end(), gen);
// then pick the first `count` ones.
for (uint64_t i = 0; i < count; i++) {
member_scores->emplace_back(std::move(samples[i]));
}
}

return rocksdb::Status::OK();
return ExtractRandMemberFromSet<MemberScore>(
unique, count,
[this, user_key](std::vector<MemberScore> *scores) -> rocksdb::Status {
return this->GetAllMemberScores(user_key, scores);
},
member_scores);
}

rocksdb::Status ZSet::Diff(const std::vector<Slice> &keys, MemberScores *members) {
Expand Down
67 changes: 67 additions & 0 deletions src/types/sample_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/

#pragma once

#include <rocksdb/status.h>

#include <random>
#include <vector>

/// ExtractRandMemberFromSet is a helper function to extract random elements from a kvrocks structure.
///
/// The complexity of the function is O(N) where N is the number of elements inside the structure.
template <typename ElementType, typename GetAllMemberFnType>
rocksdb::Status ExtractRandMemberFromSet(bool unique, size_t count, const GetAllMemberFnType &get_all_member_fn,
std::vector<ElementType> *elements) {
elements->clear();
std::vector<ElementType> samples;
rocksdb::Status s = get_all_member_fn(&samples);
if (!s.ok() || samples.empty()) return s;

size_t all_element_size = samples.size();
DCHECK_GE(all_element_size, 1U);
elements->reserve(std::min(all_element_size, count));

if (!unique || count == 1) {
// Case 1: Negative count, randomly select elements or without parameter
std::mt19937 gen(std::random_device{}());
std::uniform_int_distribution<uint64_t> dist(0, all_element_size - 1);
for (uint64_t i = 0; i < count; i++) {
uint64_t index = dist(gen);
elements->emplace_back(samples[index]);
}
} else if (all_element_size <= count) {
// Case 2: Requested count is greater than or equal to the number of elements inside the structure
for (auto &sample : samples) {
elements->push_back(std::move(sample));
}
} else {
// Case 3: Requested count is less than the number of elements inside the structure
std::mt19937 gen(std::random_device{}());
// use Fisher-Yates shuffle algorithm to randomize the order
std::shuffle(samples.begin(), samples.end(), gen);
// then pick the first `count` ones.
for (uint64_t i = 0; i < count; i++) {
elements->emplace_back(std::move(samples[i]));
}
}
return rocksdb::Status::OK();
}
15 changes: 5 additions & 10 deletions tests/cppunit/types/set_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class RedisSetTest : public TestBase {
fields_ = {"set-key-1", "set-key-2", "set-key-3", "set-key-4"};
}

void TearDown() override { [[maybe_unused]] auto s = set_->Del(key_); }

std::unique_ptr<redis::Set> set_;
};

Expand All @@ -48,7 +50,6 @@ TEST_F(RedisSetTest, AddAndRemove) {
EXPECT_TRUE(s.ok() && fields_.size() == ret);
s = set_->Card(key_, &ret);
EXPECT_TRUE(s.ok() && ret == 0);
s = set_->Del(key_);
}

TEST_F(RedisSetTest, AddAndRemoveRepeated) {
Expand All @@ -65,8 +66,6 @@ TEST_F(RedisSetTest, AddAndRemoveRepeated) {
EXPECT_TRUE(s.ok() && (remembers.size() - 1) == ret);
set_->Card(key_, &card);
EXPECT_EQ(card, allmembers.size() - 1 - ret);

s = set_->Del(key_);
}

TEST_F(RedisSetTest, Members) {
Expand All @@ -82,7 +81,6 @@ TEST_F(RedisSetTest, Members) {
}
s = set_->Remove(key_, fields_, &ret);
EXPECT_TRUE(s.ok() && fields_.size() == ret);
s = set_->Del(key_);
}

TEST_F(RedisSetTest, IsMember) {
Expand All @@ -98,7 +96,6 @@ TEST_F(RedisSetTest, IsMember) {
EXPECT_TRUE(s.ok() && !flag);
s = set_->Remove(key_, fields_, &ret);
EXPECT_TRUE(s.ok() && fields_.size() == ret);
s = set_->Del(key_);
}

TEST_F(RedisSetTest, MIsMember) {
Expand All @@ -118,7 +115,6 @@ TEST_F(RedisSetTest, MIsMember) {
for (size_t i = 1; i < fields_.size(); i++) {
EXPECT_TRUE(exists[i] == 1);
}
s = set_->Del(key_);
}

TEST_F(RedisSetTest, Move) {
Expand All @@ -139,7 +135,6 @@ TEST_F(RedisSetTest, Move) {
EXPECT_TRUE(s.ok() && fields_.size() == ret);
s = set_->Remove(dst, fields_, &ret);
EXPECT_TRUE(s.ok() && fields_.size() == ret);
s = set_->Del(key_);
s = set_->Del(dst);
}

Expand All @@ -157,7 +152,6 @@ TEST_F(RedisSetTest, TakeWithPop) {
s = set_->Take(key_, &members, 1, true);
EXPECT_TRUE(s.ok());
EXPECT_TRUE(s.ok() && members.size() == 0);
s = set_->Del(key_);
}

TEST_F(RedisSetTest, Diff) {
Expand Down Expand Up @@ -261,7 +255,6 @@ TEST_F(RedisSetTest, Overwrite) {
set_->Overwrite(key_, {"a"});
set_->Card(key_, &ret);
EXPECT_EQ(ret, 1);
s = set_->Del(key_);
}

TEST_F(RedisSetTest, TakeWithoutPop) {
Expand All @@ -275,7 +268,9 @@ TEST_F(RedisSetTest, TakeWithoutPop) {
s = set_->Take(key_, &members, int(fields_.size() - 1), false);
EXPECT_TRUE(s.ok());
EXPECT_EQ(members.size(), fields_.size() - 1);
s = set_->Take(key_, &members, -int(fields_.size() - 1), false);
EXPECT_TRUE(s.ok());
EXPECT_EQ(members.size(), fields_.size() - 1);
s = set_->Remove(key_, fields_, &ret);
EXPECT_TRUE(s.ok() && fields_.size() == ret);
s = set_->Del(key_);
}
Loading