Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "common/config.h"
#include "io/io_common.h"
#include "olap/rowset/segment_v2/ann_index/ann_index.h"
#include "olap/rowset/segment_v2/ann_index/ann_index_writer.h"
#include "olap/rowset/segment_v2/ann_index/ann_search_params.h"
#include "olap/rowset/segment_v2/ann_index/faiss_ann_index.h"
#include "olap/rowset/segment_v2/index_file_reader.h"
Expand Down Expand Up @@ -58,6 +59,9 @@ AnnIndexReader::AnnIndexReader(const TabletIndex* index_meta,
it = index_properties.find("metric_type");
DCHECK(it != index_properties.end());
_metric_type = string_to_metric(it->second);
it = index_properties.find(AnnIndexColumnWriter::DIM);
DCHECK(it != index_properties.end());
_dim = std::stoi(it->second);
}

Status AnnIndexReader::new_iterator(std::unique_ptr<IndexIterator>* iterator) {
Expand Down Expand Up @@ -225,4 +229,8 @@ Status AnnIndexReader::range_search(const AnnRangeSearchParams& params,
return Status::OK();
}

size_t AnnIndexReader::get_dimension() const {
return _dim;
}

} // namespace doris::segment_v2
4 changes: 3 additions & 1 deletion be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,15 @@ class AnnIndexReader : public IndexReader {

AnnIndexMetric get_metric_type() const { return _metric_type; }

size_t get_dimension() const;

private:
TabletIndex _index_meta;
std::shared_ptr<IndexFileReader> _index_file_reader;
std::unique_ptr<VectorIndex> _vector_index;
AnnIndexType _index_type;
AnnIndexMetric _metric_type;

size_t _dim;
DorisCallOnce<Status> _load_index_once;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ namespace doris::segment_v2 {
*/
AnnRangeSearchParams AnnRangeSearchRuntime::to_range_search_params() const {
AnnRangeSearchParams params;
params.query_value = query_value.get();
const auto* query = assert_cast<const vectorized::ColumnFloat32*>(query_value.get());
params.query_value = query->get_data().data();
params.radius = static_cast<float>(radius);
params.roaring = nullptr;
params.is_le_or_lt = is_le_or_lt;
Expand All @@ -58,6 +59,7 @@ std::string AnnRangeSearchRuntime::to_string() const {
"dst_col_idx: {}, metric_type {}, radius: {}, user params: {}, query_vector is null: "
"{}",
is_ann_range_search, is_le_or_lt, src_col_idx, dst_col_idx,
metric_to_string(metric_type), radius, user_params.to_string(), query_value == nullptr);
metric_to_string(metric_type), radius, user_params.to_string(),
query_value.get() == nullptr);
}
} // namespace doris::segment_v2
29 changes: 10 additions & 19 deletions be/src/olap/rowset/segment_v2/ann_index/ann_range_search_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
#include <string>

#include "olap/rowset/segment_v2/ann_index/ann_index.h"
#include "runtime/define_primitive_type.h"
#include "runtime/primitive_type.h"
#include "vec/columns/column.h"
#include "vec/columns/column_vector.h"
#include "vec/common/assert_cast.h"
#include "vec/runtime/vector_search_user_params.h"

namespace doris::segment_v2 {
Expand Down Expand Up @@ -81,16 +86,8 @@ struct AnnRangeSearchRuntime {
dst_col_idx(other.dst_col_idx),
radius(other.radius),
metric_type(other.metric_type),
user_params(other.user_params) {
// Do deep copy to query_value.
if (other.query_value) {
query_value = std::make_unique<float[]>(other.dim);
std::copy(other.query_value.get(), other.query_value.get() + other.dim,
query_value.get());
} else {
query_value = nullptr;
}
}
user_params(other.user_params),
query_value(other.query_value) {}

/**
* @brief Assignment operator with deep copy semantics.
Expand All @@ -110,14 +107,8 @@ struct AnnRangeSearchRuntime {
metric_type = other.metric_type;
user_params = other.user_params;
dim = other.dim;
// Do deep copy to query_value.
if (other.query_value) {
query_value = std::make_unique<float[]>(other.dim);
std::copy(other.query_value.get(), other.query_value.get() + other.dim,
query_value.get());
} else {
query_value = nullptr;
}
query_value = other.query_value;

return *this;
}

Expand All @@ -142,7 +133,7 @@ struct AnnRangeSearchRuntime {
double radius = 0.0; ///< Search radius/distance threshold
AnnIndexMetric metric_type; ///< Distance metric (L2, Inner Product, etc.)
doris::VectorSearchUserParams user_params; ///< User-defined search parameters
std::unique_ptr<float[]> query_value; ///< Query vector data (deep copied)
vectorized::IColumn::Ptr query_value; ///< Query vector data (deep copied)
};
#include "common/compile_check_end.h"
} // namespace doris::segment_v2
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ struct AnnTopNParam {

struct AnnRangeSearchParams {
bool is_le_or_lt = true;
float* query_value = nullptr;
const float* query_value = nullptr;
float radius = -1;
roaring::Roaring* roaring; // roaring from segment_iterator
std::string to_string() const {
Expand Down
149 changes: 116 additions & 33 deletions be/src/olap/rowset/segment_v2/ann_index/ann_topn_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,19 @@
#include <string>
#include <utility>

#include "common/exception.h"
#include "common/logging.h"
#include "common/status.h"
#include "olap/rowset/segment_v2/ann_index/ann_index_iterator.h"
#include "olap/rowset/segment_v2/ann_index/ann_search_params.h"
#include "olap/rowset/segment_v2/inverted_index_query_type.h"
#include "runtime/primitive_type.h"
#include "runtime/runtime_state.h"
#include "vec/columns/column.h"
#include "vec/columns/column_array.h"
#include "vec/columns/column_const.h"
#include "vec/columns/column_nullable.h"
#include "vec/common/assert_cast.h"
#include "vec/exprs/varray_literal.h"
#include "vec/exprs/vcast_expr.h"
#include "vec/exprs/vexpr_context.h"
#include "vec/exprs/vexpr_fwd.h"
#include "vec/exprs/virtual_slot_ref.h"
Expand All @@ -40,6 +43,81 @@

namespace doris::segment_v2 {
#include "common/compile_check_begin.h"

Result<vectorized::IColumn::Ptr> extract_query_vector(std::shared_ptr<vectorized::VExpr> arg_expr) {
if (arg_expr->is_constant() == false) {
return ResultError(Status::InvalidArgument("Ann topn expr must be constant, got\n{}",
arg_expr->debug_string()));
}

// Accept either ArrayLiteral([..]) or CAST('..' AS Nullable(Array(Nullable(Float32))))
// First, check the expr node type for clarity.

bool is_array_literal =
std::dynamic_pointer_cast<vectorized::VArrayLiteral>(arg_expr) != nullptr;
bool is_cast_expr = std::dynamic_pointer_cast<vectorized::VCastExpr>(arg_expr) != nullptr;
if (!is_array_literal && !is_cast_expr) {
return ResultError(
Status::InvalidArgument("Constant must be ArrayLiteral or CAST to array, got\n{}",
arg_expr->debug_string()));
}

// We'll validate shape by inspecting the materialized constant column below.

std::shared_ptr<ColumnPtrWrapper> column_wrapper;
auto st = arg_expr->get_const_col(nullptr, &column_wrapper);
if (!st.ok()) {
return ResultError(Status::InvalidArgument("Failed to get constant column, error: {}",
st.to_string()));
}

// Execute the constant array literal and extract its float elements into _query_array
vectorized::IColumn::Ptr col_ptr =
column_wrapper->column_ptr->convert_to_full_column_if_const();

// The expected runtime column layout for the literal is:
// Nullable(ColumnArray(Nullable(ColumnFloat32))) with exactly 1 row (one array literal)
const vectorized::IColumn* top_col = col_ptr.get();
const vectorized::IColumn* array_holder_col = top_col;
// Handle outer Nullable and remember result nullability preference
if (auto* nullable_col =
vectorized::check_and_get_column<vectorized::ColumnNullable>(*top_col)) {
if (nullable_col->has_null()) {
return ResultError(Status::InvalidArgument("Ann query vector cannot be NULL"));
}
array_holder_col = &nullable_col->get_nested_column();
}

// Must be an array column with single row
const auto* array_col =
vectorized::check_and_get_column<vectorized::ColumnArray>(*array_holder_col);
if (array_col == nullptr || array_col->size() != 1) {
return ResultError(Status::InvalidArgument(
"Ann topn expr constant should be an Array literal, got column: {}",
array_holder_col->get_name()));
}

// Fetch nested data column: Nullable(ColumnFloat32) or ColumnFloat32
const vectorized::IColumn& nested_data_any = array_col->get_data();
vectorized::IColumn::Ptr values_holder_col = array_col->get_data_ptr();
size_t value_count = array_col->get_offsets()[0];

if (value_count == 0) {
return ResultError(Status::InvalidArgument("Ann topn query vector cannot be empty"));
}

if (auto* value_nullable_col =
vectorized::check_and_get_column<vectorized::ColumnNullable>(nested_data_any)) {
if (value_nullable_col->has_null(0, value_count)) {
return ResultError(Status::InvalidArgument(
"Ann topn query vector elements cannot contain NULL values"));
}
values_holder_col = value_nullable_col->get_nested_column_ptr();
}

return values_holder_col;
}

Status AnnTopNRuntime::prepare(RuntimeState* state, const RowDescriptor& row_desc) {
RETURN_IF_ERROR(_order_by_expr_ctx->prepare(state, row_desc));
RETURN_IF_ERROR(_order_by_expr_ctx->open(state));
Expand All @@ -54,13 +132,13 @@ Status AnnTopNRuntime::prepare(RuntimeState* state, const RowDescriptor& row_des
|----------------
| |
| |
SlotRef ArrayLiteral
SlotRef CAST(String as Nullable<ArrayFloat>) OR ArrayLiteral
*/
std::shared_ptr<vectorized::VirtualSlotRef> vir_slot_ref =
std::dynamic_pointer_cast<vectorized::VirtualSlotRef>(_order_by_expr_ctx->root());
DCHECK(vir_slot_ref != nullptr);
if (vir_slot_ref == nullptr) {
return Status::InternalError(
return Status::InvalidArgument(
"root of order by expr of ann topn must be a vectorized::VirtualSlotRef, got\n{}",
_order_by_expr_ctx->root()->debug_string());
}
Expand All @@ -71,27 +149,35 @@ Status AnnTopNRuntime::prepare(RuntimeState* state, const RowDescriptor& row_des
std::dynamic_pointer_cast<vectorized::VectorizedFnCall>(vir_col_expr);

if (distance_fn_call == nullptr) {
return Status::InternalError("Ann topn expr expect FuncationCall, got\n{}",
vir_col_expr->debug_string());
return Status::InvalidArgument("Ann topn expr expect FuncationCall, got\n{}",
vir_col_expr->debug_string());
}

std::shared_ptr<vectorized::VSlotRef> slot_ref =
std::dynamic_pointer_cast<vectorized::VSlotRef>(distance_fn_call->children()[0]);
if (slot_ref == nullptr) {
return Status::InternalError("Ann topn expr expect SlotRef, got\n{}",
distance_fn_call->children()[0]->debug_string());
return Status::InvalidArgument("Ann topn expr expect SlotRef, got\n{}",
distance_fn_call->children()[0]->debug_string());
}

// slot_ref->column_id() is acutually the columnd idx in block.
_src_column_idx = slot_ref->column_id();

std::shared_ptr<vectorized::VArrayLiteral> array_literal =
std::dynamic_pointer_cast<vectorized::VArrayLiteral>(distance_fn_call->children()[1]);
if (array_literal == nullptr) {
return Status::InternalError("Ann topn expr expect ArrayLiteral, got\n{}",
distance_fn_call->children()[1]->debug_string());
if (distance_fn_call->children()[1]->is_constant() == false) {
return Status::InvalidArgument("Ann topn expr expect constant ArrayLiteral, got\n{}",
distance_fn_call->children()[1]->debug_string());
}
_query_array = array_literal->get_column_ptr();

// Accept either ArrayLiteral([..]) or CAST('..' AS Nullable(Array(Nullable(Float32))))
// First, check the expr node type for clarity.
auto arg_expr = distance_fn_call->children()[1];

auto query_array_result = extract_query_vector(arg_expr);
if (!query_array_result.has_value()) {
return query_array_result.error();
}
_query_array = query_array_result.value();

_user_params = state->get_vector_search_params();

std::set<std::string> distance_func_names = {vectorized::L2DistanceApproximate::name,
Expand Down Expand Up @@ -121,27 +207,26 @@ Status AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::IndexIterator* ann
DCHECK(ann_index_iterator_casted != nullptr);
DCHECK(_order_by_expr_ctx != nullptr);
DCHECK(_order_by_expr_ctx->root() != nullptr);
size_t query_array_size = _query_array->size();
if (_query_array.get() == nullptr || query_array_size == 0) {
return Status::InternalError("Ann topn query vector is not initialized");
}

const vectorized::ColumnConst* const_column =
assert_cast<const vectorized::ColumnConst*>(_query_array.get());
const vectorized::ColumnArray* column_array =
assert_cast<const vectorized::ColumnArray*>(const_column->get_data_column_ptr().get());
const vectorized::ColumnNullable* column_nullable =
assert_cast<const vectorized::ColumnNullable*>(column_array->get_data_ptr().get());
const vectorized::ColumnFloat32* cf32 = assert_cast<const vectorized::ColumnFloat32*>(
column_nullable->get_nested_column_ptr().get());

const float* query_value = cf32->get_data().data();
const size_t query_value_size = cf32->get_data().size();
// TODO:(zhiqiang) Maybe we can move this dimension check to prepare phase.

std::unique_ptr<float[]> query_value_f32 = std::make_unique<float[]>(query_value_size);
for (size_t i = 0; i < query_value_size; ++i) {
query_value_f32[i] = static_cast<float>(query_value[i]);
auto index_reader = ann_index_iterator_casted->get_reader(AnnIndexReaderType::ANN);
auto ann_index_reader = std::dynamic_pointer_cast<AnnIndexReader>(index_reader);
DCHECK(ann_index_reader != nullptr);
if (ann_index_reader->get_dimension() != query_array_size) {
return Status::InvalidArgument(
"Ann topn query vector dimension {} does not match index dimension {}",
query_array_size, ann_index_reader->get_dimension());
}

const vectorized::ColumnFloat32* query =
assert_cast<const vectorized::ColumnFloat32*>(_query_array.get());
segment_v2::AnnTopNParam ann_query_params {
.query_value = query_value_f32.get(),
.query_value_size = query_value_size,
.query_value = query->get_data().data(),
.query_value_size = query_array_size,
.limit = _limit,
._user_params = _user_params,
.roaring = roaring,
Expand All @@ -157,11 +242,9 @@ Status AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::IndexIterator* ann

size_t num_results = ann_query_params.distance->size();
auto result_column_float = vectorized::ColumnFloat32::create(num_results);

for (size_t i = 0; i < num_results; ++i) {
result_column_float->get_data()[i] = (*ann_query_params.distance)[i];
}

result_column = std::move(result_column_float);
row_ids = std::move(ann_query_params.row_ids);
ann_index_stats = *ann_query_params.stats;
Expand Down
5 changes: 4 additions & 1 deletion be/src/olap/rowset/segment_v2/ann_index/ann_topn_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

#pragma once

#include "runtime/primitive_type.h"
#include "runtime/runtime_state.h"
#include "vec/columns/column.h"
#include "vec/exprs/varray_literal.h"
Expand All @@ -49,6 +50,8 @@ namespace doris::segment_v2 {
#include "common/compile_check_begin.h"
struct AnnIndexStats;

Result<vectorized::IColumn::Ptr> extract_query_vector(std::shared_ptr<vectorized::VExpr> arg_expr);

/**
* @brief Runtime execution engine for ANN (Approximate Nearest Neighbor) Top-N queries.
*
Expand Down Expand Up @@ -161,7 +164,7 @@ class AnnTopNRuntime {
size_t _src_column_idx = -1; ///< Source vector column index
size_t _dest_column_idx = -1; ///< Destination distance column index
segment_v2::AnnIndexMetric _metric_type; ///< Distance metric type
vectorized::IColumn::Ptr _query_array; ///< Query vector data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change the struct

vectorized::IColumn::Ptr _query_array; ///< Query vector data (contiguous float buffer)
doris::VectorSearchUserParams _user_params; ///< User-defined search parameters
};
#include "common/compile_check_end.h"
Expand Down
1 change: 1 addition & 0 deletions be/src/vec/exec/scan/olap_scanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ Status OlapScanner::prepare() {
_score_runtime = local_state->_score_runtime;

_score_runtime = local_state->_score_runtime;
// All scanners share the same ann_topn_runtime.
_ann_topn_runtime = local_state->_ann_topn_runtime;

// set limit to reduce end of rowset and segment mem use
Expand Down
Loading
Loading