From f06c76460e3dded247c2430d886cba8b7f48c788 Mon Sep 17 00:00:00 2001 From: git-hulk Date: Thu, 18 Jan 2024 22:52:47 +0800 Subject: [PATCH 1/4] Rename ArrayOfSet to SetOfBulkString --- src/commands/cmd_set.cc | 12 ++++++------ src/server/redis_connection.cc | 2 +- src/server/redis_connection.h | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/commands/cmd_set.cc b/src/commands/cmd_set.cc index 079b2d267cc..64012f56ba7 100644 --- a/src/commands/cmd_set.cc +++ b/src/commands/cmd_set.cc @@ -93,7 +93,7 @@ class CommandSMembers : public Commander { return {Status::RedisExecErr, s.ToString()}; } - *output = conn->ArrayOfSet(members); + *output = conn->SetOfBulkString(members); return Status::OK(); } }; @@ -171,7 +171,7 @@ class CommandSPop : public Commander { } if (with_count_) { - *output = conn->ArrayOfSet(members); + *output = conn->SetOfBulkString(members); } else { if (members.size() > 0) { *output = redis::BulkString(members.front()); @@ -211,7 +211,7 @@ class CommandSRandMember : public Commander { if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; } - *output = conn->ArrayOfSet(members); + *output = conn->SetOfBulkString(members); return Status::OK(); } @@ -249,7 +249,7 @@ class CommandSDiff : public Commander { return {Status::RedisExecErr, s.ToString()}; } - *output = conn->ArrayOfSet(members); + *output = conn->SetOfBulkString(members); return Status::OK(); } }; @@ -269,7 +269,7 @@ class CommandSUnion : public Commander { return {Status::RedisExecErr, s.ToString()}; } - *output = conn->ArrayOfSet(members); + *output = conn->SetOfBulkString(members); return Status::OK(); } }; @@ -289,7 +289,7 @@ class CommandSInter : public Commander { return {Status::RedisExecErr, s.ToString()}; } - *output = conn->ArrayOfSet(members); + *output = conn->SetOfBulkString(members); return Status::OK(); } }; diff --git a/src/server/redis_connection.cc b/src/server/redis_connection.cc index 5c93a214874..a1ff738783e 100644 --- a/src/server/redis_connection.cc +++ b/src/server/redis_connection.cc @@ -163,7 +163,7 @@ std::string Connection::MultiBulkString(const std::vector &values, return result; } -std::string Connection::ArrayOfSet(const std::vector &elems) const { +std::string Connection::SetOfBulkString(const std::vector &elems) const { std::string result; result += SizeOfSet(elems.size()); for (const auto &elem : elems) { diff --git a/src/server/redis_connection.h b/src/server/redis_connection.h index c2a0d8cb78e..c1dff4677ec 100644 --- a/src/server/redis_connection.h +++ b/src/server/redis_connection.h @@ -74,7 +74,7 @@ class Connection : public EvbufCallbackBase { std::string SizeOfSet(T len) const { return protocol_version_ == RESP::v3 ? "~" + std::to_string(len) + CRLF : MultiLen(len); } - std::string ArrayOfSet(const std::vector &elems) const; + std::string SetOfBulkString(const std::vector &elems) const; using UnsubscribeCallback = std::function; void SubscribeChannel(const std::string &channel); From 21e286d1ddc0f19e9b6abd67331b1a752f9ea8a2 Mon Sep 17 00:00:00 2001 From: git-hulk Date: Thu, 18 Jan 2024 23:52:44 +0800 Subject: [PATCH 2/4] Implement RESP3 map type --- src/commands/cmd_function.cc | 6 +- src/commands/cmd_hash.cc | 2 +- src/commands/cmd_server.cc | 13 +- src/commands/cmd_set.cc | 12 +- src/commands/cmd_stream.cc | 8 +- src/server/redis_connection.cc | 11 +- src/server/redis_connection.h | 7 +- src/storage/scripting.cc | 39 ++-- src/storage/scripting.h | 7 +- tests/gocase/unit/config/config_test.go | 13 ++ tests/gocase/unit/debug/debug_test.go | 2 + tests/gocase/unit/hello/hello_test.go | 13 +- tests/gocase/unit/protocol/protocol_test.go | 4 +- tests/gocase/unit/scripting/function_test.go | 220 +++++++++++++------ tests/gocase/unit/type/hash/hash_test.go | 24 ++ tests/gocase/unit/type/stream/stream_test.go | 14 +- 16 files changed, 284 insertions(+), 111 deletions(-) diff --git a/src/commands/cmd_function.cc b/src/commands/cmd_function.cc index 2123cc72ec7..2d7ce193e49 100644 --- a/src/commands/cmd_function.cc +++ b/src/commands/cmd_function.cc @@ -53,18 +53,18 @@ struct CommandFunction : Commander { with_code = true; } - return lua::FunctionList(srv, libname, with_code, output); + return lua::FunctionList(srv, conn, libname, with_code, output); } else if (parser.EatEqICase("listfunc")) { std::string funcname; if (parser.EatEqICase("funcname")) { funcname = GET_OR_RET(parser.TakeStr()); } - return lua::FunctionListFunc(srv, funcname, output); + return lua::FunctionListFunc(srv, conn, funcname, output); } else if (parser.EatEqICase("listlib")) { auto libname = GET_OR_RET(parser.TakeStr().Prefixed("expect a library name")); - return lua::FunctionListLib(srv, libname, output); + return lua::FunctionListLib(srv, conn, libname, output); } else if (parser.EatEqICase("delete")) { auto libname = GET_OR_RET(parser.TakeStr()); if (!lua::FunctionIsLibExist(conn, libname)) { diff --git a/src/commands/cmd_hash.cc b/src/commands/cmd_hash.cc index 8b9526f97fc..3eac5ab86ef 100644 --- a/src/commands/cmd_hash.cc +++ b/src/commands/cmd_hash.cc @@ -306,7 +306,7 @@ class CommandHGetAll : public Commander { kv_pairs.emplace_back(p.field); kv_pairs.emplace_back(p.value); } - *output = conn->MultiBulkString(kv_pairs, false); + *output = conn->MapOfBulkStrings(kv_pairs); return Status::OK(); } diff --git a/src/commands/cmd_server.cc b/src/commands/cmd_server.cc index 6ad11e4c806..2f608cfec32 100644 --- a/src/commands/cmd_server.cc +++ b/src/commands/cmd_server.cc @@ -252,7 +252,7 @@ class CommandConfig : public Commander { } else if (args_.size() == 3 && sub_command == "get") { std::vector values; config->Get(args_[2], &values); - *output = conn->MultiBulkString(values); + *output = conn->MapOfBulkStrings(values); } else if (args_.size() == 4 && sub_command == "set") { Status s = config->Set(srv, args_[2], args_[3]); if (!s.IsOK()) { @@ -617,6 +617,12 @@ class CommandDebug : public Commander { for (int i = 0; i < 3; i++) { *output += redis::Integer(i); } + } else if (protocol_type_ == "map") { + *output = conn->SizeOfMap(3); + for (int i = 0; i < 3; i++) { + *output += redis::Integer(i); + *output += conn->Bool(i == 1); + } } else if (protocol_type_ == "true") { *output = conn->Bool(true); } else if (protocol_type_ == "false") { @@ -783,7 +789,10 @@ class CommandHello final : public Commander { } else { output_list.push_back(redis::BulkString("standalone")); } - *output = redis::Array(output_list); + *output = conn->SizeOfMap(output_list.size() / 2); + for (const auto &item : output_list) { + *output += item; + } return Status::OK(); } }; diff --git a/src/commands/cmd_set.cc b/src/commands/cmd_set.cc index 64012f56ba7..ced252234b2 100644 --- a/src/commands/cmd_set.cc +++ b/src/commands/cmd_set.cc @@ -93,7 +93,7 @@ class CommandSMembers : public Commander { return {Status::RedisExecErr, s.ToString()}; } - *output = conn->SetOfBulkString(members); + *output = conn->SetOfBulkStrings(members); return Status::OK(); } }; @@ -171,7 +171,7 @@ class CommandSPop : public Commander { } if (with_count_) { - *output = conn->SetOfBulkString(members); + *output = conn->SetOfBulkStrings(members); } else { if (members.size() > 0) { *output = redis::BulkString(members.front()); @@ -211,7 +211,7 @@ class CommandSRandMember : public Commander { if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; } - *output = conn->SetOfBulkString(members); + *output = conn->SetOfBulkStrings(members); return Status::OK(); } @@ -249,7 +249,7 @@ class CommandSDiff : public Commander { return {Status::RedisExecErr, s.ToString()}; } - *output = conn->SetOfBulkString(members); + *output = conn->SetOfBulkStrings(members); return Status::OK(); } }; @@ -269,7 +269,7 @@ class CommandSUnion : public Commander { return {Status::RedisExecErr, s.ToString()}; } - *output = conn->SetOfBulkString(members); + *output = conn->SetOfBulkStrings(members); return Status::OK(); } }; @@ -289,7 +289,7 @@ class CommandSInter : public Commander { return {Status::RedisExecErr, s.ToString()}; } - *output = conn->SetOfBulkString(members); + *output = conn->SetOfBulkStrings(members); return Status::OK(); } }; diff --git a/src/commands/cmd_stream.cc b/src/commands/cmd_stream.cc index f82497fce4d..db871b07976 100644 --- a/src/commands/cmd_stream.cc +++ b/src/commands/cmd_stream.cc @@ -445,9 +445,9 @@ class CommandXInfo : public Commander { } if (!full_) { - output->append(redis::MultiLen(14)); + output->append(conn->SizeOfMap(7)); } else { - output->append(redis::MultiLen(12)); + output->append(conn->SizeOfMap(6)); } output->append(redis::BulkString("length")); output->append(redis::Integer(info.size)); @@ -503,7 +503,7 @@ class CommandXInfo : public Commander { output->append(redis::MultiLen(result_vector.size())); for (auto const &it : result_vector) { - output->append(redis::MultiLen(12)); + output->append(conn->SizeOfMap(6)); output->append(redis::BulkString("name")); output->append(redis::BulkString(it.first)); output->append(redis::BulkString("consumers")); @@ -545,7 +545,7 @@ class CommandXInfo : public Commander { output->append(redis::MultiLen(result_vector.size())); auto now = util::GetTimeStampMS(); for (auto const &it : result_vector) { - output->append(redis::MultiLen(8)); + output->append(conn->SizeOfMap(4)); output->append(redis::BulkString("name")); output->append(redis::BulkString(it.first)); output->append(redis::BulkString("pending")); diff --git a/src/server/redis_connection.cc b/src/server/redis_connection.cc index a1ff738783e..dddfe53cdff 100644 --- a/src/server/redis_connection.cc +++ b/src/server/redis_connection.cc @@ -163,7 +163,7 @@ std::string Connection::MultiBulkString(const std::vector &values, return result; } -std::string Connection::SetOfBulkString(const std::vector &elems) const { +std::string Connection::SetOfBulkStrings(const std::vector &elems) const { std::string result; result += SizeOfSet(elems.size()); for (const auto &elem : elems) { @@ -172,6 +172,15 @@ std::string Connection::SetOfBulkString(const std::vector &elems) c return result; } +std::string Connection::MapOfBulkStrings(const std::vector &elems) const { + std::string result; + result += SizeOfMap(elems.size() / 2); + for (const auto &elem : elems) { + result += BulkString(elem); + } + return result; +} + void Connection::SendFile(int fd) { // NOTE: we don't need to close the fd, the libevent will do that auto output = bufferevent_get_output(bev_); diff --git a/src/server/redis_connection.h b/src/server/redis_connection.h index c1dff4677ec..1ab076e48ec 100644 --- a/src/server/redis_connection.h +++ b/src/server/redis_connection.h @@ -74,7 +74,12 @@ class Connection : public EvbufCallbackBase { std::string SizeOfSet(T len) const { return protocol_version_ == RESP::v3 ? "~" + std::to_string(len) + CRLF : MultiLen(len); } - std::string SetOfBulkString(const std::vector &elems) const; + std::string SetOfBulkStrings(const std::vector &elems) const; + template , int> = 0> + std::string SizeOfMap(T len) const { + return protocol_version_ == RESP::v3 ? "%" + std::to_string(len) + CRLF : MultiLen(len * 2); + } + std::string MapOfBulkStrings(const std::vector &elems) const; using UnsubscribeCallback = std::function; void SubscribeChannel(const std::string &channel); diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc index a6e73fa040d..9e6cedf4b2d 100644 --- a/src/storage/scripting.cc +++ b/src/storage/scripting.cc @@ -425,7 +425,8 @@ Status FunctionCall(redis::Connection *conn, const std::string &name, const std: } // list all library names and their code (enabled via `with_code`) -Status FunctionList(Server *srv, const std::string &libname, bool with_code, std::string *output) { +Status FunctionList(Server *srv, const redis::Connection *conn, const std::string &libname, bool with_code, + std::string *output) { std::string start_key = engine::kLuaLibCodePrefix + libname; std::string end_key = start_key; end_key.back()++; @@ -445,12 +446,13 @@ Status FunctionList(Server *srv, const std::string &libname, bool with_code, std result.emplace_back(lib.ToString(), iter->value().ToString()); } - output->append(redis::MultiLen(result.size() * (with_code ? 4 : 2))); + output->append(redis::MultiLen(result.size())); for (const auto &[lib, code] : result) { - output->append(redis::SimpleString("library_name")); - output->append(redis::SimpleString(lib)); + output->append(conn->SizeOfMap(with_code ? 2 : 1)); + output->append(redis::BulkString("library_name")); + output->append(redis::BulkString(lib)); if (with_code) { - output->append(redis::SimpleString("library_code")); + output->append(redis::BulkString("library_code")); output->append(redis::BulkString(code)); } } @@ -460,7 +462,7 @@ Status FunctionList(Server *srv, const std::string &libname, bool with_code, std // extension to Redis Function // list all function names and their corresponding library names -Status FunctionListFunc(Server *srv, const std::string &funcname, std::string *output) { +Status FunctionListFunc(Server *srv, const redis::Connection *conn, const std::string &funcname, std::string *output) { std::string start_key = engine::kLuaFuncLibPrefix + funcname; std::string end_key = start_key; end_key.back()++; @@ -480,12 +482,13 @@ Status FunctionListFunc(Server *srv, const std::string &funcname, std::string *o result.emplace_back(func.ToString(), iter->value().ToString()); } - output->append(redis::MultiLen(result.size() * 4)); + output->append(redis::MultiLen(result.size())); for (const auto &[func, lib] : result) { - output->append(redis::SimpleString("function_name")); - output->append(redis::SimpleString(func)); - output->append(redis::SimpleString("from_library")); - output->append(redis::SimpleString(lib)); + output->append(conn->SizeOfMap(2)); + output->append(redis::BulkString("function_name")); + output->append(redis::BulkString(func)); + output->append(redis::BulkString("from_library")); + output->append(redis::BulkString(lib)); } return Status::OK(); @@ -495,7 +498,7 @@ Status FunctionListFunc(Server *srv, const std::string &funcname, std::string *o // list detailed informantion of a specific library // NOTE: it is required to load the library to lua runtime before listing (calling this function) // i.e. it will output nothing if the library is only in storage but not loaded -Status FunctionListLib(Server *srv, const std::string &libname, std::string *output) { +Status FunctionListLib(Server *srv, const redis::Connection *conn, const std::string &libname, std::string *output) { auto lua = srv->Lua(); lua_getglobal(lua, REDIS_FUNCTION_LIBRARIES); @@ -511,11 +514,11 @@ Status FunctionListLib(Server *srv, const std::string &libname, std::string *out return {Status::NotOK, "The library is not found or not loaded from storage"}; } - output->append(redis::MultiLen(6)); - output->append(redis::SimpleString("library_name")); - output->append(redis::SimpleString(libname)); - output->append(redis::SimpleString("engine")); - output->append(redis::SimpleString("lua")); + output->append(conn->SizeOfMap(3)); + output->append(redis::BulkString("library_name")); + output->append(redis::BulkString(libname)); + output->append(redis::BulkString("engine")); + output->append(redis::BulkString("lua")); auto count = lua_objlen(lua, -1); output->append(redis::SimpleString("functions")); @@ -524,7 +527,7 @@ Status FunctionListLib(Server *srv, const std::string &libname, std::string *out for (size_t i = 1; i <= count; ++i) { lua_rawgeti(lua, -1, static_cast(i)); auto func = lua_tostring(lua, -1); - output->append(redis::SimpleString(func)); + output->append(redis::BulkString(func)); lua_pop(lua, 1); } diff --git a/src/storage/scripting.h b/src/storage/scripting.h index 0d9ce46c316..a2c90b90ae0 100644 --- a/src/storage/scripting.h +++ b/src/storage/scripting.h @@ -66,9 +66,10 @@ Status FunctionLoad(redis::Connection *conn, const std::string &script, bool nee std::string *lib_name, bool read_only = false); Status FunctionCall(redis::Connection *conn, const std::string &name, const std::vector &keys, const std::vector &argv, std::string *output, bool read_only = false); -Status FunctionList(Server *srv, const std::string &libname, bool with_code, std::string *output); -Status FunctionListFunc(Server *srv, const std::string &funcname, std::string *output); -Status FunctionListLib(Server *srv, const std::string &libname, std::string *output); +Status FunctionList(Server *srv, const redis::Connection *conn, const std::string &libname, bool with_code, + std::string *output); +Status FunctionListFunc(Server *srv, const redis::Connection *conn, const std::string &funcname, std::string *output); +Status FunctionListLib(Server *srv, const redis::Connection *conn, const std::string &libname, std::string *output); Status FunctionDelete(Server *srv, const std::string &name); bool FunctionIsLibExist(redis::Connection *conn, const std::string &libname, bool need_check_storage = true, bool read_only = false); diff --git a/tests/gocase/unit/config/config_test.go b/tests/gocase/unit/config/config_test.go index dd880367847..d2643c2cc71 100644 --- a/tests/gocase/unit/config/config_test.go +++ b/tests/gocase/unit/config/config_test.go @@ -133,6 +133,19 @@ func TestConfigSetCompression(t *testing.T) { require.ErrorContains(t, rdb.ConfigSet(ctx, configKey, "unsupported").Err(), "invalid enum option") } +func TestConfigGetRESP3(t *testing.T) { + srv := util.StartServer(t, map[string]string{ + "resp3-enabled": "yes", + }) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + val := rdb.ConfigGet(ctx, "resp3-enabled").Val() + require.EqualValues(t, "yes", val["resp3-enabled"]) +} + func TestStartWithoutConfigurationFile(t *testing.T) { srv := util.StartServerWithCLIOptions(t, false, map[string]string{}, []string{}) defer srv.Close() diff --git a/tests/gocase/unit/debug/debug_test.go b/tests/gocase/unit/debug/debug_test.go index a8a158c3fa6..faa2981141d 100644 --- a/tests/gocase/unit/debug/debug_test.go +++ b/tests/gocase/unit/debug/debug_test.go @@ -45,6 +45,7 @@ func TestDebugProtocolV2(t *testing.T) { "integer": int64(12345), "array": []interface{}{int64(0), int64(1), int64(2)}, "set": []interface{}{int64(0), int64(1), int64(2)}, + "map": []interface{}{int64(0), int64(0), int64(1), int64(1), int64(2), int64(0)}, "true": int64(1), "false": int64(0), } @@ -87,6 +88,7 @@ func TestDebugProtocolV3(t *testing.T) { "integer": int64(12345), "array": []interface{}{int64(0), int64(1), int64(2)}, "set": []interface{}{int64(0), int64(1), int64(2)}, + "map": map[interface{}]interface{}{int64(0): false, int64(1): true, int64(2): false}, "true": true, "false": false, } diff --git a/tests/gocase/unit/hello/hello_test.go b/tests/gocase/unit/hello/hello_test.go index d965b29c28e..36296b05f09 100644 --- a/tests/gocase/unit/hello/hello_test.go +++ b/tests/gocase/unit/hello/hello_test.go @@ -86,15 +86,16 @@ func TestEnableRESP3(t *testing.T) { rdb := srv.NewClient() defer func() { require.NoError(t, rdb.Close()) }() - r := rdb.Do(ctx, "HELLO", "2") - rList := r.Val().([]interface{}) + r, err := rdb.Do(ctx, "HELLO", "2").Result() + require.NoError(t, err) + rList := r.([]interface{}) require.EqualValues(t, rList[2], "proto") require.EqualValues(t, rList[3], 2) - r = rdb.Do(ctx, "HELLO", "3") - rList = r.Val().([]interface{}) - require.EqualValues(t, rList[2], "proto") - require.EqualValues(t, rList[3], 3) + r, err = rdb.Do(ctx, "HELLO", "3").Result() + require.NoError(t, err) + rMap := r.(map[interface{}]interface{}) + require.EqualValues(t, rMap["proto"], 3) } func TestHelloWithAuth(t *testing.T) { diff --git a/tests/gocase/unit/protocol/protocol_test.go b/tests/gocase/unit/protocol/protocol_test.go index d7f57145423..7919bf77f86 100644 --- a/tests/gocase/unit/protocol/protocol_test.go +++ b/tests/gocase/unit/protocol/protocol_test.go @@ -155,6 +155,7 @@ func TestProtocolRESP2(t *testing.T) { "integer": {":12345"}, "array": {"*3", ":0", ":1", ":2"}, "set": {"*3", ":0", ":1", ":2"}, + "map": {"*6", ":0", ":0", ":1", ":1", ":2", ":0"}, "true": {":1"}, "false": {":0"}, "null": {"$-1"}, @@ -198,7 +199,7 @@ func TestProtocolRESP3(t *testing.T) { t.Run("debug protocol string", func(t *testing.T) { require.NoError(t, c.WriteArgs("HELLO", "3")) - values := []string{"*6", "$6", "server", "$5", "redis", "$5", "proto", ":3", "$4", "mode", "$10", "standalone"} + values := []string{"%3", "$6", "server", "$5", "redis", "$5", "proto", ":3", "$4", "mode", "$10", "standalone"} for _, line := range values { c.MustRead(t, line) } @@ -208,6 +209,7 @@ func TestProtocolRESP3(t *testing.T) { "integer": {":12345"}, "array": {"*3", ":0", ":1", ":2"}, "set": {"~3", ":0", ":1", ":2"}, + "map": {"%3", ":0", "#f", ":1", "#t", ":2", "#f"}, "true": {"#t"}, "false": {"#f"}, "null": {"_"}, diff --git a/tests/gocase/unit/scripting/function_test.go b/tests/gocase/unit/scripting/function_test.go index 3e262db218a..d0c12ec157a 100644 --- a/tests/gocase/unit/scripting/function_test.go +++ b/tests/gocase/unit/scripting/function_test.go @@ -25,6 +25,8 @@ import ( "strings" "testing" + "github.com/redis/go-redis/v9" + "github.com/apache/kvrocks/tests/gocase/util" "github.com/stretchr/testify/require" ) @@ -38,8 +40,74 @@ var luaMylib2 string //go:embed mylib3.lua var luaMylib3 string -func TestFunction(t *testing.T) { - srv := util.StartServer(t, map[string]string{}) +type ListFuncResult struct { + Name string + Library string +} + +func decodeListFuncResult(t *testing.T, v interface{}) ListFuncResult { + switch res := v.(type) { + case []interface{}: + require.EqualValues(t, 4, len(res)) + require.EqualValues(t, "function_name", res[0]) + require.EqualValues(t, "from_library", res[2]) + return ListFuncResult{ + Name: res[1].(string), + Library: res[3].(string), + } + case map[interface{}]interface{}: + require.EqualValues(t, 2, len(res)) + return ListFuncResult{ + Name: res["function_name"].(string), + Library: res["from_library"].(string), + } + } + require.Fail(t, "unexpected type") + return ListFuncResult{} +} + +type ListLibResult struct { + Name string + Engine string + Functions []interface{} +} + +func decodeListLibResult(t *testing.T, v interface{}) ListLibResult { + switch res := v.(type) { + case []interface{}: + require.EqualValues(t, 6, len(res)) + require.EqualValues(t, "library_name", res[0]) + require.EqualValues(t, "engine", res[2]) + require.EqualValues(t, "functions", res[4]) + return ListLibResult{ + Name: res[1].(string), + Engine: res[3].(string), + Functions: res[5].([]interface{}), + } + case map[interface{}]interface{}: + require.EqualValues(t, 3, len(res)) + return ListLibResult{ + Name: res["library_name"].(string), + Engine: res["engine"].(string), + Functions: res["functions"].([]interface{}), + } + } + require.Fail(t, "unexpected type") + return ListLibResult{} +} + +func TestFunctionsWithRESP3(t *testing.T) { + testFunctions(t, "yes") +} + +func TestFunctionsWithoutRESP2(t *testing.T) { + testFunctions(t, "no") +} + +var testFunctions = func(t *testing.T, enabledRESP3 string) { + srv := util.StartServer(t, map[string]string{ + "resp3-enabled": enabledRESP3, + }) defer srv.Close() ctx := context.Background() @@ -65,17 +133,22 @@ func TestFunction(t *testing.T) { }) t.Run("FUNCTION LIST and FUNCTION LISTFUNC mylib1", func(t *testing.T) { - list := rdb.Do(ctx, "FUNCTION", "LIST", "WITHCODE").Val().([]interface{}) - require.Equal(t, list[1].(string), "mylib1") - require.Equal(t, list[3].(string), luaMylib1) - require.Equal(t, len(list), 4) - - list = rdb.Do(ctx, "FUNCTION", "LISTFUNC").Val().([]interface{}) - require.Equal(t, list[1].(string), "add") - require.Equal(t, list[3].(string), "mylib1") - require.Equal(t, list[5].(string), "inc") - require.Equal(t, list[7].(string), "mylib1") - require.Equal(t, len(list), 8) + libraries, err := rdb.FunctionList(ctx, redis.FunctionListQuery{ + WithCode: true, + }).Result() + require.NoError(t, err) + require.EqualValues(t, 1, len(libraries)) + require.Equal(t, "mylib1", libraries[0].Name) + require.Equal(t, luaMylib1, libraries[0].Code) + + list := rdb.Do(ctx, "FUNCTION", "LISTFUNC").Val().([]interface{}) + require.EqualValues(t, 2, len(list)) + f1 := decodeListFuncResult(t, list[0]) + require.Equal(t, "add", f1.Name) + require.Equal(t, "mylib1", f1.Library) + f2 := decodeListFuncResult(t, list[1]) + require.Equal(t, "inc", f2.Name) + require.Equal(t, "mylib1", f2.Library) }) t.Run("FUNCTION LOAD and FCALL mylib2", func(t *testing.T) { @@ -87,23 +160,25 @@ func TestFunction(t *testing.T) { }) t.Run("FUNCTION LIST and FUNCTION LISTFUNC mylib2", func(t *testing.T) { - list := rdb.Do(ctx, "FUNCTION", "LIST", "WITHCODE").Val().([]interface{}) - require.Equal(t, list[1].(string), "mylib1") - require.Equal(t, list[3].(string), luaMylib1) - require.Equal(t, list[5].(string), "mylib2") - require.Equal(t, list[7].(string), luaMylib2) - require.Equal(t, len(list), 8) - - list = rdb.Do(ctx, "FUNCTION", "LISTFUNC").Val().([]interface{}) - require.Equal(t, list[1].(string), "add") - require.Equal(t, list[3].(string), "mylib1") - require.Equal(t, list[5].(string), "hello") - require.Equal(t, list[7].(string), "mylib2") - require.Equal(t, list[9].(string), "inc") - require.Equal(t, list[11].(string), "mylib1") - require.Equal(t, list[13].(string), "reverse") - require.Equal(t, list[15].(string), "mylib2") - require.Equal(t, len(list), 16) + libraries, err := rdb.FunctionList(ctx, redis.FunctionListQuery{ + WithCode: true, + }).Result() + require.NoError(t, err) + require.EqualValues(t, 2, len(libraries)) + + list := rdb.Do(ctx, "FUNCTION", "LISTFUNC").Val().([]interface{}) + expected := []ListFuncResult{ + {Name: "add", Library: "mylib1"}, + {Name: "hello", Library: "mylib2"}, + {Name: "inc", Library: "mylib1"}, + {Name: "reverse", Library: "mylib2"}, + } + require.EqualValues(t, len(expected), len(list)) + for i, f := range expected { + actual := decodeListFuncResult(t, list[i]) + require.Equal(t, f.Name, actual.Name) + require.Equal(t, f.Library, actual.Library) + } }) t.Run("FUNCTION DELETE", func(t *testing.T) { @@ -113,17 +188,24 @@ func TestFunction(t *testing.T) { util.ErrorRegexp(t, rdb.Do(ctx, "FCALL", "reverse", 0, "x").Err(), ".*No such function name.*") require.Equal(t, rdb.Do(ctx, "FCALL", "inc", 0, 3).Val(), int64(4)) - list := rdb.Do(ctx, "FUNCTION", "LIST", "WITHCODE").Val().([]interface{}) - require.Equal(t, list[1].(string), "mylib1") - require.Equal(t, list[3].(string), luaMylib1) - require.Equal(t, len(list), 4) - - list = rdb.Do(ctx, "FUNCTION", "LISTFUNC").Val().([]interface{}) - require.Equal(t, list[1].(string), "add") - require.Equal(t, list[3].(string), "mylib1") - require.Equal(t, list[5].(string), "inc") - require.Equal(t, list[7].(string), "mylib1") - require.Equal(t, len(list), 8) + libraries, err := rdb.FunctionList(ctx, redis.FunctionListQuery{ + WithCode: true, + }).Result() + require.NoError(t, err) + require.EqualValues(t, 1, len(libraries)) + require.Equal(t, "mylib1", libraries[0].Name) + + list := rdb.Do(ctx, "FUNCTION", "LISTFUNC").Val().([]interface{}) + expected := []ListFuncResult{ + {Name: "add", Library: "mylib1"}, + {Name: "inc", Library: "mylib1"}, + } + require.EqualValues(t, len(expected), len(list)) + for i, f := range expected { + actual := decodeListFuncResult(t, list[i]) + require.Equal(t, f.Name, actual.Name) + require.Equal(t, f.Library, actual.Library) + } }) t.Run("FUNCTION LOAD REPLACE", func(t *testing.T) { @@ -135,17 +217,24 @@ func TestFunction(t *testing.T) { require.Equal(t, rdb.Do(ctx, "FCALL", "reverse", 0, "xyz").Val(), "zyx") util.ErrorRegexp(t, rdb.Do(ctx, "FCALL", "inc", 0, 1).Err(), ".*No such function name.*") - list := rdb.Do(ctx, "FUNCTION", "LIST", "WITHCODE").Val().([]interface{}) - require.Equal(t, list[1].(string), "mylib1") - require.Equal(t, list[3].(string), code) - require.Equal(t, len(list), 4) - - list = rdb.Do(ctx, "FUNCTION", "LISTFUNC").Val().([]interface{}) - require.Equal(t, list[1].(string), "hello") - require.Equal(t, list[3].(string), "mylib1") - require.Equal(t, list[5].(string), "reverse") - require.Equal(t, list[7].(string), "mylib1") - require.Equal(t, len(list), 8) + libraries, err := rdb.FunctionList(ctx, redis.FunctionListQuery{ + WithCode: true, + }).Result() + require.NoError(t, err) + require.EqualValues(t, 1, len(libraries)) + require.Equal(t, "mylib1", libraries[0].Name) + + list := rdb.Do(ctx, "FUNCTION", "LISTFUNC").Val().([]interface{}) + expected := []ListFuncResult{ + {Name: "hello", Library: "mylib1"}, + {Name: "reverse", Library: "mylib1"}, + } + require.EqualValues(t, len(expected), len(list)) + for i, f := range expected { + actual := decodeListFuncResult(t, list[i]) + require.Equal(t, f.Name, actual.Name) + require.Equal(t, f.Library, actual.Library) + } }) t.Run("FCALL_RO", func(t *testing.T) { @@ -167,19 +256,24 @@ func TestFunction(t *testing.T) { require.Equal(t, rdb.Do(ctx, "FCALL", "myget", 1, "x").Val(), "2") require.Equal(t, rdb.Do(ctx, "FCALL", "hello", 0, "xxx").Val(), "Hello, xxx!") - list := rdb.Do(ctx, "FUNCTION", "LIST").Val().([]interface{}) - require.Equal(t, list[1].(string), "mylib1") - require.Equal(t, list[3].(string), "mylib3") - require.Equal(t, len(list), 4) + libraries, err := rdb.FunctionList(ctx, redis.FunctionListQuery{ + WithCode: true, + }).Result() + require.NoError(t, err) + require.EqualValues(t, 2, len(libraries)) + require.Equal(t, libraries[0].Name, "mylib1") + require.Equal(t, libraries[1].Name, "mylib3") }) t.Run("FUNCTION LISTLIB", func(t *testing.T) { - list := rdb.Do(ctx, "FUNCTION", "LISTLIB", "mylib1").Val().([]interface{}) - require.Equal(t, list[1].(string), "mylib1") - require.Equal(t, list[5].([]interface{}), []interface{}{"hello", "reverse"}) - - list = rdb.Do(ctx, "FUNCTION", "LISTLIB", "mylib3").Val().([]interface{}) - require.Equal(t, list[1].(string), "mylib3") - require.Equal(t, list[5].([]interface{}), []interface{}{"myget", "myset"}) + r := rdb.Do(ctx, "FUNCTION", "LISTLIB", "mylib1").Val() + require.EqualValues(t, ListLibResult{ + Name: "mylib1", Engine: "lua", Functions: []interface{}{"hello", "reverse"}, + }, decodeListLibResult(t, r)) + + r = rdb.Do(ctx, "FUNCTION", "LISTLIB", "mylib3").Val() + require.EqualValues(t, ListLibResult{ + Name: "mylib3", Engine: "lua", Functions: []interface{}{"myget", "myset"}, + }, decodeListLibResult(t, r)) }) } diff --git a/tests/gocase/unit/type/hash/hash_test.go b/tests/gocase/unit/type/hash/hash_test.go index bf93d268860..e6c8bebac31 100644 --- a/tests/gocase/unit/type/hash/hash_test.go +++ b/tests/gocase/unit/type/hash/hash_test.go @@ -835,6 +835,30 @@ func TestHash(t *testing.T) { } } +func TestHGetAllWithRESP3(t *testing.T) { + srv := util.StartServer(t, map[string]string{ + "resp3-enabled": "yes", + }) + defer srv.Close() + + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + ctx := context.Background() + + testKey := "test-hash-1" + require.NoError(t, rdb.Del(ctx, testKey).Err()) + require.NoError(t, rdb.HSet(ctx, testKey, "key1", "value1", "key2", "value2", "key3", "value3").Err()) + result, err := rdb.HGetAll(ctx, testKey).Result() + require.NoError(t, err) + require.Len(t, result, 3) + require.EqualValues(t, map[string]string{ + "key1": "value1", + "key2": "value2", + "key3": "value3", + }, result) +} + func TestHashWithAsyncIOEnabled(t *testing.T) { srv := util.StartServer(t, map[string]string{ "rocksdb.read_options.async_io": "yes", diff --git a/tests/gocase/unit/type/stream/stream_test.go b/tests/gocase/unit/type/stream/stream_test.go index d3a1f8d273a..7dee10b6b3e 100644 --- a/tests/gocase/unit/type/stream/stream_test.go +++ b/tests/gocase/unit/type/stream/stream_test.go @@ -34,8 +34,18 @@ import ( "github.com/stretchr/testify/require" ) -func TestStream(t *testing.T) { - srv := util.StartServer(t, map[string]string{}) +func TestStreamWithRESP2(t *testing.T) { + streamTests(t, "no") +} + +func TestStreamWithRESP3(t *testing.T) { + streamTests(t, "yes") +} + +var streamTests = func(t *testing.T, enabledRESP3 string) { + srv := util.StartServer(t, map[string]string{ + "resp3-enabled": enabledRESP3, + }) defer srv.Close() ctx := context.Background() rdb := srv.NewClient() From 2dff54f5ba01558e92c978f20333dce096e70af7 Mon Sep 17 00:00:00 2001 From: git-hulk Date: Sat, 20 Jan 2024 23:23:16 +0800 Subject: [PATCH 3/4] Add DCHECK inside MapOfBulkStrings --- src/server/redis_connection.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/server/redis_connection.cc b/src/server/redis_connection.cc index dddfe53cdff..2e5a4a4492c 100644 --- a/src/server/redis_connection.cc +++ b/src/server/redis_connection.cc @@ -173,6 +173,8 @@ std::string Connection::SetOfBulkStrings(const std::vector &elems) } std::string Connection::MapOfBulkStrings(const std::vector &elems) const { + CHECK(elems.size() % 2 == 0); + std::string result; result += SizeOfMap(elems.size() / 2); for (const auto &elem : elems) { From ac79c724cdd99c2e516e68132c07c3d18848d673 Mon Sep 17 00:00:00 2001 From: git-hulk Date: Sun, 21 Jan 2024 15:26:16 +0800 Subject: [PATCH 4/4] Rename functions --- src/commands/cmd_server.cc | 6 +++--- src/commands/cmd_stream.cc | 8 ++++---- src/server/redis_connection.cc | 4 ++-- src/server/redis_connection.h | 4 ++-- src/storage/scripting.cc | 6 +++--- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/commands/cmd_server.cc b/src/commands/cmd_server.cc index 2f608cfec32..d34c55956c8 100644 --- a/src/commands/cmd_server.cc +++ b/src/commands/cmd_server.cc @@ -613,12 +613,12 @@ class CommandDebug : public Commander { *output += redis::Integer(i); } } else if (protocol_type_ == "set") { - *output = conn->SizeOfSet(3); + *output = conn->HeaderOfSet(3); for (int i = 0; i < 3; i++) { *output += redis::Integer(i); } } else if (protocol_type_ == "map") { - *output = conn->SizeOfMap(3); + *output = conn->HeaderOfMap(3); for (int i = 0; i < 3; i++) { *output += redis::Integer(i); *output += conn->Bool(i == 1); @@ -789,7 +789,7 @@ class CommandHello final : public Commander { } else { output_list.push_back(redis::BulkString("standalone")); } - *output = conn->SizeOfMap(output_list.size() / 2); + *output = conn->HeaderOfMap(output_list.size() / 2); for (const auto &item : output_list) { *output += item; } diff --git a/src/commands/cmd_stream.cc b/src/commands/cmd_stream.cc index db871b07976..7ba408859c1 100644 --- a/src/commands/cmd_stream.cc +++ b/src/commands/cmd_stream.cc @@ -445,9 +445,9 @@ class CommandXInfo : public Commander { } if (!full_) { - output->append(conn->SizeOfMap(7)); + output->append(conn->HeaderOfMap(7)); } else { - output->append(conn->SizeOfMap(6)); + output->append(conn->HeaderOfMap(6)); } output->append(redis::BulkString("length")); output->append(redis::Integer(info.size)); @@ -503,7 +503,7 @@ class CommandXInfo : public Commander { output->append(redis::MultiLen(result_vector.size())); for (auto const &it : result_vector) { - output->append(conn->SizeOfMap(6)); + output->append(conn->HeaderOfMap(6)); output->append(redis::BulkString("name")); output->append(redis::BulkString(it.first)); output->append(redis::BulkString("consumers")); @@ -545,7 +545,7 @@ class CommandXInfo : public Commander { output->append(redis::MultiLen(result_vector.size())); auto now = util::GetTimeStampMS(); for (auto const &it : result_vector) { - output->append(conn->SizeOfMap(4)); + output->append(conn->HeaderOfMap(4)); output->append(redis::BulkString("name")); output->append(redis::BulkString(it.first)); output->append(redis::BulkString("pending")); diff --git a/src/server/redis_connection.cc b/src/server/redis_connection.cc index 2e5a4a4492c..83e4a1980e5 100644 --- a/src/server/redis_connection.cc +++ b/src/server/redis_connection.cc @@ -165,7 +165,7 @@ std::string Connection::MultiBulkString(const std::vector &values, std::string Connection::SetOfBulkStrings(const std::vector &elems) const { std::string result; - result += SizeOfSet(elems.size()); + result += HeaderOfSet(elems.size()); for (const auto &elem : elems) { result += BulkString(elem); } @@ -176,7 +176,7 @@ std::string Connection::MapOfBulkStrings(const std::vector &elems) CHECK(elems.size() % 2 == 0); std::string result; - result += SizeOfMap(elems.size() / 2); + result += HeaderOfMap(elems.size() / 2); for (const auto &elem : elems) { result += BulkString(elem); } diff --git a/src/server/redis_connection.h b/src/server/redis_connection.h index 1ab076e48ec..7f622fe59de 100644 --- a/src/server/redis_connection.h +++ b/src/server/redis_connection.h @@ -71,12 +71,12 @@ class Connection : public EvbufCallbackBase { std::string MultiBulkString(const std::vector &values, const std::vector &statuses) const; template , int> = 0> - std::string SizeOfSet(T len) const { + std::string HeaderOfSet(T len) const { return protocol_version_ == RESP::v3 ? "~" + std::to_string(len) + CRLF : MultiLen(len); } std::string SetOfBulkStrings(const std::vector &elems) const; template , int> = 0> - std::string SizeOfMap(T len) const { + std::string HeaderOfMap(T len) const { return protocol_version_ == RESP::v3 ? "%" + std::to_string(len) + CRLF : MultiLen(len * 2); } std::string MapOfBulkStrings(const std::vector &elems) const; diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc index 9e6cedf4b2d..4d94cd4ac6f 100644 --- a/src/storage/scripting.cc +++ b/src/storage/scripting.cc @@ -448,7 +448,7 @@ Status FunctionList(Server *srv, const redis::Connection *conn, const std::strin output->append(redis::MultiLen(result.size())); for (const auto &[lib, code] : result) { - output->append(conn->SizeOfMap(with_code ? 2 : 1)); + output->append(conn->HeaderOfMap(with_code ? 2 : 1)); output->append(redis::BulkString("library_name")); output->append(redis::BulkString(lib)); if (with_code) { @@ -484,7 +484,7 @@ Status FunctionListFunc(Server *srv, const redis::Connection *conn, const std::s output->append(redis::MultiLen(result.size())); for (const auto &[func, lib] : result) { - output->append(conn->SizeOfMap(2)); + output->append(conn->HeaderOfMap(2)); output->append(redis::BulkString("function_name")); output->append(redis::BulkString(func)); output->append(redis::BulkString("from_library")); @@ -514,7 +514,7 @@ Status FunctionListLib(Server *srv, const redis::Connection *conn, const std::st return {Status::NotOK, "The library is not found or not loaded from storage"}; } - output->append(conn->SizeOfMap(3)); + output->append(conn->HeaderOfMap(3)); output->append(redis::BulkString("library_name")); output->append(redis::BulkString(libname)); output->append(redis::BulkString("engine"));