Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(search): implement vector query for sql/redisearch parser & transformer #2450

Merged
merged 12 commits into from
Aug 2, 2024
26 changes: 26 additions & 0 deletions src/common/parse_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,29 @@ StatusOr<T> ParseFloat(const std::string &str) {

return result;
}

// ParseFloatArray parses a string to an array of floating-point number
// e.g. ParseFloatArray("1,2,3") -> {1.0, 2.0, 3.0}
template <typename T = double>
StatusOr<std::vector<T>> ParseFloatArray(const std::string &str) {
std::vector<T> result;
const char *current = str.c_str();
const char *end = current + str.size();

while (current < end) {
auto [value, next_pos] = GET_OR_RET(TryParseFloat<T>(current));

result.push_back(value);

if (next_pos < end && *next_pos == ',') {
next_pos++;
}
current = next_pos;

while (current < end && std::isspace(*current)) {
current++;
}
}

return result;
}
48 changes: 48 additions & 0 deletions src/search/common_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,54 @@ struct TreeTransformer {

return result;
}

static StatusOr<std::vector<char>> Binary2Chars(std::string_view str) {
std::vector<char> data;
size_t i = 0;

auto hex_char_to_binary = [](char c) -> StatusOr<char> {
if (c >= '0' && c <= '9') return c - '0';
if (c >= 'A' && c <= 'F') return 10 + (c - 'A');
if (c >= 'a' && c <= 'f') return 10 + (c - 'a');
return {Status::NotOK, "invalid hexadecimal character"};
};

while (i + 3 < str.size()) {
if (str[i] == '\\' && str[i + 1] == 'x') {
auto high = GET_OR_RET(hex_char_to_binary(str[i + 2]));
auto low = GET_OR_RET(hex_char_to_binary(str[i + 3]));
data.push_back(static_cast<char>((high << 4) | low));
i += 4;
} else {
data.push_back(str[i]);
i++;
}
}

if (i != str.size()) {
return {Status::NotOK, "input string does not align with expected length"};
}

return data;
}

template <typename T>
StatusOr<std::vector<T>> CharsToVector(const std::vector<char>& data) {
if (data.size() % sizeof(T) != 0) {
return {Status::NotOK, "Data size is not a multiple of the target type size"};
}

std::vector<T> converted_data;
converted_data.reserve(data.size() / sizeof(T));

for (size_t i = 0; i < data.size(); i += sizeof(T)) {
T value;
std::memcpy(&value, &data[i], sizeof(T));
converted_data.push_back(value);
}

return converted_data;
}
};

} // namespace kqir
68 changes: 68 additions & 0 deletions src/search/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,63 @@ struct NumericCompareExpr : BoolAtomExpr {
}
};

struct VectorLiteral : Literal {
std::vector<double> values;

explicit VectorLiteral(std::vector<double> &&values) : values(std::move(values)){};

std::string_view Name() const override { return "VectorLiteral"; }
std::string Dump() const override {
return fmt::format("[{}]", util::StringJoin(values, [](auto v) { return std::to_string(v); }));
}
std::string Content() const override { return Dump(); }

std::unique_ptr<Node> Clone() const override { return std::make_unique<VectorLiteral>(*this); }
};

struct VectorRangeExpr : BoolAtomExpr {
std::unique_ptr<FieldRef> field;
std::unique_ptr<NumericLiteral> range;
std::unique_ptr<VectorLiteral> vector;

VectorRangeExpr(std::unique_ptr<FieldRef> &&field, std::unique_ptr<NumericLiteral> &&range,
std::unique_ptr<VectorLiteral> &&vector)
: field(std::move(field)), range(std::move(range)), vector(std::move(vector)) {}

std::string_view Name() const override { return "VectorRangeExpr"; }
std::string Dump() const override {
return fmt::format("{} vector_range {} {}", field->Dump(), range->Dump(), vector->Dump());
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved
}

std::unique_ptr<Node> Clone() const override {
return std::make_unique<VectorRangeExpr>(Node::MustAs<FieldRef>(field->Clone()),
Node::MustAs<NumericLiteral>(range->Clone()),
Node::MustAs<VectorLiteral>(vector->Clone()));
}
};

struct VectorSearchExpr : BoolAtomExpr {
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved
// TODO: Support pre-filter for hybrid query
std::unique_ptr<FieldRef> field;
std::unique_ptr<NumericLiteral> k;
std::unique_ptr<VectorLiteral> vector;

VectorSearchExpr(std::unique_ptr<FieldRef> &&field, std::unique_ptr<NumericLiteral> &&k,
std::unique_ptr<VectorLiteral> &&vector)
: field(std::move(field)), k(std::move(k)), vector(std::move(vector)) {}

std::string_view Name() const override { return "VectorSearchExpr"; }
std::string Dump() const override {
return fmt::format("{} vector_search {} {}", field->Dump(), k->Dump(), vector->Dump());
}

std::unique_ptr<Node> Clone() const override {
return std::make_unique<VectorRangeExpr>(Node::MustAs<FieldRef>(field->Clone()),
Node::MustAs<NumericLiteral>(k->Clone()),
Node::MustAs<VectorLiteral>(vector->Clone()));
}
};

struct BoolLiteral : BoolAtomExpr, Literal {
bool val;

Expand Down Expand Up @@ -336,15 +393,22 @@ struct LimitClause : Node {
std::string Content() const override { return fmt::format("{}, {}", offset, count); }

std::unique_ptr<Node> Clone() const override { return std::make_unique<LimitClause>(*this); }
size_t Offset() const { return offset; }

size_t Count() const { return count; }
};

struct SortByClause : Node {
enum Order { ASC, DESC } order = ASC;
std::unique_ptr<FieldRef> field;
std::unique_ptr<VectorLiteral> vector = nullptr;

SortByClause(Order order, std::unique_ptr<FieldRef> &&field) : order(order), field(std::move(field)) {}
SortByClause(std::unique_ptr<FieldRef> &&field, std::unique_ptr<VectorLiteral> &&vector)
: field(std::move(field)), vector(std::move(vector)) {}

static constexpr const char *OrderToString(Order order) { return order == ASC ? "asc" : "desc"; }
bool IsKnn() const { return vector != nullptr; }
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved

std::string_view Name() const override { return "SortByClause"; }
std::string Dump() const override { return fmt::format("sortby {}, {}", field->Dump(), OrderToString(order)); }
Expand All @@ -356,6 +420,10 @@ struct SortByClause : Node {
std::unique_ptr<Node> Clone() const override {
return std::make_unique<SortByClause>(order, Node::MustAs<FieldRef>(field->Clone()));
}

std::unique_ptr<FieldRef> GetFieldRef() { return std::move(field); }

std::unique_ptr<VectorLiteral> GetVectorLiteral() { return std::move(vector); }
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved
};

struct SelectClause : Node {
Expand Down
16 changes: 13 additions & 3 deletions src/search/redis_query_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ namespace redis_query {

using namespace peg;

struct VectorRangeKey : string<'V', 'E', 'C', 'T', 'O', 'R', '_', 'R', 'A', 'N', 'G', 'E'> {};
struct KnnKey : string<'K', 'N', 'N'> {};
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved
struct ArrowOp : string<'=', '>'> {};
struct Wildcard : one<'*'> {};

struct Field : seq<one<'@'>, Identifier> {};

struct Param : seq<one<'$'>, Identifier> {};
Expand All @@ -44,9 +49,10 @@ struct ExclusiveNumber : seq<one<'('>, NumberOrParam> {};
struct NumericRangePart : sor<Inf, ExclusiveNumber, NumberOrParam> {};
struct NumericRange : seq<one<'['>, WSPad<NumericRangePart>, WSPad<NumericRangePart>, one<']'>> {};

struct FieldQuery : seq<WSPad<Field>, one<':'>, WSPad<sor<TagList, NumericRange>>> {};
struct KnnSearch : seq<one<'['>, WSPad<KnnKey>, WSPad<NumberOrParam>, WSPad<Field>, WSPad<Param>, one<']'>> {};
struct VectorRange : seq<one<'['>, WSPad<VectorRangeKey>, WSPad<NumberOrParam>, WSPad<Param>, one<']'>> {};

struct Wildcard : one<'*'> {};
struct FieldQuery : seq<WSPad<Field>, one<':'>, WSPad<sor<VectorRange, TagList, NumericRange>>> {};

struct QueryExpr;

Expand All @@ -64,7 +70,11 @@ struct AndExprP : sor<AndExpr, BooleanExpr> {};
struct OrExpr : seq<AndExprP, plus<seq<one<'|'>, AndExprP>>> {};
struct OrExprP : sor<OrExpr, AndExprP> {};

struct QueryExpr : seq<OrExprP> {};
struct PrefilterExpr : seq<WSPad<BooleanExpr>, ArrowOp, WSPad<KnnSearch>> {};

struct QueryP : sor<PrefilterExpr, OrExprP> {};

struct QueryExpr : seq<QueryP> {};

} // namespace redis_query

Expand Down
65 changes: 46 additions & 19 deletions src/search/redis_query_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ namespace redis_query {
namespace ir = kqir;

template <typename Rule>
using TreeSelector =
parse_tree::selector<Rule, parse_tree::store_content::on<Number, StringL, Param, Identifier, Inf>,
parse_tree::remove_content::on<TagList, NumericRange, ExclusiveNumber, FieldQuery, NotExpr,
AndExpr, OrExpr, Wildcard>>;
using TreeSelector = parse_tree::selector<
Rule, parse_tree::store_content::on<Number, StringL, Param, Identifier, Inf>,
parse_tree::remove_content::on<TagList, NumericRange, VectorRange, ExclusiveNumber, FieldQuery, NotExpr, AndExpr,
OrExpr, PrefilterExpr, KnnSearch, Wildcard, VectorRangeKey, KnnKey, ArrowOp>>;

template <typename Input>
StatusOr<std::unique_ptr<parse_tree::node>> ParseToTree(Input&& in) {
Expand All @@ -53,7 +53,32 @@ StatusOr<std::unique_ptr<parse_tree::node>> ParseToTree(Input&& in) {
struct Transformer : ir::TreeTransformer {
explicit Transformer(const ParamMap& param_map) : TreeTransformer(param_map) {}

StatusOr<std::unique_ptr<VectorLiteral>> Transform2Vector(const TreeNode& node) {
std::string vector_str = GET_OR_RET(GetParam(node));

auto vector_chars = GET_OR_RET(Binary2Chars(vector_str));
std::vector<double> values = GET_OR_RET(CharsToVector<double>(vector_chars));
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved
if (values.empty()) {
return {Status::NotOK, "empty vector is invalid"};
}
return std::make_unique<ir::VectorLiteral>(std::move(values));
};

auto Transform(const TreeNode& node) -> StatusOr<std::unique_ptr<Node>> {
auto number_or_param = [this](const TreeNode& node) -> StatusOr<std::unique_ptr<NumericLiteral>> {
if (Is<Number>(node)) {
return Node::MustAs<ir::NumericLiteral>(GET_OR_RET(Transform(node)));
} else if (Is<Param>(node)) {
auto val = GET_OR_RET(ParseFloat(GET_OR_RET(GetParam(node)))
.Prefixed(fmt::format("parameter {} is not a number", node->string_view())));

return std::make_unique<ir::NumericLiteral>(val);
} else {
return {Status::NotOK,
fmt::format("expected a number or a parameter in numeric comparison but got {}", node->type)};
}
};

if (Is<Number>(node)) {
return Node::Create<ir::NumericLiteral>(*ParseFloat(node->string()));
} else if (Is<Wildcard>(node)) {
Expand Down Expand Up @@ -88,26 +113,12 @@ struct Transformer : ir::TreeTransformer {
} else {
return std::make_unique<ir::OrExpr>(std::move(exprs));
}
} else { // NumericRange
} else if (Is<NumericRange>(query)) {
std::vector<std::unique_ptr<ir::QueryExpr>> exprs;

const auto& lhs = query->children[0];
const auto& rhs = query->children[1];

auto number_or_param = [this](const TreeNode& node) -> StatusOr<std::unique_ptr<NumericLiteral>> {
if (Is<Number>(node)) {
return Node::MustAs<ir::NumericLiteral>(GET_OR_RET(Transform(node)));
} else if (Is<Param>(node)) {
auto val = GET_OR_RET(ParseFloat(GET_OR_RET(GetParam(node)))
.Prefixed(fmt::format("parameter {} is not a number", node->string_view())));

return std::make_unique<ir::NumericLiteral>(val);
} else {
return {Status::NotOK,
fmt::format("expected a number or a parameter in numeric comparison but got {}", node->type)};
}
};

if (Is<ExclusiveNumber>(lhs)) {
exprs.push_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::GT,
std::make_unique<FieldRef>(field),
Expand Down Expand Up @@ -141,11 +152,27 @@ struct Transformer : ir::TreeTransformer {
} else {
return std::make_unique<ir::AndExpr>(std::move(exprs));
}
} else if (Is<VectorRange>(query)) {
return std::make_unique<VectorRangeExpr>(std::make_unique<FieldRef>(field),
GET_OR_RET(number_or_param(query->children[1])),
GET_OR_RET(Transform2Vector(query->children[2])));
}
} else if (Is<NotExpr>(node)) {
CHECK(node->children.size() == 1);

return Node::Create<ir::NotExpr>(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(node->children[0]))));
} else if (Is<PrefilterExpr>(node)) {
CHECK(node->children.size() == 3);

// TODO(Beihao): Support Hybrid Query
// const auto& prefilter = node->children[0];
const auto& knn_search = node->children[2];
CHECK(knn_search->children.size() == 4);

return std::make_unique<VectorSearchExpr>(std::make_unique<FieldRef>(knn_search->children[2]->string()),
GET_OR_RET(number_or_param(knn_search->children[1])),
GET_OR_RET(Transform2Vector(knn_search->children[3])));

} else if (Is<AndExpr>(node)) {
std::vector<std::unique_ptr<ir::QueryExpr>> exprs;

Expand Down
9 changes: 7 additions & 2 deletions src/search/sql_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ struct NumericAtomExpr : WSPad<sor<NumberOrParam, Identifier>> {};
struct NumericCompareOp : sor<string<'!', '='>, string<'<', '='>, string<'>', '='>, one<'=', '<', '>'>> {};
struct NumericCompareExpr : seq<NumericAtomExpr, NumericCompareOp, NumericAtomExpr> {};

struct BooleanAtomExpr : sor<HasTagExpr, NumericCompareExpr, WSPad<Boolean>> {};
struct VectorCompareOp : sor<string<'<', '-', '>'>> {};
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved
struct VectorCompareExpr : seq<WSPad<Identifier>, VectorCompareOp, WSPad<StringOrParam>> {};
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved
struct VectorRangeExpr : seq<VectorCompareExpr, one<'<'>, WSPad<NumberOrParam>> {};

struct BooleanAtomExpr : sor<HasTagExpr, NumericCompareExpr, VectorRangeExpr, WSPad<Boolean>> {};

struct QueryExpr;

Expand Down Expand Up @@ -84,7 +88,8 @@ struct Limit : string<'l', 'i', 'm', 'i', 't'> {};

struct WhereClause : seq<Where, QueryExpr> {};
struct AscOrDesc : sor<Asc, Desc> {};
struct OrderByClause : seq<OrderBy, WSPad<Identifier>, opt<WSPad<AscOrDesc>>> {};
struct OrderByExpr : sor<WSPad<VectorCompareExpr>, seq<WSPad<Identifier>, opt<WSPad<AscOrDesc>>>> {};
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved
struct OrderByClause : seq<OrderBy, OrderByExpr> {};
struct LimitClause : seq<Limit, opt<seq<WSPad<UnsignedInteger>, one<','>>>, WSPad<UnsignedInteger>> {};

struct SearchStmt
Expand Down
Loading
Loading