Skip to content

Commit

Permalink
Fix crash in zset store getkeys, fix zdiff/bzmpop range, add tests
Browse files Browse the repository at this point in the history
These following cases will crash the server, the reason is that
the index of numkeys is wrong:
```
command getkeys zdiffstore dst 2 src1 src2
command getkeys zinterstore dst 2 src1 src2
command getkeys zunionstore dst 2 src1 src2
```

These following getkeys output is wrong:
```
> command getkeys zdiff 2 key1 key2
1) "key1"
2) "key2"
3) (nil)

> command getkeys bzmpop 0 2 key1 key2
1) "key1"
```

These are ok:
```
command getkeys zinter 2 key1 key2
command getkeys zunion 2 key1 key2
command getkeys sintercard 2 key1 key2
command getkeys zintercard 2 key1 key2
command getkeys zmpop 2 key1 key2
command getkeys lmpop 2 key1 key2
command getkeys blmpop 0 2 key1 key2
```

However, at present, there is still a problem with our zset store.
We do not support returning dst key, but let's do it later...
```
127.0.0.1:6379> command getkeys zinterstore dst 2 src1 src2
1) "dst"
2) "src1"
3) "src2"

127.0.0.1:6666> command getkeys zinterstore dst 2 src1 src2
1) "src1"
2) "src2"
```
  • Loading branch information
enjoy-binbin committed Jan 25, 2024
1 parent 7716813 commit 16f8e96
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/commands/cmd_zset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ class CommandBZMPop : public BlockingCommander {

static CommandKeyRange Range(const std::vector<std::string> &args) {
int num_key = *ParseInt<int>(args[2], 10);
return {3, 1 + num_key, 1};
return {3, 2 + num_key, 1};
}

private:
Expand Down Expand Up @@ -1223,7 +1223,7 @@ class CommandZUnionStore : public Commander {
}

static CommandKeyRange Range(const std::vector<std::string> &args) {
int num_key = *ParseInt<int>(args[1], 10);
int num_key = *ParseInt<int>(args[2], 10);
return {3, 2 + num_key, 1};
}

Expand All @@ -1250,7 +1250,7 @@ class CommandZInterStore : public CommandZUnionStore {
}

static CommandKeyRange Range(const std::vector<std::string> &args) {
int num_key = *ParseInt<int>(args[1], 10);
int num_key = *ParseInt<int>(args[2], 10);
return {3, 2 + num_key, 1};
}
};
Expand Down Expand Up @@ -1464,7 +1464,7 @@ class CommandZDiff : public Commander {

static CommandKeyRange Range(const std::vector<std::string> &args) {
int num_key = *ParseInt<int>(args[1], 10);
return {2, 2 + num_key, 1};
return {2, 1 + num_key, 1};
}

protected:
Expand Down Expand Up @@ -1504,7 +1504,7 @@ class CommandZDiffStore : public Commander {
}

static CommandKeyRange Range(const std::vector<std::string> &args) {
int num_key = *ParseInt<int>(args[1], 10);
int num_key = *ParseInt<int>(args[2], 10);
return {3, 2 + num_key, 1};
}

Expand Down
108 changes: 108 additions & 0 deletions tests/gocase/unit/command/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,112 @@ func TestCommand(t *testing.T) {
require.Len(t, vs, 1)
require.Equal(t, "test", vs[0])
})

t.Run("COMMAND GETKEYS SINTERCARD", func(t *testing.T) {
r := rdb.Do(ctx, "COMMAND", "GETKEYS", "SINTERCARD", "2", "key1", "key2")
vs, err := r.Slice()
require.NoError(t, err)
require.Len(t, vs, 2)
require.Equal(t, "key1", vs[0])
require.Equal(t, "key2", vs[1])
})

t.Run("COMMAND GETKEYS ZINTER", func(t *testing.T) {
r := rdb.Do(ctx, "COMMAND", "GETKEYS", "ZINTER", "2", "key1", "key2")
vs, err := r.Slice()
require.NoError(t, err)
require.Len(t, vs, 2)
require.Equal(t, "key1", vs[0])
require.Equal(t, "key2", vs[1])
})

t.Run("COMMAND GETKEYS ZINTERSTORE", func(t *testing.T) {
r := rdb.Do(ctx, "COMMAND", "GETKEYS", "ZINTERSTORE", "dst", "2", "src1", "src2")
vs, err := r.Slice()
require.NoError(t, err)
require.Len(t, vs, 2)
require.Equal(t, "src1", vs[0])
require.Equal(t, "src2", vs[1])
})

t.Run("COMMAND GETKEYS ZINTERCARD", func(t *testing.T) {
r := rdb.Do(ctx, "COMMAND", "GETKEYS", "ZINTERCARD", "2", "key1", "key2")
vs, err := r.Slice()
require.NoError(t, err)
require.Len(t, vs, 2)
require.Equal(t, "key1", vs[0])
require.Equal(t, "key2", vs[1])
})

t.Run("COMMAND GETKEYS ZUNION", func(t *testing.T) {
r := rdb.Do(ctx, "COMMAND", "GETKEYS", "ZUNION", "2", "key1", "key2")
vs, err := r.Slice()
require.NoError(t, err)
require.Len(t, vs, 2)
require.Equal(t, "key1", vs[0])
require.Equal(t, "key2", vs[1])
})

t.Run("COMMAND GETKEYS ZUNIONSTORE", func(t *testing.T) {
r := rdb.Do(ctx, "COMMAND", "GETKEYS", "ZUNIONSTORE", "dst", "2", "src1", "src2")
vs, err := r.Slice()
require.NoError(t, err)
require.Len(t, vs, 2)
require.Equal(t, "src1", vs[0])
require.Equal(t, "src2", vs[1])
})

t.Run("COMMAND GETKEYS ZDIFF", func(t *testing.T) {
r := rdb.Do(ctx, "COMMAND", "GETKEYS", "ZDIFF", "2", "key1", "key2")
vs, err := r.Slice()
require.NoError(t, err)
require.Len(t, vs, 2)
require.Equal(t, "key1", vs[0])
require.Equal(t, "key2", vs[1])
})

t.Run("COMMAND GETKEYS ZDIFFSTORE", func(t *testing.T) {
r := rdb.Do(ctx, "COMMAND", "GETKEYS", "ZDIFFSTORE", "dst", "2", "src1", "src2")
vs, err := r.Slice()
require.NoError(t, err)
require.Len(t, vs, 2)
require.Equal(t, "src1", vs[0])
require.Equal(t, "src2", vs[1])
})

t.Run("COMMAND GETKEYS ZMPOP", func(t *testing.T) {
r := rdb.Do(ctx, "COMMAND", "GETKEYS", "ZMPOP", "2", "key1", "key2")
vs, err := r.Slice()
require.NoError(t, err)
require.Len(t, vs, 2)
require.Equal(t, "key1", vs[0])
require.Equal(t, "key2", vs[1])
})

t.Run("COMMAND GETKEYS BZMPOP", func(t *testing.T) {
r := rdb.Do(ctx, "COMMAND", "GETKEYS", "BZMPOP", "0", "2", "key1", "key2")
vs, err := r.Slice()
require.NoError(t, err)
require.Len(t, vs, 2)
require.Equal(t, "key1", vs[0])
require.Equal(t, "key2", vs[1])
})

t.Run("COMMAND GETKEYS LMPOP", func(t *testing.T) {
r := rdb.Do(ctx, "COMMAND", "GETKEYS", "LMPOP", "2", "key1", "key2")
vs, err := r.Slice()
require.NoError(t, err)
require.Len(t, vs, 2)
require.Equal(t, "key1", vs[0])
require.Equal(t, "key2", vs[1])
})

t.Run("COMMAND GETKEYS BLMPOP", func(t *testing.T) {
r := rdb.Do(ctx, "COMMAND", "GETKEYS", "BLMPOP", "0", "2", "key1", "key2")
vs, err := r.Slice()
require.NoError(t, err)
require.Len(t, vs, 2)
require.Equal(t, "key1", vs[0])
require.Equal(t, "key2", vs[1])
})
}

0 comments on commit 16f8e96

Please sign in to comment.