Skip to content

Commit

Permalink
Improve code style and do some minor fix in scripting (#1312)
Browse files Browse the repository at this point in the history
  • Loading branch information
PragmaTwice authored Mar 10, 2023
1 parent c0ba641 commit 1ab1ae8
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 75 deletions.
2 changes: 1 addition & 1 deletion src/commands/cmd_script.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class CommandScript : public Commander {
}
} else if (args_.size() == 3 && subcommand_ == "load") {
std::string sha;
auto s = Lua::createFunction(svr, args_[2], &sha, svr->Lua());
auto s = Lua::createFunction(svr, args_[2], &sha, svr->Lua(), true);
if (!s.IsOK()) {
return s;
}
Expand Down
11 changes: 8 additions & 3 deletions src/server/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "commands/commander.h"
#include "config.h"
#include "fmt/format.h"
#include "lua.h"
#include "redis_connection.h"
#include "redis_request.h"
#include "storage/compaction_checker.h"
Expand Down Expand Up @@ -1498,6 +1499,12 @@ Status Server::LookupAndCreateCommand(const std::string &cmd_name, std::unique_p
}

Status Server::ScriptExists(const std::string &sha) {
lua_getglobal(lua_, (REDIS_LUA_FUNC_SHA_PREFIX + sha).c_str());
if (!lua_isnil(lua_, -1)) {
return Status::OK();
}
lua_pop(lua_, 1);

std::string body;
return ScriptGet(sha, &body);
}
Expand All @@ -1507,9 +1514,7 @@ Status Server::ScriptGet(const std::string &sha, std::string *body) {
auto cf = storage_->GetCFHandle(Engine::kPropagateColumnFamilyName);
auto s = storage_->Get(rocksdb::ReadOptions(), cf, func_name, body);
if (!s.ok()) {
if (s.IsNotFound()) return {Status::NotFound};

return {Status::NotOK, s.ToString()};
return {s.IsNotFound() ? Status::NotFound : Status::NotOK, s.ToString()};
}
return Status::OK();
}
Expand Down
132 changes: 63 additions & 69 deletions src/storage/scripting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include <math.h>

#include <cctype>
#include <string>

#include "commands/commander.h"
Expand Down Expand Up @@ -109,12 +110,10 @@ void loadFuncs(lua_State *lua, bool read_only) {
lua_pushcfunction(lua, redisStatusReplyCommand);
lua_settable(lua, -3);

if (read_only) {
/* redis.read_only */
lua_pushstring(lua, "read_only");
lua_pushboolean(lua, 1);
lua_settable(lua, -3);
}
/* redis.read_only */
lua_pushstring(lua, "read_only");
lua_pushboolean(lua, read_only);
lua_settable(lua, -3);

lua_setglobal(lua, "redis");

Expand Down Expand Up @@ -165,7 +164,7 @@ void loadFuncs(lua_State *lua, bool read_only) {
}

int redisLogCommand(lua_State *lua) {
int j = 0, level = 0, argc = lua_gettop(lua);
int argc = lua_gettop(lua);

if (argc < 2) {
lua_pushstring(lua, "redis.log() requires two arguments or more.");
Expand All @@ -175,26 +174,20 @@ int redisLogCommand(lua_State *lua) {
lua_pushstring(lua, "First argument must be a number (log level).");
return lua_error(lua);
}
level = static_cast<int>(lua_tonumber(lua, -argc));
int level = static_cast<int>(lua_tonumber(lua, -argc));
if (level < LL_DEBUG || level > LL_WARNING) {
lua_pushstring(lua, "Invalid debug level.");
return lua_error(lua);
}
if (level < GetServer()->GetConfig()->log_level) {
return 0;
}

std::string log_message;
for (j = 1; j < argc; j++) {
for (int j = 1; j < argc; j++) {
size_t len = 0;
const char *s = nullptr;
s = lua_tolstring(lua, (-argc) + j, &len);
if (s) {
if (const char *s = lua_tolstring(lua, j - argc, &len)) {
if (j != 1) {
log_message += " " + std::string(s, len);
} else {
log_message = std::string(s, len);
log_message += " ";
}
log_message += std::string(s, len);
}
}

Expand All @@ -213,16 +206,12 @@ int redisLogCommand(lua_State *lua) {

Status evalGenericCommand(Redis::Connection *conn, const std::vector<std::string> &args, bool evalsha,
std::string *output, bool read_only) {
int64_t numkeys = 0;
char funcname[43];
Server *srv = conn->GetServer();
lua_State *lua = srv->Lua();
if (read_only) {
// Use the worker's private Lua VM when entering the read-only mode
lua = conn->Owner()->Lua();
}

numkeys = GET_OR_RET(ParseInt<int64_t>(args[2], 10));
// Use the worker's private Lua VM when entering the read-only mode
lua_State *lua = read_only ? conn->Owner()->Lua() : srv->Lua();

int64_t numkeys = GET_OR_RET(ParseInt<int64_t>(args[2], 10));
if (numkeys > int64_t(args.size() - 3)) {
return {Status::NotOK, "Number of keys can't be greater than number of args"};
} else if (numkeys < -1) {
Expand All @@ -231,16 +220,14 @@ Status evalGenericCommand(Redis::Connection *conn, const std::vector<std::string

/* We obtain the script SHA1, then check if this function is already
* defined into the Lua state */
funcname[0] = 'f';
funcname[1] = '_';
char funcname[2 + 40 + 1] = REDIS_LUA_FUNC_SHA_PREFIX;

if (!evalsha) {
SHA1Hex(funcname + 2, args[1].c_str(), args[1].size());
} else {
for (int j = 0; j < 40; j++) {
std::string sha = args[1];
funcname[j + 2] = (sha[j] >= 'A' && sha[j] <= 'Z') ? static_cast<char>(sha[j] + 'a' - 'A') : sha[j];
funcname[j + 2] = static_cast<char>(tolower(args[1][j]));
}
funcname[42] = '\0';
}

/* Push the pcall error handler function on the stack. */
Expand All @@ -260,8 +247,9 @@ Status evalGenericCommand(Redis::Connection *conn, const std::vector<std::string
} else {
body = args[1];
}

std::string sha;
auto s = createFunction(srv, body, &sha, lua);
auto s = createFunction(srv, body, &sha, lua, false);
if (!s.IsOK()) {
lua_pop(lua, 1); /* remove the error handler from the stack. */
return s;
Expand All @@ -274,9 +262,9 @@ Status evalGenericCommand(Redis::Connection *conn, const std::vector<std::string
* EVAL received. */
setGlobalArray(lua, "KEYS", std::vector<std::string>(args.begin() + 3, args.begin() + 3 + numkeys));
setGlobalArray(lua, "ARGV", std::vector<std::string>(args.begin() + 3 + numkeys, args.end()));
int err = lua_pcall(lua, 0, 1, -2);
if (err) {
std::string msg = std::string("ERR running script (call to ") + funcname + "): " + lua_tostring(lua, -1);

if (lua_pcall(lua, 0, 1, -2)) {
auto msg = fmt::format("ERR running script (call to {}): {}", funcname, lua_tostring(lua, -1));
*output = Redis::Error(msg);
lua_pop(lua, 2);
} else {
Expand Down Expand Up @@ -306,48 +294,54 @@ Status evalGenericCommand(Redis::Connection *conn, const std::vector<std::string
int redisCallCommand(lua_State *lua) { return redisGenericCommand(lua, 1); }

int redisPCallCommand(lua_State *lua) { return redisGenericCommand(lua, 0); }

// TODO: we do not want to repeat same logic as Connection::ExecuteCommands,
// so the function need to be refactored
int redisGenericCommand(lua_State *lua, int raise_error) {
int j = 0, argc = lua_gettop(lua);
std::vector<std::string> args;
lua_getglobal(lua, "redis");
lua_getfield(lua, -1, "read_only");
int read_only = lua_toboolean(lua, -1);
lua_pop(lua, 2);

int argc = lua_gettop(lua);
if (argc == 0) {
pushError(lua, "Please specify at least one argument for redis.call()");
return raise_error ? raiseError(lua) : 1;
}
for (j = 0; j < argc; j++) {
if (lua_type(lua, j + 1) == LUA_TNUMBER) {
lua_Number num = lua_tonumber(lua, j + 1);

std::vector<std::string> args;
for (int j = 1; j <= argc; j++) {
if (lua_type(lua, j) == LUA_TNUMBER) {
lua_Number num = lua_tonumber(lua, j);
args.emplace_back(fmt::format("{:.17g}", static_cast<double>(num)));
} else {
size_t obj_len = 0;
const char *obj_s = lua_tolstring(lua, j + 1, &obj_len);
if (obj_s == nullptr) break; /* no a string */
const char *obj_s = lua_tolstring(lua, j, &obj_len);
if (obj_s == nullptr) {
pushError(lua, "Lua redis() command arguments must be strings or integers");
return raise_error ? raiseError(lua) : 1;
}
args.emplace_back(obj_s, obj_len);
}
}
if (j != argc) {
pushError(lua, "Lua redis() command arguments must be strings or integers");
return raise_error ? raiseError(lua) : 1;
}

auto commands = Redis::GetCommands();
auto cmd_iter = commands->find(Util::ToLower(args[0]));
if (cmd_iter == commands->end()) {
pushError(lua, "Unknown Redis command called from Lua script");
return raise_error ? raiseError(lua) : 1;
}

auto redisCmd = cmd_iter->second;
if (read_only && redisCmd->is_write()) {
if (read_only && !(redisCmd->flags & Redis::kCmdReadOnly)) {
pushError(lua, "Write commands are not allowed from read-only scripts");
return raise_error ? raiseError(lua) : 1;
}

auto cmd = redisCmd->factory();
cmd->SetAttributes(redisCmd);
cmd->SetArgs(args);

int arity = cmd->GetAttributes()->arity;
if (((arity > 0 && argc != arity) || (arity < 0 && argc < -arity))) {
pushError(lua, "Wrong number of args calling Redis command From Lua script");
Expand All @@ -359,9 +353,10 @@ int redisGenericCommand(lua_State *lua, int raise_error) {
return raise_error ? raiseError(lua) : 1;
}

std::string output, cmd_name = Util::ToLower(args[0]);
std::string cmd_name = Util::ToLower(args[0]);
Server *srv = GetServer();
Config *config = srv->GetConfig();

Redis::Connection *conn = srv->GetCurrentConnection();
if (config->cluster_enabled) {
auto s = srv->cluster_->CanExecByMySelf(attributes, args, conn);
Expand All @@ -370,36 +365,42 @@ int redisGenericCommand(lua_State *lua, int raise_error) {
return raise_error ? raiseError(lua) : 1;
}
}

if (config->slave_readonly && srv->IsSlave() && attributes->is_write()) {
pushError(lua, "READONLY You can't write against a read only slave.");
return raise_error ? raiseError(lua) : 1;
}

if (!config->slave_serve_stale_data && srv->IsSlave() && cmd_name != "info" && cmd_name != "slaveof" &&
srv->GetReplicationState() != kReplConnected) {
pushError(lua,
"MASTERDOWN Link with MASTER is down "
"and slave-serve-stale-data is set to 'no'.");
return raise_error ? raiseError(lua) : 1;
}

auto s = cmd->Parse(args);
if (!s.IsOK()) {
if (!s) {
pushError(lua, s.Msg().data());
return raise_error ? raiseError(lua) : 1;
}

srv->stats_.IncrCalls(cmd_name);
auto start = std::chrono::high_resolution_clock::now();
bool is_profiling = conn->isProfilingEnabled(cmd_name);
std::string output;
s = cmd->Execute(GetServer(), srv->GetCurrentConnection(), &output);
auto end = std::chrono::high_resolution_clock::now();
uint64_t duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
if (is_profiling) conn->recordProfilingSampleIfNeed(cmd_name, duration);
srv->SlowlogPushEntryIfNeeded(&args, duration);
srv->stats_.IncrLatency(static_cast<uint64_t>(duration), cmd_name);
srv->FeedMonitorConns(conn, args);
if (!s.IsOK()) {
if (!s) {
pushError(lua, s.Msg().data());
return raise_error ? raiseError(lua) : 1;
}

redisProtocolToLuaType(lua, output.data());
return 1;
}
Expand Down Expand Up @@ -443,6 +444,7 @@ void loadLibraries(lua_State *lua) {
lua_pushstring(lua, libname);
lua_call(lua, 1, 0);
};

loadLib(lua, "", luaopen_base);
loadLib(lua, LUA_TABLIBNAME, luaopen_table);
loadLib(lua, LUA_STRLIBNAME, luaopen_string);
Expand Down Expand Up @@ -484,16 +486,16 @@ int redisStatusReplyCommand(lua_State *lua) { return redisReturnSingleFieldTable
* function used for sha1ing lua scripts. */
int redisSha1hexCommand(lua_State *lua) {
int argc = lua_gettop(lua);
char digest[41];
size_t len = 0;
const char *s = nullptr;

if (argc != 1) {
lua_pushstring(lua, "wrong number of arguments");
return lua_error(lua);
}

s = static_cast<const char *>(lua_tolstring(lua, 1, &len));
size_t len = 0;
const char *s = static_cast<const char *>(lua_tolstring(lua, 1, &len));

char digest[41];
SHA1Hex(digest, s, len);
lua_pushstring(lua, digest);
return 1;
Expand All @@ -513,13 +515,12 @@ void SHA1Hex(char *digest, const char *script, size_t len) {
SHA1_CTX ctx;
unsigned char hash[20];
const char *cset = "0123456789abcdef";
int j = 0;

SHA1Init(&ctx);
SHA1Update(&ctx, (const unsigned char *)script, len);
SHA1Final(hash, &ctx);

for (j = 0; j < 20; j++) {
for (int j = 0; j < 20; j++) {
digest[j * 2] = cset[((hash[j] & 0xF0) >> 4)];
digest[j * 2 + 1] = cset[(hash[j] & 0xF)];
}
Expand Down Expand Up @@ -871,33 +872,26 @@ int redisMathRandomSeed(lua_State *L) {
*
* If 'c' is not NULL, on error the client is informed with an appropriate
* error describing the nature of the problem and the Lua interpreter error. */
Status createFunction(Server *srv, const std::string &body, std::string *sha, lua_State *lua) {
char funcname[43];
Status createFunction(Server *srv, const std::string &body, std::string *sha, lua_State *lua, bool need_to_store) {
char funcname[2 + 40 + 1] = REDIS_LUA_FUNC_SHA_PREFIX;

funcname[0] = 'f';
funcname[1] = '_';
SHA1Hex(funcname + 2, body.c_str(), body.size());
*sha = funcname + 2;

std::string funcdef;
funcdef += "function ";
funcdef += funcname;
funcdef += "() ";
funcdef += body;
funcdef += "\nend";
auto funcdef = fmt::format("function {}() {}\nend", funcname, body);

if (luaL_loadbuffer(lua, funcdef.c_str(), funcdef.size(), "@user_script")) {
std::string errMsg = lua_tostring(lua, -1);
lua_pop(lua, 1);
return Status(Status::NotOK, "Error compiling script (new function): " + errMsg + "\n");
return {Status::NotOK, "Error compiling script (new function): " + errMsg + "\n"};
}
if (lua_pcall(lua, 0, 0, 0)) {
std::string errMsg = lua_tostring(lua, -1);
lua_pop(lua, 1);
return Status(Status::NotOK, "Error running script (new function): " + errMsg + "\n");
return {Status::NotOK, "Error running script (new function): " + errMsg + "\n"};
}
// would store lua function into propagate column family and propagate those scripts to slaves
return srv->ScriptSet(*sha, body);
return need_to_store ? srv->ScriptSet(*sha, body) : Status::OK();
}

} // namespace Lua
Loading

0 comments on commit 1ab1ae8

Please sign in to comment.