Skip to content

Commit

Permalink
feat(scan): Support arbitrary glob patterns (apache#2608)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanlo-hrt authored Oct 29, 2024
1 parent c7b6b22 commit 4aa36ec
Show file tree
Hide file tree
Showing 13 changed files with 448 additions and 139 deletions.
5 changes: 5 additions & 0 deletions .github/config/typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ extend-exclude = [
".git/",
"src/vendor/",
"tests/gocase/util/slot.go",

# Uses short strings for testing glob matching
"tests/cppunit/string_util_test.cc",
"tests/gocase/unit/keyspace/keyspace_test.go",
"tests/gocase/unit/scan/scan_test.go",
]
ignore-hidden = false

Expand Down
16 changes: 8 additions & 8 deletions src/commands/cmd_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
#include "commands/scan_base.h"
#include "common/io_util.h"
#include "common/rdb_stream.h"
#include "common/string_util.h"
#include "common/time_util.h"
#include "config/config.h"
#include "error_constants.h"
#include "server/redis_connection.h"
#include "server/redis_reply.h"
#include "server/server.h"
#include "stats/disk_stats.h"
#include "storage/rdb/rdb.h"
#include "string_util.h"
#include "time_util.h"

namespace redis {

Expand Down Expand Up @@ -114,15 +114,15 @@ class CommandNamespace : public Commander {
class CommandKeys : public Commander {
public:
Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override {
const std::string &prefix = args_[1];
const std::string &glob_pattern = args_[1];
std::vector<std::string> keys;
redis::Database redis(srv->storage, conn->GetNamespace());

if (prefix.empty() || prefix.find('*') != prefix.size() - 1) {
return {Status::RedisExecErr, "only keys prefix match was supported"};
if (const Status s = util::ValidateGlob(glob_pattern); !s.IsOK()) {
return {Status::RedisParseErr, "Invalid glob pattern: " + s.Msg()};
}

const rocksdb::Status s = redis.Keys(ctx, prefix.substr(0, prefix.size() - 1), &keys);
const auto [prefix, suffix_glob] = util::SplitGlob(glob_pattern);
const rocksdb::Status s = redis.Keys(ctx, prefix, suffix_glob, &keys);
if (!s.ok()) {
return {Status::RedisExecErr, s.ToString()};
}
Expand Down Expand Up @@ -846,7 +846,7 @@ class CommandScan : public CommandScanBase {
std::vector<std::string> keys;
std::string end_key;

auto s = redis_db.Scan(ctx, key_name, limit_, prefix_, &keys, &end_key, type_);
const auto s = redis_db.Scan(ctx, key_name, limit_, prefix_, suffix_glob_, &keys, &end_key, type_);
if (!s.ok()) {
return {Status::RedisExecErr, s.ToString()};
}
Expand Down
14 changes: 7 additions & 7 deletions src/commands/scan_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
#include "commander.h"
#include "commands/command_parser.h"
#include "error_constants.h"
#include "glob.h"
#include "parse_util.h"
#include "server/server.h"
#include "string_util.h"

namespace redis {

Expand All @@ -44,14 +46,11 @@ class CommandScanBase : public Commander {
Status ParseAdditionalFlags(Parser &parser) {
while (parser.Good()) {
if (parser.EatEqICase("match")) {
prefix_ = GET_OR_RET(parser.TakeStr());
// The match pattern should contain exactly one '*' at the end; remove the * to
// get the prefix to match.
if (!prefix_.empty() && prefix_.find('*') == prefix_.size() - 1) {
prefix_.pop_back();
} else {
return {Status::RedisParseErr, "currently only key prefix matching is supported"};
const std::string glob_pattern = GET_OR_RET(parser.TakeStr());
if (const Status s = util::ValidateGlob(glob_pattern); !s.IsOK()) {
return {Status::RedisParseErr, "Invalid glob pattern: " + s.Msg()};
}
std::tie(prefix_, suffix_glob_) = util::SplitGlob(glob_pattern);
} else if (parser.EatEqICase("count")) {
limit_ = GET_OR_RET(parser.TakeInt());
if (limit_ <= 0) {
Expand Down Expand Up @@ -100,6 +99,7 @@ class CommandScanBase : public Commander {
protected:
std::string cursor_;
std::string prefix_;
std::string suffix_glob_ = "*";
int limit_ = 20;
RedisType type_ = kRedisNone;
};
Expand Down
1 change: 0 additions & 1 deletion src/common/status.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include <algorithm>
#include <memory>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>

Expand Down
228 changes: 142 additions & 86 deletions src/common/string_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,118 +101,174 @@ bool HasPrefix(const std::string &str, const std::string &prefix) {
return !strncasecmp(str.data(), prefix.data(), prefix.size());
}

int StringMatch(const std::string &pattern, const std::string &in, int nocase) {
return StringMatchLen(pattern.c_str(), pattern.length(), in.c_str(), in.length(), nocase);
Status ValidateGlob(std::string_view glob) {
for (size_t idx = 0; idx < glob.size(); ++idx) {
switch (glob[idx]) {
case '*':
case '?':
break;
case ']':
return {Status::NotOK, "Unmatched unescaped ]"};
case '\\':
if (idx == glob.size() - 1) {
return {Status::NotOK, "Trailing unescaped backslash"};
}
// Skip the next character: this is a literal so nothing can go wrong
idx++;
break;
case '[':
idx++; // Skip the opening bracket
while (idx < glob.size() && glob[idx] != ']') {
if (glob[idx] == '\\') {
idx += 2;
continue;
} else if (idx + 1 < glob.size() && glob[idx + 1] == '-') {
if (idx + 2 >= glob.size()) {
return {Status::NotOK, "Unterminated character range"};
}
// Skip the - and the end of the range
idx += 2;
}
idx++;
}
if (idx == glob.size()) {
return {Status::NotOK, "Unterminated [ group"};
}
break;
default:
// This is a literal: nothing can go wrong
break;
}
}
return Status::OK();
}

// Glob-style pattern matching.
int StringMatchLen(const char *pattern, size_t pattern_len, const char *string, size_t string_len, int nocase) {
while (pattern_len && string_len) {
constexpr bool StringMatchImpl(std::string_view pattern, std::string_view string, bool ignore_case,
bool *skip_longer_matches, size_t recursion_depth = 0) {
// If we want to ignore case, this is equivalent to converting both the pattern and the string to lowercase
const auto canonicalize = [ignore_case](unsigned char c) -> unsigned char {
return ignore_case ? static_cast<unsigned char>(std::tolower(c)) : c;
};

if (recursion_depth > 1000) return false;

while (!pattern.empty() && !string.empty()) {
switch (pattern[0]) {
case '*':
while (pattern[1] == '*') {
pattern++;
pattern_len--;
// Optimization: collapse multiple * into one
while (pattern.size() >= 2 && pattern[1] == '*') {
pattern.remove_prefix(1);
}

if (pattern_len == 1) return 1; /* match */

while (string_len) {
if (StringMatchLen(pattern + 1, pattern_len - 1, string, string_len, nocase)) return 1; /* match */
string++;
string_len--;
// Optimization: If the '*' is the last character in the pattern, it can match anything
if (pattern.length() == 1) return true;
while (!string.empty()) {
if (StringMatchImpl(pattern.substr(1), string, ignore_case, skip_longer_matches, recursion_depth + 1))
return true;
if (*skip_longer_matches) return false;
string.remove_prefix(1);
}
return 0; /* no match */
// There was no match for the rest of the pattern starting
// from anywhere in the rest of the string. If there were
// any '*' earlier in the pattern, we can terminate the
// search early without trying to match them to longer
// substrings. This is because a longer match for the
// earlier part of the pattern would require the rest of the
// pattern to match starting later in the string, and we
// have just determined that there is no match for the rest
// of the pattern starting from anywhere in the current
// string.
*skip_longer_matches = true;
return false;
case '?':
string++;
string_len--;
if (string.empty()) return false;
string.remove_prefix(1);
break;
case '[': {
pattern++;
pattern_len--;
int not_symbol = pattern[0] == '^';
if (not_symbol) {
pattern++;
pattern_len--;
}
pattern.remove_prefix(1);
const bool invert = pattern[0] == '^';
if (invert) pattern.remove_prefix(1);

int match = 0;
bool match = false;
while (true) {
if (pattern[0] == '\\' && pattern_len >= 2) {
pattern++;
pattern_len--;
if (pattern[0] == string[0]) match = 1;
if (pattern.empty()) {
// unterminated [ group: reject invalid pattern
return false;
} else if (pattern[0] == ']') {
break;
} else if (pattern_len == 0) {
pattern--;
pattern_len++;
break;
} else if (pattern[1] == '-' && pattern_len >= 3) {
int start = pattern[0];
int end = pattern[2];
int c = string[0];
if (start > end) {
int t = start;
start = end;
end = t;
}
if (nocase) {
start = tolower(start);
end = tolower(end);
c = tolower(c);
}
pattern += 2;
pattern_len -= 2;
if (c >= start && c <= end) match = 1;
} else {
if (!nocase) {
if (pattern[0] == string[0]) match = 1;
} else {
if (tolower(static_cast<int>(pattern[0])) == tolower(static_cast<int>(string[0]))) match = 1;
}
} else if (pattern.length() >= 2 && pattern[0] == '\\') {
pattern.remove_prefix(1);
if (pattern[0] == string[0]) match = true;
} else if (pattern.length() >= 3 && pattern[1] == '-') {
unsigned char start = canonicalize(pattern[0]);
unsigned char end = canonicalize(pattern[2]);
if (start > end) std::swap(start, end);
const int c = canonicalize(string[0]);
pattern.remove_prefix(2);

if (c >= start && c <= end) match = true;
} else if (canonicalize(pattern[0]) == canonicalize(string[0])) {
match = true;
}
pattern++;
pattern_len--;
pattern.remove_prefix(1);
}

if (not_symbol) match = !match;

if (!match) return 0; /* no match */

string++;
string_len--;
if (invert) match = !match;
if (!match) return false;
string.remove_prefix(1);
break;
}
case '\\':
if (pattern_len >= 2) {
pattern++;
pattern_len--;
if (pattern.length() >= 2) {
pattern.remove_prefix(1);
}
/* fall through */
[[fallthrough]];
default:
if (!nocase) {
if (pattern[0] != string[0]) return 0; /* no match */
// Just a normal character
if (!ignore_case) {
if (pattern[0] != string[0]) return false;
} else {
if (tolower(static_cast<int>(pattern[0])) != tolower(static_cast<int>(string[0]))) return 0; /* no match */
if (std::tolower((int)pattern[0]) != std::tolower((int)string[0])) return false;
}
string++;
string_len--;
string.remove_prefix(1);
break;
}
pattern++;
pattern_len--;
if (string_len == 0) {
while (*pattern == '*') {
pattern++;
pattern_len--;
}
break;
}
pattern.remove_prefix(1);
}

if (pattern_len == 0 && string_len == 0) return 1;
return 0;
// Now that either the pattern is empty or the string is empty, this is a match iff
// the pattern consists only of '*', and the string is empty.
return string.empty() && std::all_of(pattern.begin(), pattern.end(), [](char c) { return c == '*'; });
}

// Given a glob [pattern] and a string [string], return true iff the string matches the glob.
// If [ignore_case] is true, the match is case-insensitive.
bool StringMatch(std::string_view glob, std::string_view str, bool ignore_case) {
bool skip_longer_matches = false;
return StringMatchImpl(glob, str, ignore_case, &skip_longer_matches);
}

// Split a glob pattern into a literal prefix and a suffix containing wildcards.
// For example, if the user calls [KEYS bla*bla], this function will return {"bla", "*bla"}.
// This allows the caller of this function to optimize this call by performing a
// prefix-scan on "bla" and then filtering the results using the GlobMatches function.
std::pair<std::string, std::string> SplitGlob(std::string_view glob) {
// Stores the prefix of the glob pattern, with backslashes removed
std::string prefix;
// Find the first un-escaped '*', '?' or '[' character in [glob]
for (size_t idx = 0; idx < glob.size(); ++idx) {
if (glob[idx] == '*' || glob[idx] == '?' || glob[idx] == '[') {
// Return a pair of views: the part of the glob before the wildcard, and the part after
return {prefix, std::string(glob.substr(idx))};
} else if (glob[idx] == '\\') {
// Skip checking whether the next character is a special character
++idx;
// Append the escaped special character to the prefix
if (idx < glob.size()) prefix.push_back(glob[idx]);
} else {
prefix.push_back(glob[idx]);
}
}
// No wildcard found, return the entire string (without the backslashes) as the prefix
return {prefix, ""};
}

std::vector<std::string> RegexMatch(const std::string &str, const std::string &regex) {
Expand Down
15 changes: 12 additions & 3 deletions src/common/string_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@

#pragma once

#include "status.h"
#include <cstdint>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

#include "common/status.h"

namespace util {

Expand All @@ -32,8 +38,11 @@ std::string Trim(std::string in, std::string_view chars);
std::vector<std::string> Split(std::string_view in, std::string_view delim);
std::vector<std::string> Split2KV(const std::string &in, const std::string &delim);
bool HasPrefix(const std::string &str, const std::string &prefix);
int StringMatch(const std::string &pattern, const std::string &in, int nocase);
int StringMatchLen(const char *p, size_t plen, const char *s, size_t slen, int nocase);

Status ValidateGlob(std::string_view glob);
bool StringMatch(std::string_view glob, std::string_view str, bool ignore_case = false);
std::pair<std::string, std::string> SplitGlob(std::string_view glob);

std::vector<std::string> RegexMatch(const std::string &str, const std::string &regex);
std::string StringToHex(std::string_view input);
std::vector<std::string> TokenizeRedisProtocol(const std::string &value);
Expand Down
Loading

0 comments on commit 4aa36ec

Please sign in to comment.