Skip to content

Commit

Permalink
Fix MSETNX not allow overriding the same key
Browse files Browse the repository at this point in the history
  • Loading branch information
enjoy-binbin committed Aug 3, 2023
1 parent 72cc956 commit 3860daa
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/commands/cmd_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ REDIS_REGISTER_COMMANDS(
MakeCmdAttr<CommandAppend>("append", 3, "write", 1, 1, 1), MakeCmdAttr<CommandSet>("set", -3, "write", 1, 1, 1),
MakeCmdAttr<CommandSetEX>("setex", 4, "write", 1, 1, 1), MakeCmdAttr<CommandPSetEX>("psetex", 4, "write", 1, 1, 1),
MakeCmdAttr<CommandSetNX>("setnx", 3, "write", 1, 1, 1),
MakeCmdAttr<CommandMSetNX>("msetnx", -3, "write exclusive", 1, -1, 2),
MakeCmdAttr<CommandMSetNX>("msetnx", -3, "write", 1, -1, 2),
MakeCmdAttr<CommandMSet>("mset", -3, "write", 1, -1, 2), MakeCmdAttr<CommandIncrBy>("incrby", 3, "write", 1, 1, 1),
MakeCmdAttr<CommandIncrByFloat>("incrbyfloat", 3, "write", 1, 1, 1),
MakeCmdAttr<CommandIncr>("incr", 2, "write", 1, 1, 1), MakeCmdAttr<CommandDecrBy>("decrby", 3, "write", 1, 1, 1),
Expand Down
45 changes: 16 additions & 29 deletions src/types/redis_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,12 @@ rocksdb::Status String::GetDel(const std::string &user_key, std::string *value)

rocksdb::Status String::Set(const std::string &user_key, const std::string &value) {
std::vector<StringPair> pairs{StringPair{user_key, value}};
return MSet(pairs, 0);
return MSet(pairs, 0, true);
}

rocksdb::Status String::SetEX(const std::string &user_key, const std::string &value, uint64_t ttl) {
std::vector<StringPair> pairs{StringPair{user_key, value}};
return MSet(pairs, ttl);
return MSet(pairs, ttl, true);
}

rocksdb::Status String::SetNX(const std::string &user_key, const std::string &value, uint64_t ttl, bool *flag) {
Expand Down Expand Up @@ -363,7 +363,7 @@ rocksdb::Status String::IncrByFloat(const std::string &user_key, double incremen
return updateRawValue(ns_key, raw_value);
}

rocksdb::Status String::MSet(const std::vector<StringPair> &pairs, uint64_t ttl) {
rocksdb::Status String::MSet(const std::vector<StringPair> &pairs, uint64_t ttl, bool lock) {
uint64_t expire = 0;
if (ttl > 0) {
uint64_t now = util::GetTimeStampMS();
Expand All @@ -384,7 +384,7 @@ rocksdb::Status String::MSet(const std::vector<StringPair> &pairs, uint64_t ttl)
batch->PutLogData(log_data.Encode());
AppendNamespacePrefix(pair.key, &ns_key);
batch->Put(metadata_cf_handle_, ns_key, bytes);
LockGuard guard(storage_->GetLockManager(), ns_key);
if (lock) LockGuard guard(storage_->GetLockManager(), ns_key);
auto s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
if (!s.ok()) return s;
}
Expand All @@ -394,41 +394,28 @@ rocksdb::Status String::MSet(const std::vector<StringPair> &pairs, uint64_t ttl)
rocksdb::Status String::MSetNX(const std::vector<StringPair> &pairs, uint64_t ttl, bool *flag) {
*flag = false;

uint64_t expire = 0;
if (ttl > 0) {
uint64_t now = util::GetTimeStampMS();
expire = now + ttl;
}

int exists = 0;
std::string ns_key;
std::vector<std::string> lock_keys;
std::vector<Slice> keys;
keys.reserve(pairs.size());

for (StringPair pair : pairs) {
AppendNamespacePrefix(pair.key, &ns_key);
lock_keys.emplace_back(ns_key);
keys.emplace_back(pair.key);
}

// Lock these keys before doing anything.
MultiLockGuard guard(storage_->GetLockManager(), lock_keys);

if (Exists(keys, &exists).ok() && exists > 0) {
return rocksdb::Status::OK();
}

std::string ns_key;
for (StringPair pair : pairs) {
AppendNamespacePrefix(pair.key, &ns_key);
LockGuard guard(storage_->GetLockManager(), ns_key);
if (Exists({pair.key}, &exists).ok() && exists == 1) {
return rocksdb::Status::OK();
}
std::string bytes;
Metadata metadata(kRedisString, false);
metadata.expire = expire;
metadata.Encode(&bytes);
bytes.append(pair.value.data(), pair.value.size());
auto batch = storage_->GetWriteBatchBase();
WriteBatchLogData log_data(kRedisString);
batch->PutLogData(log_data.Encode());
batch->Put(metadata_cf_handle_, ns_key, bytes);
auto s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
if (!s.ok()) return s;
}
rocksdb::Status s = MSet(pairs, ttl, false);
if (!s.ok()) return s;

*flag = true;
return rocksdb::Status::OK();
}
Expand Down
2 changes: 1 addition & 1 deletion src/types/redis_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class String : public Database {
rocksdb::Status IncrBy(const std::string &user_key, int64_t increment, int64_t *new_value);
rocksdb::Status IncrByFloat(const std::string &user_key, double increment, double *new_value);
std::vector<rocksdb::Status> MGet(const std::vector<Slice> &keys, std::vector<std::string> *values);
rocksdb::Status MSet(const std::vector<StringPair> &pairs, uint64_t ttl = 0);
rocksdb::Status MSet(const std::vector<StringPair> &pairs, uint64_t ttl = 0, bool lock = true);
rocksdb::Status MSetNX(const std::vector<StringPair> &pairs, uint64_t ttl, bool *flag);
rocksdb::Status CAS(const std::string &user_key, const std::string &old_value, const std::string &new_value,
uint64_t ttl, int *flag);
Expand Down
15 changes: 15 additions & 0 deletions tests/gocase/unit/type/strings/strings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,21 @@ func TestString(t *testing.T) {
require.Equal(t, "yyy", rdb.Get(ctx, "y2").Val())
})

t.Run("MSETNX with already existent key - same key", func(t *testing.T) {
require.NoError(t, rdb.Del(ctx, "x").Err())
require.NoError(t, rdb.Set(ctx, "x", "v0", 0).Err())
require.Equal(t, int64(0), rdb.Do(ctx, "MSETNX", "x", "v1", "x", "v2").Val())
require.EqualValues(t, 1, rdb.Exists(ctx, "x").Val())
require.Equal(t, "v0", rdb.Get(ctx, "x").Val())
})

t.Run("MSETNX with not existing keys - same key", func(t *testing.T) {
require.NoError(t, rdb.Del(ctx, "x").Err())
require.Equal(t, int64(1), rdb.Do(ctx, "MSETNX", "x", "v1", "x", "v2").Val())
require.EqualValues(t, 1, rdb.Exists(ctx, "x").Val())
require.Equal(t, "v2", rdb.Get(ctx, "x").Val())
})

t.Run("STRLEN against non-existing key", func(t *testing.T) {
require.EqualValues(t, 0, rdb.StrLen(ctx, "notakey").Val())
})
Expand Down

0 comments on commit 3860daa

Please sign in to comment.