diff --git a/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.cpp b/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.cpp index c9ff38102b5bb7..35880336580ab0 100644 --- a/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.cpp +++ b/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.cpp @@ -24,6 +24,7 @@ #include "common/config.h" #include "io/io_common.h" #include "olap/rowset/segment_v2/ann_index/ann_index.h" +#include "olap/rowset/segment_v2/ann_index/ann_index_writer.h" #include "olap/rowset/segment_v2/ann_index/ann_search_params.h" #include "olap/rowset/segment_v2/ann_index/faiss_ann_index.h" #include "olap/rowset/segment_v2/index_file_reader.h" @@ -58,6 +59,9 @@ AnnIndexReader::AnnIndexReader(const TabletIndex* index_meta, it = index_properties.find("metric_type"); DCHECK(it != index_properties.end()); _metric_type = string_to_metric(it->second); + it = index_properties.find(AnnIndexColumnWriter::DIM); + DCHECK(it != index_properties.end()); + _dim = std::stoi(it->second); } Status AnnIndexReader::new_iterator(std::unique_ptr* iterator) { @@ -225,4 +229,8 @@ Status AnnIndexReader::range_search(const AnnRangeSearchParams& params, return Status::OK(); } +size_t AnnIndexReader::get_dimension() const { + return _dim; +} + } // namespace doris::segment_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.h b/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.h index 8d9aba5239ce09..68e62f7f347c78 100644 --- a/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.h +++ b/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.h @@ -59,13 +59,15 @@ class AnnIndexReader : public IndexReader { AnnIndexMetric get_metric_type() const { return _metric_type; } + size_t get_dimension() const; + private: TabletIndex _index_meta; std::shared_ptr _index_file_reader; std::unique_ptr _vector_index; AnnIndexType _index_type; AnnIndexMetric _metric_type; - + size_t _dim; DorisCallOnce _load_index_once; }; diff --git a/be/src/olap/rowset/segment_v2/ann_index/ann_range_search_runtime.cpp b/be/src/olap/rowset/segment_v2/ann_index/ann_range_search_runtime.cpp index 416af7826d2a89..f8e0e4af8c30bc 100644 --- a/be/src/olap/rowset/segment_v2/ann_index/ann_range_search_runtime.cpp +++ b/be/src/olap/rowset/segment_v2/ann_index/ann_range_search_runtime.cpp @@ -35,7 +35,8 @@ namespace doris::segment_v2 { */ AnnRangeSearchParams AnnRangeSearchRuntime::to_range_search_params() const { AnnRangeSearchParams params; - params.query_value = query_value.get(); + const auto* query = assert_cast(query_value.get()); + params.query_value = query->get_data().data(); params.radius = static_cast(radius); params.roaring = nullptr; params.is_le_or_lt = is_le_or_lt; @@ -58,6 +59,7 @@ std::string AnnRangeSearchRuntime::to_string() const { "dst_col_idx: {}, metric_type {}, radius: {}, user params: {}, query_vector is null: " "{}", is_ann_range_search, is_le_or_lt, src_col_idx, dst_col_idx, - metric_to_string(metric_type), radius, user_params.to_string(), query_value == nullptr); + metric_to_string(metric_type), radius, user_params.to_string(), + query_value.get() == nullptr); } } // namespace doris::segment_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/ann_index/ann_range_search_runtime.h b/be/src/olap/rowset/segment_v2/ann_index/ann_range_search_runtime.h index c3d112c9bf2412..113bdf8786f446 100644 --- a/be/src/olap/rowset/segment_v2/ann_index/ann_range_search_runtime.h +++ b/be/src/olap/rowset/segment_v2/ann_index/ann_range_search_runtime.h @@ -22,6 +22,11 @@ #include #include "olap/rowset/segment_v2/ann_index/ann_index.h" +#include "runtime/define_primitive_type.h" +#include "runtime/primitive_type.h" +#include "vec/columns/column.h" +#include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" #include "vec/runtime/vector_search_user_params.h" namespace doris::segment_v2 { @@ -81,16 +86,8 @@ struct AnnRangeSearchRuntime { dst_col_idx(other.dst_col_idx), radius(other.radius), metric_type(other.metric_type), - user_params(other.user_params) { - // Do deep copy to query_value. - if (other.query_value) { - query_value = std::make_unique(other.dim); - std::copy(other.query_value.get(), other.query_value.get() + other.dim, - query_value.get()); - } else { - query_value = nullptr; - } - } + user_params(other.user_params), + query_value(other.query_value) {} /** * @brief Assignment operator with deep copy semantics. @@ -110,14 +107,8 @@ struct AnnRangeSearchRuntime { metric_type = other.metric_type; user_params = other.user_params; dim = other.dim; - // Do deep copy to query_value. - if (other.query_value) { - query_value = std::make_unique(other.dim); - std::copy(other.query_value.get(), other.query_value.get() + other.dim, - query_value.get()); - } else { - query_value = nullptr; - } + query_value = other.query_value; + return *this; } @@ -142,7 +133,7 @@ struct AnnRangeSearchRuntime { double radius = 0.0; ///< Search radius/distance threshold AnnIndexMetric metric_type; ///< Distance metric (L2, Inner Product, etc.) doris::VectorSearchUserParams user_params; ///< User-defined search parameters - std::unique_ptr query_value; ///< Query vector data (deep copied) + vectorized::IColumn::Ptr query_value; ///< Query vector data (deep copied) }; #include "common/compile_check_end.h" } // namespace doris::segment_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/ann_index/ann_search_params.h b/be/src/olap/rowset/segment_v2/ann_index/ann_search_params.h index c38c8d2138dbe3..b2d9c758659b23 100644 --- a/be/src/olap/rowset/segment_v2/ann_index/ann_search_params.h +++ b/be/src/olap/rowset/segment_v2/ann_index/ann_search_params.h @@ -94,7 +94,7 @@ struct AnnTopNParam { struct AnnRangeSearchParams { bool is_le_or_lt = true; - float* query_value = nullptr; + const float* query_value = nullptr; float radius = -1; roaring::Roaring* roaring; // roaring from segment_iterator std::string to_string() const { diff --git a/be/src/olap/rowset/segment_v2/ann_index/ann_topn_runtime.cpp b/be/src/olap/rowset/segment_v2/ann_index/ann_topn_runtime.cpp index 62dac6e70951d4..dc99192d7f1d0c 100644 --- a/be/src/olap/rowset/segment_v2/ann_index/ann_topn_runtime.cpp +++ b/be/src/olap/rowset/segment_v2/ann_index/ann_topn_runtime.cpp @@ -22,16 +22,19 @@ #include #include +#include "common/exception.h" #include "common/logging.h" +#include "common/status.h" #include "olap/rowset/segment_v2/ann_index/ann_index_iterator.h" #include "olap/rowset/segment_v2/ann_index/ann_search_params.h" +#include "olap/rowset/segment_v2/inverted_index_query_type.h" +#include "runtime/primitive_type.h" #include "runtime/runtime_state.h" #include "vec/columns/column.h" #include "vec/columns/column_array.h" -#include "vec/columns/column_const.h" #include "vec/columns/column_nullable.h" -#include "vec/common/assert_cast.h" #include "vec/exprs/varray_literal.h" +#include "vec/exprs/vcast_expr.h" #include "vec/exprs/vexpr_context.h" #include "vec/exprs/vexpr_fwd.h" #include "vec/exprs/virtual_slot_ref.h" @@ -40,6 +43,81 @@ namespace doris::segment_v2 { #include "common/compile_check_begin.h" + +Result extract_query_vector(std::shared_ptr arg_expr) { + if (arg_expr->is_constant() == false) { + return ResultError(Status::InvalidArgument("Ann topn expr must be constant, got\n{}", + arg_expr->debug_string())); + } + + // Accept either ArrayLiteral([..]) or CAST('..' AS Nullable(Array(Nullable(Float32)))) + // First, check the expr node type for clarity. + + bool is_array_literal = + std::dynamic_pointer_cast(arg_expr) != nullptr; + bool is_cast_expr = std::dynamic_pointer_cast(arg_expr) != nullptr; + if (!is_array_literal && !is_cast_expr) { + return ResultError( + Status::InvalidArgument("Constant must be ArrayLiteral or CAST to array, got\n{}", + arg_expr->debug_string())); + } + + // We'll validate shape by inspecting the materialized constant column below. + + std::shared_ptr column_wrapper; + auto st = arg_expr->get_const_col(nullptr, &column_wrapper); + if (!st.ok()) { + return ResultError(Status::InvalidArgument("Failed to get constant column, error: {}", + st.to_string())); + } + + // Execute the constant array literal and extract its float elements into _query_array + vectorized::IColumn::Ptr col_ptr = + column_wrapper->column_ptr->convert_to_full_column_if_const(); + + // The expected runtime column layout for the literal is: + // Nullable(ColumnArray(Nullable(ColumnFloat32))) with exactly 1 row (one array literal) + const vectorized::IColumn* top_col = col_ptr.get(); + const vectorized::IColumn* array_holder_col = top_col; + // Handle outer Nullable and remember result nullability preference + if (auto* nullable_col = + vectorized::check_and_get_column(*top_col)) { + if (nullable_col->has_null()) { + return ResultError(Status::InvalidArgument("Ann query vector cannot be NULL")); + } + array_holder_col = &nullable_col->get_nested_column(); + } + + // Must be an array column with single row + const auto* array_col = + vectorized::check_and_get_column(*array_holder_col); + if (array_col == nullptr || array_col->size() != 1) { + return ResultError(Status::InvalidArgument( + "Ann topn expr constant should be an Array literal, got column: {}", + array_holder_col->get_name())); + } + + // Fetch nested data column: Nullable(ColumnFloat32) or ColumnFloat32 + const vectorized::IColumn& nested_data_any = array_col->get_data(); + vectorized::IColumn::Ptr values_holder_col = array_col->get_data_ptr(); + size_t value_count = array_col->get_offsets()[0]; + + if (value_count == 0) { + return ResultError(Status::InvalidArgument("Ann topn query vector cannot be empty")); + } + + if (auto* value_nullable_col = + vectorized::check_and_get_column(nested_data_any)) { + if (value_nullable_col->has_null(0, value_count)) { + return ResultError(Status::InvalidArgument( + "Ann topn query vector elements cannot contain NULL values")); + } + values_holder_col = value_nullable_col->get_nested_column_ptr(); + } + + return values_holder_col; +} + Status AnnTopNRuntime::prepare(RuntimeState* state, const RowDescriptor& row_desc) { RETURN_IF_ERROR(_order_by_expr_ctx->prepare(state, row_desc)); RETURN_IF_ERROR(_order_by_expr_ctx->open(state)); @@ -54,13 +132,13 @@ Status AnnTopNRuntime::prepare(RuntimeState* state, const RowDescriptor& row_des |---------------- | | | | - SlotRef ArrayLiteral + SlotRef CAST(String as Nullable) OR ArrayLiteral */ std::shared_ptr vir_slot_ref = std::dynamic_pointer_cast(_order_by_expr_ctx->root()); DCHECK(vir_slot_ref != nullptr); if (vir_slot_ref == nullptr) { - return Status::InternalError( + return Status::InvalidArgument( "root of order by expr of ann topn must be a vectorized::VirtualSlotRef, got\n{}", _order_by_expr_ctx->root()->debug_string()); } @@ -71,27 +149,35 @@ Status AnnTopNRuntime::prepare(RuntimeState* state, const RowDescriptor& row_des std::dynamic_pointer_cast(vir_col_expr); if (distance_fn_call == nullptr) { - return Status::InternalError("Ann topn expr expect FuncationCall, got\n{}", - vir_col_expr->debug_string()); + return Status::InvalidArgument("Ann topn expr expect FuncationCall, got\n{}", + vir_col_expr->debug_string()); } std::shared_ptr slot_ref = std::dynamic_pointer_cast(distance_fn_call->children()[0]); if (slot_ref == nullptr) { - return Status::InternalError("Ann topn expr expect SlotRef, got\n{}", - distance_fn_call->children()[0]->debug_string()); + return Status::InvalidArgument("Ann topn expr expect SlotRef, got\n{}", + distance_fn_call->children()[0]->debug_string()); } // slot_ref->column_id() is acutually the columnd idx in block. _src_column_idx = slot_ref->column_id(); - std::shared_ptr array_literal = - std::dynamic_pointer_cast(distance_fn_call->children()[1]); - if (array_literal == nullptr) { - return Status::InternalError("Ann topn expr expect ArrayLiteral, got\n{}", - distance_fn_call->children()[1]->debug_string()); + if (distance_fn_call->children()[1]->is_constant() == false) { + return Status::InvalidArgument("Ann topn expr expect constant ArrayLiteral, got\n{}", + distance_fn_call->children()[1]->debug_string()); } - _query_array = array_literal->get_column_ptr(); + + // Accept either ArrayLiteral([..]) or CAST('..' AS Nullable(Array(Nullable(Float32)))) + // First, check the expr node type for clarity. + auto arg_expr = distance_fn_call->children()[1]; + + auto query_array_result = extract_query_vector(arg_expr); + if (!query_array_result.has_value()) { + return query_array_result.error(); + } + _query_array = query_array_result.value(); + _user_params = state->get_vector_search_params(); std::set distance_func_names = {vectorized::L2DistanceApproximate::name, @@ -121,27 +207,26 @@ Status AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::IndexIterator* ann DCHECK(ann_index_iterator_casted != nullptr); DCHECK(_order_by_expr_ctx != nullptr); DCHECK(_order_by_expr_ctx->root() != nullptr); + size_t query_array_size = _query_array->size(); + if (_query_array.get() == nullptr || query_array_size == 0) { + return Status::InternalError("Ann topn query vector is not initialized"); + } - const vectorized::ColumnConst* const_column = - assert_cast(_query_array.get()); - const vectorized::ColumnArray* column_array = - assert_cast(const_column->get_data_column_ptr().get()); - const vectorized::ColumnNullable* column_nullable = - assert_cast(column_array->get_data_ptr().get()); - const vectorized::ColumnFloat32* cf32 = assert_cast( - column_nullable->get_nested_column_ptr().get()); - - const float* query_value = cf32->get_data().data(); - const size_t query_value_size = cf32->get_data().size(); + // TODO:(zhiqiang) Maybe we can move this dimension check to prepare phase. - std::unique_ptr query_value_f32 = std::make_unique(query_value_size); - for (size_t i = 0; i < query_value_size; ++i) { - query_value_f32[i] = static_cast(query_value[i]); + auto index_reader = ann_index_iterator_casted->get_reader(AnnIndexReaderType::ANN); + auto ann_index_reader = std::dynamic_pointer_cast(index_reader); + DCHECK(ann_index_reader != nullptr); + if (ann_index_reader->get_dimension() != query_array_size) { + return Status::InvalidArgument( + "Ann topn query vector dimension {} does not match index dimension {}", + query_array_size, ann_index_reader->get_dimension()); } - + const vectorized::ColumnFloat32* query = + assert_cast(_query_array.get()); segment_v2::AnnTopNParam ann_query_params { - .query_value = query_value_f32.get(), - .query_value_size = query_value_size, + .query_value = query->get_data().data(), + .query_value_size = query_array_size, .limit = _limit, ._user_params = _user_params, .roaring = roaring, @@ -157,11 +242,9 @@ Status AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::IndexIterator* ann size_t num_results = ann_query_params.distance->size(); auto result_column_float = vectorized::ColumnFloat32::create(num_results); - for (size_t i = 0; i < num_results; ++i) { result_column_float->get_data()[i] = (*ann_query_params.distance)[i]; } - result_column = std::move(result_column_float); row_ids = std::move(ann_query_params.row_ids); ann_index_stats = *ann_query_params.stats; diff --git a/be/src/olap/rowset/segment_v2/ann_index/ann_topn_runtime.h b/be/src/olap/rowset/segment_v2/ann_index/ann_topn_runtime.h index 8fd4dcee8a69ca..121901ff9188fc 100644 --- a/be/src/olap/rowset/segment_v2/ann_index/ann_topn_runtime.h +++ b/be/src/olap/rowset/segment_v2/ann_index/ann_topn_runtime.h @@ -35,6 +35,7 @@ #pragma once +#include "runtime/primitive_type.h" #include "runtime/runtime_state.h" #include "vec/columns/column.h" #include "vec/exprs/varray_literal.h" @@ -49,6 +50,8 @@ namespace doris::segment_v2 { #include "common/compile_check_begin.h" struct AnnIndexStats; +Result extract_query_vector(std::shared_ptr arg_expr); + /** * @brief Runtime execution engine for ANN (Approximate Nearest Neighbor) Top-N queries. * @@ -161,7 +164,7 @@ class AnnTopNRuntime { size_t _src_column_idx = -1; ///< Source vector column index size_t _dest_column_idx = -1; ///< Destination distance column index segment_v2::AnnIndexMetric _metric_type; ///< Distance metric type - vectorized::IColumn::Ptr _query_array; ///< Query vector data + vectorized::IColumn::Ptr _query_array; ///< Query vector data (contiguous float buffer) doris::VectorSearchUserParams _user_params; ///< User-defined search parameters }; #include "common/compile_check_end.h" diff --git a/be/src/vec/exec/scan/olap_scanner.cpp b/be/src/vec/exec/scan/olap_scanner.cpp index 606d1757b1b149..1e8a7f9321291e 100644 --- a/be/src/vec/exec/scan/olap_scanner.cpp +++ b/be/src/vec/exec/scan/olap_scanner.cpp @@ -155,6 +155,7 @@ Status OlapScanner::prepare() { _score_runtime = local_state->_score_runtime; _score_runtime = local_state->_score_runtime; + // All scanners share the same ann_topn_runtime. _ann_topn_runtime = local_state->_ann_topn_runtime; // set limit to reduce end of rowset and segment mem use diff --git a/be/src/vec/exprs/vectorized_fn_call.cpp b/be/src/vec/exprs/vectorized_fn_call.cpp index 786674eadf3e6a..6ee7a12a251a2f 100644 --- a/be/src/vec/exprs/vectorized_fn_call.cpp +++ b/be/src/vec/exprs/vectorized_fn_call.cpp @@ -42,6 +42,7 @@ #include "vec/columns/column_array.h" #include "vec/columns/column_nullable.h" #include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" #include "vec/core/block.h" #include "vec/core/column_numbers.h" #include "vec/core/types.h" @@ -345,7 +346,7 @@ bool VectorizedFnCall::equals(const VExpr& other) { |---------------- | | | | - SlotRef ArrayLiteral + SlotRef ArrayLiteral/Cast(String as Array) */ void VectorizedFnCall::prepare_ann_range_search( @@ -425,44 +426,44 @@ void VectorizedFnCall::prepare_ann_range_search( range_search_runtime.metric_type = segment_v2::string_to_metric(metric_name); } - UInt16 idx_of_slot_ref = 0; - UInt16 idx_of_array_literal = 0; + // Identify the slot ref child and the constant query array child (ArrayLiteral or CAST to array) + Int32 idx_of_slot_ref = -1; + Int32 idx_of_array_expr = -1; for (UInt16 i = 0; i < function_call->get_num_children(); ++i) { auto child = function_call->get_child(i); - if (std::dynamic_pointer_cast(child) != nullptr) { + if (idx_of_slot_ref == -1 && std::dynamic_pointer_cast(child) != nullptr) { idx_of_slot_ref = i; - } else if (std::dynamic_pointer_cast(child) != nullptr) { - idx_of_array_literal = i; + continue; + } + // Accept either ArrayLiteral or Cast-to-array constant + if (idx_of_array_expr == -1 && + (std::dynamic_pointer_cast(child) != nullptr || + std::dynamic_pointer_cast(child) != nullptr)) { + idx_of_array_expr = i; } } - std::shared_ptr slot_ref = - std::dynamic_pointer_cast(function_call->get_child(idx_of_slot_ref)); - std::shared_ptr array_literal = std::dynamic_pointer_cast( - function_call->get_child(idx_of_array_literal)); - - if (slot_ref == nullptr || array_literal == nullptr) { + if (idx_of_slot_ref == -1 || idx_of_array_expr == -1) { suitable_for_ann_index = false; - // slot ref or array literal is null. + // slot ref or array literal/cast is missing. return; } + auto slot_ref = std::dynamic_pointer_cast( + function_call->get_child(static_cast(idx_of_slot_ref))); range_search_runtime.src_col_idx = slot_ref->column_id(); range_search_runtime.dst_col_idx = vir_slot_ref == nullptr ? -1 : vir_slot_ref->column_id(); - auto col_const = array_literal->get_column_ptr(); - auto col_array = col_const->convert_to_full_column_if_const(); - const ColumnArray* array_col = assert_cast(col_array.get()); - DCHECK(array_col->size() == 1); - size_t dim = array_col->get_offsets()[0]; - range_search_runtime.dim = dim; - range_search_runtime.query_value = std::make_unique(dim); - - const ColumnNullable* cn = assert_cast(array_col->get_data_ptr().get()); - const ColumnFloat32* cf32 = - assert_cast(cn->get_nested_column_ptr().get()); - for (size_t i = 0; i < dim; ++i) { - range_search_runtime.query_value[i] = cf32->get_data()[i]; + + // Materialize the constant array expression and validate its shape and types + std::shared_ptr column_wrapper; + auto array_expr = function_call->get_child(static_cast(idx_of_array_expr)); + auto extract_result = extract_query_vector(array_expr); + if (!extract_result.has_value()) { + suitable_for_ann_index = false; + return; } + range_search_runtime.query_value = extract_result.value(); + range_search_runtime.dim = range_search_runtime.query_value->size(); range_search_runtime.is_ann_range_search = true; range_search_runtime.user_params = user_params; VLOG_DEBUG << fmt::format("Ann range search params: {}", range_search_runtime.to_string()); @@ -513,6 +514,14 @@ Status VectorizedFnCall::evaluate_ann_range_search( return Status::OK(); } + // Check dimension if available (>0) + const size_t index_dim = ann_index_reader->get_dimension(); + if (index_dim > 0 && index_dim != range_search_runtime.dim) { + return Status::InvalidArgument( + "Ann range search query dimension {} does not match index dimension {}", + range_search_runtime.dim, index_dim); + } + AnnRangeSearchParams params = range_search_runtime.to_range_search_params(); params.roaring = &row_bitmap; diff --git a/be/test/olap/vector_search/ann_index_reader_test.cpp b/be/test/olap/vector_search/ann_index_reader_test.cpp index 7f914038982bbd..8af387ab2301a8 100644 --- a/be/test/olap/vector_search/ann_index_reader_test.cpp +++ b/be/test/olap/vector_search/ann_index_reader_test.cpp @@ -423,7 +423,9 @@ TEST_F(AnnIndexReaderTest, AnnIndexReaderRangeSearch) { for (size_t i = 0; i < iterations; ++i) { std::map index_properties; index_properties["index_type"] = "hnsw"; - index_properties["metric_type"] = "l2"; + // Use canonical metric name and include required dimension property + index_properties["metric_type"] = "l2_distance"; + index_properties["dim"] = "128"; std::unique_ptr index_meta = std::make_unique(); index_meta->_properties = index_properties; auto mock_index_file_reader = std::make_shared(); diff --git a/be/test/olap/vector_search/ann_range_search_test.cpp b/be/test/olap/vector_search/ann_range_search_test.cpp index d9541d6cbb8526..080e3c76f9e206 100644 --- a/be/test/olap/vector_search/ann_range_search_test.cpp +++ b/be/test/olap/vector_search/ann_range_search_test.cpp @@ -140,6 +140,7 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch) { std::map properties; properties["index_type"] = "hnsw"; properties["metric_type"] = "l2_distance"; + properties["dim"] = "8"; // match query vector size from thrift auto pair = vector_search_utils::create_tmp_ann_index_reader(properties); mock_ann_index_iter->_ann_reader = pair.second; @@ -230,6 +231,7 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch2) { std::map properties; properties["index_type"] = "hnsw"; properties["metric_type"] = "l2_distance"; + properties["dim"] = "8"; // match query vector size from thrift auto pair = vector_search_utils::create_tmp_ann_index_reader(properties); mock_ann_index_iter->_ann_reader = pair.second; @@ -303,11 +305,12 @@ TEST_F(VectorSearchTest, TestRangeSearchRuntimeInfoToString) { runtime_info2.radius = 15.5; runtime_info2.metric_type = doris::segment_v2::AnnIndexMetric::L2; runtime_info2.dim = 4; - runtime_info2.query_value = std::make_unique(4); - runtime_info2.query_value[0] = 1.0f; - runtime_info2.query_value[1] = 2.0f; - runtime_info2.query_value[2] = 3.0f; - runtime_info2.query_value[3] = 4.0f; + auto f32 = ColumnFloat32::create(4); + f32->get_data()[0] = 1.0f; + f32->get_data()[1] = 2.0f; + f32->get_data()[2] = 3.0f; + f32->get_data()[3] = 4.0f; + runtime_info2.query_value = std::move(f32); doris::VectorSearchUserParams user_params; user_params.hnsw_ef_search = 100; @@ -692,6 +695,56 @@ TEST_F(VectorSearchTest, TestAnnIndexReader_NewIterator) { EXPECT_NE(ann_iterator, nullptr); } +TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch_DimensionMismatch) { + // Prepare a valid range search expr from thrift + TExpr texpr = read_from_json(ann_range_search_thrift); + TDescriptorTable table1 = read_from_json(thrift_table_desc); + std::unique_ptr pool = std::make_unique(); + auto desc_tbl = std::make_unique(); + DescriptorTbl* desc_tbl_ptr = desc_tbl.get(); + ASSERT_TRUE(DescriptorTbl::create(pool.get(), table1, &(desc_tbl_ptr)).ok()); + RowDescriptor row_desc = RowDescriptor(*desc_tbl_ptr, {0}, {false}); + std::unique_ptr state = std::make_unique(); + state->set_desc_tbl(desc_tbl_ptr); + + VExprContextSPtr range_search_ctx; + ASSERT_TRUE(vectorized::VExpr::create_expr_tree(texpr, range_search_ctx).ok()); + ASSERT_TRUE(range_search_ctx->prepare(state.get(), row_desc).ok()); + ASSERT_TRUE(range_search_ctx->open(state.get()).ok()); + doris::VectorSearchUserParams user_params; + range_search_ctx->prepare_ann_range_search(user_params); + ASSERT_TRUE(range_search_ctx->_ann_range_search_runtime.is_ann_range_search); + // Force a dimension mismatch: query dim is 8 in thrift; set index dim to 4 + std::vector idx_to_cid(4); + idx_to_cid[0] = 0; + idx_to_cid[1] = 1; // embedding + idx_to_cid[2] = 2; + idx_to_cid[3] = 3; // virtual dist + + std::vector> cid_to_index_iterators(4); + auto mock_iter = std::make_unique(); + + // Back its reader with a real AnnIndexReader but with dim=4 + std::map properties; + properties["index_type"] = "hnsw"; + properties["metric_type"] = "l2_distance"; + properties["dim"] = "4"; // mismatch + auto pair = vector_search_utils::create_tmp_ann_index_reader(properties); + mock_iter->_ann_reader = pair.second; + cid_to_index_iterators[1] = std::move(mock_iter); + + std::vector> column_iterators(4); + column_iterators[3] = std::make_unique(); + + roaring::Roaring row_bitmap; + segment_v2::AnnIndexStats stats; + + auto st = range_search_ctx->evaluate_ann_range_search(cid_to_index_iterators, idx_to_cid, + column_iterators, row_bitmap, stats); + EXPECT_FALSE(st.ok()); + EXPECT_TRUE(st.is()); +} + TEST_F(VectorSearchTest, TestAnnIndexIterator_ReadFromIndex_NullParam) { // Test AnnIndexIterator::read_from_index with null parameter std::map properties; diff --git a/be/test/olap/vector_search/ann_topn_descriptor_test.cpp b/be/test/olap/vector_search/ann_topn_descriptor_test.cpp index 03d6c8800cbd44..530f1f1a038876 100644 --- a/be/test/olap/vector_search/ann_topn_descriptor_test.cpp +++ b/be/test/olap/vector_search/ann_topn_descriptor_test.cpp @@ -30,6 +30,7 @@ #include "olap/rowset/segment_v2/ann_index/ann_topn_runtime.h" #include "runtime/primitive_type.h" #include "vec/columns/column_nullable.h" +#include "vec/columns/column_vector.h" #include "vec/exprs/virtual_slot_ref.h" #include "vector_search_utils.h" @@ -119,17 +120,10 @@ TEST_F(VectorSearchTest, AnnTopNRuntimeEvaluateTopN) { ASSERT_TRUE(st.ok()) << fmt::format("st: {}, expr {}", st.to_string(), predicate->get_order_by_expr_ctx()->root()->debug_string()); - const ColumnConst* const_column = - assert_cast(predicate->_query_array.get()); - const ColumnArray* column_array = - assert_cast(const_column->get_data_column_ptr().get()); - const ColumnNullable* column_nullable = - assert_cast(column_array->get_data_ptr().get()); - const ColumnFloat32* cf32 = - assert_cast(column_nullable->get_nested_column_ptr().get()); - - const float* query_value = cf32->get_data().data(); - const size_t query_value_size = cf32->get_data().size(); + const vectorized::ColumnFloat32* query_column = + assert_cast(predicate->_query_array.get()); + const float* query_value = query_column->get_data().data(); + const size_t query_value_size = predicate->_query_array->size(); ASSERT_EQ(query_value_size, 8); std::vector query_value_f32; for (size_t i = 0; i < query_value_size; ++i) { @@ -153,6 +147,16 @@ TEST_F(VectorSearchTest, AnnTopNRuntimeEvaluateTopN) { std::cout << "query_vector: " << fmt::format("[{}]", fmt::join(*query_vector, ",")) << std::endl; + // Attach a valid ANN reader to the mock iterator so runtime can fetch reader and check dim + { + std::map properties; + properties["index_type"] = "hnsw"; + properties["metric_type"] = "l2_distance"; + properties["dim"] = "8"; // match the query vector dimension + auto pair = vector_search_utils::create_tmp_ann_index_reader(properties); + _ann_index_iterator->_ann_reader = pair.second; + } + EXPECT_CALL(*_ann_index_iterator, read_from_index(testing::_)) .Times(1) .WillOnce(testing::Invoke([](const segment_v2::IndexParam& value) { diff --git a/be/test/olap/vector_search/ann_topn_runtime_negative_test.cpp b/be/test/olap/vector_search/ann_topn_runtime_negative_test.cpp new file mode 100644 index 00000000000000..a84d99ec7831d7 --- /dev/null +++ b/be/test/olap/vector_search/ann_topn_runtime_negative_test.cpp @@ -0,0 +1,148 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include + +#include "common/status.h" +#include "olap/rowset/segment_v2/ann_index/ann_topn_runtime.h" +#include "vec/exprs/vectorized_fn_call.h" +#include "vec/exprs/vexpr.h" +#include "vec/exprs/virtual_slot_ref.h" +#include "vector_search_utils.h" + +using ::testing::HasSubstr; + +namespace doris::vectorized { + +// These tests target uncovered error branches in AnnTopNRuntime::prepare and +// evaluate_vector_ann_search using the existing VectorSearchTest fixture setup. + +TEST_F(VectorSearchTest, AnnTopNRuntimePrepare_NoFunctionCall) { + // Build a DescriptorTbl where the virtual column expr is a constant literal (not a function call) + doris::ObjectPool obj_pool_local; + TDescriptorTable thrift_tbl; + { + TTableDescriptor table_desc; + table_desc.id = 1000; + thrift_tbl.tableDescriptors.push_back(table_desc); + + TTupleDescriptor tuple_desc; + tuple_desc.__isset.tableId = true; + tuple_desc.id = 2000; + tuple_desc.tableId = 1000; + thrift_tbl.tupleDescriptors.push_back(tuple_desc); + + // Slot 0: materialized, with a virtual column expr set to FLOAT_LITERAL (not a function) + TSlotDescriptor slot0; + slot0.id = 3000; + slot0.parent = 2000; + slot0.isMaterialized = true; + slot0.need_materialize = true; + slot0.__isset.need_materialize = true; + // type: DOUBLE (matches fixture) + TTypeNode type_node; + type_node.type = TTypeNodeType::type::SCALAR; + TScalarType scalar_type; + scalar_type.__set_type(TPrimitiveType::DOUBLE); + type_node.__set_scalar_type(scalar_type); + slot0.slotType.types.push_back(type_node); + // Provide a simple FLOAT_LITERAL as the virtual column expr + doris::TExpr vexpr; + doris::TExprNode node; + node.node_type = TExprNodeType::FLOAT_LITERAL; + node.type = TTypeDesc(); + node.type.types.push_back(type_node); + doris::TFloatLiteral flit; + flit.value = 1.0; + node.__set_float_literal(flit); + node.__isset.float_literal = true; + vexpr.nodes.push_back(node); + slot0.virtual_column_expr = vexpr; + slot0.__isset.virtual_column_expr = true; + thrift_tbl.slotDescriptors.push_back(slot0); + + // Slot 1: a normal slot to satisfy references + TSlotDescriptor slot1 = slot0; + slot1.id = 3001; + slot1.__isset.virtual_column_expr = false; + thrift_tbl.slotDescriptors.push_back(slot1); + thrift_tbl.__isset.slotDescriptors = true; + } + + doris::DescriptorTbl* desc_tbl_local = nullptr; + ASSERT_TRUE(DescriptorTbl::create(&obj_pool_local, thrift_tbl, &desc_tbl_local).ok()); + RowDescriptor row_desc_local(*desc_tbl_local, {2000}, {false}); + + // Create a VirtualSlotRef root expr that points to the local descriptor's slot id (3000) + doris::TExpr local_virtual_slot_ref_expr = _virtual_slot_ref_expr; + ASSERT_TRUE(local_virtual_slot_ref_expr.nodes.size() == 1); + local_virtual_slot_ref_expr.nodes[0].slot_ref.slot_id = 3000; + std::shared_ptr vslot_ctx; + ASSERT_TRUE(VExpr::create_expr_tree(local_virtual_slot_ref_expr, vslot_ctx).ok()); + + doris::RuntimeState state_local; + state_local.set_desc_tbl(desc_tbl_local); + + auto runtime = segment_v2::AnnTopNRuntime::create_shared(true, 10, vslot_ctx); + Status st = runtime->prepare(&state_local, row_desc_local); + ASSERT_FALSE(st.ok()); + EXPECT_THAT(st.to_string(), HasSubstr("expect FuncationCall")); +} + +// Note: We intentionally avoid testing a non-VirtualSlotRef root since it triggers DCHECK. + +// Removed additional negative prepare tests that rely on internal descriptor mutations. + +TEST_F(VectorSearchTest, AnnTopNRuntimeEvaluate_DimensionMismatch) { + // Prepare a valid runtime first. + std::shared_ptr dist_ctx; + auto fn_thrift = read_from_json(_distance_function_call_thrift); + ASSERT_TRUE(VExpr::create_expr_tree(fn_thrift, dist_ctx).ok()); + + std::shared_ptr vslot_ctx; + ASSERT_TRUE(VExpr::create_expr_tree(_virtual_slot_ref_expr, vslot_ctx).ok()); + auto vir_slot = std::dynamic_pointer_cast(vslot_ctx->root()); + ASSERT_TRUE(vir_slot != nullptr); + vir_slot->set_virtual_column_expr(dist_ctx->root()); + + auto runtime = segment_v2::AnnTopNRuntime::create_shared(true, 10, vslot_ctx); + ASSERT_TRUE(runtime->prepare(&_runtime_state, _row_desc).ok()); + + // Attach an ANN reader with a different dimension to trigger the mismatch branch. + { + std::map props; + props["index_type"] = "hnsw"; + props["metric_type"] = "l2_distance"; + props["dim"] = "4"; // runtime query vector dimension is 8 from fixture JSON + auto pair = vector_search_utils::create_tmp_ann_index_reader(props); + _ann_index_iterator->_ann_reader = pair.second; + } + + roaring::Roaring bitmap; + vectorized::IColumn::MutablePtr result_col = ColumnFloat32::create(0); + std::unique_ptr> row_ids; + doris::segment_v2::AnnIndexStats stats; + Status st = runtime->evaluate_vector_ann_search(_ann_index_iterator.get(), &bitmap, 10, + result_col, row_ids, stats); + ASSERT_FALSE(st.ok()); + EXPECT_THAT(st.to_string(), HasSubstr("dimension")); +} + +} // namespace doris::vectorized diff --git a/be/test/olap/vector_search/vector_search_utils.cpp b/be/test/olap/vector_search/vector_search_utils.cpp index 506d0e4ea8d5cc..cb02b464d6a02f 100644 --- a/be/test/olap/vector_search/vector_search_utils.cpp +++ b/be/test/olap/vector_search/vector_search_utils.cpp @@ -265,6 +265,16 @@ float get_radius_from_matrix(const float* vector, int dim, std::pair, std::shared_ptr> create_tmp_ann_index_reader(std::map properties) { + // Ensure required properties exist for tests + if (properties.find("index_type") == properties.end()) { + properties["index_type"] = "hnsw"; + } + if (properties.find("metric_type") == properties.end()) { + properties["metric_type"] = "l2_distance"; + } + if (properties.find("dim") == properties.end()) { + properties["dim"] = "4"; // default small dimension for tests + } auto mock_tablet_index = std::make_unique(); mock_tablet_index->_properties = properties; auto mock_index_file_reader = std::make_shared(); diff --git a/regression-test/data/ann_index_p0/cast_string_as_array.out b/regression-test/data/ann_index_p0/cast_string_as_array.out new file mode 100644 index 00000000000000..dbbe421d76a81f --- /dev/null +++ b/regression-test/data/ann_index_p0/cast_string_as_array.out @@ -0,0 +1,47 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !sql_0 -- +1 +2 +4 + +-- !sql_1 -- +1 +2 +4 + +-- !sql_3 -- +3 +2 +1 + +-- !sql_rs_l2_le -- +1 +2 + +-- !sql_rs_l2_ge -- +3 +4 + +-- !sql_rs_ip_ge -- +2 +3 + +-- !sql_rs_ip_lt -- +1 + +-- !sql_fall_back -- +0.0 +0.0 +0.0 +0.0 + +-- !sql_rs_l2_nonconst_le -- +1 +2 +3 +4 + +-- !sql_rs_ip_nonconst_ge -- +2 +3 + diff --git a/regression-test/suites/ann_index_p0/cast_string_as_array.groovy b/regression-test/suites/ann_index_p0/cast_string_as_array.groovy new file mode 100644 index 00000000000000..9d0ea331ef2ee7 --- /dev/null +++ b/regression-test/suites/ann_index_p0/cast_string_as_array.groovy @@ -0,0 +1,155 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("cast_string_as_array") { + sql "unset variable all;" + sql "set enable_common_expr_pushdown=true;" + + // L2 table: dim=3 + sql "drop table if exists ann_cast_rhs_l2" + sql """ + CREATE TABLE ann_cast_rhs_l2 ( + id INT NOT NULL, + embedding ARRAY NOT NULL, + INDEX idx_emb (`embedding`) USING ANN PROPERTIES( + "index_type"="hnsw", + "metric_type"="l2_distance", + "dim"="3" + ) + ) ENGINE=OLAP + DUPLICATE KEY(id) + DISTRIBUTED BY HASH(id) BUCKETS AUTO + PROPERTIES ("replication_num" = "1"); + """ + sql """ + INSERT INTO ann_cast_rhs_l2 VALUES + (1, [1.0, 2.0, 3.0]), + (2, [0.5, 2.1, 2.9]), + (3, [10.0, 10.0, 10.0]), + (4, [2.0, 3.0, 4.0]); + """ + + // Success: CAST(string AS array) on RHS + qt_sql_0 "select id from ann_cast_rhs_l2 order by l2_distance_approximate(embedding, cast('[1.0,2.0,3.0]' as array)) limit 3;" + + // Success: extra spaces in the string and integer literals should parse fine + test { + sql "select id from ann_cast_rhs_l2 order by l2_distance_approximate(embedding, cast(' [1, 2 , 3 ] ' as array)) limit 3;" + + exception "Ann query vector cannot be NULL" + } + + // Success: nested cast(string->string->array) should also work + qt_sql_1 "select id from ann_cast_rhs_l2 order by l2_distance_approximate(embedding, cast(cast('[1.0,2.0,3.0]' as string) as array)) limit 3;" + + // Failure: empty array is not allowed for ANN query vector + test { + sql "select id from ann_cast_rhs_l2 order by l2_distance_approximate(embedding, cast('[]' as array)) limit 1;" + exception "Ann topn query vector cannot be empty" + } + + // A special case. + // Constant propagation may optimize l2_distance_approximate(embedding, NULL) to NULL before reaching the + // runtime of ANN topn. So here we will get null directly... + test { + sql "select id from ann_cast_rhs_l2 order by l2_distance_approximate(embedding, cast(NULL as array)) limit 1;" + exception "Constant must be ArrayLiteral or CAST to array" + } + + + // Failure: dim mismatch (2 vs table dim=3) + test { + sql "select id from ann_cast_rhs_l2 order by l2_distance_approximate(embedding, cast('[1.0,2.0]' as array)) limit 1;" + exception "[INVALID_ARGUMENT]" + } + + // Inner product table: dim=4 + sql "drop table if exists ann_cast_rhs_ip" + sql """ + CREATE TABLE ann_cast_rhs_ip ( + id INT NOT NULL, + embedding ARRAY NOT NULL, + INDEX idx_emb (`embedding`) USING ANN PROPERTIES( + "index_type"="hnsw", + "metric_type"="inner_product", + "dim"="4" + ) + ) ENGINE=OLAP + DUPLICATE KEY(id) + DISTRIBUTED BY HASH(id) BUCKETS AUTO + PROPERTIES ("replication_num" = "1"); + """ + + sql "truncate table ann_cast_rhs_ip" + sql """ + INSERT INTO ann_cast_rhs_ip VALUES + (1, [0.1, 0.2, 0.3, 0.4]), + (2, [0.5, 0.6, 0.7, 0.8]), + (3, [1.0, 1.0, 1.0, 1.0]); + """ + + // Success: DESC for inner_product + qt_sql_3 "select id from ann_cast_rhs_ip order by inner_product_approximate(embedding, cast('[0.1,0.2,0.3,0.4]' as array)) desc limit 3;" + + // Failure: dim mismatch (3 vs table dim=4) + test { + sql "select id from ann_cast_rhs_ip order by inner_product_approximate(embedding, cast('[0.1,0.2,0.3]' as array)) desc limit 1;" + exception "[INVALID_ARGUMENT]" + } + + // ---------------------- + // Range search cases (CAST string -> array) + // ---------------------- + + // L2 range search with <= radius: expect ids 1 and 2 (distance to [1,2,3] is <= 1.0) + qt_sql_rs_l2_le "select id from ann_cast_rhs_l2 where l2_distance_approximate(embedding, cast('[1,2,3]' as array)) <= 1.0 order by id;" + + // L2 range search with >= radius: expect ids 3 and 4 (distance to [1,2,3] is >= 1.0) + qt_sql_rs_l2_ge "select id from ann_cast_rhs_l2 where l2_distance_approximate(embedding, cast('[1,2,3]' as array)) >= 1.0 order by id;" + + // L2 range search: dim mismatch should error + test { + sql "select id from ann_cast_rhs_l2 where l2_distance_approximate(embedding, cast('[1,2]' as array)) <= 1.0 order by id;" + exception "[INVALID_ARGUMENT]" + } + + // Inner product range search with >= threshold: expect ids 2 and 3 + qt_sql_rs_ip_ge "select id from ann_cast_rhs_ip where inner_product_approximate(embedding, cast('[0.1,0.2,0.3,0.4]' as array)) >= 0.6 order by id;" + + // Inner product range search with < threshold: expect id 1 only + qt_sql_rs_ip_lt "select id from ann_cast_rhs_ip where inner_product_approximate(embedding, cast('[0.1,0.2,0.3,0.4]' as array)) < 0.6 order by id;" + + // Inner product range search: dim mismatch should error + test { + sql "select id from ann_cast_rhs_ip where inner_product_approximate(embedding, cast('[0.1,0.2,0.3]' as array)) >= 0.6 order by id;" + exception "[INVALID_ARGUMENT]" + } + + // ---------------------- + // Non-constant RHS behavior + // ---------------------- + + // Fall back to full scan if RHS is not constant + qt_sql_fall_back "select l2_distance_approximate(embedding, embedding) from ann_cast_rhs_l2 order by l2_distance_approximate(embedding, embedding) limit 10;" + + // Range search with non-constant RHS should execute without index pushdown + // L2: distance(embedding, embedding) == 0, so <= 0 selects all rows + qt_sql_rs_l2_nonconst_le "select id from ann_cast_rhs_l2 where l2_distance_approximate(embedding, embedding) <= 0.0 order by id;" + + // IP: inner_product(embedding, embedding) is sum of squares; with threshold 1.5 expect ids 2 and 3 + qt_sql_rs_ip_nonconst_ge "select id from ann_cast_rhs_ip where inner_product_approximate(embedding, embedding) >= 1.5 order by id;" +} \ No newline at end of file