Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of RESP3 in Lua #2119

Merged
merged 6 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 151 additions & 9 deletions src/storage/scripting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,12 @@ const char *RedisProtocolToLuaType(lua_State *lua, const char *reply) {
case ',':
p = RedisProtocolToLuaTypeDouble(lua, reply);
break;
case '(':
p = RedisProtocolToLuaTypeBigNumber(lua, reply);
break;
case '=':
p = RedisProtocolToLuaTypeVerbatimString(lua, reply);
break;
}
return p;
}
Expand Down Expand Up @@ -1009,13 +1015,36 @@ const char *RedisProtocolToLuaTypeAggregate(lua_State *lua, const char *reply, i
lua_pushboolean(lua, 0);
return p;
}
lua_newtable(lua);
for (j = 0; j < mbulklen; j++) {
lua_pushnumber(lua, j + 1);
p = RedisProtocolToLuaType(lua, p);
if (atype == '*') {
lua_newtable(lua);
for (j = 0; j < mbulklen; j++) {
lua_pushnumber(lua, j + 1);
p = RedisProtocolToLuaType(lua, p);
lua_settable(lua, -3);
}
return p;
}

CHECK(atype == '%' || atype == '~');
if (atype == '%' || atype == '~') {
lua_newtable(lua);
lua_pushstring(lua, atype == '%' ? "map" : "set");
lua_newtable(lua);
for (j = 0; j < mbulklen; j++) {
p = RedisProtocolToLuaType(lua, p);
if (atype == '%') { // map
p = RedisProtocolToLuaType(lua, p);
} else { // set
lua_pushboolean(lua, 1);
}
lua_settable(lua, -3);
}
lua_settable(lua, -3);
return p;
}
return p;

// Unreachable, return the original position if it did reach here.
return reply;
}

const char *RedisProtocolToLuaTypeNull(lua_State *lua, const char *reply) {
Expand Down Expand Up @@ -1051,6 +1080,36 @@ const char *RedisProtocolToLuaTypeDouble(lua_State *lua, const char *reply) {
return p + 2;
}

const char *RedisProtocolToLuaTypeBigNumber(lua_State *lua, const char *reply) {
const char *p = strchr(reply + 1, '\r');
lua_newtable(lua);
lua_pushstring(lua, "big_number");
lua_pushlstring(lua, reply + 1, p - reply - 1);
lua_settable(lua, -3);
return p + 2;
}

const char *RedisProtocolToLuaTypeVerbatimString(lua_State *lua, const char *reply) {
const char *p = strchr(reply + 1, '\r');
int64_t bulklen = ParseInt<int64_t>(std::string(reply + 1, p - reply - 1), 10).ValueOr(0);
p += 2; // skip \r\n

lua_newtable(lua);
lua_pushstring(lua, "verbatim_string");

lua_newtable(lua);
lua_pushstring(lua, "string");
lua_pushlstring(lua, p + 4, bulklen - 4);
lua_settable(lua, -3);

lua_pushstring(lua, "format");
lua_pushlstring(lua, p, 3);
lua_settable(lua, -3);

lua_settable(lua, -3);
return p + bulklen + 2;
}

/* This function is used in order to push an error on the Lua stack in the
* format used by redis.pcall to return errors, which is a lua table
* with a single "err" field set to the error string. Note that this
Expand Down Expand Up @@ -1094,7 +1153,7 @@ std::string ReplyToRedisReply(redis::Connection *conn, lua_State *lua) {

/* Handle error reply. */
lua_pushstring(lua, "err");
lua_gettable(lua, -2);
lua_rawget(lua, -2);
t = lua_type(lua, -1);
if (t == LUA_TSTRING) {
output = redis::Error(lua_tostring(lua, -1));
Expand All @@ -1105,7 +1164,7 @@ std::string ReplyToRedisReply(redis::Connection *conn, lua_State *lua) {

/* Handle status reply. */
lua_pushstring(lua, "ok");
lua_gettable(lua, -2);
lua_rawget(lua, -2);
t = lua_type(lua, -1);
if (t == LUA_TSTRING) {
obj_s = lua_tolstring(lua, -1, &obj_len);
Expand All @@ -1115,9 +1174,20 @@ std::string ReplyToRedisReply(redis::Connection *conn, lua_State *lua) {
}
lua_pop(lua, 1); /* Discard the 'ok' field value we pushed */

/* Handle double reply. */
lua_pushstring(lua, "double");
lua_rawget(lua, -2);
t = lua_type(lua, -1);
if (t == LUA_TNUMBER) {
output = conn->Double(lua_tonumber(lua, -1));
lua_pop(lua, 1);
return output;
}
lua_pop(lua, 1); /* Discard the 'double' field value we pushed */

/* Handle big number reply. */
lua_pushstring(lua, "big_number");
lua_gettable(lua, -2);
lua_rawget(lua, -2);
t = lua_type(lua, -1);
if (t == LUA_TSTRING) {
obj_s = lua_tolstring(lua, -1, &obj_len);
Expand All @@ -1127,10 +1197,82 @@ std::string ReplyToRedisReply(redis::Connection *conn, lua_State *lua) {
}
lua_pop(lua, 1); /* Discard the 'big_number' field value we pushed */

/* Handle verbatim reply. */
lua_pushstring(lua, "verbatim_string");
lua_rawget(lua, -2);
t = lua_type(lua, -1);
if (t == LUA_TTABLE) {
lua_pushstring(lua, "format");
lua_rawget(lua, -2);
t = lua_type(lua, -1);
if (t == LUA_TSTRING) {
const char *format = lua_tostring(lua, -1);
lua_pushstring(lua, "string");
lua_rawget(lua, -3);
t = lua_type(lua, -1);
if (t == LUA_TSTRING) {
obj_s = lua_tolstring(lua, -1, &obj_len);
output = conn->VerbatimString(std::string(format), std::string(obj_s, obj_len));
lua_pop(lua, 4);
return output;
}
// discard 'string'
lua_pop(lua, 1);
}
// discard 'format'
lua_pop(lua, 1);
}
lua_pop(lua, 1); /* Discard the 'verbatim_string' field value we pushed */

/* Handle map reply. */
lua_pushstring(lua, "map");
lua_rawget(lua, -2);
t = lua_type(lua, -1);
if (t == LUA_TTABLE) {
int map_len = 0;
std::string map_output;
lua_pushnil(lua);
while (lua_next(lua, -2)) {
lua_pushvalue(lua, -2);
// return key
map_output += ReplyToRedisReply(conn, lua);
lua_pop(lua, 1);
// return value
map_output += ReplyToRedisReply(conn, lua);
lua_pop(lua, 1);
map_len++;
}
output = conn->HeaderOfMap(map_len) + std::move(map_output);
lua_pop(lua, 1);
return output;
}
lua_pop(lua, 1); /* Discard the 'map' field value we pushed */

/* Handle set reply. */
lua_pushstring(lua, "set");
lua_rawget(lua, -2);
t = lua_type(lua, -1);
if (t == LUA_TTABLE) {
int set_len = 0;
std::string set_output;
lua_pushnil(lua);
while (lua_next(lua, -2)) {
lua_pop(lua, 1);
lua_pushvalue(lua, -1);
set_output += ReplyToRedisReply(conn, lua);
lua_pop(lua, 1);
set_len++;
}
output = conn->HeaderOfSet(set_len) + std::move(set_output);
lua_pop(lua, 1);
return output;
}
lua_pop(lua, 1); /* Discard the 'set' field value we pushed */

j = 1, mbulklen = 0;
while (true) {
lua_pushnumber(lua, j++);
lua_gettable(lua, -2);
lua_rawget(lua, -2);
t = lua_type(lua, -1);
if (t == LUA_TNIL) {
lua_pop(lua, 1);
Expand Down
2 changes: 2 additions & 0 deletions src/storage/scripting.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ const char *RedisProtocolToLuaTypeAggregate(lua_State *lua, const char *reply, i
const char *RedisProtocolToLuaTypeNull(lua_State *lua, const char *reply);
const char *RedisProtocolToLuaTypeBool(lua_State *lua, const char *reply, int tf);
const char *RedisProtocolToLuaTypeDouble(lua_State *lua, const char *reply);
const char *RedisProtocolToLuaTypeBigNumber(lua_State *lua, const char *reply);
const char *RedisProtocolToLuaTypeVerbatimString(lua_State *lua, const char *reply);

std::string ReplyToRedisReply(redis::Connection *conn, lua_State *lua);

Expand Down
64 changes: 64 additions & 0 deletions tests/gocase/unit/scripting/scripting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ package scripting
import (
"context"
"fmt"
"math/big"
"testing"

"github.com/apache/kvrocks/tests/gocase/util"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
)

func TestScripting(t *testing.T) {
Expand Down Expand Up @@ -512,3 +514,65 @@ func TestScriptingMasterSlave(t *testing.T) {
require.Equal(t, []bool{false}, slaveClient.ScriptExists(ctx, sha).Val())
})
}

func TestScriptingWithRESP3(t *testing.T) {
srv := util.StartServer(t, map[string]string{
"resp3-enabled": "yes",
})
defer srv.Close()

rdb := srv.NewClient()
defer func() {
require.NoError(t, rdb.Close())
}()

ctx := context.Background()
t.Run("EVAL - Redis protocol type map conversion", func(t *testing.T) {
rdb.HSet(ctx, "myhash", "f1", "v1")
rdb.HSet(ctx, "myhash", "f2", "v2")
val, err := rdb.Eval(ctx, `return redis.call('hgetall', KEYS[1])`, []string{"myhash"}).Result()
require.NoError(t, err)
require.Equal(t, map[interface{}]interface{}{"f1": "v1", "f2": "v2"}, val)
})

t.Run("EVAL - Redis protocol type set conversion", func(t *testing.T) {
require.NoError(t, rdb.SAdd(ctx, "myset", "m0", "m1", "m2").Err())
val, err := rdb.Eval(ctx, `return redis.call('smembers', KEYS[1])`, []string{"myset"}).StringSlice()
require.NoError(t, err)
slices.Sort(val)
require.EqualValues(t, []string{"m0", "m1", "m2"}, val)
})

t.Run("EVAL - Redis protocol type double conversion", func(t *testing.T) {
require.NoError(t, rdb.ZAdd(ctx, "mydouble", redis.Z{Member: "z0", Score: 1.5}).Err())
val, err := rdb.Eval(ctx, `return redis.call('zscore', KEYS[1], KEYS[2])`, []string{"mydouble", "z0"}).Result()
require.NoError(t, err)
require.EqualValues(t, 1.5, val)
})

t.Run("EVAL - Redis protocol type bignumber conversion", func(t *testing.T) {
val, err := rdb.Eval(ctx, `return redis.call('debug', 'protocol', 'bignum')`, []string{}).Result()
require.NoError(t, err)

bignum, _ := big.NewInt(0).SetString("1234567999999999999999999999999999999", 10)
require.EqualValues(t, bignum, val)
})

t.Run("EVAL - Redis protocol type boolean conversion", func(t *testing.T) {
val, err := rdb.Eval(ctx, `return redis.call('debug', 'protocol', 'true')`, []string{}).Result()
require.NoError(t, err)
require.EqualValues(t, true, val)

val, err = rdb.Eval(ctx, `return redis.call('debug', 'protocol', 'false')`, []string{}).Result()
require.NoError(t, err)
require.EqualValues(t, false, val)
})

t.Run("EVAL - Redis protocol type verbatim conversion", func(t *testing.T) {
val, err := rdb.Eval(ctx, `return redis.call('debug', 'protocol', 'verbatim')`, []string{}).Result()
require.NoError(t, err)

require.EqualValues(t, "verbatim string", val)
})

}
Loading