Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor eval commands #1313

Merged
merged 1 commit into from
Mar 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 20 additions & 28 deletions src/commands/cmd_script.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,44 +26,36 @@

namespace Redis {

class CommandEval : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override { return Status::OK(); }

Status Execute(Server *svr, Connection *conn, std::string *output) override {
return Lua::evalGenericCommand(conn, args_, false, output);
}
};

class CommandEvalSHA : public Commander {
template <bool evalsha, bool read_only>
class CommandEvalImpl : public Commander {
public:
Status Execute(Server *svr, Connection *conn, std::string *output) override {
if (args_[1].size() != 40) {
if (evalsha && args_[1].size() != 40) {
*output = Redis::Error(errNoMatchingScript);
return Status::OK();
}
return Lua::evalGenericCommand(conn, args_, true, output);
}
};

class CommandEvalRO : public Commander {
public:
Status Execute(Server *svr, Connection *conn, std::string *output) override {
return Lua::evalGenericCommand(conn, args_, false, output, true);
}
};

class CommandEvalSHARO : public Commander {
public:
Status Execute(Server *svr, Connection *conn, std::string *output) override {
if (args_[1].size() != 40) {
*output = Redis::Error(errNoMatchingScript);
return Status::OK();
int64_t numkeys = GET_OR_RET(ParseInt<int64_t>(args_[2], 10));
if (numkeys > int64_t(args_.size() - 3)) {
return {Status::NotOK, "Number of keys can't be greater than number of args"};
} else if (numkeys < -1) {
return {Status::NotOK, "Number of keys can't be negative"};
}
return Lua::evalGenericCommand(conn, args_, true, output, true);

return Lua::evalGenericCommand(
conn, args_[1], std::vector<std::string>(args_.begin() + 3, args_.begin() + 3 + numkeys),
std::vector<std::string>(args_.begin() + 3 + numkeys, args_.end()), evalsha, output, read_only);
}
};

class CommandEval : public CommandEvalImpl<false, false> {};

class CommandEvalSHA : public CommandEvalImpl<true, false> {};

class CommandEvalRO : public CommandEvalImpl<false, true> {};

class CommandEvalSHARO : public CommandEvalImpl<true, true> {};

class CommandScript : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override {
Expand Down
101 changes: 50 additions & 51 deletions src/storage/scripting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,33 +134,30 @@ void loadFuncs(lua_State *lua, bool read_only) {
* Note that when the error is in the C function we want to report the
* information about the caller, that's what makes sense from the point
* of view of the user debugging a script. */
{
const char *err_func =
"local dbg = debug\n"
"function __redis__err__handler(err)\n"
" local i = dbg.getinfo(2,'nSl')\n"
" if i and i.what == 'C' then\n"
" i = dbg.getinfo(3,'nSl')\n"
" end\n"
" if i then\n"
" return i.source .. ':' .. i.currentline .. ': ' .. err\n"
" else\n"
" return err\n"
" end\n"
"end\n";
luaL_loadbuffer(lua, err_func, strlen(err_func), "@err_handler_def");
lua_pcall(lua, 0, 0, 0);
}
{
const char *compare_func =
"function __redis__compare_helper(a,b)\n"
" if a == false then a = '' end\n"
" if b == false then b = '' end\n"
" return a<b\n"
"end\n";
luaL_loadbuffer(lua, compare_func, strlen(compare_func), "@cmp_func_def");
lua_pcall(lua, 0, 0, 0);
}
const char *err_func =
"local dbg = debug\n"
"function __redis__err__handler(err)\n"
" local i = dbg.getinfo(2,'nSl')\n"
" if i and i.what == 'C' then\n"
" i = dbg.getinfo(3,'nSl')\n"
" end\n"
" if i then\n"
" return i.source .. ':' .. i.currentline .. ': ' .. err\n"
" else\n"
" return err\n"
" end\n"
"end\n";
luaL_loadbuffer(lua, err_func, strlen(err_func), "@err_handler_def");
lua_pcall(lua, 0, 0, 0);

const char *compare_func =
"function __redis__compare_helper(a,b)\n"
" if a == false then a = '' end\n"
" if b == false then b = '' end\n"
" return a<b\n"
"end\n";
luaL_loadbuffer(lua, compare_func, strlen(compare_func), "@cmp_func_def");
lua_pcall(lua, 0, 0, 0);
}

int redisLogCommand(lua_State *lua) {
Expand Down Expand Up @@ -204,29 +201,22 @@ int redisLogCommand(lua_State *lua) {
return 0;
}

Status evalGenericCommand(Redis::Connection *conn, const std::vector<std::string> &args, bool evalsha,
std::string *output, bool read_only) {
Status evalGenericCommand(Redis::Connection *conn, const std::string &body_or_sha, const std::vector<std::string> &keys,
const std::vector<std::string> &argv, bool evalsha, std::string *output, bool read_only) {
Server *srv = conn->GetServer();

// Use the worker's private Lua VM when entering the read-only mode
lua_State *lua = read_only ? conn->Owner()->Lua() : srv->Lua();

int64_t numkeys = GET_OR_RET(ParseInt<int64_t>(args[2], 10));
if (numkeys > int64_t(args.size() - 3)) {
return {Status::NotOK, "Number of keys can't be greater than number of args"};
} else if (numkeys < -1) {
return {Status::NotOK, "Number of keys can't be negative"};
}

/* We obtain the script SHA1, then check if this function is already
* defined into the Lua state */
char funcname[2 + 40 + 1] = REDIS_LUA_FUNC_SHA_PREFIX;

if (!evalsha) {
SHA1Hex(funcname + 2, args[1].c_str(), args[1].size());
SHA1Hex(funcname + 2, body_or_sha.c_str(), body_or_sha.size());
} else {
for (int j = 0; j < 40; j++) {
funcname[j + 2] = static_cast<char>(tolower(args[1][j]));
funcname[j + 2] = static_cast<char>(tolower(body_or_sha[j]));
}
}

Expand All @@ -245,10 +235,10 @@ Status evalGenericCommand(Redis::Connection *conn, const std::vector<std::string
return {Status::NotOK, "NOSCRIPT No matching script. Please use EVAL"};
}
} else {
body = args[1];
body = body_or_sha;
}

std::string sha;
std::string sha = funcname + 2;
auto s = createFunction(srv, body, &sha, lua, false);
if (!s.IsOK()) {
lua_pop(lua, 1); /* remove the error handler from the stack. */
Expand All @@ -260,8 +250,8 @@ Status evalGenericCommand(Redis::Connection *conn, const std::vector<std::string

/* Populate the argv and keys table accordingly to the arguments that
* EVAL received. */
setGlobalArray(lua, "KEYS", std::vector<std::string>(args.begin() + 3, args.begin() + 3 + numkeys));
setGlobalArray(lua, "ARGV", std::vector<std::string>(args.begin() + 3 + numkeys, args.end()));
setGlobalArray(lua, "KEYS", keys);
setGlobalArray(lua, "ARGV", argv);

if (lua_pcall(lua, 0, 1, -2)) {
auto msg = fmt::format("ERR running script (call to {}): {}", funcname, lua_tostring(lua, -1));
Expand All @@ -272,22 +262,27 @@ Status evalGenericCommand(Redis::Connection *conn, const std::vector<std::string
lua_pop(lua, 1);
}

// clean global variables to prevent information leak in function commands
lua_pushnil(lua);
lua_setglobal(lua, "KEYS");
lua_pushnil(lua);
lua_setglobal(lua, "ARGV");

/* Call the Lua garbage collector from time to time to avoid a
* full cycle performed by Lua, which adds too latency.
*
* The call is performed every LUA_GC_CYCLE_PERIOD executed commands
* (and for LUA_GC_CYCLE_PERIOD collection steps) because calling it
* for every command uses too much CPU. */
constexpr int64_t LUA_GC_CYCLE_PERIOD = 50;
{
static int64_t gc_count = 0;
static int64_t gc_count = 0;

gc_count++;
if (gc_count == LUA_GC_CYCLE_PERIOD) {
lua_gc(lua, LUA_GCSTEP, LUA_GC_CYCLE_PERIOD);
gc_count = 0;
}
gc_count++;
if (gc_count == LUA_GC_CYCLE_PERIOD) {
lua_gc(lua, LUA_GCSTEP, LUA_GC_CYCLE_PERIOD);
gc_count = 0;
}

return Status::OK();
}

Expand Down Expand Up @@ -875,8 +870,12 @@ int redisMathRandomSeed(lua_State *L) {
Status createFunction(Server *srv, const std::string &body, std::string *sha, lua_State *lua, bool need_to_store) {
char funcname[2 + 40 + 1] = REDIS_LUA_FUNC_SHA_PREFIX;

SHA1Hex(funcname + 2, body.c_str(), body.size());
*sha = funcname + 2;
if (sha->empty()) {
SHA1Hex(funcname + 2, body.c_str(), body.size());
*sha = funcname + 2;
} else {
std::copy(sha->begin(), sha->end(), funcname + 2);
}

auto funcdef = fmt::format("function {}() {}\nend", funcname, body);

Expand Down
5 changes: 3 additions & 2 deletions src/storage/scripting.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ int redisLogCommand(lua_State *lua);

Status createFunction(Server *srv, const std::string &body, std::string *sha, lua_State *lua, bool need_to_store);

Status evalGenericCommand(Redis::Connection *conn, const std::vector<std::string> &args, bool evalsha,
std::string *output, bool read_only = false);
Status evalGenericCommand(Redis::Connection *conn, const std::string &body_or_sha, const std::vector<std::string> &keys,
const std::vector<std::string> &argv, bool evalsha, std::string *output,
bool read_only = false);

const char *redisProtocolToLuaType(lua_State *lua, const char *reply);
const char *redisProtocolToLuaType_Int(lua_State *lua, const char *reply);
Expand Down