Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(keys, scan): Support arbitrary glob patterns #2608

Merged
merged 18 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/config/typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ extend-exclude = [
".git/",
"src/vendor/",
"tests/gocase/util/slot.go",

# Uses short strings for testing glob matching
"tests/cppunit/string_util_test.cc",
"tests/gocase/unit/keyspace/keyspace_test.go",
"tests/gocase/unit/scan/scan_test.go",
]
ignore-hidden = false

Expand Down
16 changes: 8 additions & 8 deletions src/commands/cmd_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
#include "commands/scan_base.h"
#include "common/io_util.h"
#include "common/rdb_stream.h"
#include "common/string_util.h"
#include "common/time_util.h"
#include "config/config.h"
#include "error_constants.h"
#include "server/redis_connection.h"
#include "server/redis_reply.h"
#include "server/server.h"
#include "stats/disk_stats.h"
#include "storage/rdb/rdb.h"
#include "string_util.h"
#include "time_util.h"

namespace redis {

Expand Down Expand Up @@ -114,15 +114,15 @@ class CommandNamespace : public Commander {
class CommandKeys : public Commander {
public:
Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override {
const std::string &prefix = args_[1];
const std::string &glob_pattern = args_[1];
std::vector<std::string> keys;
redis::Database redis(srv->storage, conn->GetNamespace());

if (prefix.empty() || prefix.find('*') != prefix.size() - 1) {
return {Status::RedisExecErr, "only keys prefix match was supported"};
if (const Status s = util::ValidateGlob(glob_pattern); !s.IsOK()) {
return {Status::RedisParseErr, "Invalid glob pattern: " + s.Msg()};
}

const rocksdb::Status s = redis.Keys(ctx, prefix.substr(0, prefix.size() - 1), &keys);
const auto [prefix, suffix_glob] = util::SplitGlob(glob_pattern);
const rocksdb::Status s = redis.Keys(ctx, prefix, suffix_glob, &keys);
if (!s.ok()) {
return {Status::RedisExecErr, s.ToString()};
}
Expand Down Expand Up @@ -846,7 +846,7 @@ class CommandScan : public CommandScanBase {
std::vector<std::string> keys;
std::string end_key;

auto s = redis_db.Scan(ctx, key_name, limit_, prefix_, &keys, &end_key, type_);
const auto s = redis_db.Scan(ctx, key_name, limit_, prefix_, suffix_glob_, &keys, &end_key, type_);
if (!s.ok()) {
return {Status::RedisExecErr, s.ToString()};
}
Expand Down
14 changes: 7 additions & 7 deletions src/commands/scan_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
#include "commander.h"
#include "commands/command_parser.h"
#include "error_constants.h"
#include "glob.h"
#include "parse_util.h"
#include "server/server.h"
#include "string_util.h"

namespace redis {

Expand All @@ -44,14 +46,11 @@ class CommandScanBase : public Commander {
Status ParseAdditionalFlags(Parser &parser) {
while (parser.Good()) {
if (parser.EatEqICase("match")) {
prefix_ = GET_OR_RET(parser.TakeStr());
// The match pattern should contain exactly one '*' at the end; remove the * to
// get the prefix to match.
if (!prefix_.empty() && prefix_.find('*') == prefix_.size() - 1) {
prefix_.pop_back();
} else {
return {Status::RedisParseErr, "currently only key prefix matching is supported"};
const std::string glob_pattern = GET_OR_RET(parser.TakeStr());
if (const Status s = util::ValidateGlob(glob_pattern); !s.IsOK()) {
return {Status::RedisParseErr, "Invalid glob pattern: " + s.Msg()};
}
std::tie(prefix_, suffix_glob_) = util::SplitGlob(glob_pattern);
} else if (parser.EatEqICase("count")) {
limit_ = GET_OR_RET(parser.TakeInt());
if (limit_ <= 0) {
Expand Down Expand Up @@ -100,6 +99,7 @@ class CommandScanBase : public Commander {
protected:
std::string cursor_;
std::string prefix_;
std::string suffix_glob_ = "*";
int limit_ = 20;
RedisType type_ = kRedisNone;
};
Expand Down
1 change: 0 additions & 1 deletion src/common/status.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include <algorithm>
#include <memory>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>

Expand Down
228 changes: 142 additions & 86 deletions src/common/string_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,118 +101,174 @@ bool HasPrefix(const std::string &str, const std::string &prefix) {
return !strncasecmp(str.data(), prefix.data(), prefix.size());
}

int StringMatch(const std::string &pattern, const std::string &in, int nocase) {
return StringMatchLen(pattern.c_str(), pattern.length(), in.c_str(), in.length(), nocase);
Status ValidateGlob(std::string_view glob) {
for (size_t idx = 0; idx < glob.size(); ++idx) {
switch (glob[idx]) {
case '*':
case '?':
break;
case ']':
return {Status::NotOK, "Unmatched unescaped ]"};
case '\\':
if (idx == glob.size() - 1) {
return {Status::NotOK, "Trailing unescaped backslash"};
}
// Skip the next character: this is a literal so nothing can go wrong
idx++;
break;
case '[':
idx++; // Skip the opening bracket
while (idx < glob.size() && glob[idx] != ']') {
if (glob[idx] == '\\') {
idx += 2;
continue;
} else if (idx + 1 < glob.size() && glob[idx + 1] == '-') {
if (idx + 2 >= glob.size()) {
return {Status::NotOK, "Unterminated character range"};
}
// Skip the - and the end of the range
idx += 2;
}
idx++;
}
if (idx == glob.size()) {
return {Status::NotOK, "Unterminated [ group"};
}
break;
default:
// This is a literal: nothing can go wrong
break;
}
}
return Status::OK();
}

// Glob-style pattern matching.
int StringMatchLen(const char *pattern, size_t pattern_len, const char *string, size_t string_len, int nocase) {
while (pattern_len && string_len) {
constexpr bool StringMatchImpl(std::string_view pattern, std::string_view string, bool ignore_case,
bool *skip_longer_matches, size_t recursion_depth = 0) {
// If we want to ignore case, this is equivalent to converting both the pattern and the string to lowercase
const auto canonicalize = [ignore_case](unsigned char c) -> unsigned char {
return ignore_case ? static_cast<unsigned char>(std::tolower(c)) : c;
};

if (recursion_depth > 1000) return false;

while (!pattern.empty() && !string.empty()) {
switch (pattern[0]) {
case '*':
while (pattern[1] == '*') {
pattern++;
pattern_len--;
// Optimization: collapse multiple * into one
while (pattern.size() >= 2 && pattern[1] == '*') {
pattern.remove_prefix(1);
}

if (pattern_len == 1) return 1; /* match */

while (string_len) {
if (StringMatchLen(pattern + 1, pattern_len - 1, string, string_len, nocase)) return 1; /* match */
string++;
string_len--;
// Optimization: If the '*' is the last character in the pattern, it can match anything
if (pattern.length() == 1) return true;
while (!string.empty()) {
if (StringMatchImpl(pattern.substr(1), string, ignore_case, skip_longer_matches, recursion_depth + 1))
return true;
if (*skip_longer_matches) return false;
string.remove_prefix(1);
}
return 0; /* no match */
// There was no match for the rest of the pattern starting
// from anywhere in the rest of the string. If there were
// any '*' earlier in the pattern, we can terminate the
// search early without trying to match them to longer
// substrings. This is because a longer match for the
// earlier part of the pattern would require the rest of the
// pattern to match starting later in the string, and we
// have just determined that there is no match for the rest
// of the pattern starting from anywhere in the current
// string.
*skip_longer_matches = true;
return false;
case '?':
string++;
string_len--;
if (string.empty()) return false;
string.remove_prefix(1);
break;
case '[': {
pattern++;
pattern_len--;
int not_symbol = pattern[0] == '^';
if (not_symbol) {
pattern++;
pattern_len--;
}
pattern.remove_prefix(1);
const bool invert = pattern[0] == '^';
if (invert) pattern.remove_prefix(1);

int match = 0;
bool match = false;
while (true) {
if (pattern[0] == '\\' && pattern_len >= 2) {
pattern++;
pattern_len--;
if (pattern[0] == string[0]) match = 1;
if (pattern.empty()) {
// unterminated [ group: reject invalid pattern
return false;
} else if (pattern[0] == ']') {
break;
} else if (pattern_len == 0) {
pattern--;
pattern_len++;
break;
} else if (pattern[1] == '-' && pattern_len >= 3) {
int start = pattern[0];
int end = pattern[2];
int c = string[0];
if (start > end) {
int t = start;
start = end;
end = t;
}
if (nocase) {
start = tolower(start);
end = tolower(end);
c = tolower(c);
}
pattern += 2;
pattern_len -= 2;
if (c >= start && c <= end) match = 1;
} else {
if (!nocase) {
if (pattern[0] == string[0]) match = 1;
} else {
if (tolower(static_cast<int>(pattern[0])) == tolower(static_cast<int>(string[0]))) match = 1;
}
} else if (pattern.length() >= 2 && pattern[0] == '\\') {
pattern.remove_prefix(1);
if (pattern[0] == string[0]) match = true;
} else if (pattern.length() >= 3 && pattern[1] == '-') {
unsigned char start = canonicalize(pattern[0]);
unsigned char end = canonicalize(pattern[2]);
if (start > end) std::swap(start, end);
const int c = canonicalize(string[0]);
pattern.remove_prefix(2);

if (c >= start && c <= end) match = true;
} else if (canonicalize(pattern[0]) == canonicalize(string[0])) {
match = true;
}
pattern++;
pattern_len--;
pattern.remove_prefix(1);
}

if (not_symbol) match = !match;

if (!match) return 0; /* no match */

string++;
string_len--;
if (invert) match = !match;
if (!match) return false;
string.remove_prefix(1);
break;
}
case '\\':
if (pattern_len >= 2) {
pattern++;
pattern_len--;
if (pattern.length() >= 2) {
pattern.remove_prefix(1);
}
/* fall through */
[[fallthrough]];
default:
if (!nocase) {
if (pattern[0] != string[0]) return 0; /* no match */
// Just a normal character
if (!ignore_case) {
if (pattern[0] != string[0]) return false;
} else {
if (tolower(static_cast<int>(pattern[0])) != tolower(static_cast<int>(string[0]))) return 0; /* no match */
if (std::tolower((int)pattern[0]) != std::tolower((int)string[0])) return false;
}
string++;
string_len--;
string.remove_prefix(1);
break;
}
pattern++;
pattern_len--;
if (string_len == 0) {
while (*pattern == '*') {
pattern++;
pattern_len--;
}
break;
}
pattern.remove_prefix(1);
}

if (pattern_len == 0 && string_len == 0) return 1;
return 0;
// Now that either the pattern is empty or the string is empty, this is a match iff
// the pattern consists only of '*', and the string is empty.
return string.empty() && std::all_of(pattern.begin(), pattern.end(), [](char c) { return c == '*'; });
}

// Given a glob [pattern] and a string [string], return true iff the string matches the glob.
// If [ignore_case] is true, the match is case-insensitive.
bool StringMatch(std::string_view glob, std::string_view str, bool ignore_case) {
bool skip_longer_matches = false;
return StringMatchImpl(glob, str, ignore_case, &skip_longer_matches);
}

// Split a glob pattern into a literal prefix and a suffix containing wildcards.
// For example, if the user calls [KEYS bla*bla], this function will return {"bla", "*bla"}.
// This allows the caller of this function to optimize this call by performing a
// prefix-scan on "bla" and then filtering the results using the GlobMatches function.
std::pair<std::string, std::string> SplitGlob(std::string_view glob) {
// Stores the prefix of the glob pattern, with backslashes removed
std::string prefix;
// Find the first un-escaped '*', '?' or '[' character in [glob]
for (size_t idx = 0; idx < glob.size(); ++idx) {
if (glob[idx] == '*' || glob[idx] == '?' || glob[idx] == '[') {
// Return a pair of views: the part of the glob before the wildcard, and the part after
return {prefix, std::string(glob.substr(idx))};
} else if (glob[idx] == '\\') {
// Skip checking whether the next character is a special character
++idx;
// Append the escaped special character to the prefix
if (idx < glob.size()) prefix.push_back(glob[idx]);
} else {
prefix.push_back(glob[idx]);
}
}
// No wildcard found, return the entire string (without the backslashes) as the prefix
return {prefix, ""};
}

std::vector<std::string> RegexMatch(const std::string &str, const std::string &regex) {
Expand Down
15 changes: 12 additions & 3 deletions src/common/string_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@

#pragma once

#include "status.h"
#include <cstdint>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

#include "common/status.h"

namespace util {

Expand All @@ -32,8 +38,11 @@ std::string Trim(std::string in, std::string_view chars);
std::vector<std::string> Split(std::string_view in, std::string_view delim);
std::vector<std::string> Split2KV(const std::string &in, const std::string &delim);
bool HasPrefix(const std::string &str, const std::string &prefix);
int StringMatch(const std::string &pattern, const std::string &in, int nocase);
int StringMatchLen(const char *p, size_t plen, const char *s, size_t slen, int nocase);

Status ValidateGlob(std::string_view glob);
bool StringMatch(std::string_view glob, std::string_view str, bool ignore_case = false);
std::pair<std::string, std::string> SplitGlob(std::string_view glob);

std::vector<std::string> RegexMatch(const std::string &str, const std::string &regex);
std::string StringToHex(std::string_view input);
std::vector<std::string> TokenizeRedisProtocol(const std::string &value);
Expand Down
Loading
Loading