diff --git a/src/commands/cmd_string.cc b/src/commands/cmd_string.cc index 9832716b6b8..3bf99559382 100644 --- a/src/commands/cmd_string.cc +++ b/src/commands/cmd_string.cc @@ -613,7 +613,7 @@ REDIS_REGISTER_COMMANDS( MakeCmdAttr("append", 3, "write", 1, 1, 1), MakeCmdAttr("set", -3, "write", 1, 1, 1), MakeCmdAttr("setex", 4, "write", 1, 1, 1), MakeCmdAttr("psetex", 4, "write", 1, 1, 1), MakeCmdAttr("setnx", 3, "write", 1, 1, 1), - MakeCmdAttr("msetnx", -3, "write exclusive", 1, -1, 2), + MakeCmdAttr("msetnx", -3, "write", 1, -1, 2), MakeCmdAttr("mset", -3, "write", 1, -1, 2), MakeCmdAttr("incrby", 3, "write", 1, 1, 1), MakeCmdAttr("incrbyfloat", 3, "write", 1, 1, 1), MakeCmdAttr("incr", 2, "write", 1, 1, 1), MakeCmdAttr("decrby", 3, "write", 1, 1, 1), diff --git a/src/types/redis_string.cc b/src/types/redis_string.cc index d1b5b958f1c..e2f1ea6fc93 100644 --- a/src/types/redis_string.cc +++ b/src/types/redis_string.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include "parse_util.h" @@ -213,12 +214,12 @@ rocksdb::Status String::GetDel(const std::string &user_key, std::string *value) rocksdb::Status String::Set(const std::string &user_key, const std::string &value) { std::vector pairs{StringPair{user_key, value}}; - return MSet(pairs, 0); + return MSet(pairs, /*ttl=*/0, /*lock=*/true); } rocksdb::Status String::SetEX(const std::string &user_key, const std::string &value, uint64_t ttl) { std::vector pairs{StringPair{user_key, value}}; - return MSet(pairs, ttl); + return MSet(pairs, /*ttl=*/ttl, /*lock=*/true); } rocksdb::Status String::SetNX(const std::string &user_key, const std::string &value, uint64_t ttl, bool *flag) { @@ -363,7 +364,7 @@ rocksdb::Status String::IncrByFloat(const std::string &user_key, double incremen return updateRawValue(ns_key, raw_value); } -rocksdb::Status String::MSet(const std::vector &pairs, uint64_t ttl) { +rocksdb::Status String::MSet(const std::vector &pairs, uint64_t ttl, bool lock) { uint64_t expire = 0; if (ttl > 0) { uint64_t now = util::GetTimeStampMS(); @@ -384,7 +385,10 @@ rocksdb::Status String::MSet(const std::vector &pairs, uint64_t ttl) batch->PutLogData(log_data.Encode()); AppendNamespacePrefix(pair.key, &ns_key); batch->Put(metadata_cf_handle_, ns_key, bytes); - LockGuard guard(storage_->GetLockManager(), ns_key); + std::optional guard; + if (lock) { + guard.emplace(storage_->GetLockManager(), ns_key); + } auto s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); if (!s.ok()) return s; } @@ -394,41 +398,29 @@ rocksdb::Status String::MSet(const std::vector &pairs, uint64_t ttl) rocksdb::Status String::MSetNX(const std::vector &pairs, uint64_t ttl, bool *flag) { *flag = false; - uint64_t expire = 0; - if (ttl > 0) { - uint64_t now = util::GetTimeStampMS(); - expire = now + ttl; - } - int exists = 0; + std::string ns_key; + std::vector lock_keys; + lock_keys.reserve(pairs.size()); std::vector keys; keys.reserve(pairs.size()); + for (StringPair pair : pairs) { + AppendNamespacePrefix(pair.key, &ns_key); + lock_keys.emplace_back(ns_key); keys.emplace_back(pair.key); } + + // Lock these keys before doing anything. + MultiLockGuard guard(storage_->GetLockManager(), lock_keys); + if (Exists(keys, &exists).ok() && exists > 0) { return rocksdb::Status::OK(); } - std::string ns_key; - for (StringPair pair : pairs) { - AppendNamespacePrefix(pair.key, &ns_key); - LockGuard guard(storage_->GetLockManager(), ns_key); - if (Exists({pair.key}, &exists).ok() && exists == 1) { - return rocksdb::Status::OK(); - } - std::string bytes; - Metadata metadata(kRedisString, false); - metadata.expire = expire; - metadata.Encode(&bytes); - bytes.append(pair.value.data(), pair.value.size()); - auto batch = storage_->GetWriteBatchBase(); - WriteBatchLogData log_data(kRedisString); - batch->PutLogData(log_data.Encode()); - batch->Put(metadata_cf_handle_, ns_key, bytes); - auto s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); - if (!s.ok()) return s; - } + rocksdb::Status s = MSet(pairs, /*ttl=*/ttl, /*lock=*/false); + if (!s.ok()) return s; + *flag = true; return rocksdb::Status::OK(); } diff --git a/src/types/redis_string.h b/src/types/redis_string.h index 32ee7a9d26c..41be5bddd0e 100644 --- a/src/types/redis_string.h +++ b/src/types/redis_string.h @@ -50,7 +50,7 @@ class String : public Database { rocksdb::Status IncrBy(const std::string &user_key, int64_t increment, int64_t *new_value); rocksdb::Status IncrByFloat(const std::string &user_key, double increment, double *new_value); std::vector MGet(const std::vector &keys, std::vector *values); - rocksdb::Status MSet(const std::vector &pairs, uint64_t ttl = 0); + rocksdb::Status MSet(const std::vector &pairs, uint64_t ttl = 0, bool lock = true); rocksdb::Status MSetNX(const std::vector &pairs, uint64_t ttl, bool *flag); rocksdb::Status CAS(const std::string &user_key, const std::string &old_value, const std::string &new_value, uint64_t ttl, int *flag); diff --git a/tests/gocase/unit/type/strings/strings_test.go b/tests/gocase/unit/type/strings/strings_test.go index 86981912700..1d5cea1b46e 100644 --- a/tests/gocase/unit/type/strings/strings_test.go +++ b/tests/gocase/unit/type/strings/strings_test.go @@ -324,6 +324,21 @@ func TestString(t *testing.T) { require.Equal(t, "yyy", rdb.Get(ctx, "y2").Val()) }) + t.Run("MSETNX with already existent key - same key", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "x").Err()) + require.NoError(t, rdb.Set(ctx, "x", "v0", 0).Err()) + require.Equal(t, int64(0), rdb.Do(ctx, "MSETNX", "x", "v1", "x", "v2").Val()) + require.EqualValues(t, 1, rdb.Exists(ctx, "x").Val()) + require.Equal(t, "v0", rdb.Get(ctx, "x").Val()) + }) + + t.Run("MSETNX with not existing keys - same key", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "x").Err()) + require.Equal(t, int64(1), rdb.Do(ctx, "MSETNX", "x", "v1", "x", "v2").Val()) + require.EqualValues(t, 1, rdb.Exists(ctx, "x").Val()) + require.Equal(t, "v2", rdb.Get(ctx, "x").Val()) + }) + t.Run("STRLEN against non-existing key", func(t *testing.T) { require.EqualValues(t, 0, rdb.StrLen(ctx, "notakey").Val()) })