Skip to content

Commit

Permalink
Add the support of the LCS command
Browse files Browse the repository at this point in the history
  • Loading branch information
JoverZhang committed Feb 24, 2024
1 parent 7571034 commit ae22393
Show file tree
Hide file tree
Showing 2 changed files with 281 additions and 2 deletions.
192 changes: 190 additions & 2 deletions src/commands/cmd_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,195 @@ class CommandCAD : public Commander {
}
};

class CommandLCS : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override {
CommandParser parser(args, 3);
while (parser.Good()) {
if (parser.EatEqICase("IDX")) {
get_idx_ = true;
} else if (parser.EatEqICase("LEN")) {
get_len_ = true;
} else if (parser.EatEqICase("WITHMATCHLEN")) {
with_match_len_ = true;
} else if (parser.EatEqICase("MINMATCHLEN")) {
min_match_len_ = GET_OR_RET(parser.TakeInt<int64_t>());
if (min_match_len_ < 0) {
min_match_len_ = 0;
}
} else {
return parser.InvalidSyntax();
}
}

// Complain if the user passed ambiguous parameters.
if (get_idx_ && get_len_) {
return {Status::RedisParseErr,
"If you want both the length and indexes, "
"please just use IDX."};
}

return Status::OK();
}

Status Execute(Server *srv, Connection *conn, std::string *output) override {
redis::String string_db(srv->storage, conn->GetNamespace());

std::string a;
std::string b;
auto s1 = string_db.Get(args_[1], &a);
auto s2 = string_db.Get(args_[2], &b);

if (!s1.ok() && !s1.IsNotFound()) {
return {Status::RedisExecErr, s1.ToString()};
}
if (!s2.ok() && !s2.IsNotFound()) {
return {Status::RedisExecErr, s2.ToString()};
}
if (s1.IsNotFound()) a = "";
if (s2.IsNotFound()) b = "";

// Detect string truncation or later overflows.
if (a.length() >= UINT32_MAX - 1 || b.length() >= UINT32_MAX - 1) {
return {Status::RedisExecErr, "String too long for LCS"};
}

// Compute the LCS using the vanilla dynamic programming technique of
// building a table of LCS(x, y) substrings.
auto alen = static_cast<uint32_t>(a.length());
auto blen = static_cast<uint32_t>(b.length());

// Allocate the LCS table.
uint64_t dp_size = (alen + 1) * (blen + 1);
// TODO: Maybe we need to check for insufficient memory
std::vector<uint32_t> dp(dp_size, 0);
auto lcs = [&](const uint32_t i, const uint32_t j) -> uint32_t & { return dp[i * (blen + 1) + j]; };

// Start building the LCS table.
for (uint32_t i = 1; i <= alen; i++) {
for (uint32_t j = 1; j <= blen; j++) {
if (a[i - 1] == b[j - 1]) {
// The len LCS (and the LCS itself) of two
// sequences with the same final character, is the
// LCS of the two sequences without the last char
// plus that last char.
lcs(i, j) = lcs(i - 1, j - 1) + 1;
} else {
// If the last character is different, take the longest
// between the LCS of the first string and the second
// minus the last char, and the reverse.
lcs(i, j) = std::max(lcs(i - 1, j), lcs(i, j - 1));
}
}
}

// Store the actual LCS string if needed.
std::string result;
uint32_t idx = lcs(alen, blen);

// Do we need to compute the actual LCS string? Allocate it in that case.
bool compute_lcs = get_idx_ || !get_len_;
if (compute_lcs) result.resize(idx);

// Build a array if we have to emit the matched ranges.
std::string matches;
uint32_t matches_len = 0;

uint32_t i = alen;
uint32_t j = blen;
uint32_t arange_start = alen; // alen signals that values are not set.
uint32_t arange_end = 0;
uint32_t brange_start = 0;
uint32_t brange_end = 0;
while (compute_lcs && i > 0 && j > 0) {
bool emit_range = false;
if (a[i - 1] == b[j - 1]) {
// If there is a match, store the character and reduce
// the indexes to look for a new match.
result[idx - 1] = a[i - 1];

// Track the current range.
if (arange_start == alen) {
arange_start = i - 1;
arange_end = i - 1;
brange_start = j - 1;
brange_end = j - 1;
}
// Let's see if we can extend the range backward since
// it is contiguous.
else if (arange_start == i && brange_start == j) {
arange_start--;
brange_start--;
} else {
emit_range = true;
}

// Emit the range if we matched with the first byte of
// one of the two strings. We'll exit the loop ASAP.
if (arange_start == 0 || brange_start == 0) {
emit_range = true;
}
idx--;
i--;
j--;
} else {
// Otherwise reduce i and j depending on the largest
// LCS between, to understand what direction we need to go.
uint32_t lcs1 = lcs(i - 1, j);
uint32_t lcs2 = lcs(i, j - 1);
if (lcs1 > lcs2)
i--;
else
j--;
if (arange_start != alen) emit_range = true;
}

// Emit the current range if needed.
uint32_t match_len = arange_end - arange_start + 1;
if (emit_range) {
if (get_idx_ && (min_match_len_ == 0 || match_len >= min_match_len_)) {
matches += redis::MultiLen(with_match_len_ ? 3 : 2);
matches += redis::MultiLen(2);
matches += redis::BulkString(std::to_string(arange_start));
matches += redis::BulkString(std::to_string(arange_end));
matches += redis::MultiLen(2);
matches += redis::BulkString(std::to_string(brange_start));
matches += redis::BulkString(std::to_string(brange_end));
if (with_match_len_) {
matches += redis::BulkString(std::to_string(match_len));
}
matches_len++;
}

// Restart at the next match.
arange_start = alen;
}
}

// Build output by the given options.
if (get_idx_) {
*output = redis::MultiLen(4);
*output += redis::BulkString("matches");
*output += redis::MultiLen(matches_len);
*output += matches;
*output += redis::BulkString("len");
*output += redis::Integer(lcs(alen, blen));
} else if (get_len_) {
*output = redis::Integer(lcs(alen, blen));
} else {
*output = redis::BulkString(result);
}

return Status::OK();
}

private:
bool get_idx_ = false;
bool get_len_ = false;
bool with_match_len_ = false;
int64_t min_match_len_ = 0;
};

REDIS_REGISTER_COMMANDS(
MakeCmdAttr<CommandGet>("get", 2, "read-only", 1, 1, 1), MakeCmdAttr<CommandGetEx>("getex", -2, "write", 1, 1, 1),
MakeCmdAttr<CommandStrlen>("strlen", 2, "read-only", 1, 1, 1),
Expand All @@ -637,6 +826,5 @@ REDIS_REGISTER_COMMANDS(
MakeCmdAttr<CommandIncrByFloat>("incrbyfloat", 3, "write", 1, 1, 1),
MakeCmdAttr<CommandIncr>("incr", 2, "write", 1, 1, 1), MakeCmdAttr<CommandDecrBy>("decrby", 3, "write", 1, 1, 1),
MakeCmdAttr<CommandDecr>("decr", 2, "write", 1, 1, 1), MakeCmdAttr<CommandCAS>("cas", -4, "write", 1, 1, 1),
MakeCmdAttr<CommandCAD>("cad", 3, "write", 1, 1, 1), )

MakeCmdAttr<CommandCAD>("cad", 3, "write", 1, 1, 1), MakeCmdAttr<CommandLCS>("lcs", -3, "read-only", 1, 1, 1), )
} // namespace redis
91 changes: 91 additions & 0 deletions tests/gocase/unit/type/strings/strings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -895,4 +895,95 @@ func TestString(t *testing.T) {
require.ErrorContains(t, rdb.Do(ctx, "CAD", "cad_key").Err(), "ERR wrong number of arguments")
require.ErrorContains(t, rdb.Do(ctx, "CAD", "cad_key", "123", "234").Err(), "ERR wrong number of arguments")
})

rna1 := "CACCTTCCCAGGTAACAAACCAACCAACTTTCGATCTCTTGTAGATCTGTTCTCTAAACGAACTTTAAAATCTGTGTGGCTGTCACTCGGCTGCATGCTTAGTGCACTCACGCAGTATAATTAATAACTAATTACTGTCGTTGACAGGACACGAGTAACTCGTCTATCTTCTGCAGGCTGCTTACGGTTTCGTCCGTGTTGCAGCCGATCATCAGCACATCTAGGTTTCGTCCGGGTGTG"
rna2 := "ATTAAAGGTTTATACCTTCCCAGGTAACAAACCAACCAACTTTCGATCTCTTGTAGATCTGTTCTCTAAACGAACTTTAAAATCTGTGTGGCTGTCACTCGGCTGCATGCTTAGTGCACTCACGCAGTATAATTAATAACTAATTACTGTCGTTGACAGGACACGAGTAACTCGTCTATCTTCTGCAGGCTGCTTACGGTTTCGTCCGTGTTGCAGCCGATCATCAGCACATCTAGGTTT"
rnalcs := "ACCTTCCCAGGTAACAAACCAACCAACTTTCGATCTCTTGTAGATCTGTTCTCTAAACGAACTTTAAAATCTGTGTGGCTGTCACTCGGCTGCATGCTTAGTGCACTCACGCAGTATAATTAATAACTAATTACTGTCGTTGACAGGACACGAGTAACTCGTCTATCTTCTGCAGGCTGCTTACGGTTTCGTCCGTGTTGCAGCCGATCATCAGCACATCTAGGTTT"

t.Run("LCS basic", func(t *testing.T) {
require.NoError(t, rdb.Set(ctx, "virus1", rna1, 0).Err())
require.NoError(t, rdb.Set(ctx, "virus2", rna2, 0).Err())
require.Equal(t, rnalcs, rdb.LCS(ctx, &redis.LCSQuery{Key1: "virus1", Key2: "virus2"}).Val().MatchString)
// require.Equal(t, rnalcs, rdb.Do(ctx, "LCS", "virus1", "virus2").Val())
})

t.Run("LCS len", func(t *testing.T) {
require.NoError(t, rdb.Set(ctx, "virus1", rna1, 0).Err())
require.NoError(t, rdb.Set(ctx, "virus2", rna2, 0).Err())
require.Equal(t, int64(len(rnalcs)), rdb.LCS(ctx, &redis.LCSQuery{Key1: "virus1", Key2: "virus2", Len: true}).Val().Len)
})

t.Run("LCS indexes", func(t *testing.T) {
require.NoError(t, rdb.Set(ctx, "virus1", rna1, 0).Err())
require.NoError(t, rdb.Set(ctx, "virus2", rna2, 0).Err())
matches := rdb.LCS(ctx, &redis.LCSQuery{Key1: "virus1", Key2: "virus2", Idx: true}).Val().Matches
require.Equal(t, []redis.LCSMatchedPosition{
{
Key1: redis.LCSPosition{Start: 238, End: 238},
Key2: redis.LCSPosition{Start: 239, End: 239},
},
{
Key1: redis.LCSPosition{Start: 236, End: 236},
Key2: redis.LCSPosition{Start: 238, End: 238},
},
{
Key1: redis.LCSPosition{Start: 229, End: 230},
Key2: redis.LCSPosition{Start: 236, End: 237},
},
{
Key1: redis.LCSPosition{Start: 224, End: 224},
Key2: redis.LCSPosition{Start: 235, End: 235},
},
{
Key1: redis.LCSPosition{Start: 1, End: 222},
Key2: redis.LCSPosition{Start: 13, End: 234},
},
}, matches)
})

t.Run("LCS indexes with match len", func(t *testing.T) {
require.NoError(t, rdb.Set(ctx, "virus1", rna1, 0).Err())
require.NoError(t, rdb.Set(ctx, "virus2", rna2, 0).Err())
matches := rdb.LCS(ctx, &redis.LCSQuery{Key1: "virus1", Key2: "virus2", Idx: true, WithMatchLen: true}).Val().Matches
require.Equal(t, []redis.LCSMatchedPosition{
{
Key1: redis.LCSPosition{Start: 238, End: 238},
Key2: redis.LCSPosition{Start: 239, End: 239},
MatchLen: 1,
},
{
Key1: redis.LCSPosition{Start: 236, End: 236},
Key2: redis.LCSPosition{Start: 238, End: 238},
MatchLen: 1,
},
{
Key1: redis.LCSPosition{Start: 229, End: 230},
Key2: redis.LCSPosition{Start: 236, End: 237},
MatchLen: 2,
},
{
Key1: redis.LCSPosition{Start: 224, End: 224},
Key2: redis.LCSPosition{Start: 235, End: 235},
MatchLen: 1,
},
{
Key1: redis.LCSPosition{Start: 1, End: 222},
Key2: redis.LCSPosition{Start: 13, End: 234},
MatchLen: 222,
},
}, matches)
})

t.Run("LCS indexes with match len and minimum match len", func(t *testing.T) {
require.NoError(t, rdb.Set(ctx, "virus1", rna1, 0).Err())
require.NoError(t, rdb.Set(ctx, "virus2", rna2, 0).Err())
matches := rdb.LCS(ctx, &redis.LCSQuery{Key1: "virus1", Key2: "virus2", Idx: true, WithMatchLen: true, MinMatchLen: 5}).Val().Matches
require.Equal(t, []redis.LCSMatchedPosition{
{
Key1: redis.LCSPosition{Start: 1, End: 222},
Key2: redis.LCSPosition{Start: 13, End: 234},
MatchLen: 222,
},
}, matches)
})
}

0 comments on commit ae22393

Please sign in to comment.