diff --git a/src/search/common_transformer.h b/src/search/common_transformer.h index 5ebbcff6eb4..18b2626d7ba 100644 --- a/src/search/common_transformer.h +++ b/src/search/common_transformer.h @@ -105,6 +105,26 @@ struct TreeTransformer { return result; } + + template + static StatusOr> Binary2Vector(std::string_view str) { + if (str.size() % sizeof(T) != 0) { + return {Status::NotOK, "data size is not a multiple of the target type size"}; + } + + std::vector values; + const size_t type_size = sizeof(T); + values.reserve(str.size() / type_size); + + while (!str.empty()) { + T value; + memcpy(&value, str.data(), type_size); + values.push_back(value); + str.remove_prefix(type_size); + } + + return values; + } }; } // namespace kqir diff --git a/src/search/ir.h b/src/search/ir.h index 116fe9a7eaa..3ba980dab4e 100644 --- a/src/search/ir.h +++ b/src/search/ir.h @@ -229,6 +229,63 @@ struct NumericCompareExpr : BoolAtomExpr { } }; +struct VectorLiteral : Literal { + std::vector values; + + explicit VectorLiteral(std::vector &&values) : values(std::move(values)){}; + + std::string_view Name() const override { return "VectorLiteral"; } + std::string Dump() const override { + return fmt::format("[{}]", util::StringJoin(values, [](auto v) { return std::to_string(v); })); + } + std::string Content() const override { return Dump(); } + + std::unique_ptr Clone() const override { return std::make_unique(*this); } +}; + +struct VectorRangeExpr : BoolAtomExpr { + std::unique_ptr field; + std::unique_ptr range; + std::unique_ptr vector; + + VectorRangeExpr(std::unique_ptr &&field, std::unique_ptr &&range, + std::unique_ptr &&vector) + : field(std::move(field)), range(std::move(range)), vector(std::move(vector)) {} + + std::string_view Name() const override { return "VectorRangeExpr"; } + std::string Dump() const override { + return fmt::format("{} <-> {} < {}", field->Dump(), vector->Dump(), range->Dump()); + } + + std::unique_ptr Clone() const override { + return std::make_unique(Node::MustAs(field->Clone()), + Node::MustAs(range->Clone()), + Node::MustAs(vector->Clone())); + } +}; + +struct VectorKnnExpr : BoolAtomExpr { + // TODO: Support pre-filter for hybrid query + std::unique_ptr field; + std::unique_ptr k; + std::unique_ptr vector; + + VectorKnnExpr(std::unique_ptr &&field, std::unique_ptr &&k, + std::unique_ptr &&vector) + : field(std::move(field)), k(std::move(k)), vector(std::move(vector)) {} + + std::string_view Name() const override { return "VectorKnnExpr"; } + std::string Dump() const override { + return fmt::format("KNN k={}, {} <-> {}", k->Dump(), field->Dump(), vector->Dump()); + } + + std::unique_ptr Clone() const override { + return std::make_unique(Node::MustAs(field->Clone()), + Node::MustAs(k->Clone()), + Node::MustAs(vector->Clone())); + } +}; + struct BoolLiteral : BoolAtomExpr, Literal { bool val; @@ -336,18 +393,30 @@ struct LimitClause : Node { std::string Content() const override { return fmt::format("{}, {}", offset, count); } std::unique_ptr Clone() const override { return std::make_unique(*this); } + size_t Offset() const { return offset; } + + size_t Count() const { return count; } }; struct SortByClause : Node { enum Order { ASC, DESC } order = ASC; std::unique_ptr field; + std::unique_ptr vector = nullptr; SortByClause(Order order, std::unique_ptr &&field) : order(order), field(std::move(field)) {} + SortByClause(std::unique_ptr &&field, std::unique_ptr &&vector) + : field(std::move(field)), vector(std::move(vector)) {} static constexpr const char *OrderToString(Order order) { return order == ASC ? "asc" : "desc"; } + bool IsVectorField() const { return vector != nullptr; } std::string_view Name() const override { return "SortByClause"; } - std::string Dump() const override { return fmt::format("sortby {}, {}", field->Dump(), OrderToString(order)); } + std::string Dump() const override { + if (!IsVectorField()) { + return fmt::format("sortby {}, {}", field->Dump(), OrderToString(order)); + } + return fmt::format("sortby {} <-> {}", field->Dump(), vector->Dump()); + } std::string Content() const override { return OrderToString(order); } NodeIterator ChildBegin() override { return NodeIterator(field.get()); }; diff --git a/src/search/ir_sema_checker.h b/src/search/ir_sema_checker.h index a7a7618173d..43d722b4d0b 100644 --- a/src/search/ir_sema_checker.h +++ b/src/search/ir_sema_checker.h @@ -50,6 +50,9 @@ struct SemaChecker { GET_OR_RET(Check(v->query_expr.get())); if (v->limit) GET_OR_RET(Check(v->limit.get())); if (v->sort_by) GET_OR_RET(Check(v->sort_by.get())); + if (v->sort_by && v->sort_by->IsVectorField() && !v->limit) { + return {Status::NotOK, "expect a LIMIT clause for vector field to construct a KNN search"}; + } } else { return {Status::NotOK, fmt::format("index `{}` not found", index_name)}; } @@ -60,8 +63,25 @@ struct SemaChecker { return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)}; } else if (!iter->second.IsSortable()) { return {Status::NotOK, fmt::format("field `{}` is not sortable", v->field->name)}; + } else if (auto is_vector = iter->second.MetadataAs() != nullptr; + is_vector != v->IsVectorField()) { + std::string not_str = is_vector ? "" : "not "; + return {Status::NotOK, + fmt::format("field `{}` is {}a vector field according to metadata and does {}expect a vector parameter", + v->field->name, not_str, not_str)}; } else { v->field->info = &iter->second; + if (v->IsVectorField()) { + auto meta = v->field->info->MetadataAs(); + if (!v->field->info->HasIndex()) { + return {Status::NotOK, + fmt::format("field `{}` is marked as NOINDEX and cannot be used for KNN search", v->field->name)}; + } + if (v->vector->values.size() != meta->dim) { + return {Status::NotOK, + fmt::format("vector should be of size `{}` for field `{}`", meta->dim, v->field->name)}; + } + } } } else if (auto v = dynamic_cast(node)) { for (const auto &n : v->inners) { @@ -97,6 +117,49 @@ struct SemaChecker { } else { v->field->info = &iter->second; } + } else if (auto v = dynamic_cast(node)) { + if (auto iter = current_index->fields.find(v->field->name); iter == current_index->fields.end()) { + return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)}; + } else if (!iter->second.MetadataAs()) { + return {Status::NotOK, fmt::format("field `{}` is not a vector field", v->field->name)}; + } else { + v->field->info = &iter->second; + + if (!v->field->info->HasIndex()) { + return {Status::NotOK, + fmt::format("field `{}` is marked as NOINDEX and cannot be used for KNN search", v->field->name)}; + } + if (v->k->val <= 0) { + return {Status::NotOK, fmt::format("KNN search parameter `k` must be greater than 0")}; + } + auto meta = v->field->info->MetadataAs(); + if (v->vector->values.size() != meta->dim) { + return {Status::NotOK, + fmt::format("vector should be of size `{}` for field `{}`", meta->dim, v->field->name)}; + } + } + } else if (auto v = dynamic_cast(node)) { + if (auto iter = current_index->fields.find(v->field->name); iter == current_index->fields.end()) { + return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)}; + } else if (!iter->second.MetadataAs()) { + return {Status::NotOK, fmt::format("field `{}` is not a vector field", v->field->name)}; + } else { + v->field->info = &iter->second; + + auto meta = v->field->info->MetadataAs(); + if (meta->distance_metric == redis::DistanceMetric::L2 && v->range->val < 0) { + return {Status::NotOK, "range cannot be a negative number for l2 distance metric"}; + } + + if (meta->distance_metric == redis::DistanceMetric::COSINE && (v->range->val < 0 || v->range->val > 2)) { + return {Status::NotOK, "range has to be between 0 and 2 for cosine distance metric"}; + } + + if (v->vector->values.size() != meta->dim) { + return {Status::NotOK, + fmt::format("vector should be of size `{}` for field `{}`", meta->dim, v->field->name)}; + } + } } else if (auto v = dynamic_cast(node)) { for (const auto &n : v->fields) { if (auto iter = current_index->fields.find(n->name); iter == current_index->fields.end()) { diff --git a/src/search/redis_query_parser.h b/src/search/redis_query_parser.h index 5fe03046ada..5b0f172c763 100644 --- a/src/search/redis_query_parser.h +++ b/src/search/redis_query_parser.h @@ -30,6 +30,11 @@ namespace redis_query { using namespace peg; +struct VectorRangeToken : string<'V', 'E', 'C', 'T', 'O', 'R', '_', 'R', 'A', 'N', 'G', 'E'> {}; +struct KnnToken : string<'K', 'N', 'N'> {}; +struct ArrowOp : string<'=', '>'> {}; +struct Wildcard : one<'*'> {}; + struct Field : seq, Identifier> {}; struct Param : seq, Identifier> {}; @@ -44,9 +49,10 @@ struct ExclusiveNumber : seq, NumberOrParam> {}; struct NumericRangePart : sor {}; struct NumericRange : seq, WSPad, WSPad, one<']'>> {}; -struct FieldQuery : seq, one<':'>, WSPad>> {}; +struct KnnSearch : seq, WSPad, WSPad, WSPad, WSPad, one<']'>> {}; +struct VectorRange : seq, WSPad, WSPad, WSPad, one<']'>> {}; -struct Wildcard : one<'*'> {}; +struct FieldQuery : seq, one<':'>, WSPad>> {}; struct QueryExpr; @@ -64,7 +70,11 @@ struct AndExprP : sor {}; struct OrExpr : seq, AndExprP>>> {}; struct OrExprP : sor {}; -struct QueryExpr : seq {}; +struct PrefilterExpr : seq, ArrowOp, WSPad> {}; + +struct QueryP : sor {}; + +struct QueryExpr : seq {}; } // namespace redis_query diff --git a/src/search/redis_query_transformer.h b/src/search/redis_query_transformer.h index 6ff1581bc3e..c81230e4ebf 100644 --- a/src/search/redis_query_transformer.h +++ b/src/search/redis_query_transformer.h @@ -35,10 +35,10 @@ namespace redis_query { namespace ir = kqir; template -using TreeSelector = - parse_tree::selector, - parse_tree::remove_content::on>; +using TreeSelector = parse_tree::selector< + Rule, parse_tree::store_content::on, + parse_tree::remove_content::on>; template StatusOr> ParseToTree(Input&& in) { @@ -53,7 +53,31 @@ StatusOr> ParseToTree(Input&& in) { struct Transformer : ir::TreeTransformer { explicit Transformer(const ParamMap& param_map) : TreeTransformer(param_map) {} + StatusOr> Transform2Vector(const TreeNode& node) { + std::string vector_str = GET_OR_RET(GetParam(node)); + + std::vector values = GET_OR_RET(Binary2Vector(vector_str)); + if (values.empty()) { + return {Status::NotOK, "empty vector is invalid"}; + } + return std::make_unique(std::move(values)); + }; + auto Transform(const TreeNode& node) -> StatusOr> { + auto number_or_param = [this](const TreeNode& node) -> StatusOr> { + if (Is(node)) { + return Node::MustAs(GET_OR_RET(Transform(node))); + } else if (Is(node)) { + auto val = GET_OR_RET(ParseFloat(GET_OR_RET(GetParam(node))) + .Prefixed(fmt::format("parameter {} is not a number", node->string_view()))); + + return std::make_unique(val); + } else { + return {Status::NotOK, + fmt::format("expected a number or a parameter in numeric comparison but got {}", node->type)}; + } + }; + if (Is(node)) { return Node::Create(*ParseFloat(node->string())); } else if (Is(node)) { @@ -88,26 +112,12 @@ struct Transformer : ir::TreeTransformer { } else { return std::make_unique(std::move(exprs)); } - } else { // NumericRange + } else if (Is(query)) { std::vector> exprs; const auto& lhs = query->children[0]; const auto& rhs = query->children[1]; - auto number_or_param = [this](const TreeNode& node) -> StatusOr> { - if (Is(node)) { - return Node::MustAs(GET_OR_RET(Transform(node))); - } else if (Is(node)) { - auto val = GET_OR_RET(ParseFloat(GET_OR_RET(GetParam(node))) - .Prefixed(fmt::format("parameter {} is not a number", node->string_view()))); - - return std::make_unique(val); - } else { - return {Status::NotOK, - fmt::format("expected a number or a parameter in numeric comparison but got {}", node->type)}; - } - }; - if (Is(lhs)) { exprs.push_back(std::make_unique(NumericCompareExpr::GT, std::make_unique(field), @@ -141,11 +151,27 @@ struct Transformer : ir::TreeTransformer { } else { return std::make_unique(std::move(exprs)); } + } else if (Is(query)) { + return std::make_unique(std::make_unique(field), + GET_OR_RET(number_or_param(query->children[1])), + GET_OR_RET(Transform2Vector(query->children[2]))); } } else if (Is(node)) { CHECK(node->children.size() == 1); return Node::Create(Node::MustAs(GET_OR_RET(Transform(node->children[0])))); + } else if (Is(node)) { + CHECK(node->children.size() == 3); + + // TODO(Beihao): Support Hybrid Query + // const auto& prefilter = node->children[0]; + const auto& knn_search = node->children[2]; + CHECK(knn_search->children.size() == 4); + + return std::make_unique(std::make_unique(knn_search->children[2]->string()), + GET_OR_RET(number_or_param(knn_search->children[1])), + GET_OR_RET(Transform2Vector(knn_search->children[3]))); + } else if (Is(node)) { std::vector> exprs; diff --git a/src/search/search_encoding.h b/src/search/search_encoding.h index 2fbbde8c21e..26b442ca32c 100644 --- a/src/search/search_encoding.h +++ b/src/search/search_encoding.h @@ -373,6 +373,8 @@ struct HnswVectorFieldMetadata : IndexFieldMetadata { HnswVectorFieldMetadata() : IndexFieldMetadata(IndexFieldType::VECTOR) {} + bool IsSortable() const override { return true; } + void Encode(std::string *dst) const override { IndexFieldMetadata::Encode(dst); PutFixed8(dst, uint8_t(vector_type)); diff --git a/src/search/sql_parser.h b/src/search/sql_parser.h index 22b985fd948..751b0b47aa2 100644 --- a/src/search/sql_parser.h +++ b/src/search/sql_parser.h @@ -41,7 +41,12 @@ struct NumericAtomExpr : WSPad> {}; struct NumericCompareOp : sor', '='>, one<'=', '<', '>'>> {}; struct NumericCompareExpr : seq {}; -struct BooleanAtomExpr : sor> {}; +struct VectorCompareOp : string<'<', '-', '>'> {}; +struct VectorLiteral : seq>, Number, star>>, Number>, WSPad>> {}; +struct VectorCompareExpr : seq, VectorCompareOp, WSPad> {}; +struct VectorRangeExpr : seq, WSPad> {}; + +struct BooleanAtomExpr : sor> {}; struct QueryExpr; @@ -84,7 +89,9 @@ struct Limit : string<'l', 'i', 'm', 'i', 't'> {}; struct WhereClause : seq {}; struct AscOrDesc : sor {}; -struct OrderByClause : seq, opt>> {}; +struct SortableFieldExpr : seq, opt> {}; +struct OrderByExpr : sor, WSPad> {}; +struct OrderByClause : seq {}; struct LimitClause : seq, one<','>>>, WSPad> {}; struct SearchStmt diff --git a/src/search/sql_transformer.h b/src/search/sql_transformer.h index d2ed8c21677..01705107776 100644 --- a/src/search/sql_transformer.h +++ b/src/search/sql_transformer.h @@ -41,8 +41,9 @@ using TreeSelector = parse_tree::selector< Rule, parse_tree::store_content::on, - parse_tree::remove_content::on>; + parse_tree::remove_content::on>; template StatusOr> ParseToTree(Input&& in) { @@ -58,12 +59,32 @@ struct Transformer : ir::TreeTransformer { explicit Transformer(const ParamMap& param_map) : TreeTransformer(param_map) {} auto Transform(const TreeNode& node) -> StatusOr> { + auto number_or_param = [this](const TreeNode& node) -> StatusOr> { + if (Is(node)) { + return Node::MustAs(GET_OR_RET(Transform(node))); + } else if (Is(node)) { + auto val = GET_OR_RET(ParseFloat(GET_OR_RET(GetParam(node))) + .Prefixed(fmt::format("parameter {} is not a number", node->string_view()))); + + return std::make_unique(val); + } else { + return {Status::NotOK, + fmt::format("expected a number or a parameter in numeric comparison but got {}", node->type)}; + } + }; + if (Is(node)) { return Node::Create(node->string_view() == "true"); } else if (Is(node)) { return Node::Create(*ParseFloat(node->string())); } else if (Is(node)) { return Node::Create(GET_OR_RET(UnescapeString(node->string_view()))); + } else if (Is(node)) { + std::vector values; + for (const auto& child : node->children) { + values.push_back(*ParseFloat(child->string())); + } + return Node::Create(std::move(values)); } else if (Is(node)) { CHECK(node->children.size() == 2); @@ -85,20 +106,6 @@ struct Transformer : ir::TreeTransformer { const auto& lhs = node->children[0]; const auto& rhs = node->children[2]; - auto number_or_param = [this](const TreeNode& node) -> StatusOr> { - if (Is(node)) { - return Node::MustAs(GET_OR_RET(Transform(node))); - } else if (Is(node)) { - auto val = GET_OR_RET(ParseFloat(GET_OR_RET(GetParam(node))) - .Prefixed(fmt::format("parameter {} is not a number", node->string_view()))); - - return std::make_unique(val); - } else { - return {Status::NotOK, - fmt::format("expected a number or a parameter in numeric comparison but got {}", node->type)}; - } - }; - auto op = ir::NumericCompareExpr::FromOperator(node->children[1]->string_view()).value(); if (Is(lhs) && (Is(rhs) || Is(rhs))) { return Node::Create(op, std::make_unique(lhs->string()), @@ -110,6 +117,16 @@ struct Transformer : ir::TreeTransformer { } else { return {Status::NotOK, "the left and right side of numeric comparison should be an identifier and a number"}; } + } else if (Is(node)) { + // TODO(Beihao): Handle distance metrics for operator + CHECK(node->children.size() == 2); + const auto& vector_comp_expr = node->children[0]; + CHECK(vector_comp_expr->children.size() == 3); + + return Node::Create( + std::make_unique(vector_comp_expr->children[0]->string()), + GET_OR_RET(number_or_param(node->children[1])), + Node::MustAs(GET_OR_RET(Transform(vector_comp_expr->children[2])))); } else if (Is(node)) { CHECK(node->children.size() == 1); @@ -161,15 +178,24 @@ struct Transformer : ir::TreeTransformer { return Node::Create(offset, count); } else if (Is(node)) { - CHECK(node->children.size() == 1 || node->children.size() == 2); - - auto field = std::make_unique(node->children[0]->string()); - auto order = SortByClause::Order::ASC; - if (node->children.size() == 2 && node->children[1]->string_view() == "desc") { - order = SortByClause::Order::DESC; + CHECK(node->children.size() == 1); + const auto& order_by_expr = node->children[0]; + CHECK(order_by_expr->children.size() == 1 || order_by_expr->children.size() == 2); + + if (Is(order_by_expr->children[0])) { + const auto& vector_compare_expr = order_by_expr->children[0]; + CHECK(vector_compare_expr->children.size() == 3); + auto field = std::make_unique(vector_compare_expr->children[0]->string()); + return Node::Create( + std::move(field), Node::MustAs(GET_OR_RET(Transform(vector_compare_expr->children[2])))); + } else { + auto field = std::make_unique(order_by_expr->children[0]->string()); + auto order = SortByClause::Order::ASC; + if (order_by_expr->children.size() == 2 && order_by_expr->children[1]->string_view() == "desc") { + order = SortByClause::Order::DESC; + } + return Node::Create(order, std::move(field)); } - - return Node::Create(order, std::move(field)); } else if (Is(node)) { // root node CHECK(node->children.size() >= 2 && node->children.size() <= 5); diff --git a/tests/cppunit/ir_sema_checker_test.cc b/tests/cppunit/ir_sema_checker_test.cc index 3a15dde725c..df8076ce107 100644 --- a/tests/cppunit/ir_sema_checker_test.cc +++ b/tests/cppunit/ir_sema_checker_test.cc @@ -38,10 +38,26 @@ static IndexMap MakeIndexMap() { auto f1 = FieldInfo("f1", std::make_unique()); auto f2 = FieldInfo("f2", std::make_unique()); auto f3 = FieldInfo("f3", std::make_unique()); + + auto hnsw_field_meta = std::make_unique(); + hnsw_field_meta->vector_type = redis::VectorType::FLOAT64; + hnsw_field_meta->dim = 3; + hnsw_field_meta->distance_metric = redis::DistanceMetric::L2; + auto f4 = FieldInfo("f4", std::move(hnsw_field_meta)); + + hnsw_field_meta = std::make_unique(); + hnsw_field_meta->vector_type = redis::VectorType::FLOAT64; + hnsw_field_meta->dim = 3; + hnsw_field_meta->distance_metric = redis::DistanceMetric::COSINE; + auto f5 = FieldInfo("f5", std::move(hnsw_field_meta)); + f5.metadata->noindex = true; + auto ia = std::make_unique("ia", redis::IndexMetadata(), ""); ia->Add(std::move(f1)); ia->Add(std::move(f2)); ia->Add(std::move(f3)); + ia->Add(std::move(f4)); + ia->Add(std::move(f5)); IndexMap res; res.Insert(std::move(ia)); @@ -68,6 +84,25 @@ TEST(SemaCheckerTest, Simple) { ASSERT_EQ(checker.Check(Parse("select f1 from ia where f1 hastag \",\"")->get()).Msg(), "tag cannot contain the separator `,`"); ASSERT_EQ(checker.Check(Parse("select f1 from ia order by a")->get()).Msg(), "field `a` not found in index `ia`"); + ASSERT_EQ(checker.Check(Parse("select f4 from ia order by f4 <-> [3.6,4.7] limit 5")->get()).Msg(), + "vector should be of size `3` for field `f4`"); + ASSERT_EQ(checker.Check(Parse("select f4 from ia where f4 <-> [3.6,4.7] < 5")->get()).Msg(), + "vector should be of size `3` for field `f4`"); + ASSERT_EQ(checker.Check(Parse("select f4 from ia where f4 <-> [3.6,4.7,5.6] < -5")->get()).Msg(), + "range cannot be a negative number for l2 distance metric"); + ASSERT_EQ(checker.Check(Parse("select f4 from ia order by f4 limit 5")->get()).Msg(), + "field `f4` is a vector field according to metadata and does expect a vector parameter"); + ASSERT_EQ(checker.Check(Parse("select f4 from ia order by f1 <-> [3.6,4.7,5.6] limit 5")->get()).Msg(), + "field `f1` is not sortable"); + ASSERT_EQ(checker.Check(Parse("select f4 from ia order by f2 <-> [3.6,4.7,5.6] limit 5")->get()).Msg(), + "field `f2` is not a vector field according to metadata and does not expect a vector parameter"); + ASSERT_EQ(checker.Check(Parse("select f4 from ia order by f4 <-> [3.6,4.7,5.6]")->get()).Msg(), + "expect a LIMIT clause for vector field to construct a KNN search"); + ASSERT_EQ(checker.Check(Parse("select f5 from ia order by f5 <-> [3.6,4.7,5.6] limit 5")->get()).Msg(), + "field `f5` is marked as NOINDEX and cannot be used for KNN search"); + ASSERT_EQ(checker.Check(Parse("select f5 from ia where f5 <-> [3.6,4.7,5.6] < 5")->get()).Msg(), + "range has to be between 0 and 2 for cosine distance metric"); + ASSERT_EQ(checker.Check(Parse("select f5 from ia where f5 <-> [3.6,4.7,5.6] < 0.5")->get()).Msg(), "ok"); } { diff --git a/tests/cppunit/redis_query_parser_test.cc b/tests/cppunit/redis_query_parser_test.cc index bd66d41a2ab..4fc25e49db2 100644 --- a/tests/cppunit/redis_query_parser_test.cc +++ b/tests/cppunit/redis_query_parser_test.cc @@ -101,3 +101,36 @@ TEST(RedisQueryParserTest, Params) { AssertIR(Parse("@c:{$y} @d:[$zzz inf]", {{"y", "hello"}, {"zzz", "3"}}), "(and c hastag \"hello\", d >= 3)"); ASSERT_EQ(Parse("@c:{$y}", {{"z", "hello"}}).Msg(), "parameter with name `y` not found"); } + +TEST(RedisQueryParserTest, Vector) { + std::vector vec = {1, 2, 3}; + std::string vec_str(reinterpret_cast(vec.data()), vec.size() * sizeof(double)); + + AssertSyntaxError(Parse("@field:[RANGE 10 $vector]", {{"vector", vec_str}})); + AssertSyntaxError(Parse("@field:[VECTOR_RANGE 10 not_param")); + AssertSyntaxError(Parse("@field:[VECTOR_RANGE $vector]", {{"vector", vec_str}})); + AssertSyntaxError(Parse("@field:[VECTOR_RANGE $vector 10]", {{"vector", vec_str}})); + AssertSyntaxError(Parse("* =>[knn 5 @field $BLOB]", {{"BLOB", vec_str}})); + AssertSyntaxError(Parse("* =>[KNN 5 @field not_param]")); + AssertSyntaxError(Parse("KNN 5 @vector $BLOB", {{"BLOB", vec_str}})); + AssertSyntaxError(Parse("[KNN 5 @vector $BLOB]", {{"BLOB", vec_str}})); + AssertSyntaxError(Parse("KNN 5 @vector $BLOB", {{"BLOB", vec_str}})); + AssertSyntaxError(Parse("*=>[KNN 5 $vector_blob_param]", {{"vector_blob_param", vec_str}})); + + AssertIR(Parse("@field:[VECTOR_RANGE 10 $vector]", {{"vector", vec_str}}), + "field <-> [1.000000, 2.000000, 3.000000] < 10"); + AssertIR(Parse("*=>[KNN 10 @doc_embedding $BLOB]", {{"BLOB", vec_str}}), + "KNN k=10, doc_embedding <-> [1.000000, 2.000000, 3.000000]"); + AssertIR(Parse("(*) => [KNN 10 @doc_embedding $BLOB]", {{"BLOB", vec_str}}), + "KNN k=10, doc_embedding <-> [1.000000, 2.000000, 3.000000]"); + AssertIR(Parse("(@a:[1 2]) => [KNN 8 @vec_embedding $blob]", {{"blob", vec_str}}), + "KNN k=8, vec_embedding <-> [1.000000, 2.000000, 3.000000]"); + AssertIR(Parse("* =>[KNN 5 @vector $BLOB]", {{"BLOB", vec_str}}), + "KNN k=5, vector <-> [1.000000, 2.000000, 3.000000]"); + + vec_str = vec_str.substr(0, 3); + ASSERT_EQ(Parse("@field:[VECTOR_RANGE 10 $vector]", {{"vector", vec_str}}).Msg(), + "data size is not a multiple of the target type size"); + vec_str = ""; + ASSERT_EQ(Parse("@field:[VECTOR_RANGE 10 $vector]", {{"vector", vec_str}}).Msg(), "empty vector is invalid"); +} diff --git a/tests/cppunit/sql_parser_test.cc b/tests/cppunit/sql_parser_test.cc index e85368ce135..9173d4c3f3b 100644 --- a/tests/cppunit/sql_parser_test.cc +++ b/tests/cppunit/sql_parser_test.cc @@ -146,3 +146,25 @@ TEST(SQLParserTest, Params) { "select a from b where (and c hastag \"hello\", d = 3)"); ASSERT_EQ(Parse("select a from b where c hastag @y", {{"z", "hello"}}).Msg(), "parameter with name `y` not found"); } + +TEST(SQLParserTest, Vector) { + AssertSyntaxError(Parse("select a from b where embedding <-> [3,1,2]")); + AssertSyntaxError(Parse("select a from b where embedding <-> [3,1,2] <")); + AssertSyntaxError(Parse("select a from b where embedding [3,1,2] < 3")); + AssertSyntaxError(Parse("select a from b where embedding <> [3,1,2] < 4")); + AssertSyntaxError(Parse("select a from b where embedding <- [3,1,2] < 3")); + AssertSyntaxError(Parse("select a from b order by embedding <-> [1,2,3] < 3")); + AssertSyntaxError(Parse("select a from b where embedding <-> [1,2,3] limit 5")); + AssertSyntaxError(Parse("select a from b where [3,1,2] <-> embedding < 5")); + AssertSyntaxError(Parse("select a from b where embedding <-> [] < 5")); + AssertSyntaxError(Parse("select a from b order by embedding <-> @vec limit 5", {{"vec", "[3.6,7.8]"}})); + AssertSyntaxError(Parse("select a from b where embedding <#> [3,1,2] < 5")); + AssertSyntaxError(Parse("select a from b order by embedding <-> [3,1,2] desc limit 5")); + + AssertIR(Parse("select a from b where embedding <-> [3,1,2] < 5"), + "select a from b where embedding <-> [3.000000, 1.000000, 2.000000] < 5"); + AssertIR(Parse("select a from b where embedding <-> [0.5,0.5] < 10 and c > 100"), + "select a from b where (and embedding <-> [0.500000, 0.500000] < 10, c > 100)"); + AssertIR(Parse("select a from b order by embedding <-> [3.6] limit 5"), + "select a from b where true sortby embedding <-> [3.600000] limit 0, 5"); +}