Skip to content

Commit

Permalink
Extract a rand sample helper and support negative sample count in "Se…
Browse files Browse the repository at this point in the history
…t" (apache#2113)
  • Loading branch information
mapleFU authored Feb 26, 2024
1 parent 7571034 commit 80ca6b0
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 116 deletions.
54 changes: 21 additions & 33 deletions src/types/redis_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include "db_util.h"
#include "parse_util.h"
#include "sample_helper.h"

namespace redis {

Expand Down Expand Up @@ -389,43 +390,30 @@ rocksdb::Status Hash::RandField(const Slice &user_key, int64_t command_count, st
rocksdb::Status s = GetMetadata(ns_key, &metadata);
if (!s.ok()) return s;

uint64_t size = metadata.size;
std::vector<FieldValue> samples;
// TODO: Getting all values in Hash might be heavy, consider lazy-loading these values later
if (count == 0) return rocksdb::Status::OK();
s = GetAll(user_key, &samples, type);
if (!s.ok()) return s;
auto append_field_with_index = [field_values, &samples, type](uint64_t index) {
if (type == HashFetchType::kAll) {
field_values->emplace_back(samples[index].field, samples[index].value);
} else {
field_values->emplace_back(samples[index].field, "");
}
};
field_values->reserve(std::min(size, count));
if (!unique || count == 1) {
// Case 1: Negative count, randomly select elements or without parameter
std::mt19937 gen(std::random_device{}());
std::uniform_int_distribution<uint64_t> dis(0, size - 1);
for (uint64_t i = 0; i < count; i++) {
uint64_t index = dis(gen);
append_field_with_index(index);
}
} else if (size <= count) {
// Case 2: Requested count is greater than or equal to the number of elements inside the hash
for (uint64_t i = 0; i < size; i++) {
append_field_with_index(i);
}
} else {
// Case 3: Requested count is less than the number of elements inside the hash
std::vector<uint64_t> indices(size);
std::iota(indices.begin(), indices.end(), 0);
std::mt19937 gen(std::random_device{}());
std::shuffle(indices.begin(), indices.end(), gen); // use Fisher-Yates shuffle algorithm to randomize the order
for (uint64_t i = 0; i < count; i++) {
uint64_t index = indices[i];
append_field_with_index(index);
s = ExtractRandMemberFromSet<FieldValue>(
unique, count,
[this, user_key, type](std::vector<FieldValue> *elements) { return this->GetAll(user_key, elements, type); },
field_values);
if (!s.ok()) {
return s;
}
switch (type) {
case HashFetchType::kAll:
break;
case HashFetchType::kOnlyKey: {
// GetAll should only fetching the key, checking all the values is empty
for (const FieldValue &value : *field_values) {
DCHECK(value.value.empty());
}
break;
}
case HashFetchType::kOnlyValue:
// Unreachable.
DCHECK(false);
break;
}
return rocksdb::Status::OK();
}
Expand Down
72 changes: 29 additions & 43 deletions src/types/redis_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
#include <map>
#include <memory>
#include <optional>
#include <random>

#include "db_util.h"
#include "sample_helper.h"

namespace redis {

Expand Down Expand Up @@ -197,9 +197,14 @@ rocksdb::Status Set::MIsMember(const Slice &user_key, const std::vector<Slice> &
}

rocksdb::Status Set::Take(const Slice &user_key, std::vector<std::string> *members, int count, bool pop) {
int n = 0;
members->clear();
if (count <= 0) return rocksdb::Status::OK();
bool unique = true;
if (count == 0) return rocksdb::Status::OK();
if (count < 0) {
DCHECK(!pop);
count = -count;
unique = false;
}

std::string ns_key = AppendNamespacePrefix(user_key);

Expand All @@ -210,49 +215,30 @@ rocksdb::Status Set::Take(const Slice &user_key, std::vector<std::string> *membe
rocksdb::Status s = GetMetadata(ns_key, &metadata);
if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s;

auto batch = storage_->GetWriteBatchBase();
WriteBatchLogData log_data(kRedisSet);
batch->PutLogData(log_data.Encode());

std::string prefix = InternalKey(ns_key, "", metadata.version, storage_->IsSlotIdEncoded()).Encode();
std::string next_version_prefix = InternalKey(ns_key, "", metadata.version + 1, storage_->IsSlotIdEncoded()).Encode();

rocksdb::ReadOptions read_options = storage_->DefaultScanOptions();
LatestSnapShot ss(storage_);
read_options.snapshot = ss.GetSnapShot();
rocksdb::Slice upper_bound(next_version_prefix);
read_options.iterate_upper_bound = &upper_bound;

std::vector<std::string> iter_keys;
iter_keys.reserve(count);
std::random_device rd;
std::mt19937 gen(rd());
auto iter = util::UniqueIterator(storage_, read_options);
for (iter->Seek(prefix); iter->Valid() && iter->key().starts_with(prefix); iter->Next()) {
++n;
if (n <= count) {
iter_keys.push_back(iter->key().ToString());
} else { // n > count
std::uniform_int_distribution<> distrib(0, n - 1);
int random = distrib(gen); // [0,n-1]
if (random < count) {
iter_keys[random] = iter->key().ToString();
}
}
ObserverOrUniquePtr<rocksdb::WriteBatchBase> batch = storage_->GetWriteBatchBase();
if (pop) {
WriteBatchLogData log_data(kRedisSet);
batch->PutLogData(log_data.Encode());
}
for (Slice key : iter_keys) {
InternalKey ikey(key, storage_->IsSlotIdEncoded());
members->emplace_back(ikey.GetSubKey().ToString());
if (pop) {
batch->Delete(key);
}
members->clear();
s = ExtractRandMemberFromSet<std::string>(
unique, count, [this, user_key](std::vector<std::string> *samples) { return this->Members(user_key, samples); },
members);
if (!s.ok()) {
return s;
}
if (pop && !iter_keys.empty()) {
metadata.size -= iter_keys.size();
std::string bytes;
metadata.Encode(&bytes);
batch->Put(metadata_cf_handle_, ns_key, bytes);
// Avoid to write an empty op-log if just random select some members.
if (!pop) return rocksdb::Status::OK();
// Avoid to write an empty op-log if the set is empty.
if (members->empty()) return rocksdb::Status::OK();
for (std::string &user_sub_key : *members) {
std::string sub_key = InternalKey(ns_key, user_sub_key, metadata.version, storage_->IsSlotIdEncoded()).Encode();
batch->Delete(sub_key);
}
metadata.size -= members->size();
std::string bytes;
metadata.Encode(&bytes);
batch->Put(metadata_cf_handle_, ns_key, bytes);
return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
}

Expand Down
37 changes: 7 additions & 30 deletions src/types/redis_zset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
#include <map>
#include <memory>
#include <optional>
#include <random>
#include <set>

#include "db_util.h"
#include "sample_helper.h"

namespace redis {

Expand Down Expand Up @@ -900,35 +900,12 @@ rocksdb::Status ZSet::RandMember(const Slice &user_key, int64_t command_count,
if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s;
if (metadata.size == 0) return rocksdb::Status::OK();

std::vector<MemberScore> samples;
s = GetAllMemberScores(user_key, &samples);
if (!s.ok() || samples.empty()) return s;

uint64_t size = samples.size();
member_scores->reserve(std::min(size, count));

if (!unique || count == 1) {
std::mt19937 gen(std::random_device{}());
std::uniform_int_distribution<uint64_t> dist(0, size - 1);
for (uint64_t i = 0; i < count; i++) {
uint64_t index = dist(gen);
member_scores->emplace_back(samples[index]);
}
} else if (size <= count) {
for (auto &sample : samples) {
member_scores->push_back(std::move(sample));
}
} else {
// first shuffle the samples
std::mt19937 gen(std::random_device{}());
std::shuffle(samples.begin(), samples.end(), gen);
// then pick the first `count` ones.
for (uint64_t i = 0; i < count; i++) {
member_scores->emplace_back(std::move(samples[i]));
}
}

return rocksdb::Status::OK();
return ExtractRandMemberFromSet<MemberScore>(
unique, count,
[this, user_key](std::vector<MemberScore> *scores) -> rocksdb::Status {
return this->GetAllMemberScores(user_key, scores);
},
member_scores);
}

rocksdb::Status ZSet::Diff(const std::vector<Slice> &keys, MemberScores *members) {
Expand Down
67 changes: 67 additions & 0 deletions src/types/sample_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/

#pragma once

#include <rocksdb/status.h>

#include <random>
#include <vector>

/// ExtractRandMemberFromSet is a helper function to extract random elements from a kvrocks structure.
///
/// The complexity of the function is O(N) where N is the number of elements inside the structure.
template <typename ElementType, typename GetAllMemberFnType>
rocksdb::Status ExtractRandMemberFromSet(bool unique, size_t count, const GetAllMemberFnType &get_all_member_fn,
std::vector<ElementType> *elements) {
elements->clear();
std::vector<ElementType> samples;
rocksdb::Status s = get_all_member_fn(&samples);
if (!s.ok() || samples.empty()) return s;

size_t all_element_size = samples.size();
DCHECK_GE(all_element_size, 1U);
elements->reserve(std::min(all_element_size, count));

if (!unique || count == 1) {
// Case 1: Negative count, randomly select elements or without parameter
std::mt19937 gen(std::random_device{}());
std::uniform_int_distribution<uint64_t> dist(0, all_element_size - 1);
for (uint64_t i = 0; i < count; i++) {
uint64_t index = dist(gen);
elements->emplace_back(samples[index]);
}
} else if (all_element_size <= count) {
// Case 2: Requested count is greater than or equal to the number of elements inside the structure
for (auto &sample : samples) {
elements->push_back(std::move(sample));
}
} else {
// Case 3: Requested count is less than the number of elements inside the structure
std::mt19937 gen(std::random_device{}());
// use Fisher-Yates shuffle algorithm to randomize the order
std::shuffle(samples.begin(), samples.end(), gen);
// then pick the first `count` ones.
for (uint64_t i = 0; i < count; i++) {
elements->emplace_back(std::move(samples[i]));
}
}
return rocksdb::Status::OK();
}
15 changes: 5 additions & 10 deletions tests/cppunit/types/set_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class RedisSetTest : public TestBase {
fields_ = {"set-key-1", "set-key-2", "set-key-3", "set-key-4"};
}

void TearDown() override { [[maybe_unused]] auto s = set_->Del(key_); }

std::unique_ptr<redis::Set> set_;
};

Expand All @@ -48,7 +50,6 @@ TEST_F(RedisSetTest, AddAndRemove) {
EXPECT_TRUE(s.ok() && fields_.size() == ret);
s = set_->Card(key_, &ret);
EXPECT_TRUE(s.ok() && ret == 0);
s = set_->Del(key_);
}

TEST_F(RedisSetTest, AddAndRemoveRepeated) {
Expand All @@ -65,8 +66,6 @@ TEST_F(RedisSetTest, AddAndRemoveRepeated) {
EXPECT_TRUE(s.ok() && (remembers.size() - 1) == ret);
set_->Card(key_, &card);
EXPECT_EQ(card, allmembers.size() - 1 - ret);

s = set_->Del(key_);
}

TEST_F(RedisSetTest, Members) {
Expand All @@ -82,7 +81,6 @@ TEST_F(RedisSetTest, Members) {
}
s = set_->Remove(key_, fields_, &ret);
EXPECT_TRUE(s.ok() && fields_.size() == ret);
s = set_->Del(key_);
}

TEST_F(RedisSetTest, IsMember) {
Expand All @@ -98,7 +96,6 @@ TEST_F(RedisSetTest, IsMember) {
EXPECT_TRUE(s.ok() && !flag);
s = set_->Remove(key_, fields_, &ret);
EXPECT_TRUE(s.ok() && fields_.size() == ret);
s = set_->Del(key_);
}

TEST_F(RedisSetTest, MIsMember) {
Expand All @@ -118,7 +115,6 @@ TEST_F(RedisSetTest, MIsMember) {
for (size_t i = 1; i < fields_.size(); i++) {
EXPECT_TRUE(exists[i] == 1);
}
s = set_->Del(key_);
}

TEST_F(RedisSetTest, Move) {
Expand All @@ -139,7 +135,6 @@ TEST_F(RedisSetTest, Move) {
EXPECT_TRUE(s.ok() && fields_.size() == ret);
s = set_->Remove(dst, fields_, &ret);
EXPECT_TRUE(s.ok() && fields_.size() == ret);
s = set_->Del(key_);
s = set_->Del(dst);
}

Expand All @@ -157,7 +152,6 @@ TEST_F(RedisSetTest, TakeWithPop) {
s = set_->Take(key_, &members, 1, true);
EXPECT_TRUE(s.ok());
EXPECT_TRUE(s.ok() && members.size() == 0);
s = set_->Del(key_);
}

TEST_F(RedisSetTest, Diff) {
Expand Down Expand Up @@ -261,7 +255,6 @@ TEST_F(RedisSetTest, Overwrite) {
set_->Overwrite(key_, {"a"});
set_->Card(key_, &ret);
EXPECT_EQ(ret, 1);
s = set_->Del(key_);
}

TEST_F(RedisSetTest, TakeWithoutPop) {
Expand All @@ -275,7 +268,9 @@ TEST_F(RedisSetTest, TakeWithoutPop) {
s = set_->Take(key_, &members, int(fields_.size() - 1), false);
EXPECT_TRUE(s.ok());
EXPECT_EQ(members.size(), fields_.size() - 1);
s = set_->Take(key_, &members, -int(fields_.size() - 1), false);
EXPECT_TRUE(s.ok());
EXPECT_EQ(members.size(), fields_.size() - 1);
s = set_->Remove(key_, fields_, &ret);
EXPECT_TRUE(s.ok() && fields_.size() == ret);
s = set_->Del(key_);
}

0 comments on commit 80ca6b0

Please sign in to comment.