diff --git a/src/commands/cmd_zset.cc b/src/commands/cmd_zset.cc index 1fa51ab227c..42e1740d861 100644 --- a/src/commands/cmd_zset.cc +++ b/src/commands/cmd_zset.cc @@ -1357,6 +1357,99 @@ class CommandZScan : public CommandSubkeyScanBase { } }; +class CommandZDiff : public Commander { + public: + Status Parse(const std::vector &args) override { + auto parse_result = ParseInt(args[1], 10); + if (!parse_result) return {Status::RedisParseErr, errValueNotInteger}; + + numkeys_ = *parse_result; + if (numkeys_ > args.size() - 2) return {Status::RedisParseErr, errInvalidSyntax}; + + size_t j = 0; + while (j < numkeys_) { + keys_.emplace_back(args[j + 2]); + j++; + } + + if (auto i = 2 + numkeys_; i < args.size()) { + if (util::ToLower(args[i]) == "withscores") { + with_scores_ = true; + } + } + + return Commander::Parse(args); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::ZSet zset_db(srv->storage, conn->GetNamespace()); + + std::vector members_with_scores; + auto s = zset_db.Diff(keys_, &members_with_scores); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + + output->append(redis::MultiLen(members_with_scores.size() * (with_scores_ ? 2 : 1))); + for (const auto &ms : members_with_scores) { + output->append(redis::BulkString(ms.member)); + if (with_scores_) output->append(redis::BulkString(util::Float2String(ms.score))); + } + + return Status::OK(); + } + + static CommandKeyRange Range(const std::vector &args) { + int num_key = *ParseInt(args[1], 10); + return {2, 2 + num_key, 1}; + } + + protected: + size_t numkeys_ = 0; + std::vector keys_; + bool with_scores_ = false; +}; + +class CommandZDiffStore : public Commander { + public: + Status Parse(const std::vector &args) override { + auto parse_result = ParseInt(args[2], 10); + if (!parse_result) return {Status::RedisParseErr, errValueNotInteger}; + + numkeys_ = *parse_result; + if (numkeys_ > args.size() - 3) return {Status::RedisParseErr, errInvalidSyntax}; + + size_t j = 0; + while (j < numkeys_) { + keys_.emplace_back(args[j + 3]); + j++; + } + + return Commander::Parse(args); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::ZSet zset_db(srv->storage, conn->GetNamespace()); + + uint64_t stored_count = 0; + auto s = zset_db.DiffStore(args_[1], keys_, &stored_count); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + *output = redis::Integer(stored_count); + return Status::OK(); + } + + static CommandKeyRange Range(const std::vector &args) { + int num_key = *ParseInt(args[1], 10); + return {3, 2 + num_key, 1}; + } + + protected: + size_t numkeys_ = 0; + std::vector keys_; +}; + REDIS_REGISTER_COMMANDS(MakeCmdAttr("zadd", -4, "write", 1, 1, 1), MakeCmdAttr("zcard", 2, "read-only", 1, 1, 1), MakeCmdAttr("zcount", 4, "read-only", 1, 1, 1), @@ -1388,6 +1481,8 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr("zadd", -4, "write", 1, 1, 1), MakeCmdAttr("zmscore", -3, "read-only", 1, 1, 1), MakeCmdAttr("zscan", -3, "read-only", 1, 1, 1), MakeCmdAttr("zunionstore", -4, "write", CommandZUnionStore::Range), - MakeCmdAttr("zunion", -3, "read-only", CommandZUnion::Range), ) + MakeCmdAttr("zunion", -3, "read-only", CommandZUnion::Range), + MakeCmdAttr("zdiff", -3, "read-only", CommandZDiff::Range), + MakeCmdAttr("zdiffstore", -3, "read-only", CommandZDiffStore::Range), ) } // namespace redis diff --git a/src/types/redis_zset.cc b/src/types/redis_zset.cc index f29765443c4..d231c5a3a62 100644 --- a/src/types/redis_zset.cc +++ b/src/types/redis_zset.cc @@ -851,4 +851,46 @@ rocksdb::Status ZSet::MGet(const Slice &user_key, const std::vector &memb return rocksdb::Status::OK(); } +rocksdb::Status ZSet::Diff(const std::vector &keys, MemberScores *members) { + std::vector lock_keys; + lock_keys.reserve(keys.size()); + for (const auto key : keys) { + std::string ns_key = AppendNamespacePrefix(key); + lock_keys.emplace_back(std::move(ns_key)); + } + MultiLockGuard guard(storage_->GetLockManager(), lock_keys); + + members->clear(); + MemberScores source_member_scores; + RangeScoreSpec spec; + uint64_t size = 0; + auto s = RangeByScore(keys[0], spec, &source_member_scores, &size); + if (!s.ok()) return s; + + std::map exclude_members; + MemberScores target_member_scores; + for (size_t i = 1; i < keys.size(); i++) { + uint64_t size = 0; + s = RangeByScore(keys[i], spec, &target_member_scores, &size); + if (!s.ok()) return s; + for (const auto &member_score : target_member_scores) { + exclude_members[member_score.member] = true; + } + } + for (const auto &member_score : source_member_scores) { + if (exclude_members.find(member_score.member) == exclude_members.end()) { + members->push_back(member_score); + } + } + return rocksdb::Status::OK(); +} + +rocksdb::Status ZSet::DiffStore(const Slice &dst, const std::vector &keys, uint64_t *stored_count) { + MemberScores mscores; + auto s = Diff(keys, &mscores); + if (!s.ok()) return s; + *stored_count = mscores.size(); + return Overwrite(dst, mscores); +} + } // namespace redis diff --git a/src/types/redis_zset.h b/src/types/redis_zset.h index 3cd81622ece..4105bbbc13a 100644 --- a/src/types/redis_zset.h +++ b/src/types/redis_zset.h @@ -116,6 +116,8 @@ class ZSet : public SubKeyScanner { AggregateMethod aggregate_method, uint64_t *saved_cnt); rocksdb::Status Union(const std::vector &keys_weights, AggregateMethod aggregate_method, std::vector *members); + rocksdb::Status Diff(const std::vector &keys, MemberScores *members); + rocksdb::Status DiffStore(const Slice &dst, const std::vector &keys, uint64_t *stored_count); rocksdb::Status MGet(const Slice &user_key, const std::vector &members, std::map *scores); rocksdb::Status GetMetadata(const Slice &ns_key, ZSetMetadata *metadata); diff --git a/tests/cppunit/types/zset_test.cc b/tests/cppunit/types/zset_test.cc index 230aa4009ef..177b12ab910 100644 --- a/tests/cppunit/types/zset_test.cc +++ b/tests/cppunit/types/zset_test.cc @@ -433,3 +433,39 @@ TEST_F(RedisZSetTest, Rank) { } auto s = zset_->Del(key_); } + +TEST_F(RedisZSetTest, Diff) { + uint64_t ret = 0; + + std::string k1 = "key1"; + std::vector k1_mscores = {{"a", -100.1}, {"b", -100.1}, {"c", 0}, {"d", 1.234}}; + + std::string k2 = "key2"; + std::vector k2_mscores = {{"c", -150.1}}; + + std::string k3 = "key3"; + std::vector k3_mscores = {{"a", -1000.1}, {"c", -100.1}, {"e", 8000.9}}; + + auto s = zset_->Add(k1, ZAddFlags::Default(), &k1_mscores, &ret); + EXPECT_EQ(ret, 4); + zset_->Add(k2, ZAddFlags::Default(), &k2_mscores, &ret); + EXPECT_EQ(ret, 1); + zset_->Add(k3, ZAddFlags::Default(), &k3_mscores, &ret); + EXPECT_EQ(ret, 3); + + std::vector mscores; + zset_->Diff({k1, k2, k3}, &mscores); + + EXPECT_EQ(2, mscores.size()); + std::vector expected_mscores = {{"b", -100.1}, {"d", 1.234}}; + int index = 0; + for (const auto& mscore : expected_mscores) { + EXPECT_EQ(mscore.member, mscores[index].member); + EXPECT_EQ(mscore.score, mscores[index].score); + index++; + } + + s = zset_->Del(k1); + s = zset_->Del(k2); + s = zset_->Del(k3); +} diff --git a/tests/gocase/unit/type/zset/zset_test.go b/tests/gocase/unit/type/zset/zset_test.go index 860316b2035..b9970d9f25e 100644 --- a/tests/gocase/unit/type/zset/zset_test.go +++ b/tests/gocase/unit/type/zset/zset_test.go @@ -1331,6 +1331,93 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding s ).Err(), ".*weight.*not.*double.*") }) } + + t.Run(fmt.Sprintf("ZDIFF with two sets - %s", encoding), func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{ + {Score: 1, Member: "a"}, + {Score: 2, Member: "b"}, + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 4, Member: "e"}, + }) + createZset(rdb, ctx, "zsetb", []redis.Z{ + {Score: 1, Member: "b"}, + {Score: 2, Member: "c"}, + {Score: 4, Member: "f"}, + }) + cmd := rdb.ZDiff(ctx, "zseta", "zsetb") + require.NoError(t, cmd.Err()) + sort.Strings(cmd.Val()) + require.EqualValues(t, []string{"a", "d", "e"}, cmd.Val()) + }) + + t.Run(fmt.Sprintf("ZDIFF with three sets - %s", encoding), func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{ + {Score: 1, Member: "a"}, + {Score: 2, Member: "b"}, + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 4, Member: "e"}, + }) + createZset(rdb, ctx, "zsetb", []redis.Z{ + {Score: 1, Member: "b"}, + {Score: 2, Member: "c"}, + {Score: 4, Member: "f"}, + }) + createZset(rdb, ctx, "zsetc", []redis.Z{ + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 5, Member: "e"}, + }) + cmd := rdb.ZDiff(ctx, "zseta", "zsetb", "zsetc") + require.NoError(t, cmd.Err()) + sort.Strings(cmd.Val()) + require.EqualValues(t, []string{"a"}, cmd.Val()) + }) + + t.Run(fmt.Sprintf("ZDIFF with three sets with scores - %s", encoding), func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{ + {Score: 1, Member: "a"}, + {Score: 2, Member: "b"}, + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 4, Member: "e"}, + }) + createZset(rdb, ctx, "zsetb", []redis.Z{ + {Score: 1, Member: "b"}, + {Score: 2, Member: "c"}, + {Score: 4, Member: "f"}, + }) + createZset(rdb, ctx, "zsetc", []redis.Z{ + {Score: 4, Member: "c"}, + {Score: 5, Member: "e"}, + }) + cmd := rdb.ZDiffWithScores(ctx, "zseta", "zsetb", "zsetc") + require.NoError(t, cmd.Err()) + require.EqualValues(t, []redis.Z([]redis.Z{{Score: 1, Member: "a"}, {Score: 3, Member: "d"}}), cmd.Val()) + }) + + t.Run("ZDIFFSTORE with three sets - ", func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{ + {Score: 1, Member: "a"}, + {Score: 2, Member: "b"}, + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 4, Member: "e"}, + }) + createZset(rdb, ctx, "zsetb", []redis.Z{ + {Score: 1, Member: "b"}, + {Score: 2, Member: "c"}, + {Score: 4, Member: "f"}, + }) + createZset(rdb, ctx, "zsetc", []redis.Z{ + {Score: 4, Member: "c"}, + {Score: 5, Member: "e"}, + }) + require.NoError(t, rdb.ZDiffStore(ctx, "setres", "zseta", "zsetb", "zsetc").Err()) + cmd := rdb.ZDiffWithScores(ctx, "zseta", "zsetb", "zsetc") + require.EqualValues(t, []redis.Z([]redis.Z{{Score: 1, Member: "a"}, {Score: 3, Member: "d"}}), cmd.Val()) + }) } func stressTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding string) {