From 6238f8de1063cd83e83b3fddc95a947f399b6bc3 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Mon, 8 Apr 2024 14:46:13 +0800 Subject: [PATCH 01/23] feat: support for the sort command[draft] --- src/commands/cmd_sort.cc | 358 +++++++++++++++ src/commands/commander.h | 6 + tests/gocase/unit/sort/sort_test.go | 657 ++++++++++++++++++++++++++++ 3 files changed, 1021 insertions(+) create mode 100644 src/commands/cmd_sort.cc create mode 100644 tests/gocase/unit/sort/sort_test.go diff --git a/src/commands/cmd_sort.cc b/src/commands/cmd_sort.cc new file mode 100644 index 00000000000..bb8fadc2a1a --- /dev/null +++ b/src/commands/cmd_sort.cc @@ -0,0 +1,358 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include +#include + +#include "command_parser.h" +#include "commander.h" +#include "server/server.h" +#include "storage/redis_db.h" +#include "types/redis_hash.h" +#include "types/redis_list.h" +#include "types/redis_set.h" +#include "types/redis_string.h" +#include "types/redis_zset.h" + +namespace redis { + +class CommandSort : public Commander { + public: + Status Parse(const std::vector &args) override { + CommandParser parser(args, 2); + while (parser.Good()) { + if (parser.EatEqICase("BY")) { + if (parser.Remains() < 1) { + return parser.InvalidSyntax(); + } + sortby_ = GET_OR_RET(parser.TakeStr()); + + if (sortby_.find('*') == std::string::npos) { + dontsort_ = true; + } else { + // TODO Check + /* If BY is specified with a real pattern, we can't accept it in cluster mode, + * unless we can make sure the keys formed by the pattern are in the same slot + * as the key to sort. */ + + /* If BY is specified with a real pattern, we can't accept + * it if no full ACL key access is applied for this command. */ + } + } else if (parser.EatEqICase("LIMIT")) { + if (parser.Remains() < 2) { + return parser.InvalidSyntax(); + } + offset_ = GET_OR_RET(parser.TakeInt()); + count_ = GET_OR_RET(parser.TakeInt()); + } else if (parser.EatEqICase("GET") && parser.Remains() >= 1) { // 有嵌套 + if (parser.Remains() < 1) { + return parser.InvalidSyntax(); + } + // TODO Check + /* If GET is specified with a real pattern, we can't accept it in cluster mode, + * unless we can make sure the keys formed by the pattern are in the same slot + * as the key to sort. */ + + getpatterns_.push_back(GET_OR_RET(parser.TakeStr())); + } else if (parser.EatEqICase("ASC")) { + desc_ = false; + } else if (parser.EatEqICase("DESC")) { + desc_ = true; + } else if (parser.EatEqICase("ALPHA")) { + alpha_ = true; + } else if (parser.EatEqICase("STORE")) { + if (parser.Remains() < 1) { + return parser.InvalidSyntax(); + } + storekey_ = GET_OR_RET(parser.TakeStr()); + } else { + return parser.InvalidSyntax(); + } + } + + return Status::OK(); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + // Get Key Type + redis::Database redis(srv->storage, conn->GetNamespace()); + RedisType type = kRedisNone; + auto s = redis.Type(args_[1], &type); + if (s.ok()) { + if (type >= RedisTypeNames.size()) { + return {Status::RedisExecErr, "Invalid type"}; + } else if (type != RedisType::kRedisList && type != RedisType::kRedisSet && type != RedisType::kRedisZSet) { + *output = Error("WRONGTYPE Operation against a key holding the wrong kind of value"); + return Status::OK(); + } + } else { + return {Status::RedisExecErr, s.ToString()}; + } + + /* When sorting a set with no sort specified, we must sort the output + * so the result is consistent across scripting and replication. + * + * The other types (list, sorted set) will retain their native order + * even if no sort order is requested, so they remain stable across + * scripting and replication. */ + + // TODO c->flags & CLIENT_SCRIPT ??? + // if (dontsort_ && type == RedisType::kRedisZSet && (!storekey_.empty() || c->flags & CLIENT_SCRIPT)) + if (dontsort_ && type == RedisType::kRedisZSet && (!storekey_.empty())) { + /* Force ALPHA sorting */ + dontsort_ = false; + alpha_ = true; + sortby_ = ""; + } + + // Obtain the length of the object to sort. + uint64_t vec_count = 0; + auto list_db = redis::List(srv->storage, conn->GetNamespace()); + auto set_db = redis::Set(srv->storage, conn->GetNamespace()); + auto zset_db = redis::ZSet(srv->storage, conn->GetNamespace()); + + switch (type) { + case RedisType::kRedisList: { + s = list_db.Size(args_[1], &vec_count); + if (!s.ok() && !s.IsNotFound()) { + return {Status::RedisExecErr, s.ToString()}; + } + + break; + } + + case RedisType::kRedisSet: { + s = set_db.Card(args_[1], &vec_count); + if (!s.ok() && !s.IsNotFound()) { + return {Status::RedisExecErr, s.ToString()}; + } + + break; + } + + case RedisType::kRedisZSet: { + s = zset_db.Card(args_[1], &vec_count); + if (!s.ok() && !s.IsNotFound()) { + return {Status::RedisExecErr, s.ToString()}; + } + break; + } + + default: + vec_count = 0; + return {Status::RedisExecErr, "Bad SORT type"}; + } + + long vectorlen = (long)vec_count; + + // Adjust the offset and count of the limit + long offset = offset_ >= vectorlen ? 0 : std::clamp(offset_, 0L, vectorlen - 1); + long count = offset_ >= vectorlen ? 0 : std::clamp(count_, -1L, vectorlen - offset); + if (count == -1L) count = vectorlen - offset; + + // Get the elements that need to be sorted + std::vector str_vec; + if (count != 0) { + if (type == RedisType::kRedisList && dontsort_) { + if (desc_) { + list_db.Range(args_[1], -count - offset, -1 - offset, &str_vec); + std::reverse(str_vec.begin(), str_vec.end()); + } else { + list_db.Range(args_[1], offset, offset + count - 1, &str_vec); + } + } else if (type == RedisType::kRedisList) { + list_db.Range(args_[1], 0, -1, &str_vec); + } else if (type == RedisType::kRedisSet) { + set_db.Members(args_[1], &str_vec); + if (dontsort_) { + str_vec = std::vector(str_vec.begin() + offset, str_vec.begin() + offset + count); + } + } else if (type == RedisType::kRedisZSet && dontsort_) { + std::vector member_scores; + RangeRankSpec spec; + spec.start = (int)offset; + spec.stop = (int)(offset + count - 1); + spec.reversed = desc_; + zset_db.RangeByRank(args_[1], spec, &member_scores, nullptr); + for (size_t i = 0; i < member_scores.size(); ++i) { + str_vec.emplace_back(member_scores[i].member); + } + } else if (type == RedisType::kRedisZSet) { + std::vector member_scores; + zset_db.GetAllMemberScores(args_[1], &member_scores); + for (size_t i = 0; i < member_scores.size(); ++i) { + str_vec.emplace_back(member_scores[i].member); + } + } else { + return {Status::RedisExecErr, "Unknown type"}; + } + } + + std::vector sort_vec(str_vec.size()); + for (size_t i = 0; i < str_vec.size(); ++i) { + sort_vec[i].obj = str_vec[i]; + } + + // Sort by BY, ALPHA, ASC/DESC + if (!dontsort_) { + for (size_t i = 0; i < sort_vec.size(); ++i) { + std::string byval; + if (!sortby_.empty()) { + byval = lookupKeyByPattern(srv, conn, sortby_, str_vec[i]); + if (byval.empty()) continue; + } else { + byval = str_vec[i]; + } + + if (alpha_) { + if (!sortby_.empty()) { + sort_vec[i].v = byval; + } + } else { + try { + sort_vec[i].v = std::stod(byval); + } catch (const std::exception &e) { + *output = redis::Error("One or more scores can't be converted into double"); + return Status::OK(); + } + } + } + + std::sort(sort_vec.begin(), sort_vec.end(), + [this](const RedisSortObject &a, const RedisSortObject &b) { return sortCompare(a, b); }); + + // Gets the element specified by Limit + if (offset != 0 || count != vectorlen) { + sort_vec = std::vector(sort_vec.begin() + offset, sort_vec.begin() + offset + count); + } + } + + // Get the output and perform storage + std::vector output_vec; + + for (size_t i = 0; i < sort_vec.size(); ++i) { + if (getpatterns_.empty()) { + output_vec.emplace_back(sort_vec[i].obj); + } + for (const std::string &pattern : getpatterns_) { + std::string val = lookupKeyByPattern(srv, conn, pattern, sort_vec[i].obj); + if (val.empty()) { + output_vec.emplace_back(conn->NilString()); + } else { + output_vec.emplace_back(val); + } + } + } + + if (storekey_.empty()) { + *output = ArrayOfBulkStrings(output_vec); + } else { + std::vector elems(output_vec.begin(), output_vec.end()); + list_db.Trim(storekey_, 0, -1); + uint64_t new_size = 0; + list_db.Push(storekey_, elems, false, &new_size); + *output = Integer(new_size); + } + + return Status::OK(); + } + + private: + struct RedisSortObject { + std::string obj; + std::variant v; + }; + + bool sortCompare(const RedisSortObject &a, const RedisSortObject &b) const { + if (!alpha_) { + double score_a = std::get(a.v); + double score_b = std::get(b.v); + return !desc_ ? score_a < score_b : score_a > score_b; + } else { + if (!sortby_.empty()) { + std::string cmp_a = std::get(a.v); + std::string cmp_b = std::get(b.v); + return !desc_ ? cmp_a < cmp_b : cmp_a > cmp_b; + } else { + return !desc_ ? a.obj < b.obj : a.obj > b.obj; + } + } + } + + static std::string lookupKeyByPattern(Server *srv, Connection *conn, const std::string &pattern, + const std::string &subst) { + if (pattern == "#") { + return subst; + } + + auto match_pos = pattern.find('*'); + if (match_pos == std::string::npos) { + return ""; + } + + // hash field + std::string field; + auto arrow_pos = pattern.find("->", match_pos + 1); + if (arrow_pos != std::string::npos && arrow_pos + 2 < pattern.size()) { + field = pattern.substr(arrow_pos + 2); + } + + std::string key = pattern.substr(0, match_pos + 1); + key.replace(match_pos, 1, subst); + + std::string value; + if (!field.empty()) { + auto hash_db = redis::Hash(srv->storage, conn->GetNamespace()); + RedisType type = RedisType::kRedisNone; + if (auto s = hash_db.Type(key, &type); !s.ok() || type >= RedisTypeNames.size()) { + return ""; + } + + hash_db.Get(key, field, &value); + } else { + auto string_db = redis::String(srv->storage, conn->GetNamespace()); + RedisType type = RedisType::kRedisNone; + if (auto s = string_db.Type(key, &type); !s.ok() || type >= RedisTypeNames.size()) { + return ""; + } + string_db.Get(key, &value); + } + return value; + } + + std::string sortby_; // BY + bool dontsort_ = false; // DONT SORT + long offset_ = 0; // LIMIT OFFSET + long count_ = -1; // LIMIT COUNT + std::vector getpatterns_; // GET + bool desc_ = false; // ASC/DESC + bool alpha_ = false; // ALPHA + std::string storekey_; // STORE +}; + +class CommandSortRO : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { return Status::OK(); } +}; + +REDIS_REGISTER_COMMANDS(MakeCmdAttr("sort", -2, "write deny-oom movable-keys", 1, 1, 1), + MakeCmdAttr("sortro", -2, "read-only movable-keys", 1, 1, 1)) + +} // namespace redis \ No newline at end of file diff --git a/src/commands/commander.h b/src/commands/commander.h index 1982d8979dc..9bfef93bd1a 100644 --- a/src/commands/commander.h +++ b/src/commands/commander.h @@ -65,6 +65,8 @@ enum CommandFlags : uint64_t { kCmdROScript = 1ULL << 10, // "ro-script" flag for read-only script commands kCmdCluster = 1ULL << 11, // "cluster" flag kCmdNoDBSizeCheck = 1ULL << 12, // "no-dbsize-check" flag + kCmdDenyOom = 1ULL << 13, // "deny-oom" flag + kCmdMovableKeys = 1ULL << 14, // "movable-keys" flag }; class Commander { @@ -185,6 +187,10 @@ inline uint64_t ParseCommandFlags(const std::string &description, const std::str flags |= kCmdCluster; else if (flag == "no-dbsize-check") flags |= kCmdNoDBSizeCheck; + else if (flag == "deny-oom") + flags |= kCmdDenyOom; + else if (flag == "movable-keys") + flags |= kCmdMovableKeys; else { std::cout << fmt::format("Encountered non-existent flag '{}' in command {} in command attribute parsing", flag, cmd_name) diff --git a/tests/gocase/unit/sort/sort_test.go b/tests/gocase/unit/sort/sort_test.go new file mode 100644 index 00000000000..153cfb40aa1 --- /dev/null +++ b/tests/gocase/unit/sort/sort_test.go @@ -0,0 +1,657 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package sort + +import ( + "context" + "github.com/redis/go-redis/v9" + "testing" + + "github.com/apache/kvrocks/tests/gocase/util" + "github.com/stretchr/testify/require" +) + +func TestSortParser(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("SORT Parser", func(t *testing.T) { + rdb.RPush(ctx, "bad-case-key", 5, 4, 3, 2, 1) + + _, err := rdb.Do(ctx, "Sort").Result() + require.EqualError(t, err, "ERR wrong number of arguments") + + _, err = rdb.Do(ctx, "Sort", "bad-case-key", "BadArg").Result() + require.EqualError(t, err, "ERR syntax error") + + _, err = rdb.Do(ctx, "Sort", "bad-case-key", "LIMIT").Result() + require.EqualError(t, err, "ERR syntax error") + + _, err = rdb.Do(ctx, "Sort", "bad-case-key", "LIMIT", 1).Result() + require.EqualError(t, err, "ERR syntax error") + + _, err = rdb.Do(ctx, "Sort", "bad-case-key", "LIMIT", 1, "not-number").Result() + require.EqualError(t, err, "ERR not started as an integer") + + _, err = rdb.Do(ctx, "Sort", "bad-case-key", "STORE").Result() + require.EqualError(t, err, "ERR syntax error") + }) +} + +func TestListSort(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("SORT Basic", func(t *testing.T) { + rdb.LPush(ctx, "today_cost", 30, 1.5, 10, 8) + + sortResult, err := rdb.Sort(ctx, "today_cost", &redis.Sort{}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1.5", "8", "10", "30"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "today_cost", &redis.Sort{Order: "ASC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1.5", "8", "10", "30"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "today_cost", &redis.Sort{Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"30", "10", "8", "1.5"}, sortResult) + }) + + t.Run("SORT ALPHA", func(t *testing.T) { + rdb.LPush(ctx, "website", "www.reddit.com", "www.slashdot.com", "www.infoq.com") + + sortResult, err := rdb.Sort(ctx, "website", &redis.Sort{Alpha: true}).Result() + require.NoError(t, err) + require.Equal(t, []string{"www.infoq.com", "www.reddit.com", "www.slashdot.com"}, sortResult) + + _, err = rdb.Sort(ctx, "website", &redis.Sort{Alpha: false}).Result() + require.EqualError(t, err, "One or more scores can't be converted into double") + }) + + t.Run("SORT LIMIT", func(t *testing.T) { + rdb.RPush(ctx, "rank", 1, 3, 5, 7, 9, 2, 4, 6, 8, 10) + + sortResult, err := rdb.Sort(ctx, "rank", &redis.Sort{Offset: 0, Count: 5}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 0, Count: 5, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"10", "9", "8", "7", "6"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -1, Count: 0}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 10, Count: 0}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 10, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 11, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -1, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -2, Count: 2}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -1, Count: 11}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -2, Count: -1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -2, Count: -2}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + }) + + t.Run("SORT BY + GET", func(t *testing.T) { + rdb.LPush(ctx, "uid", 1, 2, 3, 4) + rdb.MSet(ctx, "user_name_1", "admin", "user_name_2", "jack", "user_name_3", "peter", "user_name_4", "mary") + rdb.MSet(ctx, "user_level_1", 9999, "user_level_2", 10, "user_level_3", 25, "user_level_4", 70) + + sortResult, err := rdb.Sort(ctx, "uid", &redis.Sort{By: "user_level_*"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3", "4", "1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{Get: []string{"user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"admin", "jack", "peter", "mary"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_level_*", Get: []string{"user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"jack", "peter", "mary", "admin"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{Get: []string{"user_level_*", "user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"9999", "admin", "10", "jack", "25", "peter", "70", "mary"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{Get: []string{"#", "user_level_*", "user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "9999", "admin", "2", "10", "jack", "3", "25", "peter", "4", "70", "mary"}, sortResult) + + // not sorted + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"4", "3", "2", "1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 0, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: 2}).Result() + require.NoError(t, err) + require.Equal(t, []string{"3", "2"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 0}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: -1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"3", "2", "1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 0, Count: 1, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: 2, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 1, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 0, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: -1, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3", "4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Get: []string{"#", "user_level_*", "user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"4", "70", "mary", "3", "25", "peter", "2", "10", "jack", "1", "9999", "admin"}, sortResult) + + // pattern with hash tag + rdb.HMSet(ctx, "user_info_1", "name", "admin", "level", 9999) + rdb.HMSet(ctx, "user_info_2", "name", "jack", "level", 10) + rdb.HMSet(ctx, "user_info_3", "name", "peter", "level", 25) + rdb.HMSet(ctx, "user_info_4", "name", "mary", "level", 70) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_info_*->level"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3", "4", "1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_info_*->level", Get: []string{"user_info_*->name"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"jack", "peter", "mary", "admin"}, sortResult) + }) + + t.Run("SORT STORE", func(t *testing.T) { + rdb.RPush(ctx, "numbers", 1, 3, 5, 7, 9, 2, 4, 6, 8, 10) + + storedLen, err := rdb.Do(ctx, "Sort", "numbers", "STORE", "sorted-numbers").Result() + require.NoError(t, err) + require.Equal(t, int64(10), storedLen) + + sortResult, err := rdb.LRange(ctx, "sorted-numbers", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + }) +} + +func TestSetSort(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + + t.Run("SORT Basic", func(t *testing.T) { + rdb.SAdd(ctx, "today_cost", 30, 1.5, 10, 8) + + sortResult, err := rdb.Sort(ctx, "today_cost", &redis.Sort{}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1.5", "8", "10", "30"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "today_cost", &redis.Sort{Order: "ASC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1.5", "8", "10", "30"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "today_cost", &redis.Sort{Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"30", "10", "8", "1.5"}, sortResult) + }) + + t.Run("SORT ALPHA", func(t *testing.T) { + rdb.SAdd(ctx, "website", "www.reddit.com", "www.slashdot.com", "www.infoq.com") + + sortResult, err := rdb.Sort(ctx, "website", &redis.Sort{Alpha: true}).Result() + require.NoError(t, err) + require.Equal(t, []string{"www.infoq.com", "www.reddit.com", "www.slashdot.com"}, sortResult) + + _, err = rdb.Sort(ctx, "website", &redis.Sort{Alpha: false}).Result() + require.EqualError(t, err, "One or more scores can't be converted into double") + }) + + t.Run("SORT LIMIT", func(t *testing.T) { + rdb.SAdd(ctx, "rank", 1, 3, 5, 7, 9, 2, 4, 6, 8, 10) + + sortResult, err := rdb.Sort(ctx, "rank", &redis.Sort{Offset: 0, Count: 5}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 0, Count: 5, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"10", "9", "8", "7", "6"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -1, Count: 0}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 10, Count: 0}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 10, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 11, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -1, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -2, Count: 2}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -1, Count: 11}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -2, Count: -1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -2, Count: -2}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + }) + + t.Run("SORT BY + GET", func(t *testing.T) { + rdb.SAdd(ctx, "uid", 4, 3, 2, 1) + rdb.MSet(ctx, "user_name_1", "admin", "user_name_2", "jack", "user_name_3", "peter", "user_name_4", "mary") + rdb.MSet(ctx, "user_level_1", 9999, "user_level_2", 10, "user_level_3", 25, "user_level_4", 70) + + sortResult, err := rdb.Sort(ctx, "uid", &redis.Sort{By: "user_level_*"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3", "4", "1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{Get: []string{"user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"admin", "jack", "peter", "mary"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_level_*", Get: []string{"user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"jack", "peter", "mary", "admin"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{Get: []string{"user_level_*", "user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"9999", "admin", "10", "jack", "25", "peter", "70", "mary"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{Get: []string{"#", "user_level_*", "user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "9999", "admin", "2", "10", "jack", "3", "25", "peter", "4", "70", "mary"}, sortResult) + + // not sorted + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 0, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: 2}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 0}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: -1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3", "4"}, sortResult) + + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 0, Count: 1, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: 2, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 1, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 0, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: -1, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3", "4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Get: []string{"#", "user_level_*", "user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{ "1", "9999", "admin", "2", "10", "jack", "3", "25", "peter", "4", "70", "mary"}, sortResult) + + // pattern with hash tag + rdb.HMSet(ctx, "user_info_1", "name", "admin", "level", 9999) + rdb.HMSet(ctx, "user_info_2", "name", "jack", "level", 10) + rdb.HMSet(ctx, "user_info_3", "name", "peter", "level", 25) + rdb.HMSet(ctx, "user_info_4", "name", "mary", "level", 70) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_info_*->level"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3", "4", "1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_info_*->level", Get: []string{"user_info_*->name"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"jack", "peter", "mary", "admin"}, sortResult) + }) + + t.Run("SORT STORE", func(t *testing.T) { + rdb.SAdd(ctx, "numbers", 1, 3, 5, 7, 9, 2, 4, 6, 8, 10) + + storedLen, err := rdb.Do(ctx, "Sort", "numbers", "STORE", "sorted-numbers").Result() + require.NoError(t, err) + require.Equal(t, int64(10), storedLen) + + sortResult, err := rdb.LRange(ctx, "sorted-numbers", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + }) +} + +func TestZSetSort(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("SORT Basic", func(t *testing.T) { + rdb.ZAdd(ctx, "today_cost", redis.Z{Score: 30, Member: "1"}, redis.Z{Score: 1.5, Member: "2"}, redis.Z{Score: 10, Member: "3"}, redis.Z{Score: 8, Member: "4"}) + + sortResult, err := rdb.Sort(ctx, "today_cost", &redis.Sort{}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "today_cost", &redis.Sort{Order: "ASC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "today_cost", &redis.Sort{Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"4", "3", "2", "1"}, sortResult) + }) + + t.Run("SORT ALPHA", func(t *testing.T) { + rdb.ZAdd(ctx, "website", redis.Z{Score: 1, Member: "www.reddit.com"}, redis.Z{Score: 2, Member: "www.slashdot.com"}, redis.Z{Score: 3, Member: "www.infoq.com"}) + + sortResult, err := rdb.Sort(ctx, "website", &redis.Sort{Alpha: true}).Result() + require.NoError(t, err) + require.Equal(t, []string{"www.infoq.com", "www.reddit.com", "www.slashdot.com"}, sortResult) + + _, err = rdb.Sort(ctx, "website", &redis.Sort{Alpha: false}).Result() + require.EqualError(t, err, "One or more scores can't be converted into double") + }) + + t.Run("SORT LIMIT", func(t *testing.T) { + rdb.ZAdd(ctx, "rank", + redis.Z{Score: 1, Member: "1"}, + redis.Z{Score: 2, Member: "3" }, + redis.Z{Score: 3, Member: "5"}, + redis.Z{Score: 4, Member: "7" }, + redis.Z{Score: 5, Member: "9"}, + redis.Z{Score: 6, Member: "2" }, + redis.Z{Score: 7, Member: "4"}, + redis.Z{Score: 8, Member: "6" }, + redis.Z{Score: 9, Member: "8"}, + redis.Z{Score: 10, Member: "10"}, + ) + + sortResult, err := rdb.Sort(ctx, "rank", &redis.Sort{Offset: 0, Count: 5}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 0, Count: 5, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"10", "9", "8", "7", "6"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -1, Count: 0}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 10, Count: 0}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 10, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 11, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -1, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -2, Count: 2}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -1, Count: 11}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -2, Count: -1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -2, Count: -2}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + }) + + t.Run("SORT BY + GET", func(t *testing.T) { + rdb.ZAdd(ctx, "uid", + redis.Z{Score: 1, Member: "4"}, + redis.Z{Score: 2, Member: "3"}, + redis.Z{Score: 3, Member: "2"}, + redis.Z{Score: 4, Member: "1"}) + + rdb.MSet(ctx, "user_name_1", "admin", "user_name_2", "jack", "user_name_3", "peter", "user_name_4", "mary") + rdb.MSet(ctx, "user_level_1", 9999, "user_level_2", 10, "user_level_3", 25, "user_level_4", 70) + + sortResult, err := rdb.Sort(ctx, "uid", &redis.Sort{By: "user_level_*"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3", "4", "1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{Get: []string{"user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"admin", "jack", "peter", "mary"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_level_*", Get: []string{"user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"jack", "peter", "mary", "admin"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{Get: []string{"user_level_*", "user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"9999", "admin", "10", "jack", "25", "peter", "70", "mary"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{Get: []string{"#", "user_level_*", "user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "9999", "admin", "2", "10", "jack", "3", "25", "peter", "4", "70", "mary"}, sortResult) + + // not sorted + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"4", "3", "2", "1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 0, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: 2}).Result() + require.NoError(t, err) + require.Equal(t, []string{"3", "2"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 0}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: -1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"3", "2", "1"}, sortResult) + + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 0, Count: 1, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: 2, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 1, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 0, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: -1, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3", "4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Get: []string{"#", "user_level_*", "user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{ "4", "70", "mary", "3", "25", "peter", "2", "10", "jack", "1", "9999", "admin"}, sortResult) + + // pattern with hash tag + rdb.HMSet(ctx, "user_info_1", "name", "admin", "level", 9999) + rdb.HMSet(ctx, "user_info_2", "name", "jack", "level", 10) + rdb.HMSet(ctx, "user_info_3", "name", "peter", "level", 25) + rdb.HMSet(ctx, "user_info_4", "name", "mary", "level", 70) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_info_*->level"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3", "4", "1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_info_*->level", Get: []string{"user_info_*->name"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"jack", "peter", "mary", "admin"}, sortResult) + }) + + t.Run("SORT STORE", func(t *testing.T) { + rdb.ZAdd(ctx, "numbers", + redis.Z{Score: 1, Member: "1"}, + redis.Z{Score: 2, Member: "3" }, + redis.Z{Score: 3, Member: "5"}, + redis.Z{Score: 4, Member: "7" }, + redis.Z{Score: 5, Member: "9"}, + redis.Z{Score: 6, Member: "2" }, + redis.Z{Score: 7, Member: "4"}, + redis.Z{Score: 8, Member: "6" }, + redis.Z{Score: 9, Member: "8"}, + redis.Z{Score: 10, Member: "10"}, + ) + + storedLen, err := rdb.Do(ctx, "Sort", "numbers", "STORE", "sorted-numbers").Result() + require.NoError(t, err) + require.Equal(t, int64(10), storedLen) + + sortResult, err := rdb.LRange(ctx, "sorted-numbers", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + }) +} From 760bba8256c390b447f28b4c27139bbfa0ba397d Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Mon, 8 Apr 2024 15:23:26 +0800 Subject: [PATCH 02/23] fix: delete some comments --- src/commands/cmd_sort.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/commands/cmd_sort.cc b/src/commands/cmd_sort.cc index bb8fadc2a1a..0e64974538b 100644 --- a/src/commands/cmd_sort.cc +++ b/src/commands/cmd_sort.cc @@ -61,7 +61,7 @@ class CommandSort : public Commander { } offset_ = GET_OR_RET(parser.TakeInt()); count_ = GET_OR_RET(parser.TakeInt()); - } else if (parser.EatEqICase("GET") && parser.Remains() >= 1) { // 有嵌套 + } else if (parser.EatEqICase("GET") && parser.Remains() >= 1) { if (parser.Remains() < 1) { return parser.InvalidSyntax(); } From 04f549de6317fdee85c8d9cda4af3e176a600e32 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Fri, 12 Apr 2024 21:43:18 +0800 Subject: [PATCH 03/23] feat: support sort_ro --- src/commands/cmd_sort.cc | 17 ++++++++--------- tests/gocase/unit/sort/sort_test.go | 27 +++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/src/commands/cmd_sort.cc b/src/commands/cmd_sort.cc index 0e64974538b..541318ed07d 100644 --- a/src/commands/cmd_sort.cc +++ b/src/commands/cmd_sort.cc @@ -33,6 +33,7 @@ namespace redis { +template class CommandSort : public Commander { public: Status Parse(const std::vector &args) override { @@ -59,8 +60,8 @@ class CommandSort : public Commander { if (parser.Remains() < 2) { return parser.InvalidSyntax(); } - offset_ = GET_OR_RET(parser.TakeInt()); - count_ = GET_OR_RET(parser.TakeInt()); + offset_ = GET_OR_RET(parser.template TakeInt()); + count_ = GET_OR_RET(parser.template TakeInt()); } else if (parser.EatEqICase("GET") && parser.Remains() >= 1) { if (parser.Remains() < 1) { return parser.InvalidSyntax(); @@ -78,6 +79,9 @@ class CommandSort : public Commander { } else if (parser.EatEqICase("ALPHA")) { alpha_ = true; } else if (parser.EatEqICase("STORE")) { + if constexpr (ReadOnly) { + return parser.InvalidSyntax(); + } if (parser.Remains() < 1) { return parser.InvalidSyntax(); } @@ -347,12 +351,7 @@ class CommandSort : public Commander { std::string storekey_; // STORE }; -class CommandSortRO : public Commander { - public: - Status Execute(Server *srv, Connection *conn, std::string *output) override { return Status::OK(); } -}; - -REDIS_REGISTER_COMMANDS(MakeCmdAttr("sort", -2, "write deny-oom movable-keys", 1, 1, 1), - MakeCmdAttr("sortro", -2, "read-only movable-keys", 1, 1, 1)) +REDIS_REGISTER_COMMANDS(MakeCmdAttr>("sort", -2, "write deny-oom movable-keys", 1, 1, 1), + MakeCmdAttr>("sort_ro", -2, "read-only movable-keys", 1, 1, 1)) } // namespace redis \ No newline at end of file diff --git a/tests/gocase/unit/sort/sort_test.go b/tests/gocase/unit/sort/sort_test.go index 153cfb40aa1..0a8d02e3a30 100644 --- a/tests/gocase/unit/sort/sort_test.go +++ b/tests/gocase/unit/sort/sort_test.go @@ -56,6 +56,9 @@ func TestSortParser(t *testing.T) { _, err = rdb.Do(ctx, "Sort", "bad-case-key", "STORE").Result() require.EqualError(t, err, "ERR syntax error") + + _, err = rdb.Do(ctx, "Sort_RO", "bad-case-key", "STORE", "store_ro_key").Result() + require.EqualError(t, err, "ERR syntax error") }) } @@ -81,6 +84,14 @@ func TestListSort(t *testing.T) { sortResult, err = rdb.Sort(ctx, "today_cost", &redis.Sort{Order: "DESC"}).Result() require.NoError(t, err) require.Equal(t, []string{"30", "10", "8", "1.5"}, sortResult) + + sortResult, err = rdb.SortRO(ctx, "today_cost", &redis.Sort{Order: "ASC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1.5", "8", "10", "30"}, sortResult) + + sortResult, err = rdb.SortRO(ctx, "today_cost", &redis.Sort{Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"30", "10", "8", "1.5"}, sortResult) }) t.Run("SORT ALPHA", func(t *testing.T) { @@ -271,6 +282,14 @@ func TestSetSort(t *testing.T) { sortResult, err = rdb.Sort(ctx, "today_cost", &redis.Sort{Order: "DESC"}).Result() require.NoError(t, err) require.Equal(t, []string{"30", "10", "8", "1.5"}, sortResult) + + sortResult, err = rdb.SortRO(ctx, "today_cost", &redis.Sort{Order: "ASC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1.5", "8", "10", "30"}, sortResult) + + sortResult, err = rdb.SortRO(ctx, "today_cost", &redis.Sort{Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"30", "10", "8", "1.5"}, sortResult) }) t.Run("SORT ALPHA", func(t *testing.T) { @@ -461,6 +480,14 @@ func TestZSetSort(t *testing.T) { sortResult, err = rdb.Sort(ctx, "today_cost", &redis.Sort{Order: "DESC"}).Result() require.NoError(t, err) require.Equal(t, []string{"4", "3", "2", "1"}, sortResult) + + sortResult, err = rdb.SortRO(ctx, "today_cost", &redis.Sort{Order: "ASC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4"}, sortResult) + + sortResult, err = rdb.SortRO(ctx, "today_cost", &redis.Sort{Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"4", "3", "2", "1"}, sortResult) }) t.Run("SORT ALPHA", func(t *testing.T) { From 2a77348c187f89dfdb08f684b3669f2c4e74b2b2 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Fri, 12 Apr 2024 22:26:35 +0800 Subject: [PATCH 04/23] style: clang format --- src/commands/cmd_sort.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/commands/cmd_sort.cc b/src/commands/cmd_sort.cc index 541318ed07d..fd8e03191e3 100644 --- a/src/commands/cmd_sort.cc +++ b/src/commands/cmd_sort.cc @@ -33,7 +33,7 @@ namespace redis { -template +template class CommandSort : public Commander { public: Status Parse(const std::vector &args) override { From bcf78d464434fe6cd05781d2e65d80b46bc691c1 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Fri, 12 Apr 2024 23:50:19 +0800 Subject: [PATCH 05/23] style: golangci-lint --- tests/gocase/unit/sort/sort_test.go | 32 ++++++++++++++--------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/tests/gocase/unit/sort/sort_test.go b/tests/gocase/unit/sort/sort_test.go index 0a8d02e3a30..bccbb88f9a4 100644 --- a/tests/gocase/unit/sort/sort_test.go +++ b/tests/gocase/unit/sort/sort_test.go @@ -21,9 +21,10 @@ package sort import ( "context" - "github.com/redis/go-redis/v9" "testing" + "github.com/redis/go-redis/v9" + "github.com/apache/kvrocks/tests/gocase/util" "github.com/stretchr/testify/require" ) @@ -267,7 +268,6 @@ func TestSetSort(t *testing.T) { rdb := srv.NewClient() defer func() { require.NoError(t, rdb.Close()) }() - t.Run("SORT Basic", func(t *testing.T) { rdb.SAdd(ctx, "today_cost", 30, 1.5, 10, 8) @@ -401,7 +401,6 @@ func TestSetSort(t *testing.T) { require.NoError(t, err) require.Equal(t, []string{"2", "3", "4"}, sortResult) - sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Order: "DESC"}).Result() require.NoError(t, err) require.Equal(t, []string{"1", "2", "3", "4"}, sortResult) @@ -428,7 +427,7 @@ func TestSetSort(t *testing.T) { sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Get: []string{"#", "user_level_*", "user_name_*"}}).Result() require.NoError(t, err) - require.Equal(t, []string{ "1", "9999", "admin", "2", "10", "jack", "3", "25", "peter", "4", "70", "mary"}, sortResult) + require.Equal(t, []string{"1", "9999", "admin", "2", "10", "jack", "3", "25", "peter", "4", "70", "mary"}, sortResult) // pattern with hash tag rdb.HMSet(ctx, "user_info_1", "name", "admin", "level", 9999) @@ -502,15 +501,15 @@ func TestZSetSort(t *testing.T) { }) t.Run("SORT LIMIT", func(t *testing.T) { - rdb.ZAdd(ctx, "rank", + rdb.ZAdd(ctx, "rank", redis.Z{Score: 1, Member: "1"}, - redis.Z{Score: 2, Member: "3" }, + redis.Z{Score: 2, Member: "3"}, redis.Z{Score: 3, Member: "5"}, - redis.Z{Score: 4, Member: "7" }, + redis.Z{Score: 4, Member: "7"}, redis.Z{Score: 5, Member: "9"}, - redis.Z{Score: 6, Member: "2" }, + redis.Z{Score: 6, Member: "2"}, redis.Z{Score: 7, Member: "4"}, - redis.Z{Score: 8, Member: "6" }, + redis.Z{Score: 8, Member: "6"}, redis.Z{Score: 9, Member: "8"}, redis.Z{Score: 10, Member: "10"}, ) @@ -566,7 +565,7 @@ func TestZSetSort(t *testing.T) { redis.Z{Score: 2, Member: "3"}, redis.Z{Score: 3, Member: "2"}, redis.Z{Score: 4, Member: "1"}) - + rdb.MSet(ctx, "user_name_1", "admin", "user_name_2", "jack", "user_name_3", "peter", "user_name_4", "mary") rdb.MSet(ctx, "user_level_1", 9999, "user_level_2", 10, "user_level_3", 25, "user_level_4", 70) @@ -615,7 +614,6 @@ func TestZSetSort(t *testing.T) { require.NoError(t, err) require.Equal(t, []string{"3", "2", "1"}, sortResult) - sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Order: "DESC"}).Result() require.NoError(t, err) require.Equal(t, []string{"1", "2", "3", "4"}, sortResult) @@ -642,7 +640,7 @@ func TestZSetSort(t *testing.T) { sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Get: []string{"#", "user_level_*", "user_name_*"}}).Result() require.NoError(t, err) - require.Equal(t, []string{ "4", "70", "mary", "3", "25", "peter", "2", "10", "jack", "1", "9999", "admin"}, sortResult) + require.Equal(t, []string{"4", "70", "mary", "3", "25", "peter", "2", "10", "jack", "1", "9999", "admin"}, sortResult) // pattern with hash tag rdb.HMSet(ctx, "user_info_1", "name", "admin", "level", 9999) @@ -660,15 +658,15 @@ func TestZSetSort(t *testing.T) { }) t.Run("SORT STORE", func(t *testing.T) { - rdb.ZAdd(ctx, "numbers", + rdb.ZAdd(ctx, "numbers", redis.Z{Score: 1, Member: "1"}, - redis.Z{Score: 2, Member: "3" }, + redis.Z{Score: 2, Member: "3"}, redis.Z{Score: 3, Member: "5"}, - redis.Z{Score: 4, Member: "7" }, + redis.Z{Score: 4, Member: "7"}, redis.Z{Score: 5, Member: "9"}, - redis.Z{Score: 6, Member: "2" }, + redis.Z{Score: 6, Member: "2"}, redis.Z{Score: 7, Member: "4"}, - redis.Z{Score: 8, Member: "6" }, + redis.Z{Score: 8, Member: "6"}, redis.Z{Score: 9, Member: "8"}, redis.Z{Score: 10, Member: "10"}, ) From dd4c63823df4bb1633bf5e089ed7bf2e56fc659b Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Sat, 13 Apr 2024 16:20:05 +0800 Subject: [PATCH 06/23] fix: sorting a set with no sort specified and TODO --- src/commands/cmd_sort.cc | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/commands/cmd_sort.cc b/src/commands/cmd_sort.cc index fd8e03191e3..8ecab414080 100644 --- a/src/commands/cmd_sort.cc +++ b/src/commands/cmd_sort.cc @@ -48,12 +48,11 @@ class CommandSort : public Commander { if (sortby_.find('*') == std::string::npos) { dontsort_ = true; } else { - // TODO Check - /* If BY is specified with a real pattern, we can't accept it in cluster mode, + /* TODO: + * If BY is specified with a real pattern, we can't accept it in cluster mode, * unless we can make sure the keys formed by the pattern are in the same slot - * as the key to sort. */ - - /* If BY is specified with a real pattern, we can't accept + * as the key to sort. + * If BY is specified with a real pattern, we can't accept * it if no full ACL key access is applied for this command. */ } } else if (parser.EatEqICase("LIMIT")) { @@ -66,11 +65,10 @@ class CommandSort : public Commander { if (parser.Remains() < 1) { return parser.InvalidSyntax(); } - // TODO Check - /* If GET is specified with a real pattern, we can't accept it in cluster mode, + /* TODO: + * If GET is specified with a real pattern, we can't accept it in cluster mode, * unless we can make sure the keys formed by the pattern are in the same slot * as the key to sort. */ - getpatterns_.push_back(GET_OR_RET(parser.TakeStr())); } else if (parser.EatEqICase("ASC")) { desc_ = false; @@ -115,11 +113,11 @@ class CommandSort : public Commander { * * The other types (list, sorted set) will retain their native order * even if no sort order is requested, so they remain stable across - * scripting and replication. */ + * scripting and replication. + * + * TODO: support CLIENT_SCRIPT flag, (!storekey_.empty() || c->flags & CLIENT_SCRIPT)) */ - // TODO c->flags & CLIENT_SCRIPT ??? - // if (dontsort_ && type == RedisType::kRedisZSet && (!storekey_.empty() || c->flags & CLIENT_SCRIPT)) - if (dontsort_ && type == RedisType::kRedisZSet && (!storekey_.empty())) { + if (dontsort_ && type == RedisType::kRedisSet && (!storekey_.empty())) { /* Force ALPHA sorting */ dontsort_ = false; alpha_ = true; @@ -195,14 +193,14 @@ class CommandSort : public Commander { spec.stop = (int)(offset + count - 1); spec.reversed = desc_; zset_db.RangeByRank(args_[1], spec, &member_scores, nullptr); - for (size_t i = 0; i < member_scores.size(); ++i) { - str_vec.emplace_back(member_scores[i].member); + for (auto &member_score : member_scores) { + str_vec.emplace_back(member_score.member); } } else if (type == RedisType::kRedisZSet) { std::vector member_scores; zset_db.GetAllMemberScores(args_[1], &member_scores); - for (size_t i = 0; i < member_scores.size(); ++i) { - str_vec.emplace_back(member_scores[i].member); + for (auto &member_score : member_scores) { + str_vec.emplace_back(member_score.member); } } else { return {Status::RedisExecErr, "Unknown type"}; From db94c8d5e74a6d21e2e7a84f2d3de6225cb34494 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Sat, 13 Apr 2024 16:20:24 +0800 Subject: [PATCH 07/23] feat: sorting a set with no sort specified testcase --- tests/gocase/unit/sort/sort_test.go | 32 +++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/gocase/unit/sort/sort_test.go b/tests/gocase/unit/sort/sort_test.go index bccbb88f9a4..f454134618e 100644 --- a/tests/gocase/unit/sort/sort_test.go +++ b/tests/gocase/unit/sort/sort_test.go @@ -257,6 +257,15 @@ func TestListSort(t *testing.T) { sortResult, err := rdb.LRange(ctx, "sorted-numbers", 0, -1).Result() require.NoError(t, err) require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + + rdb.LPush(ctx, "no-force-alpha-sort-key", 123, 3, 21) + storedLen, err = rdb.Do(ctx, "Sort", "no-force-alpha-sort-key", "BY", "not-exists-key", "STORE", "no-alpha-sorted").Result() + require.NoError(t, err) + require.Equal(t, int64(3), storedLen) + + sortResult, err = rdb.LRange(ctx, "no-alpha-sorted", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"21", "3", "123"}, sortResult) }) } @@ -454,6 +463,15 @@ func TestSetSort(t *testing.T) { sortResult, err := rdb.LRange(ctx, "sorted-numbers", 0, -1).Result() require.NoError(t, err) require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + + rdb.SAdd(ctx, "force-alpha-sort-key", 123, 3, 21) + storedLen, err = rdb.Do(ctx, "Sort", "force-alpha-sort-key", "BY", "not-exists-key", "STORE", "alpha-sorted").Result() + require.NoError(t, err) + require.Equal(t, int64(3), storedLen) + + sortResult, err = rdb.LRange(ctx, "alpha-sorted", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"123", "21", "3"}, sortResult) }) } @@ -678,5 +696,19 @@ func TestZSetSort(t *testing.T) { sortResult, err := rdb.LRange(ctx, "sorted-numbers", 0, -1).Result() require.NoError(t, err) require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + + rdb.ZAdd(ctx, "numbers", + redis.Z{Score: 1, Member: "123"}, + redis.Z{Score: 2, Member: "3"}, + redis.Z{Score: 3, Member: "21"}, + ) + + storedLen, err = rdb.Do(ctx, "Sort", "no-force-alpha-sort-key", "BY", "not-exists-key", "STORE", "no-alpha-sorted").Result() + require.NoError(t, err) + require.Equal(t, int64(3), storedLen) + + sortResult, err = rdb.LRange(ctx, "no-alpha-sorted", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"21", "3", "123"}, sortResult) }) } From d70aa4a8fd28b8227464e5add86107ce1ce9d611 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Sat, 13 Apr 2024 17:01:53 +0800 Subject: [PATCH 08/23] fix: TestZsetSort --- tests/gocase/unit/sort/sort_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/gocase/unit/sort/sort_test.go b/tests/gocase/unit/sort/sort_test.go index f454134618e..3bc360b52ef 100644 --- a/tests/gocase/unit/sort/sort_test.go +++ b/tests/gocase/unit/sort/sort_test.go @@ -697,7 +697,7 @@ func TestZSetSort(t *testing.T) { require.NoError(t, err) require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) - rdb.ZAdd(ctx, "numbers", + rdb.ZAdd(ctx, "no-force-alpha-sort-key", redis.Z{Score: 1, Member: "123"}, redis.Z{Score: 2, Member: "3"}, redis.Z{Score: 3, Member: "21"}, From d63bab8a99ebea0826809007a5cd053cb31bf1e3 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Sat, 13 Apr 2024 22:23:02 +0800 Subject: [PATCH 09/23] fix: TestZsetSort --- tests/gocase/unit/sort/sort_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/gocase/unit/sort/sort_test.go b/tests/gocase/unit/sort/sort_test.go index 3bc360b52ef..3bbd65a2abc 100644 --- a/tests/gocase/unit/sort/sort_test.go +++ b/tests/gocase/unit/sort/sort_test.go @@ -709,6 +709,6 @@ func TestZSetSort(t *testing.T) { sortResult, err = rdb.LRange(ctx, "no-alpha-sorted", 0, -1).Result() require.NoError(t, err) - require.Equal(t, []string{"21", "3", "123"}, sortResult) + require.Equal(t, []string{"123", "3", "21"}, sortResult) }) } From 90d970d364d68ebd273f0135fe76359f1994e36f Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Sun, 21 Apr 2024 15:05:26 +0800 Subject: [PATCH 10/23] refactor: move cmd_sort to cmd_key --- src/commands/cmd_key.cc | 109 +++++++++++- src/commands/cmd_sort.cc | 355 --------------------------------------- src/commands/commander.h | 6 - src/storage/redis_db.cc | 209 +++++++++++++++++++++++ src/storage/redis_db.h | 25 +++ 5 files changed, 342 insertions(+), 362 deletions(-) delete mode 100644 src/commands/cmd_sort.cc diff --git a/src/commands/cmd_key.cc b/src/commands/cmd_key.cc index 589fa1ed1ae..b4bf9e1e65c 100644 --- a/src/commands/cmd_key.cc +++ b/src/commands/cmd_key.cc @@ -424,6 +424,111 @@ class CommandCopy : public Commander { bool replace_ = false; }; +template +class CommandSort : public Commander { + public: + Status Parse(const std::vector &args) override { + CommandParser parser(args, 2); + while (parser.Good()) { + if (parser.EatEqICase("BY")) { + if (parser.Remains() < 1) { + return parser.InvalidSyntax(); + } + sort_argument_.sortby = GET_OR_RET(parser.TakeStr()); + + if (sort_argument_.sortby.find('*') == std::string::npos) { + sort_argument_.dontsort = true; + } else { + /* TODO: + * If BY is specified with a real pattern, we can't accept it in cluster mode, + * unless we can make sure the keys formed by the pattern are in the same slot + * as the key to sort. + * If BY is specified with a real pattern, we can't accept + * it if no full ACL key access is applied for this command. */ + } + } else if (parser.EatEqICase("LIMIT")) { + if (parser.Remains() < 2) { + return parser.InvalidSyntax(); + } + sort_argument_.offset = GET_OR_RET(parser.template TakeInt()); + sort_argument_.count = GET_OR_RET(parser.template TakeInt()); + } else if (parser.EatEqICase("GET") && parser.Remains() >= 1) { + if (parser.Remains() < 1) { + return parser.InvalidSyntax(); + } + /* TODO: + * If GET is specified with a real pattern, we can't accept it in cluster mode, + * unless we can make sure the keys formed by the pattern are in the same slot + * as the key to sort. */ + sort_argument_.getpatterns.push_back(GET_OR_RET(parser.TakeStr())); + } else if (parser.EatEqICase("ASC")) { + sort_argument_.desc = false; + } else if (parser.EatEqICase("DESC")) { + sort_argument_.desc = true; + } else if (parser.EatEqICase("ALPHA")) { + sort_argument_.alpha = true; + } else if (parser.EatEqICase("STORE")) { + if constexpr (ReadOnly) { + return parser.InvalidSyntax(); + } + if (parser.Remains() < 1) { + return parser.InvalidSyntax(); + } + sort_argument_.storekey = GET_OR_RET(parser.TakeStr()); + } else { + return parser.InvalidSyntax(); + } + } + + return Status::OK(); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::Database redis(srv->storage, conn->GetNamespace()); + RedisType type = kRedisNone; + auto s = redis.Type(args_[1], &type); + if (s.ok()) { + if (type >= RedisTypeNames.size()) { + return {Status::RedisExecErr, "Invalid type"}; + } else if (type != RedisType::kRedisList && type != RedisType::kRedisSet && type != RedisType::kRedisZSet) { + *output = Error("WRONGTYPE Operation against a key holding the wrong kind of value"); + return Status::OK(); + } + } else { + return {Status::RedisExecErr, s.ToString()}; + } + + std::vector output_vec; + Database::SortResult res = Database::SortResult::DONE; + s = redis.Sort(type, args_[1], sort_argument_, conn->GetProtocolVersion(), &output_vec, &res); + + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + + switch (res) { + case Database::SortResult::UNKNOW_TYPE: + *output = redis::Error("Unkown Type"); + break; + case Database::SortResult::DOUBLE_CONVERT_ERROR: + *output = redis::Error("One or more scores can't be converted into double"); + break; + case Database::SortResult::DONE: + if (sort_argument_.storekey.empty()) { + *output = ArrayOfBulkStrings(output_vec); + } else { + *output = Integer(output_vec.size()); + } + break; + } + + return Status::OK(); + } + + private: + SortArgument sort_argument_; +}; + REDIS_REGISTER_COMMANDS(MakeCmdAttr("ttl", 2, "read-only", 1, 1, 1), MakeCmdAttr("pttl", 2, "read-only", 1, 1, 1), MakeCmdAttr("type", 2, "read-only", 1, 1, 1), @@ -442,6 +547,8 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr("ttl", 2, "read-only", 1, 1, 1), MakeCmdAttr("unlink", -2, "write no-dbsize-check", 1, -1, 1), MakeCmdAttr("rename", 3, "write", 1, 2, 1), MakeCmdAttr("renamenx", 3, "write", 1, 2, 1), - MakeCmdAttr("copy", -3, "write", 1, 2, 1), ) + MakeCmdAttr("copy", -3, "write", 1, 2, 1), + MakeCmdAttr>("sort", -2, "write", 1, 1, 1), + MakeCmdAttr>("sort_ro", -2, "read-only", 1, 1, 1)) } // namespace redis diff --git a/src/commands/cmd_sort.cc b/src/commands/cmd_sort.cc deleted file mode 100644 index 8ecab414080..00000000000 --- a/src/commands/cmd_sort.cc +++ /dev/null @@ -1,355 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - * - */ - -#include -#include - -#include "command_parser.h" -#include "commander.h" -#include "server/server.h" -#include "storage/redis_db.h" -#include "types/redis_hash.h" -#include "types/redis_list.h" -#include "types/redis_set.h" -#include "types/redis_string.h" -#include "types/redis_zset.h" - -namespace redis { - -template -class CommandSort : public Commander { - public: - Status Parse(const std::vector &args) override { - CommandParser parser(args, 2); - while (parser.Good()) { - if (parser.EatEqICase("BY")) { - if (parser.Remains() < 1) { - return parser.InvalidSyntax(); - } - sortby_ = GET_OR_RET(parser.TakeStr()); - - if (sortby_.find('*') == std::string::npos) { - dontsort_ = true; - } else { - /* TODO: - * If BY is specified with a real pattern, we can't accept it in cluster mode, - * unless we can make sure the keys formed by the pattern are in the same slot - * as the key to sort. - * If BY is specified with a real pattern, we can't accept - * it if no full ACL key access is applied for this command. */ - } - } else if (parser.EatEqICase("LIMIT")) { - if (parser.Remains() < 2) { - return parser.InvalidSyntax(); - } - offset_ = GET_OR_RET(parser.template TakeInt()); - count_ = GET_OR_RET(parser.template TakeInt()); - } else if (parser.EatEqICase("GET") && parser.Remains() >= 1) { - if (parser.Remains() < 1) { - return parser.InvalidSyntax(); - } - /* TODO: - * If GET is specified with a real pattern, we can't accept it in cluster mode, - * unless we can make sure the keys formed by the pattern are in the same slot - * as the key to sort. */ - getpatterns_.push_back(GET_OR_RET(parser.TakeStr())); - } else if (parser.EatEqICase("ASC")) { - desc_ = false; - } else if (parser.EatEqICase("DESC")) { - desc_ = true; - } else if (parser.EatEqICase("ALPHA")) { - alpha_ = true; - } else if (parser.EatEqICase("STORE")) { - if constexpr (ReadOnly) { - return parser.InvalidSyntax(); - } - if (parser.Remains() < 1) { - return parser.InvalidSyntax(); - } - storekey_ = GET_OR_RET(parser.TakeStr()); - } else { - return parser.InvalidSyntax(); - } - } - - return Status::OK(); - } - - Status Execute(Server *srv, Connection *conn, std::string *output) override { - // Get Key Type - redis::Database redis(srv->storage, conn->GetNamespace()); - RedisType type = kRedisNone; - auto s = redis.Type(args_[1], &type); - if (s.ok()) { - if (type >= RedisTypeNames.size()) { - return {Status::RedisExecErr, "Invalid type"}; - } else if (type != RedisType::kRedisList && type != RedisType::kRedisSet && type != RedisType::kRedisZSet) { - *output = Error("WRONGTYPE Operation against a key holding the wrong kind of value"); - return Status::OK(); - } - } else { - return {Status::RedisExecErr, s.ToString()}; - } - - /* When sorting a set with no sort specified, we must sort the output - * so the result is consistent across scripting and replication. - * - * The other types (list, sorted set) will retain their native order - * even if no sort order is requested, so they remain stable across - * scripting and replication. - * - * TODO: support CLIENT_SCRIPT flag, (!storekey_.empty() || c->flags & CLIENT_SCRIPT)) */ - - if (dontsort_ && type == RedisType::kRedisSet && (!storekey_.empty())) { - /* Force ALPHA sorting */ - dontsort_ = false; - alpha_ = true; - sortby_ = ""; - } - - // Obtain the length of the object to sort. - uint64_t vec_count = 0; - auto list_db = redis::List(srv->storage, conn->GetNamespace()); - auto set_db = redis::Set(srv->storage, conn->GetNamespace()); - auto zset_db = redis::ZSet(srv->storage, conn->GetNamespace()); - - switch (type) { - case RedisType::kRedisList: { - s = list_db.Size(args_[1], &vec_count); - if (!s.ok() && !s.IsNotFound()) { - return {Status::RedisExecErr, s.ToString()}; - } - - break; - } - - case RedisType::kRedisSet: { - s = set_db.Card(args_[1], &vec_count); - if (!s.ok() && !s.IsNotFound()) { - return {Status::RedisExecErr, s.ToString()}; - } - - break; - } - - case RedisType::kRedisZSet: { - s = zset_db.Card(args_[1], &vec_count); - if (!s.ok() && !s.IsNotFound()) { - return {Status::RedisExecErr, s.ToString()}; - } - break; - } - - default: - vec_count = 0; - return {Status::RedisExecErr, "Bad SORT type"}; - } - - long vectorlen = (long)vec_count; - - // Adjust the offset and count of the limit - long offset = offset_ >= vectorlen ? 0 : std::clamp(offset_, 0L, vectorlen - 1); - long count = offset_ >= vectorlen ? 0 : std::clamp(count_, -1L, vectorlen - offset); - if (count == -1L) count = vectorlen - offset; - - // Get the elements that need to be sorted - std::vector str_vec; - if (count != 0) { - if (type == RedisType::kRedisList && dontsort_) { - if (desc_) { - list_db.Range(args_[1], -count - offset, -1 - offset, &str_vec); - std::reverse(str_vec.begin(), str_vec.end()); - } else { - list_db.Range(args_[1], offset, offset + count - 1, &str_vec); - } - } else if (type == RedisType::kRedisList) { - list_db.Range(args_[1], 0, -1, &str_vec); - } else if (type == RedisType::kRedisSet) { - set_db.Members(args_[1], &str_vec); - if (dontsort_) { - str_vec = std::vector(str_vec.begin() + offset, str_vec.begin() + offset + count); - } - } else if (type == RedisType::kRedisZSet && dontsort_) { - std::vector member_scores; - RangeRankSpec spec; - spec.start = (int)offset; - spec.stop = (int)(offset + count - 1); - spec.reversed = desc_; - zset_db.RangeByRank(args_[1], spec, &member_scores, nullptr); - for (auto &member_score : member_scores) { - str_vec.emplace_back(member_score.member); - } - } else if (type == RedisType::kRedisZSet) { - std::vector member_scores; - zset_db.GetAllMemberScores(args_[1], &member_scores); - for (auto &member_score : member_scores) { - str_vec.emplace_back(member_score.member); - } - } else { - return {Status::RedisExecErr, "Unknown type"}; - } - } - - std::vector sort_vec(str_vec.size()); - for (size_t i = 0; i < str_vec.size(); ++i) { - sort_vec[i].obj = str_vec[i]; - } - - // Sort by BY, ALPHA, ASC/DESC - if (!dontsort_) { - for (size_t i = 0; i < sort_vec.size(); ++i) { - std::string byval; - if (!sortby_.empty()) { - byval = lookupKeyByPattern(srv, conn, sortby_, str_vec[i]); - if (byval.empty()) continue; - } else { - byval = str_vec[i]; - } - - if (alpha_) { - if (!sortby_.empty()) { - sort_vec[i].v = byval; - } - } else { - try { - sort_vec[i].v = std::stod(byval); - } catch (const std::exception &e) { - *output = redis::Error("One or more scores can't be converted into double"); - return Status::OK(); - } - } - } - - std::sort(sort_vec.begin(), sort_vec.end(), - [this](const RedisSortObject &a, const RedisSortObject &b) { return sortCompare(a, b); }); - - // Gets the element specified by Limit - if (offset != 0 || count != vectorlen) { - sort_vec = std::vector(sort_vec.begin() + offset, sort_vec.begin() + offset + count); - } - } - - // Get the output and perform storage - std::vector output_vec; - - for (size_t i = 0; i < sort_vec.size(); ++i) { - if (getpatterns_.empty()) { - output_vec.emplace_back(sort_vec[i].obj); - } - for (const std::string &pattern : getpatterns_) { - std::string val = lookupKeyByPattern(srv, conn, pattern, sort_vec[i].obj); - if (val.empty()) { - output_vec.emplace_back(conn->NilString()); - } else { - output_vec.emplace_back(val); - } - } - } - - if (storekey_.empty()) { - *output = ArrayOfBulkStrings(output_vec); - } else { - std::vector elems(output_vec.begin(), output_vec.end()); - list_db.Trim(storekey_, 0, -1); - uint64_t new_size = 0; - list_db.Push(storekey_, elems, false, &new_size); - *output = Integer(new_size); - } - - return Status::OK(); - } - - private: - struct RedisSortObject { - std::string obj; - std::variant v; - }; - - bool sortCompare(const RedisSortObject &a, const RedisSortObject &b) const { - if (!alpha_) { - double score_a = std::get(a.v); - double score_b = std::get(b.v); - return !desc_ ? score_a < score_b : score_a > score_b; - } else { - if (!sortby_.empty()) { - std::string cmp_a = std::get(a.v); - std::string cmp_b = std::get(b.v); - return !desc_ ? cmp_a < cmp_b : cmp_a > cmp_b; - } else { - return !desc_ ? a.obj < b.obj : a.obj > b.obj; - } - } - } - - static std::string lookupKeyByPattern(Server *srv, Connection *conn, const std::string &pattern, - const std::string &subst) { - if (pattern == "#") { - return subst; - } - - auto match_pos = pattern.find('*'); - if (match_pos == std::string::npos) { - return ""; - } - - // hash field - std::string field; - auto arrow_pos = pattern.find("->", match_pos + 1); - if (arrow_pos != std::string::npos && arrow_pos + 2 < pattern.size()) { - field = pattern.substr(arrow_pos + 2); - } - - std::string key = pattern.substr(0, match_pos + 1); - key.replace(match_pos, 1, subst); - - std::string value; - if (!field.empty()) { - auto hash_db = redis::Hash(srv->storage, conn->GetNamespace()); - RedisType type = RedisType::kRedisNone; - if (auto s = hash_db.Type(key, &type); !s.ok() || type >= RedisTypeNames.size()) { - return ""; - } - - hash_db.Get(key, field, &value); - } else { - auto string_db = redis::String(srv->storage, conn->GetNamespace()); - RedisType type = RedisType::kRedisNone; - if (auto s = string_db.Type(key, &type); !s.ok() || type >= RedisTypeNames.size()) { - return ""; - } - string_db.Get(key, &value); - } - return value; - } - - std::string sortby_; // BY - bool dontsort_ = false; // DONT SORT - long offset_ = 0; // LIMIT OFFSET - long count_ = -1; // LIMIT COUNT - std::vector getpatterns_; // GET - bool desc_ = false; // ASC/DESC - bool alpha_ = false; // ALPHA - std::string storekey_; // STORE -}; - -REDIS_REGISTER_COMMANDS(MakeCmdAttr>("sort", -2, "write deny-oom movable-keys", 1, 1, 1), - MakeCmdAttr>("sort_ro", -2, "read-only movable-keys", 1, 1, 1)) - -} // namespace redis \ No newline at end of file diff --git a/src/commands/commander.h b/src/commands/commander.h index 9bfef93bd1a..1982d8979dc 100644 --- a/src/commands/commander.h +++ b/src/commands/commander.h @@ -65,8 +65,6 @@ enum CommandFlags : uint64_t { kCmdROScript = 1ULL << 10, // "ro-script" flag for read-only script commands kCmdCluster = 1ULL << 11, // "cluster" flag kCmdNoDBSizeCheck = 1ULL << 12, // "no-dbsize-check" flag - kCmdDenyOom = 1ULL << 13, // "deny-oom" flag - kCmdMovableKeys = 1ULL << 14, // "movable-keys" flag }; class Commander { @@ -187,10 +185,6 @@ inline uint64_t ParseCommandFlags(const std::string &description, const std::str flags |= kCmdCluster; else if (flag == "no-dbsize-check") flags |= kCmdNoDBSizeCheck; - else if (flag == "deny-oom") - flags |= kCmdDenyOom; - else if (flag == "movable-keys") - flags |= kCmdMovableKeys; else { std::cout << fmt::format("Encountered non-existent flag '{}' in command {} in command attribute parsing", flag, cmd_name) diff --git a/src/storage/redis_db.cc b/src/storage/redis_db.cc index fab6562b7c5..bd7b8b54cbf 100644 --- a/src/storage/redis_db.cc +++ b/src/storage/redis_db.cc @@ -35,6 +35,11 @@ #include "storage/redis_metadata.h" #include "storage/storage.h" #include "time_util.h" +#include "types/redis_hash.h" +#include "types/redis_list.h" +#include "types/redis_set.h" +#include "types/redis_string.h" +#include "types/redis_zset.h" namespace redis { @@ -775,4 +780,208 @@ rocksdb::Status Database::Copy(const std::string &key, const std::string &new_ke return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); } +std::string Database::lookupKeyByPattern(const std::string &pattern, const std::string &subst) { + if (pattern == "#") { + return subst; + } + + auto match_pos = pattern.find('*'); + if (match_pos == std::string::npos) { + return ""; + } + + // hash field + std::string field; + auto arrow_pos = pattern.find("->", match_pos + 1); + if (arrow_pos != std::string::npos && arrow_pos + 2 < pattern.size()) { + field = pattern.substr(arrow_pos + 2); + } + + std::string key = pattern.substr(0, match_pos + 1); + key.replace(match_pos, 1, subst); + + std::string value; + if (!field.empty()) { + auto hash_db = redis::Hash(storage_, namespace_); + RedisType type = RedisType::kRedisNone; + if (auto s = hash_db.Type(key, &type); !s.ok() || type >= RedisTypeNames.size()) { + return ""; + } + + hash_db.Get(key, field, &value); + } else { + auto string_db = redis::String(storage_, namespace_); + RedisType type = RedisType::kRedisNone; + if (auto s = string_db.Type(key, &type); !s.ok() || type >= RedisTypeNames.size()) { + return ""; + } + string_db.Get(key, &value); + } + return value; +} + +rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, SortArgument &args, const RESP &version, + std::vector *output_vec, SortResult *res) { + /* When sorting a set with no sort specified, we must sort the output + * so the result is consistent across scripting and replication. + * + * The other types (list, sorted set) will retain their native order + * even if no sort order is requested, so they remain stable across + * scripting and replication. + * + * TODO: support CLIENT_SCRIPT flag, (!storekey_.empty() || c->flags & CLIENT_SCRIPT)) */ + if (args.dontsort && type == RedisType::kRedisSet && (!args.storekey.empty())) { + /* Force ALPHA sorting */ + args.dontsort = false; + args.alpha = true; + args.sortby = ""; + } + + // Obtain the length of the object to sort. + const std::string ns_key = AppendNamespacePrefix(key); + Metadata metadata(type, false); + auto s = GetMetadata(GetOptions{}, {type}, ns_key, &metadata); + if (!s.ok()) { + return s; + } + + int vectorlen = (int)metadata.size; + + // Adjust the offset and count of the limit + int offset = args.offset >= vectorlen ? 0 : std::clamp(args.offset, 0, vectorlen - 1); + int count = args.offset >= vectorlen ? 0 : std::clamp(args.count, -1, vectorlen - offset); + if (count == -1) count = vectorlen - offset; + + // Get the elements that need to be sorted + std::vector str_vec; + if (count != 0) { + if (type == RedisType::kRedisList) { + auto list_db = redis::List(storage_, namespace_); + + if (args.dontsort) { + if (args.desc) { + list_db.Range(key, -count - offset, -1 - offset, &str_vec); + std::reverse(str_vec.begin(), str_vec.end()); + } else { + list_db.Range(key, offset, offset + count - 1, &str_vec); + } + } else { + list_db.Range(key, 0, -1, &str_vec); + } + } else if (type == RedisType::kRedisSet) { + auto set_db = redis::Set(storage_, namespace_); + set_db.Members(key, &str_vec); + + if (args.dontsort) { + str_vec = std::vector(str_vec.begin() + offset, str_vec.begin() + offset + count); + } + } else if (type == RedisType::kRedisZSet) { + auto zset_db = redis::ZSet(storage_, namespace_); + std::vector member_scores; + + if (args.dontsort) { + RangeRankSpec spec; + spec.start = offset; + spec.stop = offset + count - 1; + spec.reversed = args.desc; + zset_db.RangeByRank(key, spec, &member_scores, nullptr); + + for (auto &member_score : member_scores) { + str_vec.emplace_back(member_score.member); + } + } else { + zset_db.GetAllMemberScores(key, &member_scores); + + for (auto &member_score : member_scores) { + str_vec.emplace_back(member_score.member); + } + } + } else { + *res = SortResult::UNKNOW_TYPE; + return s; + } + } + + std::vector sort_vec(str_vec.size()); + for (size_t i = 0; i < str_vec.size(); ++i) { + sort_vec[i].obj = str_vec[i]; + } + + // Sort by BY, ALPHA, ASC/DESC + if (!args.dontsort) { + for (size_t i = 0; i < sort_vec.size(); ++i) { + std::string byval; + if (!args.sortby.empty()) { + byval = lookupKeyByPattern(args.sortby, str_vec[i]); + if (byval.empty()) continue; + } else { + byval = str_vec[i]; + } + + if (args.alpha) { + if (!args.sortby.empty()) { + sort_vec[i].v = byval; + } + } else { + try { + sort_vec[i].v = std::stod(byval); + } catch (const std::exception &e) { + *res = SortResult::DOUBLE_CONVERT_ERROR; + return rocksdb::Status::OK(); + } + } + } + + std::sort(sort_vec.begin(), sort_vec.end(), + [args](const RedisSortObject &a, const RedisSortObject &b) { return SortCompare(a, b, args); }); + + // Gets the element specified by Limit + if (offset != 0 || count != vectorlen) { + sort_vec = std::vector(sort_vec.begin() + offset, sort_vec.begin() + offset + count); + } + } + + // Get the output + for (auto &elem : sort_vec) { + if (args.getpatterns.empty()) { + output_vec->emplace_back(elem.obj); + } + for (const std::string &pattern : args.getpatterns) { + std::string val = lookupKeyByPattern(pattern, elem.obj); + if (val.empty()) { + output_vec->emplace_back(redis::NilString(version)); + } else { + output_vec->emplace_back(val); + } + } + } + + // Perform storage + if (!args.storekey.empty()) { + redis::List list_db(storage_, namespace_); + std::vector elems(output_vec->begin(), output_vec->end()); + list_db.Trim(args.storekey, 0, -1); + uint64_t new_size = 0; + list_db.Push(args.storekey, elems, false, &new_size); + } + + return rocksdb::Status::OK(); +} + +bool SortCompare(const RedisSortObject &a, const RedisSortObject &b, const SortArgument &args) { + if (!args.alpha) { + double score_a = std::get(a.v); + double score_b = std::get(b.v); + return !args.desc ? score_a < score_b : score_a > score_b; + } else { + if (!args.sortby.empty()) { + std::string cmp_a = std::get(a.v); + std::string cmp_b = std::get(b.v); + return !args.desc ? cmp_a < cmp_b : cmp_a > cmp_b; + } else { + return !args.desc ? a.obj < b.obj : a.obj > b.obj; + } + } +} + } // namespace redis diff --git a/src/storage/redis_db.h b/src/storage/redis_db.h index 31de41dc668..8b0f0e04bb9 100644 --- a/src/storage/redis_db.h +++ b/src/storage/redis_db.h @@ -23,13 +23,33 @@ #include #include #include +#include #include #include "redis_metadata.h" +#include "server/redis_reply.h" #include "storage.h" namespace redis { +struct SortArgument { + std::string sortby; // BY + bool dontsort = false; // DONT SORT + int offset = 0; // LIMIT OFFSET + int count = -1; // LIMIT COUNT + std::vector getpatterns; // GET + bool desc = false; // ASC/DESC + bool alpha = false; // ALPHA + std::string storekey; // STORE +}; + +struct RedisSortObject { + std::string obj; + std::variant v; +}; + +bool SortCompare(const RedisSortObject &a, const RedisSortObject &b, const SortArgument &args); + /// Database is a wrapper of underlying storage engine, it provides /// some common operations for redis commands. class Database { @@ -106,6 +126,9 @@ class Database { enum class CopyResult { KEY_NOT_EXIST, KEY_ALREADY_EXIST, DONE }; [[nodiscard]] rocksdb::Status Copy(const std::string &key, const std::string &new_key, bool nx, bool delete_old, CopyResult *res); + enum class SortResult { UNKNOW_TYPE, DOUBLE_CONVERT_ERROR, DONE }; + [[nodiscard]] rocksdb::Status Sort(const RedisType &type, const std::string &key, SortArgument &args, + const RESP &version, std::vector *output_vec, SortResult *res); protected: engine::Storage *storage_; @@ -118,6 +141,8 @@ class Database { // Already internal keys [[nodiscard]] rocksdb::Status existsInternal(const std::vector &keys, int *ret); [[nodiscard]] rocksdb::Status typeInternal(const Slice &key, RedisType *type); + // Sort helper + std::string lookupKeyByPattern(const std::string &pattern, const std::string &subst); }; class LatestSnapShot { public: From 6ce9dfc883a77daeb9a0597442ac61c90448eaf9 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Sun, 21 Apr 2024 15:11:59 +0800 Subject: [PATCH 11/23] fix: wrong typo "unknown" --- src/commands/cmd_key.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/commands/cmd_key.cc b/src/commands/cmd_key.cc index b4bf9e1e65c..43887d808ef 100644 --- a/src/commands/cmd_key.cc +++ b/src/commands/cmd_key.cc @@ -508,7 +508,7 @@ class CommandSort : public Commander { switch (res) { case Database::SortResult::UNKNOW_TYPE: - *output = redis::Error("Unkown Type"); + *output = redis::Error("Unknown Type"); break; case Database::SortResult::DOUBLE_CONVERT_ERROR: *output = redis::Error("One or more scores can't be converted into double"); From bbd254932e5cea3f1d24fe23e84b58f08c071bd5 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Sun, 21 Apr 2024 15:28:52 +0800 Subject: [PATCH 12/23] fix: SortResult --- src/commands/cmd_key.cc | 2 +- src/storage/redis_db.cc | 2 +- src/storage/redis_db.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/commands/cmd_key.cc b/src/commands/cmd_key.cc index 43887d808ef..f4fcf750800 100644 --- a/src/commands/cmd_key.cc +++ b/src/commands/cmd_key.cc @@ -507,7 +507,7 @@ class CommandSort : public Commander { } switch (res) { - case Database::SortResult::UNKNOW_TYPE: + case Database::SortResult::UNKNOWN_TYPE: *output = redis::Error("Unknown Type"); break; case Database::SortResult::DOUBLE_CONVERT_ERROR: diff --git a/src/storage/redis_db.cc b/src/storage/redis_db.cc index d78ea1ecb71..8ead0739ea4 100644 --- a/src/storage/redis_db.cc +++ b/src/storage/redis_db.cc @@ -899,7 +899,7 @@ rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, So } } } else { - *res = SortResult::UNKNOW_TYPE; + *res = SortResult::UNKNOWN_TYPE; return s; } } diff --git a/src/storage/redis_db.h b/src/storage/redis_db.h index 77cd29880c4..05a617490ed 100644 --- a/src/storage/redis_db.h +++ b/src/storage/redis_db.h @@ -127,7 +127,7 @@ class Database { enum class CopyResult { KEY_NOT_EXIST, KEY_ALREADY_EXIST, DONE }; [[nodiscard]] rocksdb::Status Copy(const std::string &key, const std::string &new_key, bool nx, bool delete_old, CopyResult *res); - enum class SortResult { UNKNOW_TYPE, DOUBLE_CONVERT_ERROR, DONE }; + enum class SortResult { UNKNOWN_TYPE, DOUBLE_CONVERT_ERROR, DONE }; [[nodiscard]] rocksdb::Status Sort(const RedisType &type, const std::string &key, SortArgument &args, const RESP &version, std::vector *output_vec, SortResult *res); From 84d4a02e36950dc743d984c7296b2e824696444a Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Mon, 22 Apr 2024 15:13:52 +0800 Subject: [PATCH 13/23] docs: add doc strings for SortCompare --- src/storage/redis_db.h | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/storage/redis_db.h b/src/storage/redis_db.h index 05a617490ed..77057b30aed 100644 --- a/src/storage/redis_db.h +++ b/src/storage/redis_db.h @@ -48,6 +48,17 @@ struct RedisSortObject { std::variant v; }; +/// SortCompare is a helper function that enables `RedisSortObject` to be sorted based on `SortArument`. +/// +/// It can assist in implementing the third parameter `Compare comp` required by `std::sort` +/// +/// \param args The basis used to compare two RedisSortObjects. +/// If `args.alpha` is false, `RedisSortObject.v` will be taken as double for comparison +/// If `args.alpha` is true and `args.sortby` is not empty, `RedisSortObject.v` will be taken as string for comparison +/// If `args.alpha` is true and `args.sortby` is empty, the comparison is by `RedisSortObject.obj`. +/// +/// \return If `desc` is false, returns true when `a < b`, otherwise returns true when `a > b` + bool SortCompare(const RedisSortObject &a, const RedisSortObject &b, const SortArgument &args); /// Database is a wrapper of underlying storage engine, it provides From cdba915939faab9d1510aa40269208efe48483b8 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Mon, 22 Apr 2024 16:14:28 +0800 Subject: [PATCH 14/23] docs: fix typo --- src/storage/redis_db.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/storage/redis_db.h b/src/storage/redis_db.h index 77057b30aed..694642b395b 100644 --- a/src/storage/redis_db.h +++ b/src/storage/redis_db.h @@ -48,7 +48,7 @@ struct RedisSortObject { std::variant v; }; -/// SortCompare is a helper function that enables `RedisSortObject` to be sorted based on `SortArument`. +/// SortCompare is a helper function that enables `RedisSortObject` to be sorted based on `SortArgument`. /// /// It can assist in implementing the third parameter `Compare comp` required by `std::sort` /// From 6b8880db9f350f5a351eec49061ba2083d6c6730 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Wed, 24 Apr 2024 17:54:42 +0800 Subject: [PATCH 15/23] refactor: refactor the code based on review suggestions --- src/commands/cmd_key.cc | 39 +++++++++++---------- src/storage/redis_db.cc | 53 +++++++++++------------------ src/storage/redis_db.h | 44 +++++++++++++++--------- tests/gocase/unit/sort/sort_test.go | 8 ++--- 4 files changed, 72 insertions(+), 72 deletions(-) diff --git a/src/commands/cmd_key.cc b/src/commands/cmd_key.cc index f4fcf750800..b94a10a5dd9 100644 --- a/src/commands/cmd_key.cc +++ b/src/commands/cmd_key.cc @@ -431,9 +431,6 @@ class CommandSort : public Commander { CommandParser parser(args, 2); while (parser.Good()) { if (parser.EatEqICase("BY")) { - if (parser.Remains() < 1) { - return parser.InvalidSyntax(); - } sort_argument_.sortby = GET_OR_RET(parser.TakeStr()); if (sort_argument_.sortby.find('*') == std::string::npos) { @@ -447,15 +444,9 @@ class CommandSort : public Commander { * it if no full ACL key access is applied for this command. */ } } else if (parser.EatEqICase("LIMIT")) { - if (parser.Remains() < 2) { - return parser.InvalidSyntax(); - } sort_argument_.offset = GET_OR_RET(parser.template TakeInt()); sort_argument_.count = GET_OR_RET(parser.template TakeInt()); - } else if (parser.EatEqICase("GET") && parser.Remains() >= 1) { - if (parser.Remains() < 1) { - return parser.InvalidSyntax(); - } + } else if (parser.EatEqICase("GET")) { /* TODO: * If GET is specified with a real pattern, we can't accept it in cluster mode, * unless we can make sure the keys formed by the pattern are in the same slot @@ -469,10 +460,7 @@ class CommandSort : public Commander { sort_argument_.alpha = true; } else if (parser.EatEqICase("STORE")) { if constexpr (ReadOnly) { - return parser.InvalidSyntax(); - } - if (parser.Remains() < 1) { - return parser.InvalidSyntax(); + return {Status::RedisParseErr, "SORT_RO is read-only and does not support the STORE parameter"}; } sort_argument_.storekey = GET_OR_RET(parser.TakeStr()); } else { @@ -498,9 +486,24 @@ class CommandSort : public Commander { return {Status::RedisExecErr, s.ToString()}; } - std::vector output_vec; + /* When sorting a set with no sort specified, we must sort the output + * so the result is consistent across scripting and replication. + * + * The other types (list, sorted set) will retain their native order + * even if no sort order is requested, so they remain stable across + * scripting and replication. + * + * TODO: support CLIENT_SCRIPT flag, (!storekey_.empty() || c->flags & CLIENT_SCRIPT)) */ + if (sort_argument_.dontsort && type == RedisType::kRedisSet && (!sort_argument_.storekey.empty())) { + /* Force ALPHA sorting */ + sort_argument_.dontsort = false; + sort_argument_.alpha = true; + sort_argument_.sortby = ""; + } + + std::vector sorted_elems; Database::SortResult res = Database::SortResult::DONE; - s = redis.Sort(type, args_[1], sort_argument_, conn->GetProtocolVersion(), &output_vec, &res); + s = redis.Sort(type, args_[1], sort_argument_, conn->GetProtocolVersion(), &sorted_elems, &res); if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; @@ -515,9 +518,9 @@ class CommandSort : public Commander { break; case Database::SortResult::DONE: if (sort_argument_.storekey.empty()) { - *output = ArrayOfBulkStrings(output_vec); + *output = ArrayOfBulkStrings(sorted_elems); } else { - *output = Integer(output_vec.size()); + *output = Integer(sorted_elems.size()); } break; } diff --git a/src/storage/redis_db.cc b/src/storage/redis_db.cc index 8ead0739ea4..e6a6c4f5663 100644 --- a/src/storage/redis_db.cc +++ b/src/storage/redis_db.cc @@ -822,23 +822,8 @@ std::string Database::lookupKeyByPattern(const std::string &pattern, const std:: return value; } -rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, SortArgument &args, const RESP &version, - std::vector *output_vec, SortResult *res) { - /* When sorting a set with no sort specified, we must sort the output - * so the result is consistent across scripting and replication. - * - * The other types (list, sorted set) will retain their native order - * even if no sort order is requested, so they remain stable across - * scripting and replication. - * - * TODO: support CLIENT_SCRIPT flag, (!storekey_.empty() || c->flags & CLIENT_SCRIPT)) */ - if (args.dontsort && type == RedisType::kRedisSet && (!args.storekey.empty())) { - /* Force ALPHA sorting */ - args.dontsort = false; - args.alpha = true; - args.sortby = ""; - } - +rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, const SortArgument &args, + const RESP &version, std::vector *elems, SortResult *res) { // Obtain the length of the object to sort. const std::string ns_key = AppendNamespacePrefix(key); Metadata metadata(type, false); @@ -847,7 +832,7 @@ rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, So return s; } - int vectorlen = (int)metadata.size; + int vectorlen = static_cast(metadata.size); // Adjust the offset and count of the limit int offset = args.offset >= vectorlen ? 0 : std::clamp(args.offset, 0, vectorlen - 1); @@ -889,13 +874,13 @@ rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, So zset_db.RangeByRank(key, spec, &member_scores, nullptr); for (auto &member_score : member_scores) { - str_vec.emplace_back(member_score.member); + str_vec.emplace_back(std::move(member_score.member)); } } else { zset_db.GetAllMemberScores(key, &member_scores); for (auto &member_score : member_scores) { - str_vec.emplace_back(member_score.member); + str_vec.emplace_back(std::move(member_score.member)); } } } else { @@ -925,17 +910,18 @@ rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, So sort_vec[i].v = byval; } } else { - try { - sort_vec[i].v = std::stod(byval); - } catch (const std::exception &e) { + auto double_byval = ParseFloat(byval); + if (!double_byval) { *res = SortResult::DOUBLE_CONVERT_ERROR; - return rocksdb::Status::OK(); + } else { + sort_vec[i].v = *double_byval; } } } - std::sort(sort_vec.begin(), sort_vec.end(), - [args](const RedisSortObject &a, const RedisSortObject &b) { return SortCompare(a, b, args); }); + std::sort(sort_vec.begin(), sort_vec.end(), [args](const RedisSortObject &a, const RedisSortObject &b) { + return RedisSortObject::SortCompare(a, b, args); + }); // Gets the element specified by Limit if (offset != 0 || count != vectorlen) { @@ -943,34 +929,33 @@ rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, So } } - // Get the output + // Perform storage for (auto &elem : sort_vec) { if (args.getpatterns.empty()) { - output_vec->emplace_back(elem.obj); + elems->emplace_back(elem.obj); } for (const std::string &pattern : args.getpatterns) { std::string val = lookupKeyByPattern(pattern, elem.obj); if (val.empty()) { - output_vec->emplace_back(redis::NilString(version)); + elems->emplace_back(redis::NilString(version)); } else { - output_vec->emplace_back(val); + elems->emplace_back(val); } } } - // Perform storage if (!args.storekey.empty()) { + std::vector store_elems(elems->begin(), elems->end()); redis::List list_db(storage_, namespace_); - std::vector elems(output_vec->begin(), output_vec->end()); list_db.Trim(args.storekey, 0, -1); uint64_t new_size = 0; - list_db.Push(args.storekey, elems, false, &new_size); + list_db.Push(args.storekey, store_elems, false, &new_size); } return rocksdb::Status::OK(); } -bool SortCompare(const RedisSortObject &a, const RedisSortObject &b, const SortArgument &args) { +bool RedisSortObject::SortCompare(const RedisSortObject &a, const RedisSortObject &b, const SortArgument &args) { if (!args.alpha) { double score_a = std::get(a.v); double score_b = std::get(b.v); diff --git a/src/storage/redis_db.h b/src/storage/redis_db.h index 694642b395b..ee87f3baf11 100644 --- a/src/storage/redis_db.h +++ b/src/storage/redis_db.h @@ -46,20 +46,19 @@ struct SortArgument { struct RedisSortObject { std::string obj; std::variant v; -}; - -/// SortCompare is a helper function that enables `RedisSortObject` to be sorted based on `SortArgument`. -/// -/// It can assist in implementing the third parameter `Compare comp` required by `std::sort` -/// -/// \param args The basis used to compare two RedisSortObjects. -/// If `args.alpha` is false, `RedisSortObject.v` will be taken as double for comparison -/// If `args.alpha` is true and `args.sortby` is not empty, `RedisSortObject.v` will be taken as string for comparison -/// If `args.alpha` is true and `args.sortby` is empty, the comparison is by `RedisSortObject.obj`. -/// -/// \return If `desc` is false, returns true when `a < b`, otherwise returns true when `a > b` -bool SortCompare(const RedisSortObject &a, const RedisSortObject &b, const SortArgument &args); + /// SortCompare is a helper function that enables `RedisSortObject` to be sorted based on `SortArgument`. + /// + /// It can assist in implementing the third parameter `Compare comp` required by `std::sort` + /// + /// \param args The basis used to compare two RedisSortObjects. + /// If `args.alpha` is false, `RedisSortObject.v` will be taken as double for comparison + /// If `args.alpha` is true and `args.sortby` is not empty, `RedisSortObject.v` will be taken as string for comparison + /// If `args.alpha` is true and `args.sortby` is empty, the comparison is by `RedisSortObject.obj`. + /// + /// \return If `desc` is false, returns true when `a < b`, otherwise returns true when `a > b` + static bool SortCompare(const RedisSortObject &a, const RedisSortObject &b, const SortArgument &args); +}; /// Database is a wrapper of underlying storage engine, it provides /// some common operations for redis commands. @@ -139,8 +138,8 @@ class Database { [[nodiscard]] rocksdb::Status Copy(const std::string &key, const std::string &new_key, bool nx, bool delete_old, CopyResult *res); enum class SortResult { UNKNOWN_TYPE, DOUBLE_CONVERT_ERROR, DONE }; - [[nodiscard]] rocksdb::Status Sort(const RedisType &type, const std::string &key, SortArgument &args, - const RESP &version, std::vector *output_vec, SortResult *res); + [[nodiscard]] rocksdb::Status Sort(const RedisType &type, const std::string &key, const SortArgument &args, + const RESP &version, std::vector *elems, SortResult *res); protected: engine::Storage *storage_; @@ -153,7 +152,20 @@ class Database { // Already internal keys [[nodiscard]] rocksdb::Status existsInternal(const std::vector &keys, int *ret); [[nodiscard]] rocksdb::Status typeInternal(const Slice &key, RedisType *type); - // Sort helper + + /// lookupKeyByPattern is a helper function of `Sort` to support `GET` and `BY` fields. + /// + /// \param pattern can be the value of a `BY` or `GET` field + /// \param subst is used to replace the "*" or "#" matched in the pattern string. + /// \return Return the value associated to the key with a name obtained using the following rules: + /// 1) The first occurrence of '*' in 'pattern' is substituted with 'subst'. + /// 2) If 'pattern' matches the "->" string, everything on the left of + /// the arrow is treated as the name of a hash field, and the part on the + /// left as the key name containing a hash. The value of the specified + /// field is returned. + /// 3) If 'pattern' equals "#", the function simply returns 'subst' itself so + /// that the SORT command can be used like: SORT key GET # to retrieve + /// the Set/List elements directly. std::string lookupKeyByPattern(const std::string &pattern, const std::string &subst); }; class LatestSnapShot { diff --git a/tests/gocase/unit/sort/sort_test.go b/tests/gocase/unit/sort/sort_test.go index 3bbd65a2abc..3b45c445a06 100644 --- a/tests/gocase/unit/sort/sort_test.go +++ b/tests/gocase/unit/sort/sort_test.go @@ -47,19 +47,19 @@ func TestSortParser(t *testing.T) { require.EqualError(t, err, "ERR syntax error") _, err = rdb.Do(ctx, "Sort", "bad-case-key", "LIMIT").Result() - require.EqualError(t, err, "ERR syntax error") + require.EqualError(t, err, "ERR no more item to parse") _, err = rdb.Do(ctx, "Sort", "bad-case-key", "LIMIT", 1).Result() - require.EqualError(t, err, "ERR syntax error") + require.EqualError(t, err, "ERR no more item to parse") _, err = rdb.Do(ctx, "Sort", "bad-case-key", "LIMIT", 1, "not-number").Result() require.EqualError(t, err, "ERR not started as an integer") _, err = rdb.Do(ctx, "Sort", "bad-case-key", "STORE").Result() - require.EqualError(t, err, "ERR syntax error") + require.EqualError(t, err, "ERR no more item to parse") _, err = rdb.Do(ctx, "Sort_RO", "bad-case-key", "STORE", "store_ro_key").Result() - require.EqualError(t, err, "ERR syntax error") + require.EqualError(t, err, "ERR SORT_RO is read-only and does not support the STORE parameter") }) } From 75b3e4c28a0f49e8274d33961cdeb8a0257cae13 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Fri, 26 Apr 2024 12:06:36 +0800 Subject: [PATCH 16/23] fix: sort in case of get empty and add test --- src/commands/cmd_key.cc | 8 +++- src/storage/redis_db.cc | 4 +- src/storage/redis_db.h | 2 +- tests/gocase/unit/sort/sort_test.go | 59 +++++++++++++++++++++++++++++ 4 files changed, 68 insertions(+), 5 deletions(-) diff --git a/src/commands/cmd_key.cc b/src/commands/cmd_key.cc index b94a10a5dd9..55284ba7095 100644 --- a/src/commands/cmd_key.cc +++ b/src/commands/cmd_key.cc @@ -503,7 +503,7 @@ class CommandSort : public Commander { std::vector sorted_elems; Database::SortResult res = Database::SortResult::DONE; - s = redis.Sort(type, args_[1], sort_argument_, conn->GetProtocolVersion(), &sorted_elems, &res); + s = redis.Sort(type, args_[1], sort_argument_, &sorted_elems, &res); if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; @@ -518,7 +518,11 @@ class CommandSort : public Commander { break; case Database::SortResult::DONE: if (sort_argument_.storekey.empty()) { - *output = ArrayOfBulkStrings(sorted_elems); + std::vector output_vec; + for (const auto &elem : sorted_elems) { + output_vec.emplace_back(elem.empty() ? conn->NilString() : redis::BulkString(elem)); + } + *output = redis::Array(output_vec); } else { *output = Integer(sorted_elems.size()); } diff --git a/src/storage/redis_db.cc b/src/storage/redis_db.cc index e6a6c4f5663..c3290716d3d 100644 --- a/src/storage/redis_db.cc +++ b/src/storage/redis_db.cc @@ -823,7 +823,7 @@ std::string Database::lookupKeyByPattern(const std::string &pattern, const std:: } rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, const SortArgument &args, - const RESP &version, std::vector *elems, SortResult *res) { + std::vector *elems, SortResult *res) { // Obtain the length of the object to sort. const std::string ns_key = AppendNamespacePrefix(key); Metadata metadata(type, false); @@ -937,7 +937,7 @@ rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, co for (const std::string &pattern : args.getpatterns) { std::string val = lookupKeyByPattern(pattern, elem.obj); if (val.empty()) { - elems->emplace_back(redis::NilString(version)); + elems->emplace_back(""); } else { elems->emplace_back(val); } diff --git a/src/storage/redis_db.h b/src/storage/redis_db.h index ee87f3baf11..e8fef670d20 100644 --- a/src/storage/redis_db.h +++ b/src/storage/redis_db.h @@ -139,7 +139,7 @@ class Database { CopyResult *res); enum class SortResult { UNKNOWN_TYPE, DOUBLE_CONVERT_ERROR, DONE }; [[nodiscard]] rocksdb::Status Sort(const RedisType &type, const std::string &key, const SortArgument &args, - const RESP &version, std::vector *elems, SortResult *res); + std::vector *elems, SortResult *res); protected: engine::Storage *storage_; diff --git a/tests/gocase/unit/sort/sort_test.go b/tests/gocase/unit/sort/sort_test.go index 3b45c445a06..c156b7ceeaf 100644 --- a/tests/gocase/unit/sort/sort_test.go +++ b/tests/gocase/unit/sort/sort_test.go @@ -245,6 +245,12 @@ func TestListSort(t *testing.T) { sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_info_*->level", Get: []string{"user_info_*->name"}}).Result() require.NoError(t, err) require.Equal(t, []string{"jack", "peter", "mary", "admin"}, sortResult) + + // get empty + rdb.LPush(ctx, "uid_get_empty", 4, 5, 6, 7) + getResult, err := rdb.Do(ctx, "Sort", "uid_get_empty", "Get", "user_name_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"mary", nil, nil, nil}, getResult) }) t.Run("SORT STORE", func(t *testing.T) { @@ -266,6 +272,17 @@ func TestListSort(t *testing.T) { sortResult, err = rdb.LRange(ctx, "no-alpha-sorted", 0, -1).Result() require.NoError(t, err) require.Equal(t, []string{"21", "3", "123"}, sortResult) + + // get empty + rdb.LPush(ctx, "uid_get_empty_store", 4, 5, 6, 7) + rdb.MSet(ctx, "user_name_1", "admin", "user_name_2", "jack", "user_name_3", "peter", "user_name_4", "mary") + storedLen, err = rdb.Do(ctx, "Sort", "uid_get_empty_store", "Get", "user_name_*", "STORE", "get_empty_store").Result() + require.NoError(t, err) + require.Equal(t, int64(4), storedLen) + + sortResult, err = rdb.LRange(ctx, "get_empty_store", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"mary", "", "", ""}, sortResult) }) } @@ -451,6 +468,12 @@ func TestSetSort(t *testing.T) { sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_info_*->level", Get: []string{"user_info_*->name"}}).Result() require.NoError(t, err) require.Equal(t, []string{"jack", "peter", "mary", "admin"}, sortResult) + + // get empty + rdb.SAdd(ctx, "uid_get_empty", 4, 5, 6, 7) + getResult, err := rdb.Do(ctx, "Sort", "uid_get_empty", "Get", "user_name_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"mary", nil, nil, nil}, getResult) }) t.Run("SORT STORE", func(t *testing.T) { @@ -472,6 +495,17 @@ func TestSetSort(t *testing.T) { sortResult, err = rdb.LRange(ctx, "alpha-sorted", 0, -1).Result() require.NoError(t, err) require.Equal(t, []string{"123", "21", "3"}, sortResult) + + // get empty + rdb.SAdd(ctx, "uid_get_empty_store", 4, 5, 6, 7) + rdb.MSet(ctx, "user_name_1", "admin", "user_name_2", "jack", "user_name_3", "peter", "user_name_4", "mary") + storedLen, err = rdb.Do(ctx, "Sort", "uid_get_empty_store", "Get", "user_name_*", "STORE", "get_empty_store").Result() + require.NoError(t, err) + require.Equal(t, int64(4), storedLen) + + sortResult, err = rdb.LRange(ctx, "get_empty_store", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"mary", "", "", ""}, sortResult) }) } @@ -673,6 +707,16 @@ func TestZSetSort(t *testing.T) { sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_info_*->level", Get: []string{"user_info_*->name"}}).Result() require.NoError(t, err) require.Equal(t, []string{"jack", "peter", "mary", "admin"}, sortResult) + + // get empty + rdb.ZAdd(ctx, "uid_get_empty", + redis.Z{Score: 4, Member: "7"}, + redis.Z{Score: 5, Member: "6"}, + redis.Z{Score: 6, Member: "5"}, + redis.Z{Score: 7, Member: "4"}) + getResult, err := rdb.Do(ctx, "Sort", "uid_get_empty", "Get", "user_name_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"mary", nil, nil, nil}, getResult) }) t.Run("SORT STORE", func(t *testing.T) { @@ -710,5 +754,20 @@ func TestZSetSort(t *testing.T) { sortResult, err = rdb.LRange(ctx, "no-alpha-sorted", 0, -1).Result() require.NoError(t, err) require.Equal(t, []string{"123", "3", "21"}, sortResult) + + // get empty + rdb.ZAdd(ctx, "uid_get_empty_store", + redis.Z{Score: 4, Member: "7"}, + redis.Z{Score: 5, Member: "6"}, + redis.Z{Score: 6, Member: "5"}, + redis.Z{Score: 7, Member: "4"}) + rdb.MSet(ctx, "user_name_1", "admin", "user_name_2", "jack", "user_name_3", "peter", "user_name_4", "mary") + storedLen, err = rdb.Do(ctx, "Sort", "uid_get_empty_store", "Get", "user_name_*", "STORE", "get_empty_store").Result() + require.NoError(t, err) + require.Equal(t, int64(4), storedLen) + + sortResult, err = rdb.LRange(ctx, "get_empty_store", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"mary", "", "", ""}, sortResult) }) } From b2e728b1e0eac85a625346c15b81ed0b9fc1a4d3 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Fri, 26 Apr 2024 12:25:15 +0800 Subject: [PATCH 17/23] fix: clang-tidy --- src/commands/cmd_key.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/commands/cmd_key.cc b/src/commands/cmd_key.cc index 55284ba7095..3e0d817ab27 100644 --- a/src/commands/cmd_key.cc +++ b/src/commands/cmd_key.cc @@ -519,6 +519,7 @@ class CommandSort : public Commander { case Database::SortResult::DONE: if (sort_argument_.storekey.empty()) { std::vector output_vec; + output_vec.reverse(sorted_elems.size()); for (const auto &elem : sorted_elems) { output_vec.emplace_back(elem.empty() ? conn->NilString() : redis::BulkString(elem)); } From 866fe85a9e4490e16c17805c91b92486a08e9a9c Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Fri, 26 Apr 2024 12:27:59 +0800 Subject: [PATCH 18/23] fix: clang-tidy --- src/commands/cmd_key.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/commands/cmd_key.cc b/src/commands/cmd_key.cc index 3e0d817ab27..b37c486765d 100644 --- a/src/commands/cmd_key.cc +++ b/src/commands/cmd_key.cc @@ -519,7 +519,7 @@ class CommandSort : public Commander { case Database::SortResult::DONE: if (sort_argument_.storekey.empty()) { std::vector output_vec; - output_vec.reverse(sorted_elems.size()); + output_vec.reserve(sorted_elems.size()); for (const auto &elem : sorted_elems) { output_vec.emplace_back(elem.empty() ? conn->NilString() : redis::BulkString(elem)); } From 1b6d115a288d488ba903ca3e65aebfd90c2557cc Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Sat, 27 Apr 2024 20:50:19 +0800 Subject: [PATCH 19/23] fix: distinguish between nil and empty string --- src/commands/cmd_key.cc | 22 ++-- src/storage/redis_db.cc | 50 ++++----- src/storage/redis_db.h | 5 +- tests/gocase/unit/sort/sort_test.go | 155 ++++++++++++++++++++-------- 4 files changed, 152 insertions(+), 80 deletions(-) diff --git a/src/commands/cmd_key.cc b/src/commands/cmd_key.cc index b37c486765d..5717ee0b14f 100644 --- a/src/commands/cmd_key.cc +++ b/src/commands/cmd_key.cc @@ -474,18 +474,15 @@ class CommandSort : public Commander { Status Execute(Server *srv, Connection *conn, std::string *output) override { redis::Database redis(srv->storage, conn->GetNamespace()); RedisType type = kRedisNone; - auto s = redis.Type(args_[1], &type); - if (s.ok()) { - if (type >= RedisTypeNames.size()) { - return {Status::RedisExecErr, "Invalid type"}; - } else if (type != RedisType::kRedisList && type != RedisType::kRedisSet && type != RedisType::kRedisZSet) { - *output = Error("WRONGTYPE Operation against a key holding the wrong kind of value"); - return Status::OK(); - } - } else { + if (auto s = redis.Type(args_[1], &type); !s.ok()) { return {Status::RedisExecErr, s.ToString()}; } + if (type != RedisType::kRedisList && type != RedisType::kRedisSet && type != RedisType::kRedisZSet) { + *output = Error("WRONGTYPE Operation against a key holding the wrong kind of value"); + return Status::OK(); + } + /* When sorting a set with no sort specified, we must sort the output * so the result is consistent across scripting and replication. * @@ -501,11 +498,10 @@ class CommandSort : public Commander { sort_argument_.sortby = ""; } - std::vector sorted_elems; + std::vector> sorted_elems; Database::SortResult res = Database::SortResult::DONE; - s = redis.Sort(type, args_[1], sort_argument_, &sorted_elems, &res); - if (!s.ok()) { + if (auto s = redis.Sort(type, args_[1], sort_argument_, &sorted_elems, &res); !s.ok()) { return {Status::RedisExecErr, s.ToString()}; } @@ -521,7 +517,7 @@ class CommandSort : public Commander { std::vector output_vec; output_vec.reserve(sorted_elems.size()); for (const auto &elem : sorted_elems) { - output_vec.emplace_back(elem.empty() ? conn->NilString() : redis::BulkString(elem)); + output_vec.emplace_back(elem.has_value() ? redis::BulkString(elem.value()) : conn->NilString()); } *output = redis::Array(output_vec); } else { diff --git a/src/storage/redis_db.cc b/src/storage/redis_db.cc index c3290716d3d..ddaa5341f40 100644 --- a/src/storage/redis_db.cc +++ b/src/storage/redis_db.cc @@ -782,14 +782,14 @@ rocksdb::Status Database::Copy(const std::string &key, const std::string &new_ke return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); } -std::string Database::lookupKeyByPattern(const std::string &pattern, const std::string &subst) { +std::optional Database::lookupKeyByPattern(const std::string &pattern, const std::string &subst) { if (pattern == "#") { return subst; } auto match_pos = pattern.find('*'); if (match_pos == std::string::npos) { - return ""; + return std::nullopt; } // hash field @@ -806,16 +806,16 @@ std::string Database::lookupKeyByPattern(const std::string &pattern, const std:: if (!field.empty()) { auto hash_db = redis::Hash(storage_, namespace_); RedisType type = RedisType::kRedisNone; - if (auto s = hash_db.Type(key, &type); !s.ok() || type >= RedisTypeNames.size()) { - return ""; + if (auto s = hash_db.Type(key, &type); !s.ok() || type != RedisType::kRedisHash) { + return std::nullopt; } hash_db.Get(key, field, &value); } else { auto string_db = redis::String(storage_, namespace_); RedisType type = RedisType::kRedisNone; - if (auto s = string_db.Type(key, &type); !s.ok() || type >= RedisTypeNames.size()) { - return ""; + if (auto s = string_db.Type(key, &type); !s.ok() || type != RedisType::kRedisString) { + return std::nullopt; } string_db.Get(key, &value); } @@ -823,7 +823,7 @@ std::string Database::lookupKeyByPattern(const std::string &pattern, const std:: } rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, const SortArgument &args, - std::vector *elems, SortResult *res) { + std::vector> *elems, SortResult *res) { // Obtain the length of the object to sort. const std::string ns_key = AppendNamespacePrefix(key); Metadata metadata(type, false); @@ -897,25 +897,23 @@ rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, co // Sort by BY, ALPHA, ASC/DESC if (!args.dontsort) { for (size_t i = 0; i < sort_vec.size(); ++i) { - std::string byval; + std::optional byval; if (!args.sortby.empty()) { byval = lookupKeyByPattern(args.sortby, str_vec[i]); - if (byval.empty()) continue; + if (!byval.has_value()) continue; } else { byval = str_vec[i]; } - if (args.alpha) { - if (!args.sortby.empty()) { - sort_vec[i].v = byval; - } - } else { - auto double_byval = ParseFloat(byval); + if (args.alpha && !args.sortby.empty()) { + sort_vec[i].v = byval.value(); + } else if (!args.alpha && !byval.value().empty()) { + auto double_byval = ParseFloat(byval.value()); if (!double_byval) { *res = SortResult::DOUBLE_CONVERT_ERROR; - } else { - sort_vec[i].v = *double_byval; + return rocksdb::Status::OK(); } + sort_vec[i].v = *double_byval; } } @@ -935,21 +933,25 @@ rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, co elems->emplace_back(elem.obj); } for (const std::string &pattern : args.getpatterns) { - std::string val = lookupKeyByPattern(pattern, elem.obj); - if (val.empty()) { - elems->emplace_back(""); + std::optional val = lookupKeyByPattern(pattern, elem.obj); + if (val.has_value()) { + elems->emplace_back(val.value()); } else { - elems->emplace_back(val); + elems->emplace_back(std::nullopt); } } } if (!args.storekey.empty()) { - std::vector store_elems(elems->begin(), elems->end()); + std::vector store_elems; + store_elems.reserve(elems->size()); + for (const auto &e : *elems) { + store_elems.emplace_back(e.value_or("")); + } redis::List list_db(storage_, namespace_); - list_db.Trim(args.storekey, 0, -1); + list_db.Trim(args.storekey, -1, 0); uint64_t new_size = 0; - list_db.Push(args.storekey, store_elems, false, &new_size); + list_db.Push(args.storekey, std::vector(store_elems.cbegin(), store_elems.cend()), false, &new_size); } return rocksdb::Status::OK(); diff --git a/src/storage/redis_db.h b/src/storage/redis_db.h index e8fef670d20..3d510d3f330 100644 --- a/src/storage/redis_db.h +++ b/src/storage/redis_db.h @@ -21,6 +21,7 @@ #pragma once #include +#include #include #include #include @@ -139,7 +140,7 @@ class Database { CopyResult *res); enum class SortResult { UNKNOWN_TYPE, DOUBLE_CONVERT_ERROR, DONE }; [[nodiscard]] rocksdb::Status Sort(const RedisType &type, const std::string &key, const SortArgument &args, - std::vector *elems, SortResult *res); + std::vector> *elems, SortResult *res); protected: engine::Storage *storage_; @@ -166,7 +167,7 @@ class Database { /// 3) If 'pattern' equals "#", the function simply returns 'subst' itself so /// that the SORT command can be used like: SORT key GET # to retrieve /// the Set/List elements directly. - std::string lookupKeyByPattern(const std::string &pattern, const std::string &subst); + std::optional lookupKeyByPattern(const std::string &pattern, const std::string &subst); }; class LatestSnapShot { public: diff --git a/tests/gocase/unit/sort/sort_test.go b/tests/gocase/unit/sort/sort_test.go index c156b7ceeaf..7c75cf752b8 100644 --- a/tests/gocase/unit/sort/sort_test.go +++ b/tests/gocase/unit/sort/sort_test.go @@ -246,11 +246,26 @@ func TestListSort(t *testing.T) { require.NoError(t, err) require.Equal(t, []string{"jack", "peter", "mary", "admin"}, sortResult) - // get empty - rdb.LPush(ctx, "uid_get_empty", 4, 5, 6, 7) - getResult, err := rdb.Do(ctx, "Sort", "uid_get_empty", "Get", "user_name_*").Slice() + // get/by empty and nil + rdb.LPush(ctx, "uid_empty_nil", 4, 5, 6) + rdb.MSet(ctx, "user_name_5", "tom", "user_level_5", -1) + + getResult, err := rdb.Do(ctx, "Sort", "uid_empty_nil", "Get", "user_name_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"mary", "tom", nil}, getResult) + byResult, err := rdb.Do(ctx, "Sort", "uid_empty_nil", "By", "user_level_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"5", "6", "4"}, byResult) + + rdb.MSet(ctx, "user_name_6", "", "user_level_6", "") + + getResult, err = rdb.Do(ctx, "Sort", "uid_empty_nil", "Get", "user_name_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"mary", "tom", ""}, getResult) + + byResult, err = rdb.Do(ctx, "Sort", "uid_empty_nil", "By", "user_level_*").Slice() require.NoError(t, err) - require.Equal(t, []interface{}{"mary", nil, nil, nil}, getResult) + require.Equal(t, []interface{}{"5", "6", "4"}, byResult) }) t.Run("SORT STORE", func(t *testing.T) { @@ -273,16 +288,26 @@ func TestListSort(t *testing.T) { require.NoError(t, err) require.Equal(t, []string{"21", "3", "123"}, sortResult) - // get empty - rdb.LPush(ctx, "uid_get_empty_store", 4, 5, 6, 7) - rdb.MSet(ctx, "user_name_1", "admin", "user_name_2", "jack", "user_name_3", "peter", "user_name_4", "mary") - storedLen, err = rdb.Do(ctx, "Sort", "uid_get_empty_store", "Get", "user_name_*", "STORE", "get_empty_store").Result() + // get empty and nil + rdb.LPush(ctx, "uid_get_empty_nil", 4, 5, 6) + rdb.MSet(ctx, "user_name_4", "mary", "user_level_4", 70, "user_name_5", "tom", "user_level_5", -1) + + storedLen, err = rdb.Do(ctx, "Sort", "uid_get_empty_nil", "Get", "user_name_*", "Store", "get_empty_nil_store").Result() + require.NoError(t, err) + require.Equal(t, int64(3), storedLen) + + sortResult, err = rdb.LRange(ctx, "get_empty_nil_store", 0, -1).Result() require.NoError(t, err) - require.Equal(t, int64(4), storedLen) + require.Equal(t, []string{"mary", "tom", ""}, sortResult) - sortResult, err = rdb.LRange(ctx, "get_empty_store", 0, -1).Result() + rdb.MSet(ctx, "user_name_6", "", "user_level_6", "") + storedLen, err = rdb.Do(ctx, "Sort", "uid_get_empty_nil", "Get", "user_name_*", "Store", "get_empty_nil_store").Result() require.NoError(t, err) - require.Equal(t, []string{"mary", "", "", ""}, sortResult) + require.Equal(t, int64(3), storedLen) + + sortResult, err = rdb.LRange(ctx, "get_empty_nil_store", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"mary", "tom", ""}, sortResult) }) } @@ -469,11 +494,26 @@ func TestSetSort(t *testing.T) { require.NoError(t, err) require.Equal(t, []string{"jack", "peter", "mary", "admin"}, sortResult) - // get empty - rdb.SAdd(ctx, "uid_get_empty", 4, 5, 6, 7) - getResult, err := rdb.Do(ctx, "Sort", "uid_get_empty", "Get", "user_name_*").Slice() + // get/by empty and nil + rdb.SAdd(ctx, "uid_empty_nil", 4, 5, 6) + rdb.MSet(ctx, "user_name_5", "tom", "user_level_5", -1) + + getResult, err := rdb.Do(ctx, "Sort", "uid_empty_nil", "Get", "user_name_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"mary", "tom", nil}, getResult) + byResult, err := rdb.Do(ctx, "Sort", "uid_empty_nil", "By", "user_level_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"5", "6", "4"}, byResult) + + rdb.MSet(ctx, "user_name_6", "", "user_level_6", "") + + getResult, err = rdb.Do(ctx, "Sort", "uid_empty_nil", "Get", "user_name_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"mary", "tom", ""}, getResult) + + byResult, err = rdb.Do(ctx, "Sort", "uid_empty_nil", "By", "user_level_*").Slice() require.NoError(t, err) - require.Equal(t, []interface{}{"mary", nil, nil, nil}, getResult) + require.Equal(t, []interface{}{"5", "6", "4"}, byResult) }) t.Run("SORT STORE", func(t *testing.T) { @@ -496,16 +536,26 @@ func TestSetSort(t *testing.T) { require.NoError(t, err) require.Equal(t, []string{"123", "21", "3"}, sortResult) - // get empty - rdb.SAdd(ctx, "uid_get_empty_store", 4, 5, 6, 7) - rdb.MSet(ctx, "user_name_1", "admin", "user_name_2", "jack", "user_name_3", "peter", "user_name_4", "mary") - storedLen, err = rdb.Do(ctx, "Sort", "uid_get_empty_store", "Get", "user_name_*", "STORE", "get_empty_store").Result() + // get empty and nil + rdb.SAdd(ctx, "uid_get_empty_nil", 4, 5, 6) + rdb.MSet(ctx, "user_name_4", "mary", "user_level_4", 70, "user_name_5", "tom", "user_level_5", -1) + + storedLen, err = rdb.Do(ctx, "Sort", "uid_get_empty_nil", "Get", "user_name_*", "Store", "get_empty_nil_store").Result() require.NoError(t, err) - require.Equal(t, int64(4), storedLen) + require.Equal(t, int64(3), storedLen) + + sortResult, err = rdb.LRange(ctx, "get_empty_nil_store", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"mary", "tom", ""}, sortResult) + + rdb.MSet(ctx, "user_name_6", "", "user_level_6", "") + storedLen, err = rdb.Do(ctx, "Sort", "uid_get_empty_nil", "Get", "user_name_*", "Store", "get_empty_nil_store").Result() + require.NoError(t, err) + require.Equal(t, int64(3), storedLen) - sortResult, err = rdb.LRange(ctx, "get_empty_store", 0, -1).Result() + sortResult, err = rdb.LRange(ctx, "get_empty_nil_store", 0, -1).Result() require.NoError(t, err) - require.Equal(t, []string{"mary", "", "", ""}, sortResult) + require.Equal(t, []string{"mary", "tom", ""}, sortResult) }) } @@ -708,15 +758,29 @@ func TestZSetSort(t *testing.T) { require.NoError(t, err) require.Equal(t, []string{"jack", "peter", "mary", "admin"}, sortResult) - // get empty - rdb.ZAdd(ctx, "uid_get_empty", - redis.Z{Score: 4, Member: "7"}, - redis.Z{Score: 5, Member: "6"}, - redis.Z{Score: 6, Member: "5"}, - redis.Z{Score: 7, Member: "4"}) - getResult, err := rdb.Do(ctx, "Sort", "uid_get_empty", "Get", "user_name_*").Slice() + // get/by empty and nil + rdb.ZAdd(ctx, "uid_empty_nil", + redis.Z{Score: 4, Member: "6"}, + redis.Z{Score: 5, Member: "5"}, + redis.Z{Score: 6, Member: "4"}) + rdb.MSet(ctx, "user_name_5", "tom", "user_level_5", -1) + + getResult, err := rdb.Do(ctx, "Sort", "uid_empty_nil", "Get", "user_name_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"mary", "tom", nil}, getResult) + byResult, err := rdb.Do(ctx, "Sort", "uid_empty_nil", "By", "user_level_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"5", "6", "4"}, byResult) + + rdb.MSet(ctx, "user_name_6", "", "user_level_6", "") + + getResult, err = rdb.Do(ctx, "Sort", "uid_empty_nil", "Get", "user_name_*").Slice() require.NoError(t, err) - require.Equal(t, []interface{}{"mary", nil, nil, nil}, getResult) + require.Equal(t, []interface{}{"mary", "tom", ""}, getResult) + + byResult, err = rdb.Do(ctx, "Sort", "uid_empty_nil", "By", "user_level_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"5", "6", "4"}, byResult) }) t.Run("SORT STORE", func(t *testing.T) { @@ -755,19 +819,28 @@ func TestZSetSort(t *testing.T) { require.NoError(t, err) require.Equal(t, []string{"123", "3", "21"}, sortResult) - // get empty - rdb.ZAdd(ctx, "uid_get_empty_store", - redis.Z{Score: 4, Member: "7"}, - redis.Z{Score: 5, Member: "6"}, - redis.Z{Score: 6, Member: "5"}, - redis.Z{Score: 7, Member: "4"}) - rdb.MSet(ctx, "user_name_1", "admin", "user_name_2", "jack", "user_name_3", "peter", "user_name_4", "mary") - storedLen, err = rdb.Do(ctx, "Sort", "uid_get_empty_store", "Get", "user_name_*", "STORE", "get_empty_store").Result() + // get empty and nil + rdb.ZAdd(ctx, "uid_get_empty_nil", + redis.Z{Score: 4, Member: "6"}, + redis.Z{Score: 5, Member: "5"}, + redis.Z{Score: 6, Member: "4"}) + rdb.MSet(ctx, "user_name_4", "mary", "user_level_4", 70, "user_name_5", "tom", "user_level_5", -1) + + storedLen, err = rdb.Do(ctx, "Sort", "uid_get_empty_nil", "Get", "user_name_*", "Store", "get_empty_nil_store").Result() require.NoError(t, err) - require.Equal(t, int64(4), storedLen) + require.Equal(t, int64(3), storedLen) + + sortResult, err = rdb.LRange(ctx, "get_empty_nil_store", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"mary", "tom", ""}, sortResult) + + rdb.MSet(ctx, "user_name_6", "", "user_level_6", "") + storedLen, err = rdb.Do(ctx, "Sort", "uid_get_empty_nil", "Get", "user_name_*", "Store", "get_empty_nil_store").Result() + require.NoError(t, err) + require.Equal(t, int64(3), storedLen) - sortResult, err = rdb.LRange(ctx, "get_empty_store", 0, -1).Result() + sortResult, err = rdb.LRange(ctx, "get_empty_nil_store", 0, -1).Result() require.NoError(t, err) - require.Equal(t, []string{"mary", "", "", ""}, sortResult) + require.Equal(t, []string{"mary", "tom", ""}, sortResult) }) } From 2c3ae39bb37fcf857d3ca399e7e57af26bbc3472 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Sat, 4 May 2024 16:49:08 +0800 Subject: [PATCH 20/23] refactor: use move_iterator --- src/storage/redis_db.cc | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/storage/redis_db.cc b/src/storage/redis_db.cc index ddaa5341f40..67499e8eacd 100644 --- a/src/storage/redis_db.cc +++ b/src/storage/redis_db.cc @@ -860,7 +860,8 @@ rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, co set_db.Members(key, &str_vec); if (args.dontsort) { - str_vec = std::vector(str_vec.begin() + offset, str_vec.begin() + offset + count); + str_vec = std::vector(std::make_move_iterator(str_vec.begin() + offset), + std::make_move_iterator(str_vec.begin() + offset + count)); } } else if (type == RedisType::kRedisZSet) { auto zset_db = redis::ZSet(storage_, namespace_); @@ -923,7 +924,8 @@ rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, co // Gets the element specified by Limit if (offset != 0 || count != vectorlen) { - sort_vec = std::vector(sort_vec.begin() + offset, sort_vec.begin() + offset + count); + sort_vec = std::vector(std::make_move_iterator(sort_vec.begin() + offset), + std::make_move_iterator(sort_vec.begin() + offset + count)); } } @@ -951,7 +953,10 @@ rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, co redis::List list_db(storage_, namespace_); list_db.Trim(args.storekey, -1, 0); uint64_t new_size = 0; - list_db.Push(args.storekey, std::vector(store_elems.cbegin(), store_elems.cend()), false, &new_size); + list_db.Push( + args.storekey, + std::vector(std::make_move_iterator(store_elems.cbegin()), std::make_move_iterator(store_elems.cend())), + false, &new_size); } return rocksdb::Status::OK(); From 7667559d476e0039656add1a81975f057f70e311 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Sat, 4 May 2024 17:43:29 +0800 Subject: [PATCH 21/23] refactor: remove move_iterator on vector --- src/storage/redis_db.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/storage/redis_db.cc b/src/storage/redis_db.cc index d770fcc9b86..b9da335f49f 100644 --- a/src/storage/redis_db.cc +++ b/src/storage/redis_db.cc @@ -944,10 +944,7 @@ rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, co redis::List list_db(storage_, namespace_); list_db.Trim(args.storekey, -1, 0); uint64_t new_size = 0; - list_db.Push( - args.storekey, - std::vector(std::make_move_iterator(store_elems.cbegin()), std::make_move_iterator(store_elems.cend())), - false, &new_size); + list_db.Push(args.storekey, std::vector(store_elems.cbegin(), store_elems.cend()), false, &new_size); } return rocksdb::Status::OK(); From b7d7f68fb1c8d265bcc1c3f97a97152adae0f75a Mon Sep 17 00:00:00 2001 From: Zhou SiLe Date: Mon, 6 May 2024 11:12:11 +0800 Subject: [PATCH 22/23] fix: Return => Returns Co-authored-by: mwish --- src/storage/redis_db.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/storage/redis_db.h b/src/storage/redis_db.h index 3d510d3f330..86056431490 100644 --- a/src/storage/redis_db.h +++ b/src/storage/redis_db.h @@ -158,7 +158,7 @@ class Database { /// /// \param pattern can be the value of a `BY` or `GET` field /// \param subst is used to replace the "*" or "#" matched in the pattern string. - /// \return Return the value associated to the key with a name obtained using the following rules: + /// \return Returns the value associated to the key with a name obtained using the following rules: /// 1) The first occurrence of '*' in 'pattern' is substituted with 'subst'. /// 2) If 'pattern' matches the "->" string, everything on the left of /// the arrow is treated as the name of a hash field, and the part on the From 6cfa1988c0032fd4a18a6e6737f932762d6cdc6e Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Mon, 6 May 2024 22:01:04 +0800 Subject: [PATCH 23/23] refactor: modify code according to reviewer suggestions --- src/commands/cmd_key.cc | 5 +++ src/storage/redis_db.cc | 62 ++++++++++++++++++----------- src/storage/redis_db.h | 18 ++++++++- tests/gocase/unit/sort/sort_test.go | 35 ++++++++++++++++ 4 files changed, 94 insertions(+), 26 deletions(-) diff --git a/src/commands/cmd_key.cc b/src/commands/cmd_key.cc index 5717ee0b14f..24d8fe29c4d 100644 --- a/src/commands/cmd_key.cc +++ b/src/commands/cmd_key.cc @@ -431,6 +431,7 @@ class CommandSort : public Commander { CommandParser parser(args, 2); while (parser.Good()) { if (parser.EatEqICase("BY")) { + if (!sort_argument_.sortby.empty()) return {Status::InvalidArgument, "don't use multiple BY parameters"}; sort_argument_.sortby = GET_OR_RET(parser.TakeStr()); if (sort_argument_.sortby.find('*') == std::string::npos) { @@ -512,6 +513,10 @@ class CommandSort : public Commander { case Database::SortResult::DOUBLE_CONVERT_ERROR: *output = redis::Error("One or more scores can't be converted into double"); break; + case Database::SortResult::LIMIT_EXCEEDED: + *output = redis::Error("The number of elements to be sorted exceeds SORT_LENGTH_LIMIT = " + + std::to_string(SORT_LENGTH_LIMIT)); + break; case Database::SortResult::DONE: if (sort_argument_.storekey.empty()) { std::vector output_vec; diff --git a/src/storage/redis_db.cc b/src/storage/redis_db.cc index b9da335f49f..13f9bd2f5c0 100644 --- a/src/storage/redis_db.cc +++ b/src/storage/redis_db.cc @@ -794,36 +794,41 @@ std::optional Database::lookupKeyByPattern(const std::string &patte key.replace(match_pos, 1, subst); std::string value; + RedisType type = RedisType::kRedisNone; if (!field.empty()) { auto hash_db = redis::Hash(storage_, namespace_); - RedisType type = RedisType::kRedisNone; if (auto s = hash_db.Type(key, &type); !s.ok() || type != RedisType::kRedisHash) { return std::nullopt; } - hash_db.Get(key, field, &value); + if (auto s = hash_db.Get(key, field, &value); !s.ok()) { + return std::nullopt; + } } else { auto string_db = redis::String(storage_, namespace_); - RedisType type = RedisType::kRedisNone; if (auto s = string_db.Type(key, &type); !s.ok() || type != RedisType::kRedisString) { return std::nullopt; } - string_db.Get(key, &value); + if (auto s = string_db.Get(key, &value); !s.ok()) { + return std::nullopt; + } } return value; } -rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, const SortArgument &args, +rocksdb::Status Database::Sort(RedisType type, const std::string &key, const SortArgument &args, std::vector> *elems, SortResult *res) { // Obtain the length of the object to sort. const std::string ns_key = AppendNamespacePrefix(key); Metadata metadata(type, false); auto s = GetMetadata(GetOptions{}, {type}, ns_key, &metadata); - if (!s.ok()) { - return s; - } + if (!s.ok()) return s; - int vectorlen = static_cast(metadata.size); + if (metadata.size > SORT_LENGTH_LIMIT) { + *res = SortResult::LIMIT_EXCEEDED; + return rocksdb::Status::OK(); + } + auto vectorlen = static_cast(metadata.size); // Adjust the offset and count of the limit int offset = args.offset >= vectorlen ? 0 : std::clamp(args.offset, 0, vectorlen - 1); @@ -838,17 +843,21 @@ rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, co if (args.dontsort) { if (args.desc) { - list_db.Range(key, -count - offset, -1 - offset, &str_vec); + s = list_db.Range(key, -count - offset, -1 - offset, &str_vec); + if (!s.ok()) return s; std::reverse(str_vec.begin(), str_vec.end()); } else { - list_db.Range(key, offset, offset + count - 1, &str_vec); + s = list_db.Range(key, offset, offset + count - 1, &str_vec); + if (!s.ok()) return s; } } else { - list_db.Range(key, 0, -1, &str_vec); + s = list_db.Range(key, 0, -1, &str_vec); + if (!s.ok()) return s; } } else if (type == RedisType::kRedisSet) { auto set_db = redis::Set(storage_, namespace_); - set_db.Members(key, &str_vec); + s = set_db.Members(key, &str_vec); + if (!s.ok()) return s; if (args.dontsort) { str_vec = std::vector(std::make_move_iterator(str_vec.begin() + offset), @@ -863,13 +872,15 @@ rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, co spec.start = offset; spec.stop = offset + count - 1; spec.reversed = args.desc; - zset_db.RangeByRank(key, spec, &member_scores, nullptr); + s = zset_db.RangeByRank(key, spec, &member_scores, nullptr); + if (!s.ok()) return s; for (auto &member_score : member_scores) { str_vec.emplace_back(std::move(member_score.member)); } } else { - zset_db.GetAllMemberScores(key, &member_scores); + s = zset_db.GetAllMemberScores(key, &member_scores); + if (!s.ok()) return s; for (auto &member_score : member_scores) { str_vec.emplace_back(std::move(member_score.member)); @@ -889,18 +900,19 @@ rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, co // Sort by BY, ALPHA, ASC/DESC if (!args.dontsort) { for (size_t i = 0; i < sort_vec.size(); ++i) { - std::optional byval; + std::string byval; if (!args.sortby.empty()) { - byval = lookupKeyByPattern(args.sortby, str_vec[i]); - if (!byval.has_value()) continue; + auto lookup = lookupKeyByPattern(args.sortby, str_vec[i]); + if (!lookup.has_value()) continue; + byval = std::move(lookup.value()); } else { byval = str_vec[i]; } if (args.alpha && !args.sortby.empty()) { - sort_vec[i].v = byval.value(); - } else if (!args.alpha && !byval.value().empty()) { - auto double_byval = ParseFloat(byval.value()); + sort_vec[i].v = byval; + } else if (!args.alpha && !byval.empty()) { + auto double_byval = ParseFloat(byval); if (!double_byval) { *res = SortResult::DOUBLE_CONVERT_ERROR; return rocksdb::Status::OK(); @@ -909,7 +921,7 @@ rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, co } } - std::sort(sort_vec.begin(), sort_vec.end(), [args](const RedisSortObject &a, const RedisSortObject &b) { + std::sort(sort_vec.begin(), sort_vec.end(), [&args](const RedisSortObject &a, const RedisSortObject &b) { return RedisSortObject::SortCompare(a, b, args); }); @@ -942,9 +954,11 @@ rocksdb::Status Database::Sort(const RedisType &type, const std::string &key, co store_elems.emplace_back(e.value_or("")); } redis::List list_db(storage_, namespace_); - list_db.Trim(args.storekey, -1, 0); + s = list_db.Trim(args.storekey, -1, 0); + if (!s.ok()) return s; uint64_t new_size = 0; - list_db.Push(args.storekey, std::vector(store_elems.cbegin(), store_elems.cend()), false, &new_size); + s = list_db.Push(args.storekey, std::vector(store_elems.cbegin(), store_elems.cend()), false, &new_size); + if (!s.ok()) return s; } return rocksdb::Status::OK(); diff --git a/src/storage/redis_db.h b/src/storage/redis_db.h index 86056431490..84579b107fd 100644 --- a/src/storage/redis_db.h +++ b/src/storage/redis_db.h @@ -33,6 +33,12 @@ namespace redis { +/// SORT_LENGTH_LIMIT limits the number of elements to be sorted +/// to avoid using too much memory and causing system crashes. +/// TODO: Expect to expand or eliminate SORT_LENGTH_LIMIT +/// through better mechanisms such as memory restriction logic. +constexpr uint64_t SORT_LENGTH_LIMIT = 512; + struct SortArgument { std::string sortby; // BY bool dontsort = false; // DONT SORT @@ -138,8 +144,16 @@ class Database { enum class CopyResult { KEY_NOT_EXIST, KEY_ALREADY_EXIST, DONE }; [[nodiscard]] rocksdb::Status Copy(const std::string &key, const std::string &new_key, bool nx, bool delete_old, CopyResult *res); - enum class SortResult { UNKNOWN_TYPE, DOUBLE_CONVERT_ERROR, DONE }; - [[nodiscard]] rocksdb::Status Sort(const RedisType &type, const std::string &key, const SortArgument &args, + enum class SortResult { UNKNOWN_TYPE, DOUBLE_CONVERT_ERROR, LIMIT_EXCEEDED, DONE }; + /// Sort sorts keys of the specified type according to SortArgument + /// + /// \param type is the type of sort key, which must be LIST, SET or ZSET + /// \param key is to be sorted + /// \param args provide the parameters to sort by + /// \param elems contain the sorted results + /// \param res represents the sorted result type. + /// When status is not ok, `res` should not been checked, otherwise it should be checked whether `res` is `DONE` + [[nodiscard]] rocksdb::Status Sort(RedisType type, const std::string &key, const SortArgument &args, std::vector> *elems, SortResult *res); protected: diff --git a/tests/gocase/unit/sort/sort_test.go b/tests/gocase/unit/sort/sort_test.go index 7c75cf752b8..6715ed783a0 100644 --- a/tests/gocase/unit/sort/sort_test.go +++ b/tests/gocase/unit/sort/sort_test.go @@ -21,6 +21,7 @@ package sort import ( "context" + "fmt" "testing" "github.com/redis/go-redis/v9" @@ -58,11 +59,44 @@ func TestSortParser(t *testing.T) { _, err = rdb.Do(ctx, "Sort", "bad-case-key", "STORE").Result() require.EqualError(t, err, "ERR no more item to parse") + rdb.MSet(ctx, "rank_1", 1, "rank_2", "rank_3", 3, "rank_4", 4, "rank_5", 5) + _, err = rdb.Do(ctx, "Sort", "bad-case-key", "BY", "dontsort", "BY", "rank_*").Result() + require.EqualError(t, err, "ERR don't use multiple BY parameters") + _, err = rdb.Do(ctx, "Sort_RO", "bad-case-key", "STORE", "store_ro_key").Result() require.EqualError(t, err, "ERR SORT_RO is read-only and does not support the STORE parameter") }) } +func TestSortLengthLimit(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("SORT Length Limit", func(t *testing.T) { + for i := 0; i <= 512; i++ { + rdb.LPush(ctx, "many-list-elems-key", i) + } + _, err := rdb.Sort(ctx, "many-list-elems-key", &redis.Sort{}).Result() + require.EqualError(t, err, "The number of elements to be sorted exceeds SORT_LENGTH_LIMIT = 512") + + for i := 0; i <= 512; i++ { + rdb.SAdd(ctx, "many-set-elems-key", i) + } + _, err = rdb.Sort(ctx, "many-set-elems-key", &redis.Sort{}).Result() + require.EqualError(t, err, "The number of elements to be sorted exceeds SORT_LENGTH_LIMIT = 512") + + for i := 0; i <= 512; i++ { + rdb.ZAdd(ctx, "many-zset-elems-key", redis.Z{Score: float64(i), Member: fmt.Sprintf("%d", i)}) + } + _, err = rdb.Sort(ctx, "many-zset-elems-key", &redis.Sort{}).Result() + require.EqualError(t, err, "The number of elements to be sorted exceeds SORT_LENGTH_LIMIT = 512") + }) +} + func TestListSort(t *testing.T) { srv := util.StartServer(t, map[string]string{}) defer srv.Close() @@ -514,6 +548,7 @@ func TestSetSort(t *testing.T) { byResult, err = rdb.Do(ctx, "Sort", "uid_empty_nil", "By", "user_level_*").Slice() require.NoError(t, err) require.Equal(t, []interface{}{"5", "6", "4"}, byResult) + }) t.Run("SORT STORE", func(t *testing.T) {