Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 10 additions & 23 deletions be/src/olap/rowset/segment_v2/ann_index/ann_topn_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,7 @@ Status AnnTopNRuntime::prepare(RuntimeState* state, const RowDescriptor& row_des
|----------------
| |
| |
CastToArray ArrayLiteral
|
|
SlotRef
SlotRef ArrayLiteral
*/
std::shared_ptr<vectorized::VirtualSlotRef> vir_slot_ref =
std::dynamic_pointer_cast<vectorized::VirtualSlotRef>(_order_by_expr_ctx->root());
Expand All @@ -78,19 +75,11 @@ Status AnnTopNRuntime::prepare(RuntimeState* state, const RowDescriptor& row_des
vir_col_expr->debug_string());
}

std::shared_ptr<vectorized::VCastExpr> cast_to_array_expr =
std::dynamic_pointer_cast<vectorized::VCastExpr>(distance_fn_call->children()[0]);

if (cast_to_array_expr == nullptr) {
return Status::InternalError("Ann topn expr expect cast_to_array_expr, got\n{}",
distance_fn_call->children()[0]->debug_string());
}

std::shared_ptr<vectorized::VSlotRef> slot_ref =
std::dynamic_pointer_cast<vectorized::VSlotRef>(cast_to_array_expr->children()[0]);
std::dynamic_pointer_cast<vectorized::VSlotRef>(distance_fn_call->children()[0]);
if (slot_ref == nullptr) {
return Status::InternalError("Ann topn expr expect SlotRef, got\n{}",
cast_to_array_expr->children()[0]->debug_string());
distance_fn_call->children()[0]->debug_string());
}

// slot_ref->column_id() is acutually the columnd idx in block.
Expand Down Expand Up @@ -139,11 +128,11 @@ Status AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::IndexIterator* ann
assert_cast<const vectorized::ColumnArray*>(const_column->get_data_column_ptr().get());
const vectorized::ColumnNullable* column_nullable =
assert_cast<const vectorized::ColumnNullable*>(column_array->get_data_ptr().get());
const vectorized::ColumnFloat64* cf64 = assert_cast<const vectorized::ColumnFloat64*>(
const vectorized::ColumnFloat32* cf32 = assert_cast<const vectorized::ColumnFloat32*>(
column_nullable->get_nested_column_ptr().get());

const double* query_value = cf64->get_data().data();
const size_t query_value_size = cf64->get_data().size();
const float* query_value = cf32->get_data().data();
const size_t query_value_size = cf32->get_data().size();

std::unique_ptr<float[]> query_value_f32 = std::make_unique<float[]>(query_value_size);
for (size_t i = 0; i < query_value_size; ++i) {
Expand All @@ -167,15 +156,13 @@ Status AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::IndexIterator* ann
DCHECK(ann_query_params.row_ids != nullptr);

size_t num_results = ann_query_params.distance->size();
auto result_column_double = vectorized::ColumnFloat64::create(num_results);
auto result_null_map = vectorized::ColumnUInt8::create(num_results, 0);
auto result_column_float = vectorized::ColumnFloat32::create(num_results);

for (size_t i = 0; i < num_results; ++i) {
result_column_double->get_data()[i] = (*ann_query_params.distance)[i];
result_column_float->get_data()[i] = (*ann_query_params.distance)[i];
}

result_column = vectorized::ColumnNullable::create(std::move(result_column_double),
std::move(result_null_map));
result_column = std::move(result_column_float);
row_ids = std::move(ann_query_params.row_ids);
ann_index_stats = *ann_query_params.stats;
return Status::OK();
Expand All @@ -188,4 +175,4 @@ std::string AnnTopNRuntime::debug_string() const {
_limit, _src_column_idx, _dest_column_idx, _asc, _user_params.to_string(),
segment_v2::metric_to_string(_metric_type), _order_by_expr_ctx->root()->debug_string());
}
} // namespace doris::segment_v2
} // namespace doris::segment_v2
2 changes: 1 addition & 1 deletion be/src/vec/exprs/vdirect_in_predicate.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,4 @@ class VDirectInPredicate final : public VExpr {
};

#include "common/compile_check_end.h"
} // namespace doris::vectorized
} // namespace doris::vectorized
69 changes: 33 additions & 36 deletions be/src/vec/exprs/vectorized_fn_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,17 +335,17 @@ bool VectorizedFnCall::equals(const VExpr& other) {
|----------------
| |
| |
VirtualSlotRef Float64Literal
CastToDouble Float64Literal
|
|
VirtualSlotRef
|
|
FuncationCall
|----------------
| |
| |
CastToArray ArrayLiteral
|
|
SlotRef
SlotRef ArrayLiteral
*/

void VectorizedFnCall::prepare_ann_range_search(
Expand All @@ -369,7 +369,7 @@ void VectorizedFnCall::prepare_ann_range_search(
auto left_child = get_child(0);
auto right_child = get_child(1);

// Return type of L2Distance is always double.
// right side
auto right_literal = std::dynamic_pointer_cast<VLiteral>(right_child);
if (right_literal == nullptr) {
suitable_for_ann_index = false;
Expand All @@ -388,14 +388,24 @@ void VectorizedFnCall::prepare_ann_range_search(
const ColumnFloat64* cf64_right = assert_cast<const ColumnFloat64*>(right_col.get());
range_search_runtime.radius = cf64_right->get_data()[0];

// left side
auto cast_to_double_expr = std::dynamic_pointer_cast<VCastExpr>(left_child);
if (cast_to_double_expr == nullptr) {
suitable_for_ann_index = false;
return;
}

std::shared_ptr<VectorizedFnCall> function_call;
auto vir_slot_ref = std::dynamic_pointer_cast<VirtualSlotRef>(left_child);
auto vir_slot_ref =
std::dynamic_pointer_cast<VirtualSlotRef>(cast_to_double_expr->children()[0]);
// Return type of L2Distance is always float.
if (vir_slot_ref != nullptr) {
DCHECK(vir_slot_ref->get_virtual_column_expr() != nullptr);
function_call = std::dynamic_pointer_cast<VectorizedFnCall>(
vir_slot_ref->get_virtual_column_expr());
} else {
function_call = std::dynamic_pointer_cast<VectorizedFnCall>(left_child);
function_call =
std::dynamic_pointer_cast<VectorizedFnCall>(cast_to_double_expr->children()[0]);
}

if (function_call == nullptr) {
Expand All @@ -415,34 +425,25 @@ void VectorizedFnCall::prepare_ann_range_search(
range_search_runtime.metric_type = segment_v2::string_to_metric(metric_name);
}

UInt16 idx_of_cast_to_array = 0;
UInt16 idx_of_slot_ref = 0;
UInt16 idx_of_array_literal = 0;
for (UInt16 i = 0; i < function_call->get_num_children(); ++i) {
auto child = function_call->get_child(i);
if (std::dynamic_pointer_cast<VCastExpr>(child) != nullptr) {
idx_of_cast_to_array = i;
if (std::dynamic_pointer_cast<VSlotRef>(child) != nullptr) {
idx_of_slot_ref = i;
} else if (std::dynamic_pointer_cast<VArrayLiteral>(child) != nullptr) {
idx_of_array_literal = i;
}
}

std::shared_ptr<VCastExpr> cast_to_array_expr =
std::dynamic_pointer_cast<VCastExpr>(function_call->get_child(idx_of_cast_to_array));
std::shared_ptr<VSlotRef> slot_ref =
std::dynamic_pointer_cast<VSlotRef>(function_call->get_child(idx_of_slot_ref));
std::shared_ptr<VArrayLiteral> array_literal = std::dynamic_pointer_cast<VArrayLiteral>(
function_call->get_child(idx_of_array_literal));

if (cast_to_array_expr == nullptr || array_literal == nullptr) {
suitable_for_ann_index = false;
// Cast to array expr or array literal is null.
return;
}

// One of the children is a slot ref, and the other is an array literal, now begin to create search params.
std::shared_ptr<VSlotRef> slot_ref =
std::dynamic_pointer_cast<VSlotRef>(cast_to_array_expr->get_child(0));
if (slot_ref == nullptr) {
if (slot_ref == nullptr || array_literal == nullptr) {
suitable_for_ann_index = false;
// Cast to array expr's child is not a slot ref.
// slot ref or array literal is null.
return;
}

Expand All @@ -457,10 +458,10 @@ void VectorizedFnCall::prepare_ann_range_search(
range_search_runtime.query_value = std::make_unique<float[]>(dim);

const ColumnNullable* cn = assert_cast<const ColumnNullable*>(array_col->get_data_ptr().get());
const ColumnFloat64* cf64 =
assert_cast<const ColumnFloat64*>(cn->get_nested_column_ptr().get());
const ColumnFloat32* cf32 =
assert_cast<const ColumnFloat32*>(cn->get_nested_column_ptr().get());
for (size_t i = 0; i < dim; ++i) {
range_search_runtime.query_value[i] = static_cast<Float32>(cf64->get_data()[i]);
range_search_runtime.query_value[i] = cf32->get_data()[i];
}
range_search_runtime.is_ann_range_search = true;
range_search_runtime.user_params = user_params;
Expand Down Expand Up @@ -551,17 +552,13 @@ Status VectorizedFnCall::evaluate_ann_range_search(
DCHECK(virtual_column_iterator != nullptr);
// Now convert distance to column
size_t size = result.roaring->cardinality();
auto distance_col = ColumnFloat64::create(size);
auto null_map = ColumnUInt8::create(size, 0);
// TODO: Return type of L2DistanceApproximate/InnerProductApproximate should be changed to float.
const float* src = reinterpret_cast<const float*>(result.distance.get());
double* dst = distance_col->get_data().data();
auto distance_col = ColumnFloat32::create(size);
const float* src = result.distance.get();
float* dst = distance_col->get_data().data();
for (size_t i = 0; i < size; ++i) {
dst[i] = static_cast<double>(src[i]);
dst[i] = src[i];
}
auto nullable_distance_col =
ColumnNullable::create(std::move(distance_col), std::move(null_map));
virtual_column_iterator->prepare_materialization(std::move(nullable_distance_col),
virtual_column_iterator->prepare_materialization(std::move(distance_col),
std::move(result.row_ids));
} else {
DCHECK(this->op() != TExprOpcode::LE && this->op() != TExprOpcode::LT)
Expand Down
1 change: 1 addition & 0 deletions be/src/vec/exprs/vtopn_pred.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class VTopNPred : public VExpr {
node.__set_is_nullable(target_ctx->root()->is_nullable());
expr = vectorized::VTopNPred::create_shared(node, source_node_id, target_ctx);

DCHECK(target_ctx->root() != nullptr);
expr->add_child(target_ctx->root());

return Status::OK();
Expand Down
12 changes: 6 additions & 6 deletions be/src/vec/functions/array/function_array_distance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
namespace doris::vectorized {

void register_function_array_distance(SimpleFunctionFactory& factory) {
factory.register_function<FunctionArrayDistance<L1Distance> >();
factory.register_function<FunctionArrayDistance<L2Distance> >();
factory.register_function<FunctionArrayDistance<CosineDistance> >();
factory.register_function<FunctionArrayDistance<InnerProduct> >();
factory.register_function<FunctionArrayDistance<L2DistanceApproximate> >();
factory.register_function<FunctionArrayDistance<InnerProductApproximate> >();
factory.register_function<FunctionArrayDistance<L1Distance>>();
factory.register_function<FunctionArrayDistance<L2Distance>>();
factory.register_function<FunctionArrayDistance<CosineDistance>>();
factory.register_function<FunctionArrayDistance<InnerProduct>>();
factory.register_function<FunctionArrayDistance<L2DistanceApproximate>>();
factory.register_function<FunctionArrayDistance<InnerProductApproximate>>();
}

} // namespace doris::vectorized
Loading
Loading