diff --git a/src/commands/cmd_server.cc b/src/commands/cmd_server.cc index 2b37d6fdaab..4002aa0057d 100644 --- a/src/commands/cmd_server.cc +++ b/src/commands/cmd_server.cc @@ -657,7 +657,11 @@ class CommandTime : public Commander { } }; -/* HELLO [ [AUTH ] [SETNAME ] ] */ +/* + * HELLO [ [AUTH [| ]] [SETNAME ] ] + * Note that the should always be `default` if provided otherwise AUTH fails. + * And it is only meant to be aligning syntax with Redis HELLO. + */ class CommandHello final : public Commander { public: Status Execute(Server *svr, Connection *conn, std::string *output) override { @@ -683,7 +687,13 @@ class CommandHello final : public Commander { for (; next_arg < args_.size(); ++next_arg) { size_t more_args = args_.size() - next_arg - 1; const std::string &opt = args_[next_arg]; - if (opt == "AUTH" && more_args != 0) { + if (util::ToLower(opt) == "auth" && more_args != 0) { + if (more_args == 2 || more_args == 4) { + if (args_[next_arg + 1] != "default") { + return {Status::NotOK, "invalid password"}; + } + next_arg++; + } const auto &user_password = args_[next_arg + 1]; auto auth_result = AuthenticateUser(conn, svr->GetConfig(), user_password); switch (auth_result) { @@ -695,7 +705,7 @@ class CommandHello final : public Commander { break; } next_arg += 1; - } else if (opt == "SETNAME" && more_args != 0) { + } else if (util::ToLower(opt) == "setname" && more_args != 0) { const std::string &name = args_[next_arg + 1]; conn->SetName(name); next_arg += 1; diff --git a/tests/gocase/unit/hello/hello_test.go b/tests/gocase/unit/hello/hello_test.go index 21e33a4666d..cd3256ad420 100644 --- a/tests/gocase/unit/hello/hello_test.go +++ b/tests/gocase/unit/hello/hello_test.go @@ -24,6 +24,7 @@ import ( "testing" "github.com/apache/incubator-kvrocks/tests/gocase/util" + "github.com/go-redis/redis/v9" "github.com/stretchr/testify/require" ) @@ -90,6 +91,11 @@ func TestHelloWithAuth(t *testing.T) { require.ErrorContains(t, r.Err(), "invalid password") }) + t.Run("AUTH fails when a wrong username is given", func(t *testing.T) { + r := rdb.Do(ctx, "HELLO", "3", "AUTH", "wrong!", "foobar") + require.ErrorContains(t, r.Err(), "invalid password") + }) + t.Run("Arbitrary command gives an error when AUTH is required", func(t *testing.T) { r := rdb.Set(ctx, "foo", "bar", 0) require.ErrorContains(t, r.Err(), "NOAUTH Authentication required.") @@ -100,6 +106,11 @@ func TestHelloWithAuth(t *testing.T) { t.Log(r) }) + t.Run("AUTH succeeds when the right username and password are given", func(t *testing.T) { + r := rdb.Do(ctx, "HELLO", "3", "AUTH", "default", "foobar") + t.Log(r) + }) + t.Run("Once AUTH succeeded we can actually send commands to the server", func(t *testing.T) { require.Equal(t, "OK", rdb.Set(ctx, "foo", 100, 0).Val()) require.EqualValues(t, 101, rdb.Incr(ctx, "foo").Val()) @@ -114,4 +125,30 @@ func TestHelloWithAuth(t *testing.T) { r = rdb.Do(ctx, "CLIENT", "GETNAME") require.EqualValues(t, r.Val(), "kvrocks") }) + + t.Run("hello with non protocol", func(t *testing.T) { + r := rdb.Do(ctx, "HELLO", "2", "AUTH", "default", "foobar", "SETNAME", "kvrocks") + rList := r.Val().([]interface{}) + require.EqualValues(t, rList[2], "proto") + require.EqualValues(t, rList[3], 2) + + r = rdb.Do(ctx, "CLIENT", "GETNAME") + require.EqualValues(t, r.Val(), "kvrocks") + }) +} + +func TestHelloWithAuthByGoRedis(t *testing.T) { + srv := util.StartServer(t, map[string]string{ + "requirepass": "foobar", + }) + defer srv.Close() + + t.Run("hello with auth sent by go-redis", func(t *testing.T) { + rdb := srv.NewClientWithOption(&redis.Options{ + Password: "foobar", + }) + defer func() { require.NoError(t, rdb.Close()) }() + + require.Equal(t, "PONG", rdb.Ping(context.Background()).Val()) + }) }