Skip to content

Commit

Permalink
Add the support of the LCS command (apache#2116)
Browse files Browse the repository at this point in the history
Co-authored-by: 纪华裕 <jihuayu123@gmail.com>
Co-authored-by: hulk <hulk.website@gmail.com>
  • Loading branch information
3 people authored Mar 3, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent be072b4 commit 8a4457a
Showing 7 changed files with 428 additions and 7 deletions.
86 changes: 84 additions & 2 deletions src/commands/cmd_string.cc
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@
#include "commands/command_parser.h"
#include "error_constants.h"
#include "server/redis_reply.h"
#include "server/redis_request.h"
#include "server/server.h"
#include "storage/redis_db.h"
#include "time_util.h"
@@ -620,6 +621,88 @@ class CommandCAD : public Commander {
}
};

class CommandLCS : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override {
CommandParser parser(args, 3);
bool get_idx = false;
bool get_len = false;
while (parser.Good()) {
if (parser.EatEqICase("IDX")) {
get_idx = true;
} else if (parser.EatEqICase("LEN")) {
get_len = true;
} else if (parser.EatEqICase("WITHMATCHLEN")) {
with_match_len_ = true;
} else if (parser.EatEqICase("MINMATCHLEN")) {
min_match_len_ = GET_OR_RET(parser.TakeInt<int64_t>());
if (min_match_len_ < 0) {
min_match_len_ = 0;
}
} else {
return parser.InvalidSyntax();
}
}

// Complain if the user passed ambiguous parameters.
if (get_idx && get_len) {
return {Status::RedisParseErr,
"If you want both the length and indexes, "
"please just use IDX."};
}

if (get_len) {
type_ = StringLCSType::LEN;
} else if (get_idx) {
type_ = StringLCSType::IDX;
}

return Status::OK();
}

Status Execute(Server *srv, Connection *conn, std::string *output) override {
redis::String string_db(srv->storage, conn->GetNamespace());

StringLCSResult rst;
auto s = string_db.LCS(args_[1], args_[2], {type_, min_match_len_}, &rst);
if (!s.ok()) {
return {Status::RedisExecErr, s.ToString()};
}

// Build output by the rst type.
if (auto lcs = std::get_if<std::string>(&rst)) {
*output = redis::BulkString(*lcs);
} else if (auto len = std::get_if<uint32_t>(&rst)) {
*output = redis::Integer(*len);
} else if (auto result = std::get_if<StringLCSIdxResult>(&rst)) {
*output = conn->HeaderOfMap(2);
*output += redis::BulkString("matches");
*output += redis::MultiLen(result->matches.size());
for (const auto &match : result->matches) {
*output += redis::MultiLen(with_match_len_ ? 3 : 2);
*output += redis::MultiLen(2);
*output += redis::Integer(match.a.start);
*output += redis::Integer(match.a.end);
*output += redis::MultiLen(2);
*output += redis::Integer(match.b.start);
*output += redis::Integer(match.b.end);
if (with_match_len_) {
*output += redis::Integer(match.match_len);
}
}
*output += redis::BulkString("len");
*output += redis::Integer(result->len);
}

return Status::OK();
}

private:
StringLCSType type_ = StringLCSType::NONE;
bool with_match_len_ = false;
int64_t min_match_len_ = 0;
};

REDIS_REGISTER_COMMANDS(
MakeCmdAttr<CommandGet>("get", 2, "read-only", 1, 1, 1), MakeCmdAttr<CommandGetEx>("getex", -2, "write", 1, 1, 1),
MakeCmdAttr<CommandStrlen>("strlen", 2, "read-only", 1, 1, 1),
@@ -637,6 +720,5 @@ REDIS_REGISTER_COMMANDS(
MakeCmdAttr<CommandIncrByFloat>("incrbyfloat", 3, "write", 1, 1, 1),
MakeCmdAttr<CommandIncr>("incr", 2, "write", 1, 1, 1), MakeCmdAttr<CommandDecrBy>("decrby", 3, "write", 1, 1, 1),
MakeCmdAttr<CommandDecr>("decr", 2, "write", 1, 1, 1), MakeCmdAttr<CommandCAS>("cas", -4, "write", 1, 1, 1),
MakeCmdAttr<CommandCAD>("cad", 3, "write", 1, 1, 1), )

MakeCmdAttr<CommandCAD>("cad", 3, "write", 1, 1, 1), MakeCmdAttr<CommandLCS>("lcs", -3, "read-only", 1, 2, 1), )
} // namespace redis
4 changes: 0 additions & 4 deletions src/server/redis_request.cc
Original file line number Diff line number Diff line change
@@ -36,10 +36,6 @@

namespace redis {

const size_t PROTO_INLINE_MAX_SIZE = 16 * 1024L;
const size_t PROTO_BULK_MAX_SIZE = 512 * 1024L * 1024L;
const size_t PROTO_MULTI_MAX_SIZE = 1024 * 1024L;

Status Request::Tokenize(evbuffer *input) {
size_t pipeline_size = 0;

4 changes: 4 additions & 0 deletions src/server/redis_request.h
Original file line number Diff line number Diff line change
@@ -32,6 +32,10 @@ class Server;

namespace redis {

constexpr size_t PROTO_INLINE_MAX_SIZE = 16 * 1024L;
constexpr size_t PROTO_BULK_MAX_SIZE = 512 * 1024L * 1024L;
constexpr size_t PROTO_MULTI_MAX_SIZE = 1024 * 1024L;

using CommandTokens = std::vector<std::string>;

class Connection;
153 changes: 153 additions & 0 deletions src/types/redis_string.cc
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@
#include <string>

#include "parse_util.h"
#include "server/redis_request.h"
#include "storage/redis_metadata.h"
#include "time_util.h"

@@ -530,4 +531,156 @@ rocksdb::Status String::CAD(const std::string &user_key, const std::string &valu
return rocksdb::Status::OK();
}

rocksdb::Status String::LCS(const std::string &user_key1, const std::string &user_key2, StringLCSArgs args,
StringLCSResult *rst) {
if (args.type == StringLCSType::LEN) {
*rst = static_cast<uint32_t>(0);
} else if (args.type == StringLCSType::IDX) {
*rst = StringLCSIdxResult{{}, 0};
} else {
*rst = std::string{};
}

std::string a;
std::string b;
std::string ns_key1 = AppendNamespacePrefix(user_key1);
std::string ns_key2 = AppendNamespacePrefix(user_key2);
auto s1 = getValue(ns_key1, &a);
auto s2 = getValue(ns_key2, &b);

if (!s1.ok() && !s1.IsNotFound()) {
return s1;
}
if (!s2.ok() && !s2.IsNotFound()) {
return s2;
}
if (s1.IsNotFound()) a = "";
if (s2.IsNotFound()) b = "";

// Detect string truncation or later overflows.
if (a.length() >= UINT32_MAX - 1 || b.length() >= UINT32_MAX - 1) {
return rocksdb::Status::InvalidArgument("String too long for LCS");
}

// Compute the LCS using the vanilla dynamic programming technique of
// building a table of LCS(x, y) substrings.
auto alen = static_cast<uint32_t>(a.length());
auto blen = static_cast<uint32_t>(b.length());

// Allocate the LCS table.
uint64_t dp_size = (alen + 1) * (blen + 1);
uint64_t bulk_size = dp_size * sizeof(uint32_t);
if (bulk_size > PROTO_BULK_MAX_SIZE || bulk_size / dp_size != sizeof(uint32_t)) {
return rocksdb::Status::Aborted("Insufficient memory, transient memory for LCS exceeds proto-max-bulk-len");
}
std::vector<uint32_t> dp(dp_size, 0);
auto lcs = [&dp, blen](const uint32_t i, const uint32_t j) -> uint32_t & { return dp[i * (blen + 1) + j]; };

// Start building the LCS table.
for (uint32_t i = 1; i <= alen; i++) {
for (uint32_t j = 1; j <= blen; j++) {
if (a[i - 1] == b[j - 1]) {
// The len LCS (and the LCS itself) of two
// sequences with the same final character, is the
// LCS of the two sequences without the last char
// plus that last char.
lcs(i, j) = lcs(i - 1, j - 1) + 1;
} else {
// If the last character is different, take the longest
// between the LCS of the first string and the second
// minus the last char, and the reverse.
lcs(i, j) = std::max(lcs(i - 1, j), lcs(i, j - 1));
}
}
}

uint32_t idx = lcs(alen, blen);

// Only compute the length of LCS.
if (auto result = std::get_if<uint32_t>(rst)) {
*result = idx;
return rocksdb::Status::OK();
}

// Store the length of the LCS first if needed.
if (auto result = std::get_if<StringLCSIdxResult>(rst)) {
result->len = idx;
}

// Allocate when we need to compute the actual LCS string.
if (auto result = std::get_if<std::string>(rst)) {
result->resize(idx);
}

uint32_t i = alen;
uint32_t j = blen;
uint32_t arange_start = alen; // alen signals that values are not set.
uint32_t arange_end = 0;
uint32_t brange_start = 0;
uint32_t brange_end = 0;
while (i > 0 && j > 0) {
bool emit_range = false;
if (a[i - 1] == b[j - 1]) {
// If there is a match, store the character if needed.
// And reduce the indexes to look for a new match.
if (auto result = std::get_if<std::string>(rst)) {
result->at(idx - 1) = a[i - 1];
}

// Track the current range.
if (arange_start == alen) {
arange_start = i - 1;
arange_end = i - 1;
brange_start = j - 1;
brange_end = j - 1;
}
// Let's see if we can extend the range backward since
// it is contiguous.
else if (arange_start == i && brange_start == j) {
arange_start--;
brange_start--;
} else {
emit_range = true;
}

// Emit the range if we matched with the first byte of
// one of the two strings. We'll exit the loop ASAP.
if (arange_start == 0 || brange_start == 0) {
emit_range = true;
}
idx--;
i--;
j--;
} else {
// Otherwise reduce i and j depending on the largest
// LCS between, to understand what direction we need to go.
uint32_t lcs1 = lcs(i - 1, j);
uint32_t lcs2 = lcs(i, j - 1);
if (lcs1 > lcs2)
i--;
else
j--;
if (arange_start != alen) emit_range = true;
}

// Emit the current range if needed.
if (emit_range) {
if (auto result = std::get_if<StringLCSIdxResult>(rst)) {
uint32_t match_len = arange_end - arange_start + 1;

// Always emit the range when the `min_match_len` is not set.
if (args.min_match_len == 0 || match_len >= args.min_match_len) {
result->matches.emplace_back(StringLCSRange{arange_start, arange_end},
StringLCSRange{brange_start, brange_end}, match_len);
}
}

// Restart at the next match.
arange_start = alen;
}
}

return rocksdb::Status::OK();
}

} // namespace redis
33 changes: 32 additions & 1 deletion src/types/redis_string.h
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@
#include <cstdint>
#include <optional>
#include <string>
#include <variant>
#include <vector>

#include "storage/redis_db.h"
@@ -42,8 +43,36 @@ struct StringSetArgs {
bool keep_ttl;
};

namespace redis {
enum class StringLCSType { NONE, LEN, IDX };

struct StringLCSArgs {
StringLCSType type;
int64_t min_match_len;
};

struct StringLCSRange {
uint32_t start;
uint32_t end;
};

struct StringLCSMatchedRange {
StringLCSRange a;
StringLCSRange b;
uint32_t match_len;

StringLCSMatchedRange(StringLCSRange ra, StringLCSRange rb, uint32_t len) : a(ra), b(rb), match_len(len) {}
};

struct StringLCSIdxResult {
// Matched ranges.
std::vector<StringLCSMatchedRange> matches;
// LCS length.
uint32_t len;
};

using StringLCSResult = std::variant<std::string, uint32_t, StringLCSIdxResult>;

namespace redis {
class String : public Database {
public:
explicit String(engine::Storage *storage, const std::string &ns) : Database(storage, ns) {}
@@ -68,6 +97,8 @@ class String : public Database {
rocksdb::Status CAS(const std::string &user_key, const std::string &old_value, const std::string &new_value,
uint64_t ttl, int *flag);
rocksdb::Status CAD(const std::string &user_key, const std::string &value, int *flag);
rocksdb::Status LCS(const std::string &user_key1, const std::string &user_key2, StringLCSArgs args,
StringLCSResult *rst);

private:
rocksdb::Status getValue(const std::string &ns_key, std::string *value);
Loading

0 comments on commit 8a4457a

Please sign in to comment.