Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for the BITFIELD command #1901

Merged
merged 8 commits into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions src/cluster/slot_migrate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,8 @@ Status SlotMigrator::checkSingleResponse(int sock_fd) { return checkMultipleResp

// Commands | Response | Instance
// ++++++++++++++++++++++++++++++++++++++++
// set Redis::Integer :1/r/n
// hset Redis::SimpleString +OK/r/n
// set Redis::Integer :1\r\n
// hset Redis::SimpleString +OK\r\n
// sadd Redis::Integer
// zadd Redis::Integer
// siadd Redis::Integer
Expand All @@ -497,6 +497,7 @@ Status SlotMigrator::checkSingleResponse(int sock_fd) { return checkMultipleResp
// sirem Redis::Integer
// del Redis::Integer
// xadd Redis::BulkString
// bitfield Redis::Array *1\r\n:0
Status SlotMigrator::checkMultipleResponses(int sock_fd, int total) {
if (sock_fd < 0 || total <= 0) {
return {Status::NotOK, fmt::format("invalid arguments: sock_fd={}, count={}", sock_fd, total)};
Expand All @@ -509,7 +510,7 @@ Status SlotMigrator::checkMultipleResponses(int sock_fd, int total) {
setsockopt(sock_fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));

// Start checking response
size_t bulk_len = 0;
size_t bulk_or_array_len = 0;
int cnt = 0;
parser_state_ = ParserState::ArrayLen;
UniqueEvbuf evbuf;
Expand All @@ -534,14 +535,20 @@ Status SlotMigrator::checkMultipleResponses(int sock_fd, int total) {

if (line[0] == '-') {
return {Status::NotOK, fmt::format("got invalid response of length {}: {}", line.length, line.get())};
} else if (line[0] == '$') {
} else if (line[0] == '$' || line[0] == '*') {
auto parse_result = ParseInt<uint64_t>(std::string(line.get() + 1, line.length - 1), 10);
if (!parse_result) {
return {Status::NotOK, "protocol error: expected integer value"};
}

bulk_len = *parse_result;
parser_state_ = bulk_len > 0 ? ParserState::BulkData : ParserState::OneRspEnd;
bulk_or_array_len = *parse_result;
if (bulk_or_array_len <= 0) {
parser_state_ = ParserState::OneRspEnd;
} else if (line[0] == '$') {
parser_state_ = ParserState::BulkData;
} else {
parser_state_ = ParserState::ArrayData;
}
} else if (line[0] == '+' || line[0] == ':') {
parser_state_ = ParserState::OneRspEnd;
} else {
Expand All @@ -552,17 +559,33 @@ Status SlotMigrator::checkMultipleResponses(int sock_fd, int total) {
}
// Handle bulk string response
case ParserState::BulkData: {
if (evbuffer_get_length(evbuf.get()) < bulk_len + 2) {
if (evbuffer_get_length(evbuf.get()) < bulk_or_array_len + 2) {
LOG(INFO) << "[migrate] Bulk data in event buffer is not complete, read socket again";
run = false;
break;
}
// TODO(chrisZMF): Check tail '\r\n'
evbuffer_drain(evbuf.get(), bulk_len + 2);
bulk_len = 0;
evbuffer_drain(evbuf.get(), bulk_or_array_len + 2);
bulk_or_array_len = 0;
parser_state_ = ParserState::OneRspEnd;
break;
}
case ParserState::ArrayData: {
while (run && bulk_or_array_len > 0) {
evbuffer_ptr ptr = evbuffer_search_eol(evbuf.get(), nullptr, nullptr, EVBUFFER_EOL_CRLF_STRICT);
if (ptr.pos < 0) {
LOG(INFO) << "[migrate] Array data in event buffer is not complete, read socket again";
run = false;
break;
}
evbuffer_drain(evbuf.get(), ptr.pos + 2);
--bulk_or_array_len;
}
if (run) {
parser_state_ = ParserState::OneRspEnd;
}
break;
}
case ParserState::OneRspEnd: {
cnt++;
if (cnt >= total) {
Expand Down
2 changes: 1 addition & 1 deletion src/cluster/slot_migrate.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class SlotMigrator : public redis::Database {

void resumeSyncCtx(const Status &migrate_result);

enum class ParserState { ArrayLen, BulkLen, BulkData, OneRspEnd };
enum class ParserState { ArrayLen, BulkLen, BulkData, ArrayData, OneRspEnd };
enum class ThreadState { Uninitialized, Running, Terminated };

static const int kDefaultMaxPipelineSize = 16;
Expand Down
153 changes: 152 additions & 1 deletion src/commands/cmd_bit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
*/

#include "commander.h"
#include "commands/command_parser.h"
#include "error_constants.h"
#include "server/server.h"
#include "types/redis_bitmap.h"
Expand Down Expand Up @@ -132,6 +133,8 @@ class CommandBitCount : public Commander {

class CommandBitPos : public Commander {
public:
using Commander::Parse;

Status Parse(const std::vector<std::string> &args) override {
if (args.size() >= 4) {
auto parse_start = ParseInt<int64_t>(args[3], 10);
Expand Down Expand Up @@ -225,10 +228,158 @@ class CommandBitOp : public Commander {
BitOpFlags op_flag_;
};

class CommandBitfield : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override {
BitfieldOperation cmd;

read_only_ = true;
// BITFIELD <key> [commands...]
for (CommandParser group(args, 2); group.Good();) {
auto remains = group.Remains();

std::string opcode = util::ToLower(group[0]);
if (opcode == "get") {
cmd.type = BitfieldOperation::Type::kGet;
} else if (opcode == "set") {
cmd.type = BitfieldOperation::Type::kSet;
read_only_ = false;
} else if (opcode == "incrby") {
cmd.type = BitfieldOperation::Type::kIncrBy;
read_only_ = false;
} else if (opcode == "overflow") {
constexpr auto kOverflowCmdSize = 2;
if (remains < kOverflowCmdSize) {
return {Status::RedisParseErr, errWrongNumOfArguments};
}
auto s = parseOverflowSubCommand(group[1], &cmd);
if (!s.IsOK()) {
return s;
}

group.Skip(kOverflowCmdSize);
continue;
} else {
return {Status::RedisParseErr, errUnknownSubcommandOrWrongArguments};
}

if (remains < 3) {
return {Status::RedisParseErr, errWrongNumOfArguments};
}

// parse encoding
auto encoding = parseBitfieldEncoding(group[1]);
if (!encoding.IsOK()) {
return encoding.ToStatus();
}
cmd.encoding = encoding.GetValue();

// parse offset
if (!GetBitOffsetFromArgument(group[2], &cmd.offset).IsOK()) {
return {Status::RedisParseErr, "bit offset is not an integer or out of range"};
}

if (cmd.type != BitfieldOperation::Type::kGet) {
if (remains < 4) {
return {Status::RedisParseErr, errWrongNumOfArguments};
}

auto value = ParseInt<int64_t>(group[3], 10);
if (!value.IsOK()) {
return value.ToStatus();
}
cmd.value = value.GetValue();

// SET|INCRBY <encoding> <offset> <value>
group.Skip(4);
} else {
// GET <encoding> <offset>
group.Skip(3);
}

cmds_.push_back(cmd);
}

return Commander::Parse(args);
}

Status Execute(Server *srv, Connection *conn, std::string *output) override {
redis::Bitmap bitmap_db(srv->storage, conn->GetNamespace());
std::vector<std::optional<BitfieldValue>> rets;
rocksdb::Status s;
if (read_only_) {
s = bitmap_db.BitfieldReadOnly(args_[1], cmds_, &rets);
} else {
s = bitmap_db.Bitfield(args_[1], cmds_, &rets);
}
std::vector<std::string> str_rets(rets.size());
for (size_t i = 0; i != rets.size(); ++i) {
if (rets[i].has_value()) {
if (rets[i]->Encoding().IsSigned()) {
str_rets[i] = redis::Integer(CastToSignedWithoutBitChanges(rets[i]->Value()));
} else {
str_rets[i] = redis::Integer(rets[i]->Value());
}
} else {
str_rets[i] = redis::NilString();
}
}
*output = redis::Array(str_rets);
return Status::OK();
}

private:
static Status parseOverflowSubCommand(const std::string &overflow, BitfieldOperation *cmd) {
std::string lower = util::ToLower(overflow);
if (lower == "wrap") {
cmd->overflow = BitfieldOverflowBehavior::kWrap;
} else if (lower == "sat") {
cmd->overflow = BitfieldOverflowBehavior::kSat;
} else if (lower == "fail") {
cmd->overflow = BitfieldOverflowBehavior::kFail;
} else {
return {Status::RedisParseErr, errUnknownSubcommandOrWrongArguments};
}
return Status::OK();
}

static StatusOr<BitfieldEncoding> parseBitfieldEncoding(const std::string &token) {
if (token.empty()) {
return {Status::RedisParseErr, errUnknownSubcommandOrWrongArguments};
}

auto sign = std::tolower(token[0]);
if (sign != 'u' && sign != 'i') {
return {Status::RedisParseErr, errUnknownSubcommandOrWrongArguments};
}

auto type = BitfieldEncoding::Type::kUnsigned;
if (sign == 'i') {
type = BitfieldEncoding::Type::kSigned;
}

auto bits_parse = ParseInt<uint8_t>(token.substr(1), 10);
if (!bits_parse.IsOK()) {
return bits_parse.ToStatus();
}
uint8_t bits = bits_parse.GetValue();

auto encoding = BitfieldEncoding::Create(type, bits);
if (!encoding.IsOK()) {
return {Status::RedisParseErr, errUnknownSubcommandOrWrongArguments};
}
return encoding.GetValue();
}

std::vector<BitfieldOperation> cmds_;
bool read_only_;
};

REDIS_REGISTER_COMMANDS(MakeCmdAttr<CommandGetBit>("getbit", 3, "read-only", 1, 1, 1),
MakeCmdAttr<CommandSetBit>("setbit", 4, "write", 1, 1, 1),
MakeCmdAttr<CommandBitCount>("bitcount", -2, "read-only", 1, 1, 1),
MakeCmdAttr<CommandBitPos>("bitpos", -3, "read-only", 1, 1, 1),
MakeCmdAttr<CommandBitOp>("bitop", -4, "write", 2, -1, 1), )
MakeCmdAttr<CommandBitOp>("bitop", -4, "write", 2, -1, 1),
MakeCmdAttr<CommandBitfield>("bitfield", -2, "write", 1, 1, 1), )

} // namespace redis
32 changes: 32 additions & 0 deletions src/commands/command_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ struct CommandParser {
public:
using ValueType = typename Iter::value_type;

static constexpr bool IsRandomAccessIter =
std::is_base_of_v<std::random_access_iterator_tag, typename std::iterator_traits<Iter>::iterator_category>;

CommandParser(Iter begin, Iter end) : begin_(std::move(begin)), end_(std::move(end)) {}

template <typename Container>
Expand All @@ -56,12 +59,41 @@ struct CommandParser {

decltype(auto) RawPeek() const { return *begin_; }

decltype(auto) operator[](size_t index) const {
Iter iter = begin_;
std::advance(iter, index);
return *iter;
}

decltype(auto) RawTake() { return *begin_++; }

decltype(auto) RawNext() { ++begin_; }

bool Good() const { return begin_ != end_; }

std::enable_if_t<IsRandomAccessIter, size_t> Remains() const {
// O(1) iff Iter is random access iterator.
auto d = std::distance(begin_, end_);
DCHECK(d >= 0);
return d;
}

size_t Skip(size_t count) {
if constexpr (IsRandomAccessIter) {
size_t steps = std::min(Remains(), count);
begin_ += steps;
return steps;
} else {
size_t steps = 0;
while (count != 0 && Good()) {
++begin_;
++steps;
--count;
}
return steps;
}
}

template <typename Pred>
bool EatPred(Pred&& pred) {
if (Good() && std::forward<Pred>(pred)(RawPeek())) {
Expand Down
Loading