Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the command hello #881

Merged
merged 10 commits into from
Sep 18, 2022
118 changes: 102 additions & 16 deletions src/redis_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include "scripting.h"
#include "slot_import.h"
#include "slot_migrate.h"
#include "parse_util.h"

namespace Redis {

Expand All @@ -73,29 +74,47 @@ const char *errUnbalacedStreamList =
const char *errTimeoutIsNegative = "timeout is negative";
const char *errLimitOptionNotAllowed = "syntax error, LIMIT cannot be used without the special ~ option";

enum class AuthResult {
OK,
INVALID_PASSWORD,
NO_REQUIRE_PASS,
};

AuthResult AuthenticateUser(Connection *conn, Config* config, const std::string& user_password) {
auto iter = config->tokens.find(user_password);
if (iter != config->tokens.end()) {
conn->SetNamespace(iter->second);
conn->BecomeUser();
return AuthResult::OK;
}
const auto& requirepass = config->requirepass;
if (!requirepass.empty() && user_password != requirepass) {
return AuthResult::INVALID_PASSWORD;
}
conn->SetNamespace(kDefaultNamespace);
conn->BecomeAdmin();
if (requirepass.empty()) {
return AuthResult::NO_REQUIRE_PASS;
}
return AuthResult::OK;
}

enjoy-binbin marked this conversation as resolved.
Show resolved Hide resolved
class CommandAuth : public Commander {
public:
Status Execute(Server *svr, Connection *conn, std::string *output) override {
Config *config = svr->GetConfig();
auto user_password = args_[1];
auto iter = config->tokens.find(user_password);
if (iter != config->tokens.end()) {
conn->SetNamespace(iter->second);
conn->BecomeUser();
auto& user_password = args_[1];
AuthResult result = AuthenticateUser(conn, config, user_password);
switch (result) {
case AuthResult::OK:
*output = Redis::SimpleString("OK");
return Status::OK();
}
const auto requirepass = config->requirepass;
if (!requirepass.empty() && user_password != requirepass) {
break;
case AuthResult::INVALID_PASSWORD:
*output = Redis::Error("ERR invalid password");
return Status::OK();
}
conn->SetNamespace(kDefaultNamespace);
conn->BecomeAdmin();
if (requirepass.empty()) {
break;
case AuthResult::NO_REQUIRE_PASS:
*output = Redis::Error("ERR Client sent AUTH, but no password is set");
} else {
*output = Redis::SimpleString("OK");
break;
}
return Status::OK();
}
Expand Down Expand Up @@ -4132,6 +4151,72 @@ class CommandEcho : public Commander {
}
};

/* HELLO [<protocol-version> [AUTH <password>] [SETNAME <name>] ] */
class CommandHello final : public Commander {
public:
Status Execute(Server *svr, Connection *conn, std::string *output) override {
size_t next_arg = 1;
if (args_.size() >= 2) {
int64_t protocol;
auto parseResult = ParseInt<int64_t>(args_[next_arg], /* base= */ 10);
++next_arg;
if (!parseResult.IsOK()) {
return Status(Status::NotOK, "Protocol version is not an integer or out of range");
}
protocol = parseResult.GetValue();

// In redis, it will check protocol < 2 or protocol > 3,
// kvrocks only supports REPL2 by now, but for supporting some
// `hello 3`, it will not report error when using 3.
if (protocol < 2 || protocol > 3) {
return Status(Status::NotOK, "-NOPROTO unsupported protocol version");
}
}

// Handling AUTH and SETNAME
for (; next_arg < args_.size(); ++next_arg) {
size_t moreargs = args_.size() - next_arg - 1;
const std::string& opt = args_[next_arg];
if (opt == "AUTH" && moreargs != 0) {
const auto& user_password = args_[next_arg + 1];
auto authResult = AuthenticateUser(conn, svr->GetConfig(), user_password);
switch (authResult) {
case AuthResult::INVALID_PASSWORD:
return Status(Status::NotOK, "invalid password");
case AuthResult::NO_REQUIRE_PASS:
return Status(Status::NotOK, "Client sent AUTH, but no password is set");
case AuthResult::OK:
break;
}
next_arg += 1;
} else if (opt == "SETNAME" && moreargs != 0) {
const std::string& name = args_[next_arg + 1];
conn->SetName(name);
next_arg += 1;
} else {
*output = Redis::Error("Syntax error in HELLO option " + opt);
return Status::OK();
}
}

std::vector<std::string> output_list;
output_list.push_back(Redis::BulkString("server"));
output_list.push_back(Redis::BulkString("redis"));
output_list.push_back(Redis::BulkString("proto"));
output_list.push_back(Redis::Integer(2));

output_list.push_back(Redis::BulkString("mode"));
// Note: sentinel is not supported in kvrocks.
if (svr->GetConfig()->cluster_enabled) {
output_list.push_back(Redis::BulkString("cluster"));
} else {
output_list.push_back(Redis::BulkString("standalone"));
}
*output = Redis::Array(output_list);
return Status::OK();
}
};

class CommandScanBase : public Commander {
public:
Status ParseMatchAndCountParam(const std::string &type, std::string value) {
Expand Down Expand Up @@ -5680,6 +5765,7 @@ CommandAttributes redisCommandTable[] = {
ADD_CMD("debug", -2, "read-only exclusive", 0, 0, 0, CommandDebug),
ADD_CMD("command", -1, "read-only", 0, 0, 0, CommandCommand),
ADD_CMD("echo", 2, "read-only", 0, 0, 0, CommandEcho),
ADD_CMD("hello", -1, "read-only ok-loading", 0, 0, 0, CommandHello),

ADD_CMD("ttl", 2, "read-only", 1, 1, 1, CommandTTL),
ADD_CMD("pttl", 2, "read-only", 1, 1, 1, CommandPTTL),
Expand Down
37 changes: 14 additions & 23 deletions src/redis_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@
*
*/

#include <rocksdb/perf_context.h>
#include <rocksdb/iostats_context.h>
#include <glog/logging.h>
#include <rocksdb/iostats_context.h>
#include <rocksdb/perf_context.h>
#ifdef ENABLE_OPENSSL
#include <event2/bufferevent_ssl.h>
#endif

#include "redis_connection.h"
#include "worker.h"
#include "server.h"
#include "tls_util.h"
#include "worker.h"

namespace Redis {

Expand Down Expand Up @@ -74,9 +74,7 @@ void Connection::Close() {
owner_->FreeConnection(this);
}

void Connection::Detach() {
owner_->DetachConnection(this);
}
void Connection::Detach() { owner_->DetachConnection(this); }

void Connection::OnRead(struct bufferevent *bev, void *ctx) {
DLOG(INFO) << "[connection] on read: " << bufferevent_getfd(bev);
Expand Down Expand Up @@ -143,23 +141,21 @@ void Connection::SendFile(int fd) {
void Connection::SetAddr(std::string ip, int port) {
ip_ = std::move(ip);
port_ = port;
addr_ = ip_ +":"+ std::to_string(port_);
addr_ = ip_ + ":" + std::to_string(port_);
}

uint64_t Connection::GetAge() {
time_t now;
time(&now);
return static_cast<uint64_t>(now-create_time_);
return static_cast<uint64_t>(now - create_time_);
}

void Connection::SetLastInteraction() {
time(&last_interaction_);
}
void Connection::SetLastInteraction() { time(&last_interaction_); }

uint64_t Connection::GetIdleTime() {
time_t now;
time(&now);
return static_cast<uint64_t>(now-last_interaction_);
return static_cast<uint64_t>(now - last_interaction_);
}

// Currently, master connection is not handled in connection
Expand All @@ -185,17 +181,11 @@ std::string Connection::GetFlags() {
return flags;
}

void Connection::EnableFlag(Flag flag) {
flags_ |= flag;
}
void Connection::EnableFlag(Flag flag) { flags_ |= flag; }

void Connection::DisableFlag(Flag flag) {
flags_ &= (~flag);
}
void Connection::DisableFlag(Flag flag) { flags_ &= (~flag); }

bool Connection::IsFlagEnabled(Flag flag) {
return (flags_ & flag) > 0;
}
bool Connection::IsFlagEnabled(Flag flag) { return (flags_ & flag) > 0; }

void Connection::SubscribeChannel(const std::string &channel) {
for (const auto &chan : subscribe_channels_) {
Expand Down Expand Up @@ -333,7 +323,8 @@ void Connection::ExecuteCommands(std::deque<CommandTokens> *to_process_cmds) {
}

if (GetNamespace().empty()) {
if (!password.empty() && Util::ToLower(cmd_tokens.front()) != "auth") {
if (!password.empty() && Util::ToLower(cmd_tokens.front()) != "auth" &&
Util::ToLower(cmd_tokens.front()) != "hello") {
Reply(Redis::Error("NOAUTH Authentication required."));
continue;
}
Expand Down Expand Up @@ -402,7 +393,7 @@ void Connection::ExecuteCommands(std::deque<CommandTokens> *to_process_cmds) {
if (!s.IsOK()) {
if (IsFlagEnabled(Connection::kMultiExec)) multi_error_ = true;
Reply(Redis::Error(s.Msg()));
continue;;
continue;
}
}

Expand Down
33 changes: 18 additions & 15 deletions src/redis_reply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,36 +45,39 @@ std::string MultiLen(int64_t len) {
return "*"+std::to_string(len)+"\r\n";
}

std::string MultiBulkString(std::vector<std::string> values, bool output_nil_for_empty_string) {
std::string MultiBulkString(const std::vector<std::string>& values, bool output_nil_for_empty_string) {
std::string result = "*" + std::to_string(values.size()) + CRLF;
for (size_t i = 0; i < values.size(); i++) {
if (values[i].empty() && output_nil_for_empty_string) {
values[i] = NilString();
result += NilString();
} else {
values[i] = BulkString(values[i]);
result += BulkString(values[i]);
}
}
return Array(values);
return result;
}


std::string MultiBulkString(std::vector<std::string> values, const std::vector<rocksdb::Status> &statuses) {
std::string MultiBulkString(const std::vector<std::string>& values, const std::vector<rocksdb::Status> &statuses) {
std::string result = "*" + std::to_string(values.size()) + CRLF;
for (size_t i = 0; i < values.size(); i++) {
if (i < statuses.size() && !statuses[i].ok()) {
values[i] = NilString();
result += NilString();
} else {
values[i] = BulkString(values[i]);
result += BulkString(values[i]);
}
}
return Array(values);
return result;
}
std::string Array(std::vector<std::string> list) {
std::string::size_type n = std::accumulate(
list.begin(), list.end(), std::string::size_type(0),
[] ( std::string::size_type n, const std::string &s ) { return ( n += s.size() ); });

std::string Array(const std::vector<std::string>& list) {
size_t n = std::accumulate(
list.begin(), list.end(), 0, [] (size_t n, const std::string &s) { return n + s.size(); });
std::string result = "*" + std::to_string(list.size()) + CRLF;
result.reserve(n);
return std::accumulate(list.begin(), list.end(), result,
[](std::string &dest, std::string const &item) -> std::string& {dest += item; return dest;});
std::string::size_type final_size = result.size() + n;
result.reserve(final_size);
for (const auto& i : list) result += i;
return result;
}

std::string Command2RESP(const std::vector<std::string> &cmd_args) {
Expand Down
6 changes: 3 additions & 3 deletions src/redis_reply.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ std::string Integer(int64_t data);
std::string BulkString(const std::string &data);
std::string NilString();
std::string MultiLen(int64_t len);
std::string Array(std::vector<std::string> list);
std::string MultiBulkString(std::vector<std::string> values, bool output_nil_for_empty_string = true);
std::string MultiBulkString(std::vector<std::string> values, const std::vector<rocksdb::Status> &statuses);
std::string Array(const std::vector<std::string>& list);
std::string MultiBulkString(const std::vector<std::string>& values, bool output_nil_for_empty_string = true);
std::string MultiBulkString(const std::vector<std::string>& values, const std::vector<rocksdb::Status> &statuses);
std::string Command2RESP(const std::vector<std::string> &cmd_args);
} // namespace Redis
4 changes: 2 additions & 2 deletions src/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ Status DecimalStringToNum(const std::string &str, int64_t *n, int64_t min, int64
try {
*n = static_cast<int64_t>(std::stoll(str));
if (max > min && (*n < min || *n > max)) {
return Status(Status::NotOK, "value shoud between "+std::to_string(min)+" and "+std::to_string(max));
return Status(Status::NotOK, "value should between "+std::to_string(min)+" and "+std::to_string(max));
}
} catch (std::exception &e) {
return Status(Status::NotOK, "value is not an integer or out of range");
Expand All @@ -356,7 +356,7 @@ Status OctalStringToNum(const std::string &str, int64_t *n, int64_t min, int64_t
try {
*n = static_cast<int64_t>(std::stoll(str, nullptr, 8));
if (max > min && (*n < min || *n > max)) {
return Status(Status::NotOK, "value shoud between "+std::to_string(min)+" and "+std::to_string(max));
return Status(Status::NotOK, "value should between "+std::to_string(min)+" and "+std::to_string(max));
}
} catch (std::exception &e) {
return Status(Status::NotOK, "value is not an integer or out of range");
Expand Down
4 changes: 2 additions & 2 deletions tests/gocase/unit/command/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ func TestCommand(t *testing.T) {
rdb := srv.NewClient()
defer func() { require.NoError(t, rdb.Close()) }()

t.Run("Kvrocks supports 180 commands currently", func(t *testing.T) {
t.Run("Kvrocks supports 181 commands currently", func(t *testing.T) {
r := rdb.Do(ctx, "COMMAND", "COUNT")
v, err := r.Int()
require.NoError(t, err)
require.Equal(t, 180, v)
require.Equal(t, 181, v)
})

t.Run("acquire GET command info by COMMAND INFO", func(t *testing.T) {
Expand Down
Loading