diff --git a/src/commands/cmd_server.cc b/src/commands/cmd_server.cc index d595875d7bb..ae26c628464 100644 --- a/src/commands/cmd_server.cc +++ b/src/commands/cmd_server.cc @@ -1265,6 +1265,64 @@ class CommandDump : public Commander { } }; +class CommandPollUpdates : public Commander { + public: + Status Parse(const std::vector &args) override { + CommandParser parser(args, 1); + sequence_ = GET_OR_RET(parser.TakeInt()); + + while (parser.Good()) { + if (parser.EatEqICase("MAX")) { + max_ = GET_OR_RET(parser.TakeInt(NumericRange{1, 1000})); + } else if (parser.EatEqICase("STRICT")) { + is_strict_ = true; + } else if (parser.EatEqICase("FORMAT")) { + auto format = GET_OR_RET(parser.TakeStr()); + if (util::EqualICase(format, "RAW")) { + format_ = Format::Raw; + } else { + return {Status::RedisParseErr, "invalid FORMAT option, only support RAW"}; + } + } else { + return {Status::RedisParseErr, errInvalidSyntax}; + } + } + return Status::OK(); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + // sequence + 1 is for excluding the current sequence to avoid getting duplicate updates + auto batches = GET_OR_RET(srv->PollUpdates(sequence_ + 1, max_, is_strict_)); + + *output = redis::MultiLen(8); + *output += redis::BulkString("latest_sequence"); + *output += redis::Integer(srv->storage->LatestSeqNumber()); + *output += redis::BulkString("format"); + *output += redis::BulkString("RAW"); + *output += redis::BulkString("updates"); + *output += redis::MultiLen(batches.size()); + uint64_t next_sequence = sequence_; + for (const auto &batch : batches) { + *output += redis::BulkString(util::StringToHex(batch.writeBatchPtr->Data())); + // It might contain more than one sequence in a batch + next_sequence = batch.sequence + batch.writeBatchPtr->Count() - 1; + } + *output += redis::BulkString("next_sequence"); + *output += redis::Integer(next_sequence); + return Status::OK(); + } + + private: + enum class Format { + Raw, + }; + + uint64_t sequence_ = -1; + bool is_strict_ = false; + int64_t max_ = 16; + Format format_ = Format::Raw; +}; + REDIS_REGISTER_COMMANDS(MakeCmdAttr("auth", 2, "read-only ok-loading", 0, 0, 0), MakeCmdAttr("ping", -1, "read-only", 0, 0, 0), MakeCmdAttr("select", 2, "read-only", 0, 0, 0), @@ -1302,5 +1360,6 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr("auth", 2, "read-only ok-loadin MakeCmdAttr("rdb", -3, "write exclusive", 0, 0, 0), MakeCmdAttr("reset", 1, "ok-loading multi no-script pub-sub", 0, 0, 0), MakeCmdAttr("applybatch", -2, "write no-multi", 0, 0, 0), - MakeCmdAttr("dump", 2, "read-only", 0, 0, 0), ) + MakeCmdAttr("dump", 2, "read-only", 0, 0, 0), + MakeCmdAttr("pollupdates", -2, "read-only", 0, 0, 0), ) } // namespace redis diff --git a/src/server/server.cc b/src/server/server.cc index 5e3b7f1363c..7e4c4f2dafa 100644 --- a/src/server/server.cc +++ b/src/server/server.cc @@ -1556,6 +1556,36 @@ int64_t Server::GetLastScanTime(const std::string &ns) const { return 0; } +StatusOr> Server::PollUpdates(uint64_t next_sequence, int64_t count, + bool is_strict) const { + std::vector batches; + auto latest_sequence = storage->LatestSeqNumber(); + if (next_sequence == latest_sequence + 1) { + // return empty result if there is no new updates + return batches; + } else if (next_sequence > latest_sequence + 1) { + return {Status::NotOK, "next sequence is out of range"}; + } + + std::unique_ptr iter; + if (auto s = storage->GetWALIter(next_sequence, &iter); !s.IsOK()) return s; + if (!iter) { + return Status{Status::NotOK, "unable to get WAL iterator"}; + } + + for (int64_t i = 0; i < count && iter->Valid() && iter->status().ok(); ++i, iter->Next()) { + // The first batch should have the same sequence number as the next sequence number + // if it requires strictly matched. + auto batch = iter->GetBatch(); + if (i == 0 && is_strict && batch.sequence != next_sequence) { + return {Status::NotOK, + fmt::format("mismatched sequence number, expected {} but got {}", next_sequence, batch.sequence)}; + } + batches.emplace_back(std::move(batch)); + } + return batches; +} + void Server::SlowlogPushEntryIfNeeded(const std::vector *args, uint64_t duration, const redis::Connection *conn) { int64_t threshold = config_->slowlog_log_slower_than; diff --git a/src/server/server.h b/src/server/server.h index c1793e81ce6..1bb639ba6b0 100644 --- a/src/server/server.h +++ b/src/server/server.h @@ -253,6 +253,7 @@ class Server { Status AsyncScanDBSize(const std::string &ns); void GetLatestKeyNumStats(const std::string &ns, KeyNumStats *stats); int64_t GetLastScanTime(const std::string &ns) const; + StatusOr> PollUpdates(uint64_t next_sequence, int64_t count, bool is_strict) const; std::string GenerateCursorFromKeyName(const std::string &key_name, CursorType cursor_type, const char *prefix = ""); std::string GetKeyNameFromCursor(const std::string &cursor, CursorType cursor_type); diff --git a/tests/gocase/unit/server/poll_updates_test.go b/tests/gocase/unit/server/poll_updates_test.go new file mode 100644 index 00000000000..1bd60587223 --- /dev/null +++ b/tests/gocase/unit/server/poll_updates_test.go @@ -0,0 +1,169 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. + */ + +package server + +import ( + "context" + "encoding/hex" + "fmt" + "strconv" + "testing" + + "github.com/apache/kvrocks/tests/gocase/util" + "github.com/stretchr/testify/require" +) + +type PollUpdatesResult struct { + LatestSeq int64 + NextSeq int64 + Updates []string +} + +func sliceToPollUpdatesResult(t *testing.T, slice []interface{}) *PollUpdatesResult { + require.Len(t, slice, 8) + + require.Equal(t, "latest_sequence", slice[0]) + latestSeq, ok := slice[1].(int64) + require.True(t, ok) + + require.Equal(t, "format", slice[2]) + require.Equal(t, "RAW", slice[3]) + require.Equal(t, "updates", slice[4]) + updates := make([]string, 0) + if slice[5] != nil { + fields, ok := slice[5].([]interface{}) + require.True(t, ok) + for _, field := range fields { + str, ok := field.(string) + require.True(t, ok) + updates = append(updates, str) + } + } + + require.Equal(t, "next_sequence", slice[6]) + nextSeq, ok := slice[7].(int64) + require.True(t, ok) + + return &PollUpdatesResult{ + LatestSeq: latestSeq, + NextSeq: nextSeq, + Updates: updates, + } +} + +func TestPollUpdates_Basic(t *testing.T) { + ctx := context.Background() + + srv0 := util.StartServer(t, map[string]string{}) + defer srv0.Close() + rdb0 := srv0.NewClient() + defer func() { require.NoError(t, rdb0.Close()) }() + + srv1 := util.StartServer(t, map[string]string{}) + defer srv1.Close() + rdb1 := srv1.NewClient() + defer func() { require.NoError(t, rdb1.Close()) }() + + t.Run("Make sure the command POLLUPDATES works well", func(t *testing.T) { + for i := 0; i < 10; i++ { + rdb0.Set(ctx, fmt.Sprintf("key-%d", i), i, 0) + } + + updates := make([]string, 0) + slice, err := rdb0.Do(ctx, "POLLUPDATES", 0, "MAX", 6).Slice() + require.NoError(t, err) + pollUpdates := sliceToPollUpdatesResult(t, slice) + require.EqualValues(t, 10, pollUpdates.LatestSeq) + require.EqualValues(t, 6, pollUpdates.NextSeq) + require.Len(t, pollUpdates.Updates, 6) + updates = append(updates, pollUpdates.Updates...) + + slice, err = rdb0.Do(ctx, "POLLUPDATES", pollUpdates.NextSeq, "MAX", 6).Slice() + require.NoError(t, err) + pollUpdates = sliceToPollUpdatesResult(t, slice) + require.EqualValues(t, 10, pollUpdates.LatestSeq) + require.EqualValues(t, 10, pollUpdates.NextSeq) + require.Len(t, pollUpdates.Updates, 4) + updates = append(updates, pollUpdates.Updates...) + + for i := 0; i < 10; i++ { + batch, err := hex.DecodeString(updates[i]) + require.NoError(t, err) + applied, err := rdb1.Do(ctx, "APPLYBATCH", batch).Bool() + require.NoError(t, err) + require.True(t, applied) + require.EqualValues(t, strconv.Itoa(i), rdb1.Get(ctx, fmt.Sprintf("key-%d", i)).Val()) + } + }) + + t.Run("Runs POLLUPDATES with invalid arguments", func(t *testing.T) { + require.ErrorContains(t, rdb0.Do(ctx, "POLLUPDATES", 0, "MAX", -1).Err(), + "ERR out of numeric range") + require.ErrorContains(t, rdb0.Do(ctx, "POLLUPDATES", 0, "MAX", 1001).Err(), + "ERR out of numeric range") + require.ErrorContains(t, rdb0.Do(ctx, "POLLUPDATES", 0, "FORMAT", "COMMAND").Err(), + "ERR invalid FORMAT option, only support RAW") + require.ErrorContains(t, rdb0.Do(ctx, "POLLUPDATES", 12, "FORMAT", "RAW").Err(), + "ERR next sequence is out of range") + require.Error(t, rdb0.Do(ctx, "POLLUPDATES", 1, "FORMAT", "EXTRA").Err()) + }) +} + +func TestPollUpdates_WithStrict(t *testing.T) { + ctx := context.Background() + + srv0 := util.StartServer(t, map[string]string{}) + defer srv0.Close() + rdb0 := srv0.NewClient() + defer func() { require.NoError(t, rdb0.Close()) }() + + srv1 := util.StartServer(t, map[string]string{}) + defer srv1.Close() + rdb1 := srv1.NewClient() + defer func() { require.NoError(t, rdb1.Close()) }() + + // The latest sequence is 2 after running the HSET command, 1 for the metadata and 1 for the field + require.NoError(t, rdb0.HSet(ctx, "h0", "f0", "v0").Err()) + // The latest sequence is 3 after running the SET command + require.NoError(t, rdb0.Set(ctx, "k0", "v0", 0).Err()) + + // PollUpdates with strict mode should return an error if the sequence number is mismatched + err := rdb0.Do(ctx, "POLLUPDATES", 1, "MAX", 1, "STRICT").Err() + require.ErrorContains(t, err, "ERR mismatched sequence number") + + // Works well if the sequence number is mismatched but not in strict mode + require.NoError(t, rdb0.Do(ctx, "POLLUPDATES", 1, "MAX", 1).Err()) + + slice, err := rdb0.Do(ctx, "POLLUPDATES", 0, "MAX", 10, "STRICT").Slice() + require.NoError(t, err) + pollUpdates := sliceToPollUpdatesResult(t, slice) + require.EqualValues(t, 3, pollUpdates.LatestSeq) + require.EqualValues(t, 3, pollUpdates.NextSeq) + require.Len(t, pollUpdates.Updates, 2) + + for _, update := range pollUpdates.Updates { + batch, err := hex.DecodeString(update) + require.NoError(t, err) + require.NoError(t, rdb1.Do(ctx, "APPLYBATCH", batch).Err()) + } + + require.Equal(t, "v0", rdb1.Get(ctx, "k0").Val()) + require.Equal(t, "v0", rdb1.HGet(ctx, "h0", "f0").Val()) +}