diff --git a/src/types/redis_hash.cc b/src/types/redis_hash.cc index d7a1ad8b7bc..dcb1978e599 100644 --- a/src/types/redis_hash.cc +++ b/src/types/redis_hash.cc @@ -30,6 +30,7 @@ #include "db_util.h" #include "parse_util.h" +#include "sample_helper.h" namespace redis { @@ -389,43 +390,30 @@ rocksdb::Status Hash::RandField(const Slice &user_key, int64_t command_count, st rocksdb::Status s = GetMetadata(ns_key, &metadata); if (!s.ok()) return s; - uint64_t size = metadata.size; std::vector samples; // TODO: Getting all values in Hash might be heavy, consider lazy-loading these values later if (count == 0) return rocksdb::Status::OK(); - s = GetAll(user_key, &samples, type); - if (!s.ok()) return s; - auto append_field_with_index = [field_values, &samples, type](uint64_t index) { - if (type == HashFetchType::kAll) { - field_values->emplace_back(samples[index].field, samples[index].value); - } else { - field_values->emplace_back(samples[index].field, ""); - } - }; - field_values->reserve(std::min(size, count)); - if (!unique || count == 1) { - // Case 1: Negative count, randomly select elements or without parameter - std::mt19937 gen(std::random_device{}()); - std::uniform_int_distribution dis(0, size - 1); - for (uint64_t i = 0; i < count; i++) { - uint64_t index = dis(gen); - append_field_with_index(index); - } - } else if (size <= count) { - // Case 2: Requested count is greater than or equal to the number of elements inside the hash - for (uint64_t i = 0; i < size; i++) { - append_field_with_index(i); - } - } else { - // Case 3: Requested count is less than the number of elements inside the hash - std::vector indices(size); - std::iota(indices.begin(), indices.end(), 0); - std::mt19937 gen(std::random_device{}()); - std::shuffle(indices.begin(), indices.end(), gen); // use Fisher-Yates shuffle algorithm to randomize the order - for (uint64_t i = 0; i < count; i++) { - uint64_t index = indices[i]; - append_field_with_index(index); + s = ExtractRandMemberFromSet( + unique, count, + [this, user_key, type](std::vector *elements) { return this->GetAll(user_key, elements, type); }, + field_values); + if (!s.ok()) { + return s; + } + switch (type) { + case HashFetchType::kAll: + break; + case HashFetchType::kOnlyKey: { + // GetAll should only fetching the key, checking all the values is empty + for (const FieldValue &value : *field_values) { + DCHECK(value.value.empty()); + } + break; } + case HashFetchType::kOnlyValue: + // Unreachable. + DCHECK(false); + break; } return rocksdb::Status::OK(); } diff --git a/src/types/redis_set.cc b/src/types/redis_set.cc index 98677e27560..35403b88ac8 100644 --- a/src/types/redis_set.cc +++ b/src/types/redis_set.cc @@ -23,9 +23,9 @@ #include #include #include -#include #include "db_util.h" +#include "sample_helper.h" namespace redis { @@ -197,9 +197,14 @@ rocksdb::Status Set::MIsMember(const Slice &user_key, const std::vector & } rocksdb::Status Set::Take(const Slice &user_key, std::vector *members, int count, bool pop) { - int n = 0; members->clear(); - if (count <= 0) return rocksdb::Status::OK(); + bool unique = true; + if (count == 0) return rocksdb::Status::OK(); + if (count < 0) { + DCHECK(!pop); + count = -count; + unique = false; + } std::string ns_key = AppendNamespacePrefix(user_key); @@ -210,49 +215,30 @@ rocksdb::Status Set::Take(const Slice &user_key, std::vector *membe rocksdb::Status s = GetMetadata(ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; - auto batch = storage_->GetWriteBatchBase(); - WriteBatchLogData log_data(kRedisSet); - batch->PutLogData(log_data.Encode()); - - std::string prefix = InternalKey(ns_key, "", metadata.version, storage_->IsSlotIdEncoded()).Encode(); - std::string next_version_prefix = InternalKey(ns_key, "", metadata.version + 1, storage_->IsSlotIdEncoded()).Encode(); - - rocksdb::ReadOptions read_options = storage_->DefaultScanOptions(); - LatestSnapShot ss(storage_); - read_options.snapshot = ss.GetSnapShot(); - rocksdb::Slice upper_bound(next_version_prefix); - read_options.iterate_upper_bound = &upper_bound; - - std::vector iter_keys; - iter_keys.reserve(count); - std::random_device rd; - std::mt19937 gen(rd()); - auto iter = util::UniqueIterator(storage_, read_options); - for (iter->Seek(prefix); iter->Valid() && iter->key().starts_with(prefix); iter->Next()) { - ++n; - if (n <= count) { - iter_keys.push_back(iter->key().ToString()); - } else { // n > count - std::uniform_int_distribution<> distrib(0, n - 1); - int random = distrib(gen); // [0,n-1] - if (random < count) { - iter_keys[random] = iter->key().ToString(); - } - } + ObserverOrUniquePtr batch = storage_->GetWriteBatchBase(); + if (pop) { + WriteBatchLogData log_data(kRedisSet); + batch->PutLogData(log_data.Encode()); } - for (Slice key : iter_keys) { - InternalKey ikey(key, storage_->IsSlotIdEncoded()); - members->emplace_back(ikey.GetSubKey().ToString()); - if (pop) { - batch->Delete(key); - } + members->clear(); + s = ExtractRandMemberFromSet( + unique, count, [this, user_key](std::vector *samples) { return this->Members(user_key, samples); }, + members); + if (!s.ok()) { + return s; } - if (pop && !iter_keys.empty()) { - metadata.size -= iter_keys.size(); - std::string bytes; - metadata.Encode(&bytes); - batch->Put(metadata_cf_handle_, ns_key, bytes); + // Avoid to write an empty op-log if just random select some members. + if (!pop) return rocksdb::Status::OK(); + // Avoid to write an empty op-log if the set is empty. + if (members->empty()) return rocksdb::Status::OK(); + for (std::string &user_sub_key : *members) { + std::string sub_key = InternalKey(ns_key, user_sub_key, metadata.version, storage_->IsSlotIdEncoded()).Encode(); + batch->Delete(sub_key); } + metadata.size -= members->size(); + std::string bytes; + metadata.Encode(&bytes); + batch->Put(metadata_cf_handle_, ns_key, bytes); return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); } diff --git a/src/types/redis_zset.cc b/src/types/redis_zset.cc index 328182b5518..1dd2feb5f51 100644 --- a/src/types/redis_zset.cc +++ b/src/types/redis_zset.cc @@ -25,10 +25,10 @@ #include #include #include -#include #include #include "db_util.h" +#include "sample_helper.h" namespace redis { @@ -900,35 +900,12 @@ rocksdb::Status ZSet::RandMember(const Slice &user_key, int64_t command_count, if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; if (metadata.size == 0) return rocksdb::Status::OK(); - std::vector samples; - s = GetAllMemberScores(user_key, &samples); - if (!s.ok() || samples.empty()) return s; - - uint64_t size = samples.size(); - member_scores->reserve(std::min(size, count)); - - if (!unique || count == 1) { - std::mt19937 gen(std::random_device{}()); - std::uniform_int_distribution dist(0, size - 1); - for (uint64_t i = 0; i < count; i++) { - uint64_t index = dist(gen); - member_scores->emplace_back(samples[index]); - } - } else if (size <= count) { - for (auto &sample : samples) { - member_scores->push_back(std::move(sample)); - } - } else { - // first shuffle the samples - std::mt19937 gen(std::random_device{}()); - std::shuffle(samples.begin(), samples.end(), gen); - // then pick the first `count` ones. - for (uint64_t i = 0; i < count; i++) { - member_scores->emplace_back(std::move(samples[i])); - } - } - - return rocksdb::Status::OK(); + return ExtractRandMemberFromSet( + unique, count, + [this, user_key](std::vector *scores) -> rocksdb::Status { + return this->GetAllMemberScores(user_key, scores); + }, + member_scores); } rocksdb::Status ZSet::Diff(const std::vector &keys, MemberScores *members) { diff --git a/src/types/sample_helper.h b/src/types/sample_helper.h new file mode 100644 index 00000000000..1f1618063f5 --- /dev/null +++ b/src/types/sample_helper.h @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include +#include + +/// ExtractRandMemberFromSet is a helper function to extract random elements from a kvrocks structure. +/// +/// The complexity of the function is O(N) where N is the number of elements inside the structure. +template +rocksdb::Status ExtractRandMemberFromSet(bool unique, size_t count, const GetAllMemberFnType &get_all_member_fn, + std::vector *elements) { + elements->clear(); + std::vector samples; + rocksdb::Status s = get_all_member_fn(&samples); + if (!s.ok() || samples.empty()) return s; + + size_t all_element_size = samples.size(); + DCHECK_GE(all_element_size, 1U); + elements->reserve(std::min(all_element_size, count)); + + if (!unique || count == 1) { + // Case 1: Negative count, randomly select elements or without parameter + std::mt19937 gen(std::random_device{}()); + std::uniform_int_distribution dist(0, all_element_size - 1); + for (uint64_t i = 0; i < count; i++) { + uint64_t index = dist(gen); + elements->emplace_back(samples[index]); + } + } else if (all_element_size <= count) { + // Case 2: Requested count is greater than or equal to the number of elements inside the structure + for (auto &sample : samples) { + elements->push_back(std::move(sample)); + } + } else { + // Case 3: Requested count is less than the number of elements inside the structure + std::mt19937 gen(std::random_device{}()); + // use Fisher-Yates shuffle algorithm to randomize the order + std::shuffle(samples.begin(), samples.end(), gen); + // then pick the first `count` ones. + for (uint64_t i = 0; i < count; i++) { + elements->emplace_back(std::move(samples[i])); + } + } + return rocksdb::Status::OK(); +} diff --git a/tests/cppunit/types/set_test.cc b/tests/cppunit/types/set_test.cc index 5e5c774b1cf..de94a611b68 100644 --- a/tests/cppunit/types/set_test.cc +++ b/tests/cppunit/types/set_test.cc @@ -35,6 +35,8 @@ class RedisSetTest : public TestBase { fields_ = {"set-key-1", "set-key-2", "set-key-3", "set-key-4"}; } + void TearDown() override { [[maybe_unused]] auto s = set_->Del(key_); } + std::unique_ptr set_; }; @@ -48,7 +50,6 @@ TEST_F(RedisSetTest, AddAndRemove) { EXPECT_TRUE(s.ok() && fields_.size() == ret); s = set_->Card(key_, &ret); EXPECT_TRUE(s.ok() && ret == 0); - s = set_->Del(key_); } TEST_F(RedisSetTest, AddAndRemoveRepeated) { @@ -65,8 +66,6 @@ TEST_F(RedisSetTest, AddAndRemoveRepeated) { EXPECT_TRUE(s.ok() && (remembers.size() - 1) == ret); set_->Card(key_, &card); EXPECT_EQ(card, allmembers.size() - 1 - ret); - - s = set_->Del(key_); } TEST_F(RedisSetTest, Members) { @@ -82,7 +81,6 @@ TEST_F(RedisSetTest, Members) { } s = set_->Remove(key_, fields_, &ret); EXPECT_TRUE(s.ok() && fields_.size() == ret); - s = set_->Del(key_); } TEST_F(RedisSetTest, IsMember) { @@ -98,7 +96,6 @@ TEST_F(RedisSetTest, IsMember) { EXPECT_TRUE(s.ok() && !flag); s = set_->Remove(key_, fields_, &ret); EXPECT_TRUE(s.ok() && fields_.size() == ret); - s = set_->Del(key_); } TEST_F(RedisSetTest, MIsMember) { @@ -118,7 +115,6 @@ TEST_F(RedisSetTest, MIsMember) { for (size_t i = 1; i < fields_.size(); i++) { EXPECT_TRUE(exists[i] == 1); } - s = set_->Del(key_); } TEST_F(RedisSetTest, Move) { @@ -139,7 +135,6 @@ TEST_F(RedisSetTest, Move) { EXPECT_TRUE(s.ok() && fields_.size() == ret); s = set_->Remove(dst, fields_, &ret); EXPECT_TRUE(s.ok() && fields_.size() == ret); - s = set_->Del(key_); s = set_->Del(dst); } @@ -157,7 +152,6 @@ TEST_F(RedisSetTest, TakeWithPop) { s = set_->Take(key_, &members, 1, true); EXPECT_TRUE(s.ok()); EXPECT_TRUE(s.ok() && members.size() == 0); - s = set_->Del(key_); } TEST_F(RedisSetTest, Diff) { @@ -261,7 +255,6 @@ TEST_F(RedisSetTest, Overwrite) { set_->Overwrite(key_, {"a"}); set_->Card(key_, &ret); EXPECT_EQ(ret, 1); - s = set_->Del(key_); } TEST_F(RedisSetTest, TakeWithoutPop) { @@ -275,7 +268,9 @@ TEST_F(RedisSetTest, TakeWithoutPop) { s = set_->Take(key_, &members, int(fields_.size() - 1), false); EXPECT_TRUE(s.ok()); EXPECT_EQ(members.size(), fields_.size() - 1); + s = set_->Take(key_, &members, -int(fields_.size() - 1), false); + EXPECT_TRUE(s.ok()); + EXPECT_EQ(members.size(), fields_.size() - 1); s = set_->Remove(key_, fields_, &ret); EXPECT_TRUE(s.ok() && fields_.size() == ret); - s = set_->Del(key_); }