Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the command hello #881

Merged
merged 10 commits into from
Sep 18, 2022
124 changes: 108 additions & 16 deletions src/redis_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include "scripting.h"
#include "slot_import.h"
#include "slot_migrate.h"
#include "parse_util.h"

namespace Redis {

Expand All @@ -73,29 +74,47 @@ const char *errUnbalacedStreamList =
const char *errTimeoutIsNegative = "timeout is negative";
const char *errLimitOptionNotAllowed = "syntax error, LIMIT cannot be used without the special ~ option";

enum class AuthResult {
OK,
INVALID_PASSWORD,
NO_REQUIRE_PASS,
};

AuthResult AuthenticateUser(Connection *conn, Config* config, const std::string& user_password) {
auto iter = config->tokens.find(user_password);
if (iter != config->tokens.end()) {
conn->SetNamespace(iter->second);
conn->BecomeUser();
return AuthResult::OK;
}
const auto& requirepass = config->requirepass;
if (!requirepass.empty() && user_password != requirepass) {
return AuthResult::INVALID_PASSWORD;
}
conn->SetNamespace(kDefaultNamespace);
conn->BecomeAdmin();
if (requirepass.empty()) {
return AuthResult::NO_REQUIRE_PASS;
}
return AuthResult::OK;
}

enjoy-binbin marked this conversation as resolved.
Show resolved Hide resolved
class CommandAuth : public Commander {
public:
Status Execute(Server *svr, Connection *conn, std::string *output) override {
Config *config = svr->GetConfig();
auto user_password = args_[1];
auto iter = config->tokens.find(user_password);
if (iter != config->tokens.end()) {
conn->SetNamespace(iter->second);
conn->BecomeUser();
auto& user_password = args_[1];
AuthResult result = AuthenticateUser(conn, config, user_password);
switch (result) {
case AuthResult::OK:
*output = Redis::SimpleString("OK");
return Status::OK();
}
const auto requirepass = config->requirepass;
if (!requirepass.empty() && user_password != requirepass) {
break;
case AuthResult::INVALID_PASSWORD:
*output = Redis::Error("ERR invalid password");
return Status::OK();
}
conn->SetNamespace(kDefaultNamespace);
conn->BecomeAdmin();
if (requirepass.empty()) {
break;
case AuthResult::NO_REQUIRE_PASS:
*output = Redis::Error("ERR Client sent AUTH, but no password is set");
} else {
*output = Redis::SimpleString("OK");
break;
}
return Status::OK();
}
Expand Down Expand Up @@ -4132,6 +4151,78 @@ class CommandEcho : public Commander {
}
};

/* HELLO [<protocol-version> [AUTH <password>] [SETNAME <name>] ] */
class CommandHello final : public Commander {
public:
Status Execute(Server *svr, Connection *conn, std::string *output) override {
size_t next_arg = 1;
if (args_.size() >= 2) {
int64_t protocol;
auto parseResult = ParseInt<int64_t>(args_[next_arg], /* base= */ 10);
++next_arg;
if (!parseResult.IsOK()) {
*output = Redis::Error("Protocol version is not an integer or out of range");
return parseResult.ToStatus();
}
protocol = parseResult.GetValue();

// In redis, it will check protocol < 2 or protocol > 3,
// but kvrocks only supports REPL2 by now.
if (protocol != 2) {
*output = Redis::Error("-NOPROTO unsupported protocol version");
return Status::OK();
}
}

// Handling AUTH and SETNAME
for (; next_arg < args_.size(); ++next_arg) {
size_t moreargs = args_.size() - next_arg - 1;
const std::string& opt = args_[next_arg];
if (opt == "AUTH" && moreargs != 0) {
const auto& user_password = args_[next_arg + 1];
auto authResult = AuthenticateUser(conn, svr->GetConfig(), user_password);
switch (authResult) {
case AuthResult::INVALID_PASSWORD:
*output = Redis::Error("ERR invalid password");
break;
case AuthResult::NO_REQUIRE_PASS:
*output = Redis::Error("ERR Client sent AUTH, but no password is set");
break;
case AuthResult::OK:
break;
}
if (authResult != AuthResult::OK) {
return Status::OK();
}
mapleFU marked this conversation as resolved.
Show resolved Hide resolved
next_arg += 1;
} else if (opt == "SETNAME" && moreargs != 0) {
const std::string& name = args_[next_arg + 1];
conn->SetName(name);
next_arg += 1;
} else {
*output = Redis::Error("Syntax error in HELLO option " + opt);
return Status::OK();
}
}

std::vector<std::string> output_list;
output_list.push_back(Redis::BulkString("server"));
output_list.push_back(Redis::BulkString("redis"));
output_list.push_back(Redis::BulkString("proto"));
output_list.push_back(Redis::Integer(2));

output_list.push_back(Redis::BulkString("mode"));
// Note: sentinel is not supported in kvrocks.
if (svr->GetConfig()->cluster_enabled) {
output_list.push_back(Redis::BulkString("cluster"));
} else {
output_list.push_back(Redis::BulkString("standalone"));
}
*output = Redis::Array(output_list);
return Status::OK();
}
};

class CommandScanBase : public Commander {
public:
Status ParseMatchAndCountParam(const std::string &type, std::string value) {
Expand Down Expand Up @@ -5680,6 +5771,7 @@ CommandAttributes redisCommandTable[] = {
ADD_CMD("debug", -2, "read-only exclusive", 0, 0, 0, CommandDebug),
ADD_CMD("command", -1, "read-only", 0, 0, 0, CommandCommand),
ADD_CMD("echo", 2, "read-only", 0, 0, 0, CommandEcho),
ADD_CMD("hello", -1, "read-only", 0, 0, 0, CommandHello),

ADD_CMD("ttl", 2, "read-only", 1, 1, 1, CommandTTL),
ADD_CMD("pttl", 2, "read-only", 1, 1, 1, CommandPTTL),
Expand Down
33 changes: 18 additions & 15 deletions src/redis_reply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,36 +45,39 @@ std::string MultiLen(int64_t len) {
return "*"+std::to_string(len)+"\r\n";
}

std::string MultiBulkString(std::vector<std::string> values, bool output_nil_for_empty_string) {
std::string MultiBulkString(const std::vector<std::string>& values, bool output_nil_for_empty_string) {
std::string result = "*" + std::to_string(values.size()) + CRLF;
for (size_t i = 0; i < values.size(); i++) {
if (values[i].empty() && output_nil_for_empty_string) {
values[i] = NilString();
result += NilString();
} else {
values[i] = BulkString(values[i]);
result += BulkString(values[i]);
}
}
return Array(values);
return result;
}


std::string MultiBulkString(std::vector<std::string> values, const std::vector<rocksdb::Status> &statuses) {
std::string MultiBulkString(const std::vector<std::string>& values, const std::vector<rocksdb::Status> &statuses) {
std::string result = "*" + std::to_string(values.size()) + CRLF;
for (size_t i = 0; i < values.size(); i++) {
if (i < statuses.size() && !statuses[i].ok()) {
values[i] = NilString();
result += NilString();
} else {
values[i] = BulkString(values[i]);
result += BulkString(values[i]);
}
}
return Array(values);
return result;
}
std::string Array(std::vector<std::string> list) {
std::string::size_type n = std::accumulate(
list.begin(), list.end(), std::string::size_type(0),
[] ( std::string::size_type n, const std::string &s ) { return ( n += s.size() ); });

std::string Array(const std::vector<std::string>& list) {
size_t n = std::accumulate(
list.begin(), list.end(), 0, [] (size_t n, const std::string &s) { return n + s.size(); });
std::string result = "*" + std::to_string(list.size()) + CRLF;
result.reserve(n);
return std::accumulate(list.begin(), list.end(), result,
[](std::string &dest, std::string const &item) -> std::string& {dest += item; return dest;});
std::string::size_type final_size = result.size() + n;
result.reserve(final_size);
for (const auto& i : list) result += i;
return result;
}

std::string Command2RESP(const std::vector<std::string> &cmd_args) {
Expand Down
6 changes: 3 additions & 3 deletions src/redis_reply.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ std::string Integer(int64_t data);
std::string BulkString(const std::string &data);
std::string NilString();
std::string MultiLen(int64_t len);
std::string Array(std::vector<std::string> list);
std::string MultiBulkString(std::vector<std::string> values, bool output_nil_for_empty_string = true);
std::string MultiBulkString(std::vector<std::string> values, const std::vector<rocksdb::Status> &statuses);
std::string Array(const std::vector<std::string>& list);
std::string MultiBulkString(const std::vector<std::string>& values, bool output_nil_for_empty_string = true);
std::string MultiBulkString(const std::vector<std::string>& values, const std::vector<rocksdb::Status> &statuses);
std::string Command2RESP(const std::vector<std::string> &cmd_args);
} // namespace Redis
4 changes: 2 additions & 2 deletions src/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ Status DecimalStringToNum(const std::string &str, int64_t *n, int64_t min, int64
try {
*n = static_cast<int64_t>(std::stoll(str));
if (max > min && (*n < min || *n > max)) {
return Status(Status::NotOK, "value shoud between "+std::to_string(min)+" and "+std::to_string(max));
return Status(Status::NotOK, "value should between "+std::to_string(min)+" and "+std::to_string(max));
}
} catch (std::exception &e) {
return Status(Status::NotOK, "value is not an integer or out of range");
Expand All @@ -356,7 +356,7 @@ Status OctalStringToNum(const std::string &str, int64_t *n, int64_t min, int64_t
try {
*n = static_cast<int64_t>(std::stoll(str, nullptr, 8));
if (max > min && (*n < min || *n > max)) {
return Status(Status::NotOK, "value shoud between "+std::to_string(min)+" and "+std::to_string(max));
return Status(Status::NotOK, "value should between "+std::to_string(min)+" and "+std::to_string(max));
}
} catch (std::exception &e) {
return Status(Status::NotOK, "value is not an integer or out of range");
Expand Down
31 changes: 31 additions & 0 deletions tests/gocase/unit/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,34 @@ func TestAuth(t *testing.T) {
require.EqualValues(t, 101, rdb.Incr(ctx, "foo").Val())
})
}

func TestHello(t *testing.T) {
srv := util.StartServer(t, map[string]string{
"requirepass": "foobar",
})
defer srv.Close()

ctx := context.Background()
rdb := srv.NewClient()
defer func() { require.NoError(t, rdb.Close()) }()

t.Run("hello with wrong protocol", func(t *testing.T) {
r := rdb.Do(ctx, "HELLO 3")
require.ErrorContains(t, r.Err(), "-NOPROTO unsupported protocol version")
})

t.Run("AUTH succeeds when the right password is given", func(t *testing.T) {
r := rdb.Do(ctx, "AUTH", "foobar")
require.Equal(t, "OK", r.Val())
})

t.Run("hello with wrong protocol", func(t *testing.T) {
r := rdb.Do(ctx, "HELLO 5")
require.ErrorContains(t, r.Err(), "-NOPROTO unsupported protocol version")
})

t.Run("hello with non protocol", func(t *testing.T) {
r := rdb.Do(ctx, "HELLO AUTH")
require.ErrorContains(t, r.Err(), "Protocol version is not an integer or out of range")
})
}