From 64e7834fea79a81d5de125a8fe3ae8e00dc82ee6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BA=AA=E5=8D=8E=E8=A3=95?= <8042833@qq.com> Date: Mon, 18 Dec 2023 18:13:48 +0800 Subject: [PATCH] Add the support of KEEPTLE and GET options to SET command(#1935) --- src/commands/cmd_string.cc | 50 +++++--- src/types/redis_string.cc | 116 ++++++++++++------ src/types/redis_string.h | 15 ++- tests/cppunit/types/string_test.cc | 6 +- .../gocase/unit/type/strings/strings_test.go | 88 +++++++++++++ 5 files changed, 212 insertions(+), 63 deletions(-) diff --git a/src/commands/cmd_string.cc b/src/commands/cmd_string.cc index 99783172c42..a0e1a690b5c 100644 --- a/src/commands/cmd_string.cc +++ b/src/commands/cmd_string.cc @@ -20,10 +20,12 @@ #include #include +#include #include "commander.h" #include "commands/command_parser.h" #include "error_constants.h" +#include "server/redis_reply.h" #include "server/server.h" #include "storage/redis_db.h" #include "time_util.h" @@ -131,16 +133,16 @@ class CommandGetSet : public Commander { public: Status Execute(Server *srv, Connection *conn, std::string *output) override { redis::String string_db(srv->storage, conn->GetNamespace()); - std::string old_value; - auto s = string_db.GetSet(args_[1], args_[2], &old_value); - if (!s.ok() && !s.IsNotFound()) { + std::optional old_value; + auto s = string_db.GetSet(args_[1], args_[2], old_value); + if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; } - if (s.IsNotFound()) { - *output = redis::NilString(); + if (old_value.has_value()) { + *output = redis::BulkString(old_value.value()); } else { - *output = redis::BulkString(old_value); + *output = redis::NilString(); } return Status::OK(); } @@ -281,10 +283,14 @@ class CommandSet : public Commander { while (parser.Good()) { if (auto v = GET_OR_RET(ParseTTL(parser, ttl_flag))) { ttl_ = *v; + } else if (parser.EatEqICaseFlag("KEEPTTL", ttl_flag)) { + keep_ttl_ = true; } else if (parser.EatEqICaseFlag("NX", set_flag)) { - set_flag_ = NX; + set_flag_ = StringSetType::NX; } else if (parser.EatEqICaseFlag("XX", set_flag)) { - set_flag_ = XX; + set_flag_ = StringSetType::XX; + } else if (parser.EatEqICase("GET")) { + get_ = true; } else { return parser.InvalidSyntax(); } @@ -294,7 +300,7 @@ class CommandSet : public Commander { } Status Execute(Server *srv, Connection *conn, std::string *output) override { - bool ret = false; + std::optional ret; redis::String string_db(srv->storage, conn->GetNamespace()); if (ttl_ < 0) { @@ -307,29 +313,33 @@ class CommandSet : public Commander { } rocksdb::Status s; - if (set_flag_ == NX) { - s = string_db.SetNX(args_[1], args_[2], ttl_, &ret); - } else if (set_flag_ == XX) { - s = string_db.SetXX(args_[1], args_[2], ttl_, &ret); - } else { - s = string_db.SetEX(args_[1], args_[2], ttl_); - } + s = string_db.Set(args_[1], args_[2], {ttl_, set_flag_, get_, keep_ttl_}, ret); if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; } - if (set_flag_ != NONE && !ret) { - *output = redis::NilString(); + if (get_) { + if (ret.has_value()) { + *output = redis::BulkString(ret.value()); + } else { + *output = redis::NilString(); + } } else { - *output = redis::SimpleString("OK"); + if (ret.has_value()) { + *output = redis::SimpleString("OK"); + } else { + *output = redis::NilString(); + } } return Status::OK(); } private: uint64_t ttl_ = 0; - enum { NONE, NX, XX } set_flag_ = NONE; + bool get_ = false; + bool keep_ttl_ = false; + StringSetType set_flag_ = StringSetType::NONE; }; class CommandSetEX : public Commander { diff --git a/src/types/redis_string.cc b/src/types/redis_string.cc index 6153fe4eeec..420d3994f9e 100644 --- a/src/types/redis_string.cc +++ b/src/types/redis_string.cc @@ -23,7 +23,6 @@ #include #include #include -#include #include #include @@ -187,20 +186,10 @@ rocksdb::Status String::GetEx(const std::string &user_key, std::string *value, u return rocksdb::Status::OK(); } -rocksdb::Status String::GetSet(const std::string &user_key, const std::string &new_value, std::string *old_value) { - std::string ns_key = AppendNamespacePrefix(user_key); - - LockGuard guard(storage_->GetLockManager(), ns_key); - rocksdb::Status s = getValue(ns_key, old_value); - if (!s.ok() && !s.IsNotFound()) return s; - - std::string raw_value; - Metadata metadata(kRedisString, false); - metadata.Encode(&raw_value); - raw_value.append(new_value); - auto write_status = updateRawValue(ns_key, raw_value); - // prev status was used to tell whether old value was empty or not - return !write_status.ok() ? write_status : s; +rocksdb::Status String::GetSet(const std::string &user_key, const std::string &new_value, + std::optional &old_value) { + auto s = Set(user_key, new_value, {/*ttl=*/0, StringSetType::NONE, /*get=*/true, /*keep_ttl=*/false}, old_value); + return s; } rocksdb::Status String::GetDel(const std::string &user_key, std::string *value) { std::string ns_key = AppendNamespacePrefix(user_key); @@ -217,38 +206,87 @@ rocksdb::Status String::Set(const std::string &user_key, const std::string &valu return MSet(pairs, /*ttl=*/0, /*lock=*/true); } -rocksdb::Status String::SetEX(const std::string &user_key, const std::string &value, uint64_t ttl) { - std::vector pairs{StringPair{user_key, value}}; - return MSet(pairs, /*ttl=*/ttl, /*lock=*/true); -} +rocksdb::Status String::Set(const std::string &user_key, const std::string &value, StringSetArgs args, + std::optional &ret) { + std::string ns_key = AppendNamespacePrefix(user_key); -rocksdb::Status String::SetNX(const std::string &user_key, const std::string &value, uint64_t ttl, bool *flag) { - std::vector pairs{StringPair{user_key, value}}; - return MSetNX(pairs, ttl, flag); -} + LockGuard guard(storage_->GetLockManager(), ns_key); -rocksdb::Status String::SetXX(const std::string &user_key, const std::string &value, uint64_t ttl, bool *flag) { - *flag = false; - int exists = 0; + // Get old value for NX/XX/GET/KEEPTTL option + std::string old_raw_value; + auto s = getRawValue(ns_key, &old_raw_value); + if (!s.ok() && !s.IsNotFound() && !s.IsInvalidArgument()) return s; + auto old_key_found = !s.IsNotFound(); + // The reply following Redis doc: https://redis.io/commands/set/ + // Handle GET option + if (args.get) { + if (s.IsInvalidArgument()) { + return s; + } + if (old_key_found) { + // if GET option given: return The previous value of the key. + auto offset = Metadata::GetOffsetAfterExpire(old_raw_value[0]); + ret = std::make_optional(old_raw_value.substr(offset)); + } else { + // if GET option given, the key didn't exist before: return nil + ret = std::nullopt; + } + } + + // Handle NX/XX option + if (old_key_found && args.type == StringSetType::NX) { + // if GET option not given, operation aborted: return nil + if (!args.get) ret = std::nullopt; + return rocksdb::Status::OK(); + } else if (!old_key_found && args.type == StringSetType::XX) { + // if GET option not given, operation aborted: return nil + if (!args.get) ret = std::nullopt; + return rocksdb::Status::OK(); + } else { + // if GET option not given, make ret not nil + if (!args.get) ret = ""; + } + + // Handle expire time uint64_t expire = 0; - if (ttl > 0) { + if (args.ttl > 0) { uint64_t now = util::GetTimeStampMS(); - expire = now + ttl; + expire = now + args.ttl; + } else if (args.keep_ttl && old_key_found) { + Metadata metadata(kRedisString, false); + auto s = metadata.Decode(old_raw_value); + if (!s.ok()) { + return s; + } + expire = metadata.expire; } - std::string ns_key = AppendNamespacePrefix(user_key); - LockGuard guard(storage_->GetLockManager(), ns_key); - auto s = Exists({user_key}, &exists); - if (!s.ok()) return s; - if (exists != 1) return rocksdb::Status::OK(); - - *flag = true; - std::string raw_value; + // Create new value + std::string new_raw_value; Metadata metadata(kRedisString, false); metadata.expire = expire; - metadata.Encode(&raw_value); - raw_value.append(value); - return updateRawValue(ns_key, raw_value); + metadata.Encode(&new_raw_value); + new_raw_value.append(value); + return updateRawValue(ns_key, new_raw_value); +} + +rocksdb::Status String::SetEX(const std::string &user_key, const std::string &value, uint64_t ttl) { + std::optional ret; + return Set(user_key, value, {ttl, StringSetType::NONE, /*get=*/false, /*keep_ttl=*/false}, ret); +} + +rocksdb::Status String::SetNX(const std::string &user_key, const std::string &value, uint64_t ttl, bool *flag) { + std::optional ret; + auto s = Set(user_key, value, {ttl, StringSetType::NX, /*get=*/false, /*keep_ttl=*/false}, ret); + *flag = ret.has_value(); + return s; +} + +rocksdb::Status String::SetXX(const std::string &user_key, const std::string &value, uint64_t ttl, bool *flag) { + std::optional ret; + auto s = Set(user_key, value, {ttl, StringSetType::XX, /*get=*/false, /*keep_ttl=*/false}, ret); + *flag = ret.has_value(); + return s; } rocksdb::Status String::SetRange(const std::string &user_key, size_t offset, const std::string &value, diff --git a/src/types/redis_string.h b/src/types/redis_string.h index 41be5bddd0e..bfb4ef99005 100644 --- a/src/types/redis_string.h +++ b/src/types/redis_string.h @@ -21,6 +21,7 @@ #pragma once #include +#include #include #include @@ -32,6 +33,15 @@ struct StringPair { Slice value; }; +enum class StringSetType { NONE, NX, XX }; + +struct StringSetArgs { + uint64_t ttl; + StringSetType type; + bool get; + bool keep_ttl; +}; + namespace redis { class String : public Database { @@ -40,9 +50,12 @@ class String : public Database { rocksdb::Status Append(const std::string &user_key, const std::string &value, uint64_t *new_size); rocksdb::Status Get(const std::string &user_key, std::string *value); rocksdb::Status GetEx(const std::string &user_key, std::string *value, uint64_t ttl, bool persist); - rocksdb::Status GetSet(const std::string &user_key, const std::string &new_value, std::string *old_value); + rocksdb::Status GetSet(const std::string &user_key, const std::string &new_value, + std::optional &old_value); rocksdb::Status GetDel(const std::string &user_key, std::string *value); rocksdb::Status Set(const std::string &user_key, const std::string &value); + rocksdb::Status Set(const std::string &user_key, const std::string &value, StringSetArgs args, + std::optional &ret); rocksdb::Status SetEX(const std::string &user_key, const std::string &value, uint64_t ttl); rocksdb::Status SetNX(const std::string &user_key, const std::string &value, uint64_t ttl, bool *flag); rocksdb::Status SetXX(const std::string &user_key, const std::string &value, uint64_t ttl, bool *flag); diff --git a/tests/cppunit/types/string_test.cc b/tests/cppunit/types/string_test.cc index fe916adc5ca..631b4ec8114 100644 --- a/tests/cppunit/types/string_test.cc +++ b/tests/cppunit/types/string_test.cc @@ -142,15 +142,15 @@ TEST_F(RedisStringTest, GetSet) { rocksdb::Env::Default()->GetCurrentTime(&now); std::vector values = {"a", "b", "c", "d"}; for (size_t i = 0; i < values.size(); i++) { - std::string old_value; + std::optional old_value; auto s = string_->Expire(key_, now * 1000 + 100000); - string_->GetSet(key_, values[i], &old_value); + string_->GetSet(key_, values[i], old_value); if (i != 0) { EXPECT_EQ(values[i - 1], old_value); auto s = string_->TTL(key_, &ttl); EXPECT_TRUE(ttl == -1); } else { - EXPECT_TRUE(old_value.empty()); + EXPECT_TRUE(!old_value.has_value()); } } auto s = string_->Del(key_); diff --git a/tests/gocase/unit/type/strings/strings_test.go b/tests/gocase/unit/type/strings/strings_test.go index f255228d18f..fc799fc5cd8 100644 --- a/tests/gocase/unit/type/strings/strings_test.go +++ b/tests/gocase/unit/type/strings/strings_test.go @@ -678,6 +678,94 @@ func TestString(t *testing.T) { util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second) }) + t.Run("Extended SET KEEPTTL and EX/PX/EXAT/PXAT option", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "foo").Err()) + require.Error(t, rdb.Do(ctx, "SET", "foo", "xx", "keepttl", "ex", "100").Err()) + require.Error(t, rdb.Do(ctx, "SET", "foo", "xx", "keepttl", "px", "100").Err()) + require.Error(t, rdb.Do(ctx, "SET", "foo", "xx", "keepttl", "exat", "100").Err()) + require.Error(t, rdb.Do(ctx, "SET", "foo", "xx", "keepttl", "pxat", "100").Err()) + }) + + t.Run("Extended SET KEEPTTL WITH option", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "foo").Err()) + require.Equal(t, "OK", rdb.SetArgs(ctx, "foo", "xx", redis.SetArgs{KeepTTL: true}).Val()) + ttl := rdb.TTL(ctx, "foo").Val() + require.Equal(t, time.Duration(-1), ttl) + require.Equal(t, "OK", rdb.Set(ctx, "foo", "bar", 10*time.Second).Val()) + require.Equal(t, "OK", rdb.SetArgs(ctx, "foo", "xx", redis.SetArgs{KeepTTL: true}).Val()) + ttl = rdb.TTL(ctx, "foo").Val() + util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second) + }) + + t.Run("Extended SET GET option", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "foo").Err()) + require.Equal(t, "", rdb.SetArgs(ctx, "foo", "bar", redis.SetArgs{Get: true}).Val()) + require.Equal(t, "bar", rdb.SetArgs(ctx, "foo", "xx", redis.SetArgs{Get: true}).Val()) + require.Equal(t, "xx", rdb.Get(ctx, "foo").Val()) + }) + + t.Run("Extended SET GET and NX option", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "foo").Err()) + require.Equal(t, "", rdb.SetArgs(ctx, "foo", "xx", redis.SetArgs{Get: true, Mode: "NX"}).Val()) + require.Equal(t, "xx", rdb.Get(ctx, "foo").Val()) + require.Equal(t, "OK", rdb.Set(ctx, "foo", "bar", 0).Val()) + require.Equal(t, "bar", rdb.SetArgs(ctx, "foo", "xx", redis.SetArgs{Get: true, Mode: "NX"}).Val()) + require.Equal(t, "bar", rdb.Get(ctx, "foo").Val()) + }) + + t.Run("Extended SET GET and XX option", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "foo").Err()) + require.Equal(t, "", rdb.SetArgs(ctx, "foo", "xx", redis.SetArgs{Get: true, Mode: "XX"}).Val()) + require.Equal(t, "", rdb.Get(ctx, "foo").Val()) + require.Equal(t, "OK", rdb.Set(ctx, "foo", "bar", 0).Val()) + require.Equal(t, "bar", rdb.SetArgs(ctx, "foo", "xx", redis.SetArgs{Get: true, Mode: "XX"}).Val()) + require.Equal(t, "xx", rdb.Get(ctx, "foo").Val()) + }) + + t.Run("Extended SET GET and KEEPTTL option", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "foo").Err()) + require.Equal(t, "", rdb.SetArgs(ctx, "foo", "xx", redis.SetArgs{Get: true, KeepTTL: true}).Val()) + ttl := rdb.TTL(ctx, "foo").Val() + require.Equal(t, time.Duration(-1), ttl) + require.Equal(t, "OK", rdb.Set(ctx, "foo", "bar", 10*time.Second).Val()) + require.Equal(t, "bar", rdb.SetArgs(ctx, "foo", "xx", redis.SetArgs{Get: true, KeepTTL: true}).Val()) + ttl = rdb.TTL(ctx, "foo").Val() + util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second) + }) + + t.Run("Extended SET GET and EX option", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "foo").Err()) + require.Equal(t, nil, rdb.Do(ctx, "SET", "foo", "bar", "ex", "10", "get").Val()) + ttl := rdb.TTL(ctx, "foo").Val() + util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second) + }) + + t.Run("Extended SET GET and PX option", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "foo").Err()) + require.Equal(t, nil, rdb.Do(ctx, "SET", "foo", "bar", "px", "10000", "get").Val()) + ttl := rdb.TTL(ctx, "foo").Val() + util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second) + }) + + t.Run("Extended SET GET and EXAT option", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "foo").Err()) + + expireAt := strconv.FormatInt(time.Now().Add(10*time.Second).Unix(), 10) + require.Equal(t, nil, rdb.Do(ctx, "SET", "foo", "bar", "exat", expireAt, "get").Val()) + ttl := rdb.TTL(ctx, "foo").Val() + util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second) + }) + + t.Run("Extended SET GET and PXAT option", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "foo").Err()) + + expireAt := strconv.FormatInt(time.Now().Add(10*time.Second).UnixMilli(), 10) + require.Equal(t, nil, rdb.Do(ctx, "SET", "foo", "bar", "pxat", expireAt, "get").Val()) + + ttl := rdb.TTL(ctx, "foo").Val() + util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second) + }) + t.Run("GETRANGE with huge ranges, Github issue redis/redis#1844", func(t *testing.T) { require.NoError(t, rdb.Set(ctx, "foo", "bar", 0).Err()) require.Equal(t, "bar", rdb.GetRange(ctx, "foo", 0, 2094967291).Val())