diff --git a/src/commands/cmd_zset.cc b/src/commands/cmd_zset.cc index acddad82ddf..05cdde2eb43 100644 --- a/src/commands/cmd_zset.cc +++ b/src/commands/cmd_zset.cc @@ -1419,6 +1419,99 @@ class CommandZRandMember : public Commander { bool no_parameters_ = true; }; +class CommandZDiff : public Commander { + public: + Status Parse(const std::vector &args) override { + auto parse_result = ParseInt(args[1], 10); + if (!parse_result) return {Status::RedisParseErr, errValueNotInteger}; + + numkeys_ = *parse_result; + if (numkeys_ > args.size() - 2) return {Status::RedisParseErr, errInvalidSyntax}; + + size_t j = 0; + while (j < numkeys_) { + keys_.emplace_back(args[j + 2]); + j++; + } + + if (auto i = 2 + numkeys_; i < args.size()) { + if (util::ToLower(args[i]) == "withscores") { + with_scores_ = true; + } + } + + return Commander::Parse(args); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::ZSet zset_db(srv->storage, conn->GetNamespace()); + + std::vector members_with_scores; + auto s = zset_db.Diff(keys_, &members_with_scores); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + + output->append(redis::MultiLen(members_with_scores.size() * (with_scores_ ? 2 : 1))); + for (const auto &ms : members_with_scores) { + output->append(redis::BulkString(ms.member)); + if (with_scores_) output->append(redis::BulkString(util::Float2String(ms.score))); + } + + return Status::OK(); + } + + static CommandKeyRange Range(const std::vector &args) { + int num_key = *ParseInt(args[1], 10); + return {2, 2 + num_key, 1}; + } + + protected: + size_t numkeys_ = 0; + std::vector keys_; + bool with_scores_ = false; +}; + +class CommandZDiffStore : public Commander { + public: + Status Parse(const std::vector &args) override { + auto parse_result = ParseInt(args[2], 10); + if (!parse_result) return {Status::RedisParseErr, errValueNotInteger}; + + numkeys_ = *parse_result; + if (numkeys_ > args.size() - 3) return {Status::RedisParseErr, errInvalidSyntax}; + + size_t j = 0; + while (j < numkeys_) { + keys_.emplace_back(args[j + 3]); + j++; + } + + return Commander::Parse(args); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::ZSet zset_db(srv->storage, conn->GetNamespace()); + + uint64_t stored_count = 0; + auto s = zset_db.DiffStore(args_[1], keys_, &stored_count); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + *output = redis::Integer(stored_count); + return Status::OK(); + } + + static CommandKeyRange Range(const std::vector &args) { + int num_key = *ParseInt(args[1], 10); + return {3, 2 + num_key, 1}; + } + + protected: + size_t numkeys_ = 0; + std::vector keys_; +}; + REDIS_REGISTER_COMMANDS(MakeCmdAttr("zadd", -4, "write", 1, 1, 1), MakeCmdAttr("zcard", 2, "read-only", 1, 1, 1), MakeCmdAttr("zcount", 4, "read-only", 1, 1, 1), @@ -1451,6 +1544,8 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr("zadd", -4, "write", 1, 1, 1), MakeCmdAttr("zscan", -3, "read-only", 1, 1, 1), MakeCmdAttr("zunionstore", -4, "write", CommandZUnionStore::Range), MakeCmdAttr("zunion", -3, "read-only", CommandZUnion::Range), - MakeCmdAttr("zrandmember", -2, "read-only", 1, 1, 1)) + MakeCmdAttr("zrandmember", -2, "read-only", 1, 1, 1), + MakeCmdAttr("zdiff", -3, "read-only", CommandZDiff::Range), + MakeCmdAttr("zdiffstore", -3, "read-only", CommandZDiffStore::Range), ) } // namespace redis diff --git a/src/types/redis_zset.cc b/src/types/redis_zset.cc index 9215d6211f5..7532a5f35c7 100644 --- a/src/types/redis_zset.cc +++ b/src/types/redis_zset.cc @@ -931,4 +931,42 @@ rocksdb::Status ZSet::RandMember(const Slice &user_key, int64_t command_count, return rocksdb::Status::OK(); } +rocksdb::Status ZSet::Diff(const std::vector &keys, MemberScores *members) { + members->clear(); + MemberScores source_member_scores; + RangeScoreSpec spec; + uint64_t size = 0; + auto s = RangeByScore(keys[0], spec, &source_member_scores, &size); + if (!s.ok()) return s; + + if (size == 0) { + return rocksdb::Status::OK(); + } + + std::map exclude_members; + MemberScores target_member_scores; + for (size_t i = 1; i < keys.size(); i++) { + uint64_t size = 0; + s = RangeByScore(keys[i], spec, &target_member_scores, &size); + if (!s.ok()) return s; + for (const auto &member_score : target_member_scores) { + exclude_members[member_score.member] = true; + } + } + for (const auto &member_score : source_member_scores) { + if (exclude_members.find(member_score.member) == exclude_members.end()) { + members->push_back(member_score); + } + } + return rocksdb::Status::OK(); +} + +rocksdb::Status ZSet::DiffStore(const Slice &dst, const std::vector &keys, uint64_t *stored_count) { + MemberScores mscores; + auto s = Diff(keys, &mscores); + if (!s.ok()) return s; + *stored_count = mscores.size(); + return Overwrite(dst, mscores); +} + } // namespace redis diff --git a/src/types/redis_zset.h b/src/types/redis_zset.h index 397ca10b126..d806d57e3cf 100644 --- a/src/types/redis_zset.h +++ b/src/types/redis_zset.h @@ -116,6 +116,8 @@ class ZSet : public SubKeyScanner { AggregateMethod aggregate_method, uint64_t *saved_cnt); rocksdb::Status Union(const std::vector &keys_weights, AggregateMethod aggregate_method, std::vector *members); + rocksdb::Status Diff(const std::vector &keys, MemberScores *members); + rocksdb::Status DiffStore(const Slice &dst, const std::vector &keys, uint64_t *stored_count); rocksdb::Status MGet(const Slice &user_key, const std::vector &members, std::map *scores); rocksdb::Status GetMetadata(const Slice &ns_key, ZSetMetadata *metadata); diff --git a/tests/cppunit/types/zset_test.cc b/tests/cppunit/types/zset_test.cc index 34c71d78c2f..da2ce71469c 100644 --- a/tests/cppunit/types/zset_test.cc +++ b/tests/cppunit/types/zset_test.cc @@ -535,3 +535,81 @@ TEST_F(RedisZSetTest, RandMember) { auto s = zset_->Del(key_); EXPECT_TRUE(s.ok()); } + +TEST_F(RedisZSetTest, Diff) { + uint64_t ret = 0; + + std::string k1 = "key1"; + std::vector k1_mscores = {{"a", -100.1}, {"b", -100.1}, {"c", 0}, {"d", 1.234}}; + + std::string k2 = "key2"; + std::vector k2_mscores = {{"c", -150.1}}; + + std::string k3 = "key3"; + std::vector k3_mscores = {{"a", -1000.1}, {"c", -100.1}, {"e", 8000.9}}; + + auto s = zset_->Add(k1, ZAddFlags::Default(), &k1_mscores, &ret); + EXPECT_EQ(ret, 4); + zset_->Add(k2, ZAddFlags::Default(), &k2_mscores, &ret); + EXPECT_EQ(ret, 1); + zset_->Add(k3, ZAddFlags::Default(), &k3_mscores, &ret); + EXPECT_EQ(ret, 3); + + std::vector mscores; + zset_->Diff({k1, k2, k3}, &mscores); + + EXPECT_EQ(2, mscores.size()); + std::vector expected_mscores = {{"b", -100.1}, {"d", 1.234}}; + int index = 0; + for (const auto &mscore : expected_mscores) { + EXPECT_EQ(mscore.member, mscores[index].member); + EXPECT_EQ(mscore.score, mscores[index].score); + index++; + } + + s = zset_->Del(k1); + EXPECT_TRUE(s.ok()); + s = zset_->Del(k2); + EXPECT_TRUE(s.ok()); + s = zset_->Del(k3); + EXPECT_TRUE(s.ok()); +} + +TEST_F(RedisZSetTest, DiffStore) { + uint64_t ret = 0; + + std::string k1 = "key1"; + std::vector k1_mscores = {{"a", -100.1}, {"b", -100.1}, {"c", 0}, {"d", 1.234}}; + + std::string k2 = "key2"; + std::vector k2_mscores = {{"c", -150.1}}; + + auto s = zset_->Add(k1, ZAddFlags::Default(), &k1_mscores, &ret); + EXPECT_EQ(ret, 4); + zset_->Add(k2, ZAddFlags::Default(), &k2_mscores, &ret); + EXPECT_EQ(ret, 1); + + uint64_t stored_count = 0; + zset_->DiffStore("zsetdiff", {k1, k2}, &stored_count); + EXPECT_EQ(stored_count, 3); + + RangeScoreSpec spec; + std::vector mscores; + zset_->RangeByScore("zsetdiff", spec, &mscores, nullptr); + EXPECT_EQ(mscores.size(), 3); + + std::vector expected_mscores = {{"a", -100.1}, {"b", -100.1}, {"d", 1.234}}; + int index = 0; + for (const auto &mscore : expected_mscores) { + EXPECT_EQ(mscore.member, mscores[index].member); + EXPECT_EQ(mscore.score, mscores[index].score); + index++; + } + + s = zset_->Del(k1); + EXPECT_TRUE(s.ok()); + s = zset_->Del(k2); + EXPECT_TRUE(s.ok()); + s = zset_->Del("zsetdiff"); + EXPECT_TRUE(s.ok()); +} diff --git a/tests/gocase/unit/type/zset/zset_test.go b/tests/gocase/unit/type/zset/zset_test.go index 86adceda403..d7bc434e924 100644 --- a/tests/gocase/unit/type/zset/zset_test.go +++ b/tests/gocase/unit/type/zset/zset_test.go @@ -1463,6 +1463,167 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding s ).Err(), ".*weight.*not.*double.*") }) } + + t.Run(fmt.Sprintf("ZDIFF with two sets - %s", encoding), func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{ + {Score: 1, Member: "a"}, + {Score: 2, Member: "b"}, + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 4, Member: "e"}, + }) + createZset(rdb, ctx, "zsetb", []redis.Z{ + {Score: 1, Member: "b"}, + {Score: 2, Member: "c"}, + {Score: 4, Member: "f"}, + }) + cmd := rdb.ZDiff(ctx, "zseta", "zsetb") + require.NoError(t, cmd.Err()) + sort.Strings(cmd.Val()) + require.EqualValues(t, []string{"a", "d", "e"}, cmd.Val()) + }) + + t.Run(fmt.Sprintf("ZDIFF with three sets - %s", encoding), func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{ + {Score: 1, Member: "a"}, + {Score: 2, Member: "b"}, + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 4, Member: "e"}, + }) + createZset(rdb, ctx, "zsetb", []redis.Z{ + {Score: 1, Member: "b"}, + {Score: 2, Member: "c"}, + {Score: 4, Member: "f"}, + }) + createZset(rdb, ctx, "zsetc", []redis.Z{ + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 5, Member: "e"}, + }) + cmd := rdb.ZDiff(ctx, "zseta", "zsetb", "zsetc") + require.NoError(t, cmd.Err()) + sort.Strings(cmd.Val()) + require.EqualValues(t, []string{"a"}, cmd.Val()) + }) + + t.Run(fmt.Sprintf("ZDIFF with three sets with scores - %s", encoding), func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{ + {Score: 1, Member: "a"}, + {Score: 2, Member: "b"}, + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 4, Member: "e"}, + }) + createZset(rdb, ctx, "zsetb", []redis.Z{ + {Score: 1, Member: "b"}, + {Score: 2, Member: "c"}, + {Score: 4, Member: "f"}, + }) + createZset(rdb, ctx, "zsetc", []redis.Z{ + {Score: 4, Member: "c"}, + {Score: 5, Member: "e"}, + }) + cmd := rdb.ZDiffWithScores(ctx, "zseta", "zsetb", "zsetc") + require.NoError(t, cmd.Err()) + require.EqualValues(t, []redis.Z([]redis.Z{{Score: 1, Member: "a"}, {Score: 3, Member: "d"}}), cmd.Val()) + }) + + t.Run(fmt.Sprintf("ZDIFF with empty sets - %s", encoding), func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{}) + createZset(rdb, ctx, "zsetb", []redis.Z{}) + cmd := rdb.ZDiff(ctx, "zseta", "zsetb") + require.NoError(t, cmd.Err()) + require.EqualValues(t, []string([]string{}), cmd.Val()) + }) + + t.Run(fmt.Sprintf("ZDIFF with non existing sets - %s", encoding), func(t *testing.T) { + rdb.Del(ctx, "zseta") + rdb.Del(ctx, "zsetb") + cmd := rdb.ZDiff(ctx, "zseta", "zsetb") + require.NoError(t, cmd.Err()) + require.EqualValues(t, []string([]string{}), cmd.Val()) + }) + + t.Run(fmt.Sprintf("ZDIFF with missing set with scores - %s", encoding), func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{ + {Score: 1, Member: "a"}, + {Score: 2, Member: "b"}, + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + }) + createZset(rdb, ctx, "zsetb", []redis.Z{ + {Score: 1, Member: "b"}, + {Score: 2, Member: "c"}, + {Score: 4, Member: "f"}, + }) + rdb.Del(ctx, "zsetc") + cmd := rdb.ZDiffWithScores(ctx, "zseta", "zsetb", "zsetc") + require.NoError(t, cmd.Err()) + require.EqualValues(t, []redis.Z([]redis.Z{{Score: 1, Member: "a"}, {Score: 3, Member: "d"}}), cmd.Val()) + }) + + t.Run(fmt.Sprintf("ZDIFF with empty sets with scores - %s", encoding), func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{}) + createZset(rdb, ctx, "zsetb", []redis.Z{}) + cmd := rdb.ZDiffWithScores(ctx, "zseta", "zsetb") + require.NoError(t, cmd.Err()) + require.EqualValues(t, []redis.Z([]redis.Z{}), cmd.Val()) + }) + + t.Run("ZDIFFSTORE with three sets - ", func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{ + {Score: 1, Member: "a"}, + {Score: 2, Member: "b"}, + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 4, Member: "e"}, + }) + createZset(rdb, ctx, "zsetb", []redis.Z{ + {Score: 1, Member: "b"}, + {Score: 2, Member: "c"}, + {Score: 4, Member: "f"}, + }) + createZset(rdb, ctx, "zsetc", []redis.Z{ + {Score: 4, Member: "c"}, + {Score: 5, Member: "e"}, + }) + cmd := rdb.ZDiffStore(ctx, "setres", "zseta", "zsetb", "zsetc") + require.NoError(t, cmd.Err()) + require.EqualValues(t, int64(2), cmd.Val()) + require.Equal(t, []redis.Z([]redis.Z{{Score: 1, Member: "a"}, {Score: 3, Member: "d"}}), rdb.ZRangeWithScores(ctx, "setres", 0, -1).Val()) + }) + + t.Run("ZDIFFSTORE with missing sets - ", func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{ + {Score: 1, Member: "a"}, + {Score: 2, Member: "b"}, + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 4, Member: "e"}, + }) + createZset(rdb, ctx, "zsetb", []redis.Z{ + {Score: 1, Member: "b"}, + {Score: 2, Member: "c"}, + {Score: 4, Member: "f"}, + {Score: 4, Member: "e"}, + }) + rdb.Del(ctx, "zsetc") + cmd := rdb.ZDiffStore(ctx, "setres", "zseta", "zsetb", "zsetc") + require.NoError(t, cmd.Err()) + require.EqualValues(t, int64(2), cmd.Val()) + require.Equal(t, []redis.Z([]redis.Z{{Score: 1, Member: "a"}, {Score: 3, Member: "d"}}), rdb.ZRangeWithScores(ctx, "setres", 0, -1).Val()) + }) + + t.Run("ZDIFFSTORE with missing sets - ", func(t *testing.T) { + rdb.Del(ctx, "zseta") + rdb.Del(ctx, "zsetb") + rdb.Del(ctx, "zsetc") + cmd := rdb.ZDiffStore(ctx, "setres", "zseta", "zsetb", "zsetc") + require.NoError(t, cmd.Err()) + require.EqualValues(t, int64(0), cmd.Val()) + require.Equal(t, []redis.Z([]redis.Z{}), rdb.ZRangeWithScores(ctx, "setres", 0, -1).Val()) + }) } func stressTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding string) {