diff --git a/src/commands/cmd_hash.cc b/src/commands/cmd_hash.cc index c6e2ddf6f33..1e0947e79c2 100644 --- a/src/commands/cmd_hash.cc +++ b/src/commands/cmd_hash.cc @@ -377,6 +377,56 @@ class CommandHScan : public CommandSubkeyScanBase { } }; +class CommandHRandField : public Commander { + public: + Status Parse(const std::vector &args) override { + if (args.size() >= 3) { + no_parameters_ = false; + auto parse_result = ParseInt(args[2], 10); + if (!parse_result) { + return {Status::RedisParseErr, errValueNotInteger}; + } + command_count_ = *parse_result; + + if (args.size() > 4 || (args.size() == 4 && !util::EqualICase(args[3], "withvalues"))) { + return {Status::RedisParseErr, errInvalidSyntax}; + } else if (args.size() == 4) { + withvalues_ = true; + } + } + return Commander::Parse(args); + } + + Status Execute(Server *svr, Connection *conn, std::string *output) override { + redis::Hash hash_db(svr->storage, conn->GetNamespace()); + std::vector field_values; + + auto s = hash_db.RandField(args_[1], command_count_, &field_values, + withvalues_ ? HashFetchType::kAll : HashFetchType::kOnlyKey); + if (!s.ok() && !s.IsNotFound()) { + return {Status::RedisExecErr, s.ToString()}; + } + + std::vector result_entries; + result_entries.reserve(field_values.size()); + for (const auto &p : field_values) { + result_entries.emplace_back(p.field); + if (withvalues_) result_entries.emplace_back(p.value); + } + + if (no_parameters_) + *output = s.IsNotFound() ? redis::NilString() : redis::BulkString(result_entries[0]); + else + *output = redis::MultiBulkString(result_entries, false); + return Status::OK(); + } + + private: + bool withvalues_ = false; + int64_t command_count_ = 1; + bool no_parameters_ = true; +}; + REDIS_REGISTER_COMMANDS(MakeCmdAttr("hget", 3, "read-only", 1, 1, 1), MakeCmdAttr("hincrby", 4, "write", 1, 1, 1), MakeCmdAttr("hincrbyfloat", 4, "write", 1, 1, 1), @@ -392,6 +442,7 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr("hget", 3, "read-only", 1, 1, 1 MakeCmdAttr("hvals", 2, "read-only", 1, 1, 1), MakeCmdAttr("hgetall", 2, "read-only", 1, 1, 1), MakeCmdAttr("hscan", -3, "read-only", 1, 1, 1), - MakeCmdAttr("hrangebylex", -4, "read-only", 1, 1, 1), ) + MakeCmdAttr("hrangebylex", -4, "read-only", 1, 1, 1), + MakeCmdAttr("hrandfield", -2, "read-only", 1, 1, 1), ) } // namespace redis diff --git a/src/types/redis_hash.cc b/src/types/redis_hash.cc index 2946f47edcc..b61f8219d4d 100644 --- a/src/types/redis_hash.cc +++ b/src/types/redis_hash.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include "db_util.h" @@ -384,4 +385,57 @@ rocksdb::Status Hash::Scan(const Slice &user_key, const std::string &cursor, uin return SubKeyScanner::Scan(kRedisHash, user_key, cursor, limit, field_prefix, fields, values); } +rocksdb::Status Hash::RandField(const Slice &user_key, int64_t command_count, std::vector *field_values, + HashFetchType type) { + uint64_t count = (command_count >= 0) ? static_cast(command_count) : static_cast(-command_count); + bool unique = (command_count >= 0); + + std::string ns_key; + AppendNamespacePrefix(user_key, &ns_key); + HashMetadata metadata(/*generate_version=*/false); + rocksdb::Status s = GetMetadata(ns_key, &metadata); + if (!s.ok()) return s; + + uint64_t size = metadata.size; + std::vector samples; + // TODO: Getting all values in Hash might be heavy, consider lazy-loading these values later + if (count == 0) return rocksdb::Status::OK(); + s = GetAll(user_key, &samples, type); + if (!s.ok()) return s; + auto append_field_with_index = [field_values, &samples, type](uint64_t index) { + if (type == HashFetchType::kAll) { + field_values->emplace_back(samples[index].field, samples[index].value); + } else { + field_values->emplace_back(samples[index].field, ""); + } + }; + field_values->reserve(std::min(size, count)); + if (!unique || count == 1) { + // Case 1: Negative count, randomly select elements or without parameter + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(0, size - 1); + for (uint64_t i = 0; i < count; i++) { + uint64_t index = dis(gen); + append_field_with_index(index); + } + } else if (size <= count) { + // Case 2: Requested count is greater than or equal to the number of elements inside the hash + for (uint64_t i = 0; i < size; i++) { + append_field_with_index(i); + } + } else { + // Case 3: Requested count is less than the number of elements inside the hash + std::vector indices(size); + std::iota(indices.begin(), indices.end(), 0); + std::shuffle(indices.begin(), indices.end(), + std::random_device{}); // use Fisher-Yates shuffle algorithm to randomize the order + for (uint64_t i = 0; i < count; i++) { + uint64_t index = indices[i]; + append_field_with_index(index); + } + } + return rocksdb::Status::OK(); +} + } // namespace redis diff --git a/src/types/redis_hash.h b/src/types/redis_hash.h index ee17583478b..fc004ed840d 100644 --- a/src/types/redis_hash.h +++ b/src/types/redis_hash.h @@ -61,6 +61,8 @@ class Hash : public SubKeyScanner { rocksdb::Status Scan(const Slice &user_key, const std::string &cursor, uint64_t limit, const std::string &field_prefix, std::vector *fields, std::vector *values = nullptr); + rocksdb::Status RandField(const Slice &user_key, int64_t command_count, std::vector *field_values, + HashFetchType type = HashFetchType::kOnlyKey); private: rocksdb::Status GetMetadata(const Slice &ns_key, HashMetadata *metadata); diff --git a/tests/cppunit/types/hash_test.cc b/tests/cppunit/types/hash_test.cc index 61745029270..216acc7245e 100644 --- a/tests/cppunit/types/hash_test.cc +++ b/tests/cppunit/types/hash_test.cc @@ -331,3 +331,34 @@ TEST_F(RedisHashTest, HRangeByLexNonExistingKey) { EXPECT_TRUE(s.ok()); EXPECT_EQ(result.size(), 0); } + +TEST_F(RedisHashTest, HRandField) { + uint64_t ret = 0; + for (size_t i = 0; i < fields_.size(); i++) { + auto s = hash_->Set(key_, fields_[i], values_[i], &ret); + EXPECT_TRUE(s.ok() && ret == 1); + } + auto size = static_cast(fields_.size()); + std::vector fvs; + // Case 1: Negative count, randomly select elements + fvs.clear(); + auto s = hash_->RandField(key_, -(size + 10), &fvs); + EXPECT_TRUE(s.ok() && fvs.size() == (fields_.size() + 10)); + + // Case 2: Requested count is greater than or equal to the number of elements inside the hash + fvs.clear(); + s = hash_->RandField(key_, size + 1, &fvs); + EXPECT_TRUE(s.ok() && fvs.size() == fields_.size()); + + // Case 3: Requested count is less than the number of elements inside the hash + fvs.clear(); + s = hash_->RandField(key_, size - 1, &fvs); + EXPECT_TRUE(s.ok() && fvs.size() == fields_.size() - 1); + + // hrandfield key 0 + fvs.clear(); + s = hash_->RandField(key_, 0, &fvs); + EXPECT_TRUE(s.ok() && fvs.size() == 0); + + hash_->Del(key_); +} diff --git a/tests/gocase/unit/type/hash/hash_test.go b/tests/gocase/unit/type/hash/hash_test.go index 8489527370a..bf93d268860 100644 --- a/tests/gocase/unit/type/hash/hash_test.go +++ b/tests/gocase/unit/type/hash/hash_test.go @@ -783,6 +783,55 @@ func TestHash(t *testing.T) { require.Equal(t, []interface{}{"field1", "some-value", "field2", ""}, rdb.Do(ctx, "HrangeByLex", testKey, "[a", "[z").Val()) }) + + t.Run("HRandField count is positive", func(t *testing.T) { + testKey := "test-hash-1" + require.NoError(t, rdb.Del(ctx, testKey).Err()) + require.NoError(t, rdb.HSet(ctx, testKey, "key1", "value1", "key2", "value2", "key3", "value3").Err()) + result, err := rdb.HRandField(ctx, testKey, 5).Result() + require.NoError(t, err) + require.Len(t, result, 3) + require.Equal(t, []string{"key1", "key2", "key3"}, result) + result, err = rdb.HRandField(ctx, testKey, 2).Result() + require.NoError(t, err) + require.Len(t, result, 2) + require.Contains(t, []string{"key1", "key2", "key3"}, result[0]) + require.Contains(t, []string{"key1", "key2", "key3"}, result[1]) + result, err = rdb.HRandField(ctx, testKey, 0).Result() + require.NoError(t, err) + require.Len(t, result, 0) + result, err = rdb.HRandField(ctx, "nonexistent-key", 1).Result() + require.NoError(t, err) + require.Len(t, result, 0) + var rv [][]interface{} + resultWithValues, err := rdb.HRandFieldWithValues(ctx, testKey, 5).Result() + require.NoError(t, err) + require.Len(t, resultWithValues, 3) + for _, kv := range resultWithValues { + keys := []interface{}{kv.Key, kv.Value} + rv = append(rv, keys) + } + require.Equal(t, [][]interface{}{ + {"key1", "value1"}, + {"key2", "value2"}, + {"key3", "value3"}, + }, rv) + // TODO: Add test to verify randomness of the selected random fields + }) + + t.Run("HRandField count is negative", func(t *testing.T) { + testKey := "test-hash-1" + require.NoError(t, rdb.Del(ctx, testKey).Err()) + require.NoError(t, rdb.HSet(ctx, testKey, "key1", "value1", "key2", "value2", "key3", "value3").Err()) + result, err := rdb.HRandField(ctx, testKey, -4).Result() + require.NoError(t, err) + require.Len(t, result, 4) + resultWithValues, err := rdb.HRandFieldWithValues(ctx, testKey, -12).Result() + require.NoError(t, err) + require.Len(t, resultWithValues, 12) + // TODO: Add test to verify randomness of the selected random fields + }) + } }