From 4aa36ec5efd4d4cb6e8250a674715c119d1c3b64 Mon Sep 17 00:00:00 2001 From: Nathan <151768548+nathanlo-hrt@users.noreply.github.com> Date: Mon, 28 Oct 2024 22:22:36 -0400 Subject: [PATCH] feat(scan): Support arbitrary glob patterns (#2608) --- .github/config/typos.toml | 5 + src/commands/cmd_server.cc | 16 +- src/commands/scan_base.h | 14 +- src/common/status.h | 1 - src/common/string_util.cc | 228 ++++++++++++-------- src/common/string_util.h | 15 +- src/config/config.cc | 2 +- src/server/server.cc | 11 +- src/storage/redis_db.cc | 28 ++- src/storage/redis_db.h | 9 +- tests/cppunit/string_util_test.cc | 156 ++++++++++++++ tests/gocase/unit/keyspace/keyspace_test.go | 51 ++++- tests/gocase/unit/scan/scan_test.go | 51 ++++- 13 files changed, 448 insertions(+), 139 deletions(-) diff --git a/.github/config/typos.toml b/.github/config/typos.toml index daae57c8137..035185408e1 100644 --- a/.github/config/typos.toml +++ b/.github/config/typos.toml @@ -20,6 +20,11 @@ extend-exclude = [ ".git/", "src/vendor/", "tests/gocase/util/slot.go", + + # Uses short strings for testing glob matching + "tests/cppunit/string_util_test.cc", + "tests/gocase/unit/keyspace/keyspace_test.go", + "tests/gocase/unit/scan/scan_test.go", ] ignore-hidden = false diff --git a/src/commands/cmd_server.cc b/src/commands/cmd_server.cc index 3d98a5f56f8..600f23c07a2 100644 --- a/src/commands/cmd_server.cc +++ b/src/commands/cmd_server.cc @@ -23,6 +23,8 @@ #include "commands/scan_base.h" #include "common/io_util.h" #include "common/rdb_stream.h" +#include "common/string_util.h" +#include "common/time_util.h" #include "config/config.h" #include "error_constants.h" #include "server/redis_connection.h" @@ -30,8 +32,6 @@ #include "server/server.h" #include "stats/disk_stats.h" #include "storage/rdb/rdb.h" -#include "string_util.h" -#include "time_util.h" namespace redis { @@ -114,15 +114,15 @@ class CommandNamespace : public Commander { class CommandKeys : public Commander { public: Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override { - const std::string &prefix = args_[1]; + const std::string &glob_pattern = args_[1]; std::vector keys; redis::Database redis(srv->storage, conn->GetNamespace()); - if (prefix.empty() || prefix.find('*') != prefix.size() - 1) { - return {Status::RedisExecErr, "only keys prefix match was supported"}; + if (const Status s = util::ValidateGlob(glob_pattern); !s.IsOK()) { + return {Status::RedisParseErr, "Invalid glob pattern: " + s.Msg()}; } - - const rocksdb::Status s = redis.Keys(ctx, prefix.substr(0, prefix.size() - 1), &keys); + const auto [prefix, suffix_glob] = util::SplitGlob(glob_pattern); + const rocksdb::Status s = redis.Keys(ctx, prefix, suffix_glob, &keys); if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; } @@ -846,7 +846,7 @@ class CommandScan : public CommandScanBase { std::vector keys; std::string end_key; - auto s = redis_db.Scan(ctx, key_name, limit_, prefix_, &keys, &end_key, type_); + const auto s = redis_db.Scan(ctx, key_name, limit_, prefix_, suffix_glob_, &keys, &end_key, type_); if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; } diff --git a/src/commands/scan_base.h b/src/commands/scan_base.h index b3773b94ff7..5a0d4ca17c1 100644 --- a/src/commands/scan_base.h +++ b/src/commands/scan_base.h @@ -23,8 +23,10 @@ #include "commander.h" #include "commands/command_parser.h" #include "error_constants.h" +#include "glob.h" #include "parse_util.h" #include "server/server.h" +#include "string_util.h" namespace redis { @@ -44,14 +46,11 @@ class CommandScanBase : public Commander { Status ParseAdditionalFlags(Parser &parser) { while (parser.Good()) { if (parser.EatEqICase("match")) { - prefix_ = GET_OR_RET(parser.TakeStr()); - // The match pattern should contain exactly one '*' at the end; remove the * to - // get the prefix to match. - if (!prefix_.empty() && prefix_.find('*') == prefix_.size() - 1) { - prefix_.pop_back(); - } else { - return {Status::RedisParseErr, "currently only key prefix matching is supported"}; + const std::string glob_pattern = GET_OR_RET(parser.TakeStr()); + if (const Status s = util::ValidateGlob(glob_pattern); !s.IsOK()) { + return {Status::RedisParseErr, "Invalid glob pattern: " + s.Msg()}; } + std::tie(prefix_, suffix_glob_) = util::SplitGlob(glob_pattern); } else if (parser.EatEqICase("count")) { limit_ = GET_OR_RET(parser.TakeInt()); if (limit_ <= 0) { @@ -100,6 +99,7 @@ class CommandScanBase : public Commander { protected: std::string cursor_; std::string prefix_; + std::string suffix_glob_ = "*"; int limit_ = 20; RedisType type_ = kRedisNone; }; diff --git a/src/common/status.h b/src/common/status.h index 2bd610ad904..aef74033fce 100644 --- a/src/common/status.h +++ b/src/common/status.h @@ -26,7 +26,6 @@ #include #include #include -#include #include #include diff --git a/src/common/string_util.cc b/src/common/string_util.cc index 476e3173bee..cce6440227a 100644 --- a/src/common/string_util.cc +++ b/src/common/string_util.cc @@ -101,118 +101,174 @@ bool HasPrefix(const std::string &str, const std::string &prefix) { return !strncasecmp(str.data(), prefix.data(), prefix.size()); } -int StringMatch(const std::string &pattern, const std::string &in, int nocase) { - return StringMatchLen(pattern.c_str(), pattern.length(), in.c_str(), in.length(), nocase); +Status ValidateGlob(std::string_view glob) { + for (size_t idx = 0; idx < glob.size(); ++idx) { + switch (glob[idx]) { + case '*': + case '?': + break; + case ']': + return {Status::NotOK, "Unmatched unescaped ]"}; + case '\\': + if (idx == glob.size() - 1) { + return {Status::NotOK, "Trailing unescaped backslash"}; + } + // Skip the next character: this is a literal so nothing can go wrong + idx++; + break; + case '[': + idx++; // Skip the opening bracket + while (idx < glob.size() && glob[idx] != ']') { + if (glob[idx] == '\\') { + idx += 2; + continue; + } else if (idx + 1 < glob.size() && glob[idx + 1] == '-') { + if (idx + 2 >= glob.size()) { + return {Status::NotOK, "Unterminated character range"}; + } + // Skip the - and the end of the range + idx += 2; + } + idx++; + } + if (idx == glob.size()) { + return {Status::NotOK, "Unterminated [ group"}; + } + break; + default: + // This is a literal: nothing can go wrong + break; + } + } + return Status::OK(); } -// Glob-style pattern matching. -int StringMatchLen(const char *pattern, size_t pattern_len, const char *string, size_t string_len, int nocase) { - while (pattern_len && string_len) { +constexpr bool StringMatchImpl(std::string_view pattern, std::string_view string, bool ignore_case, + bool *skip_longer_matches, size_t recursion_depth = 0) { + // If we want to ignore case, this is equivalent to converting both the pattern and the string to lowercase + const auto canonicalize = [ignore_case](unsigned char c) -> unsigned char { + return ignore_case ? static_cast(std::tolower(c)) : c; + }; + + if (recursion_depth > 1000) return false; + + while (!pattern.empty() && !string.empty()) { switch (pattern[0]) { case '*': - while (pattern[1] == '*') { - pattern++; - pattern_len--; + // Optimization: collapse multiple * into one + while (pattern.size() >= 2 && pattern[1] == '*') { + pattern.remove_prefix(1); } - - if (pattern_len == 1) return 1; /* match */ - - while (string_len) { - if (StringMatchLen(pattern + 1, pattern_len - 1, string, string_len, nocase)) return 1; /* match */ - string++; - string_len--; + // Optimization: If the '*' is the last character in the pattern, it can match anything + if (pattern.length() == 1) return true; + while (!string.empty()) { + if (StringMatchImpl(pattern.substr(1), string, ignore_case, skip_longer_matches, recursion_depth + 1)) + return true; + if (*skip_longer_matches) return false; + string.remove_prefix(1); } - return 0; /* no match */ + // There was no match for the rest of the pattern starting + // from anywhere in the rest of the string. If there were + // any '*' earlier in the pattern, we can terminate the + // search early without trying to match them to longer + // substrings. This is because a longer match for the + // earlier part of the pattern would require the rest of the + // pattern to match starting later in the string, and we + // have just determined that there is no match for the rest + // of the pattern starting from anywhere in the current + // string. + *skip_longer_matches = true; + return false; case '?': - string++; - string_len--; + if (string.empty()) return false; + string.remove_prefix(1); break; case '[': { - pattern++; - pattern_len--; - int not_symbol = pattern[0] == '^'; - if (not_symbol) { - pattern++; - pattern_len--; - } + pattern.remove_prefix(1); + const bool invert = pattern[0] == '^'; + if (invert) pattern.remove_prefix(1); - int match = 0; + bool match = false; while (true) { - if (pattern[0] == '\\' && pattern_len >= 2) { - pattern++; - pattern_len--; - if (pattern[0] == string[0]) match = 1; + if (pattern.empty()) { + // unterminated [ group: reject invalid pattern + return false; } else if (pattern[0] == ']') { break; - } else if (pattern_len == 0) { - pattern--; - pattern_len++; - break; - } else if (pattern[1] == '-' && pattern_len >= 3) { - int start = pattern[0]; - int end = pattern[2]; - int c = string[0]; - if (start > end) { - int t = start; - start = end; - end = t; - } - if (nocase) { - start = tolower(start); - end = tolower(end); - c = tolower(c); - } - pattern += 2; - pattern_len -= 2; - if (c >= start && c <= end) match = 1; - } else { - if (!nocase) { - if (pattern[0] == string[0]) match = 1; - } else { - if (tolower(static_cast(pattern[0])) == tolower(static_cast(string[0]))) match = 1; - } + } else if (pattern.length() >= 2 && pattern[0] == '\\') { + pattern.remove_prefix(1); + if (pattern[0] == string[0]) match = true; + } else if (pattern.length() >= 3 && pattern[1] == '-') { + unsigned char start = canonicalize(pattern[0]); + unsigned char end = canonicalize(pattern[2]); + if (start > end) std::swap(start, end); + const int c = canonicalize(string[0]); + pattern.remove_prefix(2); + + if (c >= start && c <= end) match = true; + } else if (canonicalize(pattern[0]) == canonicalize(string[0])) { + match = true; } - pattern++; - pattern_len--; + pattern.remove_prefix(1); } - - if (not_symbol) match = !match; - - if (!match) return 0; /* no match */ - - string++; - string_len--; + if (invert) match = !match; + if (!match) return false; + string.remove_prefix(1); break; } case '\\': - if (pattern_len >= 2) { - pattern++; - pattern_len--; + if (pattern.length() >= 2) { + pattern.remove_prefix(1); } - /* fall through */ + [[fallthrough]]; default: - if (!nocase) { - if (pattern[0] != string[0]) return 0; /* no match */ + // Just a normal character + if (!ignore_case) { + if (pattern[0] != string[0]) return false; } else { - if (tolower(static_cast(pattern[0])) != tolower(static_cast(string[0]))) return 0; /* no match */ + if (std::tolower((int)pattern[0]) != std::tolower((int)string[0])) return false; } - string++; - string_len--; + string.remove_prefix(1); break; } - pattern++; - pattern_len--; - if (string_len == 0) { - while (*pattern == '*') { - pattern++; - pattern_len--; - } - break; - } + pattern.remove_prefix(1); } - if (pattern_len == 0 && string_len == 0) return 1; - return 0; + // Now that either the pattern is empty or the string is empty, this is a match iff + // the pattern consists only of '*', and the string is empty. + return string.empty() && std::all_of(pattern.begin(), pattern.end(), [](char c) { return c == '*'; }); +} + +// Given a glob [pattern] and a string [string], return true iff the string matches the glob. +// If [ignore_case] is true, the match is case-insensitive. +bool StringMatch(std::string_view glob, std::string_view str, bool ignore_case) { + bool skip_longer_matches = false; + return StringMatchImpl(glob, str, ignore_case, &skip_longer_matches); +} + +// Split a glob pattern into a literal prefix and a suffix containing wildcards. +// For example, if the user calls [KEYS bla*bla], this function will return {"bla", "*bla"}. +// This allows the caller of this function to optimize this call by performing a +// prefix-scan on "bla" and then filtering the results using the GlobMatches function. +std::pair SplitGlob(std::string_view glob) { + // Stores the prefix of the glob pattern, with backslashes removed + std::string prefix; + // Find the first un-escaped '*', '?' or '[' character in [glob] + for (size_t idx = 0; idx < glob.size(); ++idx) { + if (glob[idx] == '*' || glob[idx] == '?' || glob[idx] == '[') { + // Return a pair of views: the part of the glob before the wildcard, and the part after + return {prefix, std::string(glob.substr(idx))}; + } else if (glob[idx] == '\\') { + // Skip checking whether the next character is a special character + ++idx; + // Append the escaped special character to the prefix + if (idx < glob.size()) prefix.push_back(glob[idx]); + } else { + prefix.push_back(glob[idx]); + } + } + // No wildcard found, return the entire string (without the backslashes) as the prefix + return {prefix, ""}; } std::vector RegexMatch(const std::string &str, const std::string ®ex) { diff --git a/src/common/string_util.h b/src/common/string_util.h index 2dcb10800f5..f86590ad046 100644 --- a/src/common/string_util.h +++ b/src/common/string_util.h @@ -20,7 +20,13 @@ #pragma once -#include "status.h" +#include +#include +#include +#include +#include + +#include "common/status.h" namespace util { @@ -32,8 +38,11 @@ std::string Trim(std::string in, std::string_view chars); std::vector Split(std::string_view in, std::string_view delim); std::vector Split2KV(const std::string &in, const std::string &delim); bool HasPrefix(const std::string &str, const std::string &prefix); -int StringMatch(const std::string &pattern, const std::string &in, int nocase); -int StringMatchLen(const char *p, size_t plen, const char *s, size_t slen, int nocase); + +Status ValidateGlob(std::string_view glob); +bool StringMatch(std::string_view glob, std::string_view str, bool ignore_case = false); +std::pair SplitGlob(std::string_view glob); + std::vector RegexMatch(const std::string &str, const std::string ®ex); std::string StringToHex(std::string_view input); std::vector TokenizeRedisProtocol(const std::string &value); diff --git a/src/config/config.cc b/src/config/config.cc index 57b2c7b08d6..f14dc78c614 100644 --- a/src/config/config.cc +++ b/src/config/config.cc @@ -904,7 +904,7 @@ Status Config::Load(const CLIOptions &opts) { void Config::Get(const std::string &key, std::vector *values) const { values->clear(); for (const auto &iter : fields_) { - if (util::StringMatch(key, iter.first, 1)) { + if (util::StringMatch(key, iter.first, true)) { if (iter.second->IsMultiConfig()) { for (const auto &p : util::Split(iter.second->ToString(), "\n")) { values->emplace_back(iter.first); diff --git a/src/server/server.cc b/src/server/server.cc index 1de5534dc54..e569d12e5f9 100644 --- a/src/server/server.cc +++ b/src/server/server.cc @@ -38,7 +38,7 @@ #include #include "commands/commander.h" -#include "config.h" +#include "common/string_util.h" #include "config/config.h" #include "fmt/format.h" #include "redis_connection.h" @@ -46,7 +46,6 @@ #include "storage/redis_db.h" #include "storage/scripting.h" #include "storage/storage.h" -#include "string_util.h" #include "thread_util.h" #include "time_util.h" #include "version.h" @@ -160,7 +159,7 @@ Status Server::Start() { if (!config_->cluster_enabled) { engine::Context no_txn_ctx = engine::Context::NoTransactionContext(storage); GET_OR_RET(index_mgr.Load(no_txn_ctx, kDefaultNamespace)); - for (auto [_, ns] : namespace_.List()) { + for (const auto &[_, ns] : namespace_.List()) { GET_OR_RET(index_mgr.Load(no_txn_ctx, ns)); } } @@ -391,7 +390,7 @@ int Server::PublishMessage(const std::string &channel, const std::string &msg) { std::vector patterns; std::vector to_publish_patterns_conn_ctxs; for (const auto &iter : pubsub_patterns_) { - if (util::StringMatch(iter.first, channel, 0)) { + if (util::StringMatch(iter.first, channel, false)) { for (const auto &conn_ctx : iter.second) { to_publish_patterns_conn_ctxs.emplace_back(conn_ctx); patterns.emplace_back(iter.first); @@ -463,7 +462,7 @@ void Server::GetChannelsByPattern(const std::string &pattern, std::vector guard(pubsub_channels_mu_); for (const auto &iter : pubsub_channels_) { - if (pattern.empty() || util::StringMatch(pattern, iter.first, 0)) { + if (pattern.empty() || util::StringMatch(pattern, iter.first, false)) { channels->emplace_back(iter.first); } } @@ -549,7 +548,7 @@ void Server::GetSChannelsByPattern(const std::string &pattern, std::vectoremplace_back(iter.first); } } diff --git a/src/storage/redis_db.cc b/src/storage/redis_db.cc index 7fe83477e9a..5eabd8d8c94 100644 --- a/src/storage/redis_db.cc +++ b/src/storage/redis_db.cc @@ -21,16 +21,15 @@ #include "redis_db.h" #include -#include #include #include "cluster/redis_slot.h" #include "common/scope_exit.h" +#include "common/string_util.h" #include "db_util.h" #include "parse_util.h" #include "rocksdb/iterator.h" #include "rocksdb/status.h" -#include "server/server.h" #include "storage/iterator.h" #include "storage/redis_metadata.h" #include "storage/storage.h" @@ -249,11 +248,11 @@ rocksdb::Status Database::GetExpireTime(engine::Context &ctx, const Slice &user_ } rocksdb::Status Database::GetKeyNumStats(engine::Context &ctx, const std::string &prefix, KeyNumStats *stats) { - return Keys(ctx, prefix, nullptr, stats); + return Keys(ctx, prefix, "*", nullptr, stats); } -rocksdb::Status Database::Keys(engine::Context &ctx, const std::string &prefix, std::vector *keys, - KeyNumStats *stats) { +rocksdb::Status Database::Keys(engine::Context &ctx, const std::string &prefix, const std::string &suffix_glob, + std::vector *keys, KeyNumStats *stats) { uint16_t slot_id = 0; std::string ns_prefix; if (namespace_ != kDefaultNamespace || keys != nullptr) { @@ -277,6 +276,10 @@ rocksdb::Status Database::Keys(engine::Context &ctx, const std::string &prefix, if (!ns_prefix.empty() && !iter->key().starts_with(ns_prefix)) { break; } + auto [_, user_key] = ExtractNamespaceKey(iter->key(), storage_->IsSlotIdEncoded()); + if (!util::StringMatch(suffix_glob, user_key.ToString().substr(prefix.size()))) { + continue; + } Metadata metadata(kRedisNone, false); auto s = metadata.Decode(iter->value()); if (!s.ok()) continue; @@ -293,7 +296,6 @@ rocksdb::Status Database::Keys(engine::Context &ctx, const std::string &prefix, } } if (keys) { - auto [_, user_key] = ExtractNamespaceKey(iter->key(), storage_->IsSlotIdEncoded()); keys->emplace_back(user_key.ToString()); } } @@ -319,8 +321,8 @@ rocksdb::Status Database::Keys(engine::Context &ctx, const std::string &prefix, } rocksdb::Status Database::Scan(engine::Context &ctx, const std::string &cursor, uint64_t limit, - const std::string &prefix, std::vector *keys, std::string *end_cursor, - RedisType type) { + const std::string &prefix, const std::string &suffix_glob, + std::vector *keys, std::string *end_cursor, RedisType type) { end_cursor->clear(); uint64_t cnt = 0; uint16_t slot_start = 0; @@ -366,6 +368,10 @@ rocksdb::Status Database::Scan(engine::Context &ctx, const std::string &cursor, if (metadata.Expired()) continue; std::tie(std::ignore, user_key) = ExtractNamespaceKey(iter->key(), storage_->IsSlotIdEncoded()); + + if (!util::StringMatch(suffix_glob, user_key.substr(prefix.size()))) { + continue; + } keys->emplace_back(user_key); cnt++; } @@ -395,7 +401,7 @@ rocksdb::Status Database::Scan(engine::Context &ctx, const std::string &cursor, if (iter->Valid()) { std::tie(std::ignore, user_key) = ExtractNamespaceKey(iter->key(), storage_->IsSlotIdEncoded()); auto res = std::mismatch(prefix.begin(), prefix.end(), user_key.begin()); - if (res.first == prefix.end()) { + if (res.first == prefix.end() && util::StringMatch(suffix_glob, user_key.substr(prefix.size()))) { keys->emplace_back(user_key); } @@ -420,13 +426,13 @@ rocksdb::Status Database::RandomKey(engine::Context &ctx, const std::string &cur std::string end_cursor; std::vector keys; - auto s = Scan(ctx, cursor, RANDOM_KEY_SCAN_LIMIT, "", &keys, &end_cursor); + auto s = Scan(ctx, cursor, RANDOM_KEY_SCAN_LIMIT, "", "*", &keys, &end_cursor); if (!s.ok()) { return s; } if (keys.empty() && !cursor.empty()) { // if reach the end, restart from beginning - s = Scan(ctx, "", RANDOM_KEY_SCAN_LIMIT, "", &keys, &end_cursor); + s = Scan(ctx, "", RANDOM_KEY_SCAN_LIMIT, "", "*", &keys, &end_cursor); if (!s.ok()) { return s; } diff --git a/src/storage/redis_db.h b/src/storage/redis_db.h index 7111fed1099..41ed3daeb24 100644 --- a/src/storage/redis_db.h +++ b/src/storage/redis_db.h @@ -20,7 +20,6 @@ #pragma once -#include #include #include #include @@ -29,7 +28,6 @@ #include "cluster/cluster_defs.h" #include "redis_metadata.h" -#include "server/redis_reply.h" #include "storage.h" namespace redis { @@ -119,11 +117,12 @@ class Database { [[nodiscard]] rocksdb::Status FlushDB(engine::Context &ctx); [[nodiscard]] rocksdb::Status FlushAll(engine::Context &ctx); [[nodiscard]] rocksdb::Status GetKeyNumStats(engine::Context &ctx, const std::string &prefix, KeyNumStats *stats); - [[nodiscard]] rocksdb::Status Keys(engine::Context &ctx, const std::string &prefix, + [[nodiscard]] rocksdb::Status Keys(engine::Context &ctx, const std::string &prefix, const std::string &suffix_glob, std::vector *keys = nullptr, KeyNumStats *stats = nullptr); [[nodiscard]] rocksdb::Status Scan(engine::Context &ctx, const std::string &cursor, uint64_t limit, - const std::string &prefix, std::vector *keys, - std::string *end_cursor = nullptr, RedisType type = kRedisNone); + const std::string &prefix, const std::string &suffix_glob, + std::vector *keys, std::string *end_cursor = nullptr, + RedisType type = kRedisNone); [[nodiscard]] rocksdb::Status RandomKey(engine::Context &ctx, const std::string &cursor, std::string *key); std::string AppendNamespacePrefix(const Slice &user_key); [[nodiscard]] rocksdb::Status ClearKeysOfSlotRange(engine::Context &ctx, const rocksdb::Slice &ns, diff --git a/tests/cppunit/string_util_test.cc b/tests/cppunit/string_util_test.cc index f95ccbff3dd..1d24cf594e4 100644 --- a/tests/cppunit/string_util_test.cc +++ b/tests/cppunit/string_util_test.cc @@ -22,6 +22,7 @@ #include +#include #include #include #include @@ -84,6 +85,161 @@ TEST(StringUtil, HasPrefix) { ASSERT_FALSE(util::HasPrefix("has", "has_prefix")); } +TEST(StringUtil, ValidateGlob) { + const auto expect_ok = [](std::string_view glob) { + const auto result = util::ValidateGlob(glob); + EXPECT_TRUE(result.IsOK()) << glob << ": " << result.Msg(); + }; + + const auto expect_error = [](std::string_view glob, std::string_view expected_error) { + const auto result = util::ValidateGlob(glob); + EXPECT_FALSE(result.IsOK()); + EXPECT_EQ(result.Msg(), expected_error) << glob; + }; + + expect_ok("a"); + expect_ok("\\*"); + expect_ok("\\?"); + expect_ok("\\["); + expect_ok("\\]"); + expect_ok("a*"); + expect_ok("a?"); + expect_ok("[ab]"); + expect_ok("[^ab]"); + expect_ok("[a-c]"); + // Surprisingly valid: this accepts the characters {a, b, c, e, f, g, -} + expect_ok("[a-c-e-g]"); + expect_ok("[^a-c]"); + expect_ok("[-]"); + expect_ok("[\\]]"); + expect_ok("[\\\\]"); + expect_ok("[\\?]"); + expect_ok("[\\*]"); + expect_ok("[\\[]"); + + expect_error("[", "Unterminated [ group"); + expect_error("]", "Unmatched unescaped ]"); + expect_error("[a", "Unterminated [ group"); + expect_error("\\", "Trailing unescaped backslash"); + + // Weird case: we open a character class, with the range 'a' to ']', but then never close it + expect_error("[a-]", "Unterminated [ group"); + expect_ok("[a-]]"); +} + +TEST(StringUtil, StringMatch) { + /* Some basic tests */ + EXPECT_TRUE(util::StringMatch("a", "a")); + EXPECT_FALSE(util::StringMatch("a", "b")); + EXPECT_FALSE(util::StringMatch("a", "aa")); + EXPECT_FALSE(util::StringMatch("a", "")); + EXPECT_TRUE(util::StringMatch("", "")); + EXPECT_FALSE(util::StringMatch("", "a")); + EXPECT_TRUE(util::StringMatch("*", "")); + EXPECT_TRUE(util::StringMatch("*", "a")); + + /* Simple character class tests */ + EXPECT_TRUE(util::StringMatch("[a]", "a")); + EXPECT_FALSE(util::StringMatch("[a]", "b")); + EXPECT_FALSE(util::StringMatch("[^a]", "a")); + EXPECT_TRUE(util::StringMatch("[^a]", "b")); + EXPECT_TRUE(util::StringMatch("[ab]", "a")); + EXPECT_TRUE(util::StringMatch("[ab]", "b")); + EXPECT_FALSE(util::StringMatch("[ab]", "c")); + EXPECT_TRUE(util::StringMatch("[^ab]", "c")); + EXPECT_TRUE(util::StringMatch("[a-c]", "b")); + EXPECT_FALSE(util::StringMatch("[a-c]", "d")); + + /* Corner cases in character class parsing */ + EXPECT_TRUE(util::StringMatch("[a-c-e-g]", "-")); + EXPECT_FALSE(util::StringMatch("[a-c-e-g]", "d")); + EXPECT_TRUE(util::StringMatch("[a-c-e-g]", "f")); + + /* Escaping */ + EXPECT_TRUE(util::StringMatch("\\?", "?")); + EXPECT_FALSE(util::StringMatch("\\?", "a")); + EXPECT_TRUE(util::StringMatch("\\*", "*")); + EXPECT_FALSE(util::StringMatch("\\*", "a")); + EXPECT_TRUE(util::StringMatch("\\[", "[")); + EXPECT_TRUE(util::StringMatch("\\]", "]")); + EXPECT_TRUE(util::StringMatch("\\\\", "\\")); + EXPECT_TRUE(util::StringMatch("[\\.]", ".")); + EXPECT_TRUE(util::StringMatch("[\\-]", "-")); + EXPECT_TRUE(util::StringMatch("[\\[]", "[")); + EXPECT_TRUE(util::StringMatch("[\\]]", "]")); + EXPECT_TRUE(util::StringMatch("[\\\\]", "\\")); + EXPECT_TRUE(util::StringMatch("[\\?]", "?")); + EXPECT_TRUE(util::StringMatch("[\\*]", "*")); + + /* Simple wild cards */ + EXPECT_TRUE(util::StringMatch("?", "a")); + EXPECT_FALSE(util::StringMatch("?", "aa")); + EXPECT_FALSE(util::StringMatch("??", "a")); + EXPECT_TRUE(util::StringMatch("?x?", "axb")); + EXPECT_FALSE(util::StringMatch("?x?", "abx")); + EXPECT_FALSE(util::StringMatch("?x?", "xab")); + + /* Asterisk wild cards (backtracking) */ + EXPECT_FALSE(util::StringMatch("*??", "a")); + EXPECT_TRUE(util::StringMatch("*??", "ab")); + EXPECT_TRUE(util::StringMatch("*??", "abc")); + EXPECT_TRUE(util::StringMatch("*??", "abcd")); + EXPECT_FALSE(util::StringMatch("??*", "a")); + EXPECT_TRUE(util::StringMatch("??*", "ab")); + EXPECT_TRUE(util::StringMatch("??*", "abc")); + EXPECT_TRUE(util::StringMatch("??*", "abcd")); + EXPECT_FALSE(util::StringMatch("?*?", "a")); + EXPECT_TRUE(util::StringMatch("?*?", "ab")); + EXPECT_TRUE(util::StringMatch("?*?", "abc")); + EXPECT_TRUE(util::StringMatch("?*?", "abcd")); + EXPECT_TRUE(util::StringMatch("*b", "b")); + EXPECT_TRUE(util::StringMatch("*b", "ab")); + EXPECT_FALSE(util::StringMatch("*b", "ba")); + EXPECT_TRUE(util::StringMatch("*b", "bb")); + EXPECT_TRUE(util::StringMatch("*b", "abb")); + EXPECT_TRUE(util::StringMatch("*b", "bab")); + EXPECT_TRUE(util::StringMatch("*bc", "abbc")); + EXPECT_TRUE(util::StringMatch("*bc", "bc")); + EXPECT_TRUE(util::StringMatch("*bc", "bbc")); + EXPECT_TRUE(util::StringMatch("*bc", "bcbc")); + + /* Multiple asterisks (complex backtracking) */ + EXPECT_TRUE(util::StringMatch("*ac*", "abacadaeafag")); + EXPECT_TRUE(util::StringMatch("*ac*ae*ag*", "abacadaeafag")); + EXPECT_TRUE(util::StringMatch("*a*b*[bc]*[ef]*g*", "abacadaeafag")); + EXPECT_FALSE(util::StringMatch("*a*b*[ef]*[cd]*g*", "abacadaeafag")); + EXPECT_TRUE(util::StringMatch("*abcd*", "abcabcabcabcdefg")); + EXPECT_TRUE(util::StringMatch("*ab*cd*", "abcabcabcabcdefg")); + EXPECT_TRUE(util::StringMatch("*abcd*abcdef*", "abcabcdabcdeabcdefg")); + EXPECT_FALSE(util::StringMatch("*abcd*", "abcabcabcabcefg")); + EXPECT_FALSE(util::StringMatch("*ab*cd*", "abcabcabcabcefg")); + + /* Robustness to exponential blow-ups with lots of non-collapsible asterisks */ + EXPECT_TRUE( + util::StringMatch("?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*a", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")); + EXPECT_FALSE( + util::StringMatch("?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*b", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")); +} + +TEST(StringUtil, SplitGlob) { + using namespace std::string_literals; + + // Basic functionality: no escaped characters + EXPECT_EQ(util::SplitGlob(""), std::make_pair(""s, ""s)); + EXPECT_EQ(util::SplitGlob("string"), std::make_pair("string"s, ""s)); + EXPECT_EQ(util::SplitGlob("string*"), std::make_pair("string"s, "*"s)); + EXPECT_EQ(util::SplitGlob("*string"), std::make_pair(""s, "*string"s)); + EXPECT_EQ(util::SplitGlob("str*ing"), std::make_pair("str"s, "*ing"s)); + EXPECT_EQ(util::SplitGlob("string?"), std::make_pair("string"s, "?"s)); + EXPECT_EQ(util::SplitGlob("?string"), std::make_pair(""s, "?string"s)); + EXPECT_EQ(util::SplitGlob("ab[cd]ef"), std::make_pair("ab"s, "[cd]ef"s)); + + // Escaped characters; also tests that prefix is trimmed of backslashes + EXPECT_EQ(util::SplitGlob("str\\*ing*"), std::make_pair("str*ing"s, "*"s)); + EXPECT_EQ(util::SplitGlob("str\\?ing?"), std::make_pair("str?ing"s, "?"s)); + EXPECT_EQ(util::SplitGlob("str\\[ing[a]"), std::make_pair("str[ing"s, "[a]"s)); +} + TEST(StringUtil, EscapeString) { std::unordered_map origin_to_escaped = { {"abc", "abc"}, diff --git a/tests/gocase/unit/keyspace/keyspace_test.go b/tests/gocase/unit/keyspace/keyspace_test.go index 6fbb84a7b94..37b86afd1d8 100644 --- a/tests/gocase/unit/keyspace/keyspace_test.go +++ b/tests/gocase/unit/keyspace/keyspace_test.go @@ -27,6 +27,7 @@ import ( "github.com/apache/kvrocks/tests/gocase/util" "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" ) func TestKeyspace(t *testing.T) { @@ -65,10 +66,6 @@ func TestKeyspace(t *testing.T) { require.Equal(t, []string{"foo_a", "foo_b", "foo_c"}, keys) }) - t.Run("KEYS with invalid pattern", func(t *testing.T) { - require.Error(t, rdb.Keys(ctx, "*ab*").Err()) - }) - t.Run("KEYS to get all keys", func(t *testing.T) { keys := rdb.Keys(ctx, "*").Val() sort.Slice(keys, func(i, j int) bool { @@ -77,12 +74,58 @@ func TestKeyspace(t *testing.T) { require.Equal(t, []string{"foo_a", "foo_b", "foo_c", "key_x", "key_y", "key_z"}, keys) }) + t.Run("KEYS with invalid patterns", func(t *testing.T) { + require.Error(t, rdb.Keys(ctx, "[").Err()) + require.Error(t, rdb.Keys(ctx, "\\").Err()) + require.Error(t, rdb.Keys(ctx, "[a-]").Err()) + require.Error(t, rdb.Keys(ctx, "[a").Err()) + }) + t.Run("DBSize", func(t *testing.T) { require.NoError(t, rdb.Do(ctx, "dbsize", "scan").Err()) time.Sleep(100 * time.Millisecond) require.EqualValues(t, 6, rdb.Do(ctx, "dbsize").Val()) }) + t.Run("KEYS with non-trivial patterns", func(t *testing.T) { + require.NoError(t, rdb.FlushDB(ctx).Err()) + for _, key := range []string{"aa", "aab", "aabb", "ab", "abb"} { + require.NoError(t, rdb.Set(ctx, key, "hello", 0).Err()) + } + + keys := rdb.Keys(ctx, "a*").Val() + slices.Sort(keys) + require.Equal(t, []string{"aa", "aab", "aabb", "ab", "abb"}, keys) + + keys = rdb.Keys(ctx, "aa").Val() + slices.Sort(keys) + require.Equal(t, []string{"aa"}, keys) + + keys = rdb.Keys(ctx, "aa*").Val() + slices.Sort(keys) + require.Equal(t, []string{"aa", "aab", "aabb"}, keys) + + keys = rdb.Keys(ctx, "a?").Val() + slices.Sort(keys) + require.Equal(t, []string{"aa", "ab"}, keys) + + keys = rdb.Keys(ctx, "a*?").Val() + slices.Sort(keys) + require.Equal(t, []string{"aa", "aab", "aabb", "ab", "abb"}, keys) + + keys = rdb.Keys(ctx, "ab*").Val() + slices.Sort(keys) + require.Equal(t, []string{"ab", "abb"}, keys) + + keys = rdb.Keys(ctx, "*ab").Val() + slices.Sort(keys) + require.Equal(t, []string{"aab", "ab"}, keys) + + keys = rdb.Keys(ctx, "*ab*").Val() + slices.Sort(keys) + require.Equal(t, []string{"aab", "aabb", "ab", "abb"}, keys) + }) + t.Run("DEL all keys", func(t *testing.T) { vals := rdb.Keys(ctx, "*").Val() require.EqualValues(t, len(vals), rdb.Del(ctx, vals...).Val()) diff --git a/tests/gocase/unit/scan/scan_test.go b/tests/gocase/unit/scan/scan_test.go index cade5dcc5cc..5d2f9fb9a1f 100644 --- a/tests/gocase/unit/scan/scan_test.go +++ b/tests/gocase/unit/scan/scan_test.go @@ -77,7 +77,6 @@ func ScanTest(t *testing.T, rdb *redis.Client, ctx context.Context) { require.NoError(t, rdb.FlushDB(ctx).Err()) util.Populate(t, rdb, "", 1000, 10) keys := scanAll(t, rdb) - keys = slices.Compact(keys) require.Len(t, keys, 1000) }) @@ -85,7 +84,6 @@ func ScanTest(t *testing.T, rdb *redis.Client, ctx context.Context) { require.NoError(t, rdb.FlushDB(ctx).Err()) util.Populate(t, rdb, "", 1000, 10) keys := scanAll(t, rdb, "count", 5) - keys = slices.Compact(keys) require.Len(t, keys, 1000) }) @@ -93,15 +91,46 @@ func ScanTest(t *testing.T, rdb *redis.Client, ctx context.Context) { require.NoError(t, rdb.FlushDB(ctx).Err()) util.Populate(t, rdb, "key:", 1000, 10) keys := scanAll(t, rdb, "match", "key:*") - keys = slices.Compact(keys) require.Len(t, keys, 1000) }) - t.Run("SCAN MATCH invalid pattern", func(t *testing.T) { + t.Run("SCAN MATCH non-trivial pattern", func(t *testing.T) { require.NoError(t, rdb.FlushDB(ctx).Err()) - util.Populate(t, rdb, "*ab", 1000, 10) - // SCAN MATCH with invalid pattern should return an error - require.Error(t, rdb.Do(context.Background(), "SCAN", "match", "*ab*").Err()) + + for _, key := range []string{"aa", "aab", "aabb", "ab", "abb", "ba"} { + require.NoError(t, rdb.Set(ctx, key, "hello", 0).Err()) + } + + keys := scanAll(t, rdb, "match", "a*") + require.Equal(t, []string{"aa", "aab", "aabb", "ab", "abb"}, keys) + + keys = scanAll(t, rdb, "match", "aa") + require.Equal(t, []string{"aa"}, keys) + + keys = scanAll(t, rdb, "match", "aa*") + require.Equal(t, []string{"aa", "aab", "aabb"}, keys) + + keys = scanAll(t, rdb, "match", "a?") + require.Equal(t, []string{"aa", "ab"}, keys) + + keys = scanAll(t, rdb, "match", "a*?") + require.Equal(t, []string{"aa", "aab", "aabb", "ab", "abb"}, keys) + + keys = scanAll(t, rdb, "match", "ab*") + require.Equal(t, []string{"ab", "abb"}, keys) + + keys = scanAll(t, rdb, "match", "*ab") + require.Equal(t, []string{"aab", "ab"}, keys) + + keys = scanAll(t, rdb, "match", "*ab*") + require.Equal(t, []string{"aab", "aabb", "ab", "abb"}, keys) + + // Special case: using [b]* instead of b* forces the a full scan of the keyspace, + // matching every result with the pattern. We ask for exactly one key, but the + // first 5 keys don't match the pattern. This tests that SCAN returns a valid + // cursor even when the first [limit] keys don't satisfy the pattern. + keys = scanAll(t, rdb, "match", "[b]*", "count", "1") + require.Equal(t, []string{"ba"}, keys) }) t.Run("SCAN guarantees check under write load", func(t *testing.T) { @@ -226,6 +255,7 @@ func ScanTest(t *testing.T, rdb *redis.Client, ctx context.Context) { require.NoError(t, rdb.SAdd(ctx, "set", elements...).Err()) keys, _, err := rdb.SScan(ctx, "set", 0, "", 10000).Result() require.NoError(t, err) + slices.Sort(keys) keys = slices.Compact(keys) require.Len(t, keys, 100) }) @@ -307,6 +337,11 @@ func ScanTest(t *testing.T, rdb *redis.Client, ctx context.Context) { require.NoError(t, rdb.Do(ctx, "SCAN", "0", "match", "a*", "count", "1").Err()) util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "count", "1", "match", "a*", "hello").Err(), ".*syntax error.*") util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "count", "1", "match", "a*", "hello", "hi").Err(), ".*syntax error.*") + + util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "match", "[").Err(), ".*Invalid glob pattern.*") + util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "match", "\\").Err(), ".*Invalid glob pattern.*") + util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "match", "[a").Err(), ".*Invalid glob pattern.*") + util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "match", "[a-]").Err(), ".*Invalid glob pattern.*") }) t.Run("SCAN with type args ", func(t *testing.T) { @@ -406,6 +441,8 @@ func scanAll(t testing.TB, rdb *redis.Client, args ...interface{}) (keys []strin keys = append(keys, keyList...) if c == "0" { + slices.Sort(keys) + keys = slices.Compact(keys) return } }