diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.cpp b/be/src/olap/rowset/segment_v2/segment_iterator.cpp index 09438e87c8184e..9b2d9b10b1a8c9 100644 --- a/be/src/olap/rowset/segment_v2/segment_iterator.cpp +++ b/be/src/olap/rowset/segment_v2/segment_iterator.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -701,6 +702,8 @@ Status SegmentIterator::_apply_ann_topn_predicate() { "Ann topn can not be evaluated by ann index, has_ann_index: {}, " "has_common_expr_push_down: {}, has_column_predicate: {}", has_ann_index, has_common_expr_push_down, has_column_predicate); + // Disable index-only scan on ann indexed column. + _need_read_data_indices[src_cid] = true; return Status::OK(); } @@ -712,11 +715,15 @@ Status SegmentIterator::_apply_ann_topn_predicate() { if (_ann_topn_runtime->is_asc()) { VLOG_DEBUG << fmt::format( "Asc topn for inner product can not be evaluated by ann index"); + // Disable index-only scan on ann indexed column. + _need_read_data_indices[src_cid] = true; return Status::OK(); } } else { if (!_ann_topn_runtime->is_asc()) { VLOG_DEBUG << fmt::format("Desc topn for l2/cosine can not be evaluated by ann index"); + // Disable index-only scan on ann indexed column. + _need_read_data_indices[src_cid] = true; return Status::OK(); } } @@ -727,6 +734,8 @@ Status SegmentIterator::_apply_ann_topn_predicate() { "ann index", metric_to_string(_ann_topn_runtime->get_metric_type()), metric_to_string(ann_index_reader->get_metric_type())); + // Disable index-only scan on ann indexed column. + _need_read_data_indices[src_cid] = true; return Status::OK(); } @@ -738,6 +747,8 @@ Status SegmentIterator::_apply_ann_topn_predicate() { "to " "filter", pre_size, rows_of_segment); + // Disable index-only scan on ann indexed column. + _need_read_data_indices[src_cid] = true; return Status::OK(); } vectorized::IColumn::MutablePtr result_column; @@ -772,6 +783,10 @@ Status SegmentIterator::_apply_ann_topn_predicate() { virtual_column_iter->prepare_materialization(std::move(result_column), std::move(result_row_ids)); + _need_read_data_indices[src_cid] = false; + VLOG_DEBUG << fmt::format( + "Enable ANN index-only scan for src column cid {} (skip reading data pages)", src_cid); + return Status::OK(); } @@ -1044,9 +1059,9 @@ Status SegmentIterator::_apply_index_expr() { segment_v2::AnnIndexStats ann_index_stats; for (const auto& expr_ctx : _common_expr_ctxs_push_down) { size_t origin_rows = _row_bitmap.cardinality(); - RETURN_IF_ERROR(expr_ctx->evaluate_ann_range_search(_index_iterators, _schema->column_ids(), - _column_iterators, _row_bitmap, - ann_index_stats)); + RETURN_IF_ERROR(expr_ctx->evaluate_ann_range_search( + _index_iterators, _schema->column_ids(), _column_iterators, + _common_expr_to_slotref_map, _row_bitmap, ann_index_stats)); _opts.stats->rows_ann_index_range_filtered += (origin_rows - _row_bitmap.cardinality()); _opts.stats->ann_index_load_ns += ann_index_stats.load_index_costs_ns.value(); _opts.stats->ann_index_range_search_ns += ann_index_stats.search_costs_ns.value(); @@ -1057,7 +1072,7 @@ Status SegmentIterator::_apply_index_expr() { } for (auto it = _common_expr_ctxs_push_down.begin(); it != _common_expr_ctxs_push_down.end();) { - if ((*it)->root()->has_been_executed()) { + if ((*it)->root()->ann_range_search_executedd()) { _opts.stats->ann_index_range_search_cnt++; it = _common_expr_ctxs_push_down.erase(it); } else { @@ -1808,14 +1823,6 @@ Status SegmentIterator::_vec_init_lazy_materialization() { if (pred_id_set.find(cid) != pred_id_set.end()) { _predicate_column_ids.push_back(cid); } - // In the past, if schema columns > pred columns, the _lazy_materialization_read maybe == false, but - // we make sure using _lazy_materialization_read= true now, so these logic may never happens. I comment - // these lines and we could delete them in the future to make the code more clear. - // else if (non_pred_set.find(cid) != non_pred_set.end()) { - // _predicate_column_ids.push_back(cid); - // // when _lazy_materialization_read = false, non-predicate column should also be filtered by sel idx, so we regard it as pred columns - // _is_pred_column[cid] = true; - // } } } else if (_is_need_expr_eval) { DCHECK(!_is_need_vec_eval && !_is_need_short_eval); @@ -2029,8 +2036,9 @@ Status SegmentIterator::_output_non_pred_columns(vectorized::Block* block) { if (column_in_block_is_nothing || column_is_normal) { block->replace_by_position(loc, std::move(_current_return_columns[cid])); VLOG_DEBUG << fmt::format( - "Output non-predicate column, cid: {}, loc: {}, col_name: {}", cid, loc, - _schema->column(cid)->name()); + "Output non-predicate column, cid: {}, loc: {}, col_name: {}, rows {}", cid, + loc, _schema->column(cid)->name(), + block->get_by_position(loc).column->size()); } // Means virtual column in block has been materialized(maybe by common expr). // so do nothing here. @@ -2073,6 +2081,8 @@ Status SegmentIterator::_read_columns_by_index(uint32_t nrows_read_limit, uint16 for (auto cid : _predicate_column_ids) { auto& column = _current_return_columns[cid]; + VLOG_DEBUG << fmt::format("Reading column {}, col_name {}", cid, + _schema->column(cid)->name()); if (!_virtual_column_exprs.contains(cid)) { if (_no_need_read_key_data(cid, column, nrows_read)) { VLOG_DEBUG << fmt::format("Column {} no need to read.", cid); @@ -2822,6 +2832,8 @@ void SegmentIterator::_calculate_expr_in_remaining_conjunct_root() { if (root_expr == nullptr) { continue; } + _common_expr_to_slotref_map[root_expr_ctx.get()] = + std::unordered_map(); std::stack stack; stack.emplace(root_expr); @@ -2831,10 +2843,53 @@ void SegmentIterator::_calculate_expr_in_remaining_conjunct_root() { stack.pop(); for (const auto& child : expr->children()) { + if (child->is_virtual_slot_ref()) { + // Expand virtual slot ref to its underlying expression tree and + // collect real slot refs used inside. We still associate those + // slot refs with the current parent expr node for inverted index + // tracking, just like normal slot refs. + auto* vir_slot_ref = assert_cast(child.get()); + auto vir_expr = vir_slot_ref->get_virtual_column_expr(); + if (vir_expr) { + std::stack vir_stack; + vir_stack.emplace(vir_expr); + + while (!vir_stack.empty()) { + const auto& vir_node = vir_stack.top(); + vir_stack.pop(); + + for (const auto& vir_child : vir_node->children()) { + if (vir_child->is_slot_ref()) { + auto* inner_slot_ref = + assert_cast(vir_child.get()); + _common_expr_inverted_index_status[_schema->column_id( + inner_slot_ref->column_id())][expr.get()] = false; + _common_expr_to_slotref_map[root_expr_ctx.get()] + [inner_slot_ref->column_id()] = + expr.get(); + // Print debug info for virtual slot expansion + LOG(INFO) << fmt::format( + "common_expr_ctx_ptr: {}, expr_ptr: {}, " + "virtual_slotref_ptr: {}, inner_slotref_ptr: {}, " + "column_id: {}", + fmt::ptr(root_expr_ctx.get()), fmt::ptr(expr.get()), + fmt::ptr(child.get()), fmt::ptr(vir_child.get()), + inner_slot_ref->column_id()); + } + + if (!vir_child->children().empty()) { + vir_stack.emplace(vir_child); + } + } + } + } + } if (child->is_slot_ref()) { auto* column_slot_ref = assert_cast(child.get()); _common_expr_inverted_index_status[_schema->column_id( column_slot_ref->column_id())][expr.get()] = false; + _common_expr_to_slotref_map[root_expr_ctx.get()][column_slot_ref->column_id()] = + expr.get(); } } diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.h b/be/src/olap/rowset/segment_v2/segment_iterator.h index 95b85649efa3df..b6d49c053a140f 100644 --- a/be/src/olap/rowset/segment_v2/segment_iterator.h +++ b/be/src/olap/rowset/segment_v2/segment_iterator.h @@ -506,6 +506,9 @@ class SegmentIterator : public RowwiseIterator { // key is column uid, value is the sparse column cache std::unordered_map _variant_sparse_column_cache; + + std::unordered_map> + _common_expr_to_slotref_map; }; } // namespace segment_v2 diff --git a/be/src/olap/rowset/segment_v2/virtual_column_iterator.cpp b/be/src/olap/rowset/segment_v2/virtual_column_iterator.cpp index 9a316be9e7853c..5bcb4b0185e79a 100644 --- a/be/src/olap/rowset/segment_v2/virtual_column_iterator.cpp +++ b/be/src/olap/rowset/segment_v2/virtual_column_iterator.cpp @@ -21,6 +21,7 @@ #include #include +#include "common/logging.h" #include "vec/columns/column.h" #include "vec/columns/column_nothing.h" @@ -79,12 +80,15 @@ void VirtualColumnIterator::prepare_materialization(vectorized::IColumn::Ptr col _size = n; - std::string msg; - for (const auto& pair : _row_id_to_idx) { - msg += fmt::format("{}: {}, ", pair.first, pair.second); + if (VLOG_DEBUG_IS_ON) { + std::string msg; + for (const auto& pair : _row_id_to_idx) { + msg += fmt::format("{}: {}, ", pair.first, pair.second); + } + + VLOG_DEBUG << fmt::format("virtual column iterator, row_idx_to_idx:\n{}", msg); } - VLOG_DEBUG << fmt::format("virtual column iterator, row_idx_to_idx:\n{}", msg); _filter = doris::vectorized::IColumn::Filter(_size, 0); } diff --git a/be/src/pipeline/exec/olap_scan_operator.cpp b/be/src/pipeline/exec/olap_scan_operator.cpp index b8805897f54af4..4041db0f0221ea 100644 --- a/be/src/pipeline/exec/olap_scan_operator.cpp +++ b/be/src/pipeline/exec/olap_scan_operator.cpp @@ -444,7 +444,7 @@ Status OlapScanLocalState::_init_scanners(std::list* sc auto& p = _parent->cast(); for (auto uid : p._olap_scan_node.output_column_unique_ids) { - _maybe_read_column_ids.emplace(uid); + _output_column_ids.emplace(uid); } // ranges constructed from scan keys diff --git a/be/src/pipeline/exec/olap_scan_operator.h b/be/src/pipeline/exec/olap_scan_operator.h index 2db857b6e15d40..84953b4f22052e 100644 --- a/be/src/pipeline/exec/olap_scan_operator.h +++ b/be/src/pipeline/exec/olap_scan_operator.h @@ -111,7 +111,7 @@ class OlapScanLocalState final : public ScanLocalState { OlapScanKeys _scan_keys; std::vector> _olap_filters; // If column id in this set, indicate that we need to read data after index filtering - std::set _maybe_read_column_ids; + std::set _output_column_ids; std::unique_ptr _segment_profile; std::unique_ptr _index_filter_profile; diff --git a/be/src/vec/exec/scan/olap_scanner.cpp b/be/src/vec/exec/scan/olap_scanner.cpp index 61d099c971c945..3f56d9e5be8c38 100644 --- a/be/src/vec/exec/scan/olap_scanner.cpp +++ b/be/src/vec/exec/scan/olap_scanner.cpp @@ -339,7 +339,7 @@ Status OlapScanner::_init_tablet_reader_params( _tablet_reader_params.vir_col_idx_to_type = _vir_col_idx_to_type; _tablet_reader_params.score_runtime = _score_runtime; _tablet_reader_params.output_columns = - ((pipeline::OlapScanLocalState*)_local_state)->_maybe_read_column_ids; + ((pipeline::OlapScanLocalState*)_local_state)->_output_column_ids; _tablet_reader_params.ann_topn_runtime = _ann_topn_runtime; for (const auto& ele : ((pipeline::OlapScanLocalState*)_local_state)->_cast_types_for_variants) { diff --git a/be/src/vec/exprs/vectorized_fn_call.cpp b/be/src/vec/exprs/vectorized_fn_call.cpp index cc61a6976ed8c1..dfaa1e8a212094 100644 --- a/be/src/vec/exprs/vectorized_fn_call.cpp +++ b/be/src/vec/exprs/vectorized_fn_call.cpp @@ -605,15 +605,31 @@ Status VectorizedFnCall::evaluate_ann_range_search( } virtual_column_iterator->prepare_materialization(std::move(distance_col), std::move(result.row_ids)); + _virtual_column_is_fulfilled = true; } else { - DCHECK(this->op() != TExprOpcode::LE && this->op() != TExprOpcode::LT) - << "Should not have distance"; + // Whether the ANN index should have produced distance depends on metric and operator: + // - L2: distance is produced for LE/LT; not produced for GE/GT + // - IP: distance is produced for GE/GT; not produced for LE/LT +#ifndef NDEBUG + const bool should_have_distance = + (range_search_runtime.is_le_or_lt && + range_search_runtime.metric_type == AnnIndexMetric::L2) || + (!range_search_runtime.is_le_or_lt && + range_search_runtime.metric_type == AnnIndexMetric::IP); + // If we expected distance but didn't get it, assert in debug to catch logic errors. + DCHECK(!should_have_distance) << "Expected distance from ANN index but got none"; +#endif + _virtual_column_is_fulfilled = false; } + } else { + // Dest is not virtual column. + _virtual_column_is_fulfilled = true; } _has_been_executed = true; - VLOG_DEBUG << fmt::format("Ann range search filtered {} rows, origin {} rows", - origin_num - row_bitmap.cardinality(), origin_num); + VLOG_DEBUG << fmt::format( + "Ann range search filtered {} rows, origin {} rows, virtual column is full-filled: {}", + origin_num - row_bitmap.cardinality(), origin_num, _virtual_column_is_fulfilled); ann_index_stats = *stats; return Status::OK(); diff --git a/be/src/vec/exprs/vexpr.cpp b/be/src/vec/exprs/vexpr.cpp index 798bacfdf9df4b..bb609134d7809a 100644 --- a/be/src/vec/exprs/vexpr.cpp +++ b/be/src/vec/exprs/vexpr.cpp @@ -1002,9 +1002,13 @@ void VExpr::prepare_ann_range_search(const doris::VectorSearchUserParams& params } } -bool VExpr::has_been_executed() { +bool VExpr::ann_range_search_executedd() { return _has_been_executed; } +bool VExpr::ann_dist_is_fulfilled() const { + return _virtual_column_is_fulfilled; +} + #include "common/compile_check_end.h" } // namespace doris::vectorized diff --git a/be/src/vec/exprs/vexpr.h b/be/src/vec/exprs/vexpr.h index 4ef6a4661f3d84..4e355da25fa3f9 100644 --- a/be/src/vec/exprs/vexpr.h +++ b/be/src/vec/exprs/vexpr.h @@ -167,6 +167,8 @@ class VExpr { bool is_slot_ref() const { return _node_type == TExprNodeType::SLOT_REF; } + bool is_virtual_slot_ref() const { return _node_type == TExprNodeType::VIRTUAL_SLOT_REF; } + bool is_column_ref() const { return _node_type == TExprNodeType::COLUMN_REF; } virtual bool is_literal() const { return false; } @@ -308,7 +310,9 @@ class VExpr { segment_v2::AnnRangeSearchRuntime& range_search_runtime, bool& suitable_for_ann_index); - bool has_been_executed(); + bool ann_range_search_executedd(); + + bool ann_dist_is_fulfilled() const; protected: /// Simple debug string that provides no expr subclass-specific information @@ -392,7 +396,12 @@ class VExpr { uint32_t _index_unique_id = 0; bool _enable_inverted_index_query = true; + // Indicates whether the expr row_bitmap has been updated. bool _has_been_executed = false; + // Indicates whether the virtual column is fulfilled. + // NOTE, if there is no virtual column in the expr tree, and expr + // is evaluated by ann index, this flag is still true. + bool _virtual_column_is_fulfilled = false; }; } // namespace vectorized diff --git a/be/src/vec/exprs/vexpr_context.cpp b/be/src/vec/exprs/vexpr_context.cpp index 36fd30121336d4..3d3afa9feff7c7 100644 --- a/be/src/vec/exprs/vexpr_context.cpp +++ b/be/src/vec/exprs/vexpr_context.cpp @@ -25,6 +25,8 @@ #include "common/cast_set.h" #include "common/compiler_util.h" // IWYU pragma: keep #include "common/exception.h" +#include "common/status.h" +#include "olap/olap_common.h" #include "runtime/runtime_state.h" #include "runtime/thread_context.h" #include "udf/udf.h" @@ -463,12 +465,45 @@ Status VExprContext::evaluate_ann_range_search( const std::vector>& cid_to_index_iterators, const std::vector& idx_to_cid, const std::vector>& column_iterators, + const std::unordered_map>& + common_expr_to_slotref_map, roaring::Roaring& row_bitmap, segment_v2::AnnIndexStats& ann_index_stats) { - if (_root != nullptr) { - return _root->evaluate_ann_range_search(_ann_range_search_runtime, cid_to_index_iterators, - idx_to_cid, column_iterators, row_bitmap, - ann_index_stats); + if (_root == nullptr) { + return Status::OK(); + } + + RETURN_IF_ERROR(_root->evaluate_ann_range_search( + _ann_range_search_runtime, cid_to_index_iterators, idx_to_cid, column_iterators, + row_bitmap, ann_index_stats)); + + if (!_root->ann_range_search_executedd()) { + return Status::OK(); + } + + if (!_root->ann_dist_is_fulfilled()) { + // Do not perform index scan in this case. + return Status::OK(); + } + + auto src_col_idx = _ann_range_search_runtime.src_col_idx; + auto slot_ref_map_it = common_expr_to_slotref_map.find(this); + if (slot_ref_map_it == common_expr_to_slotref_map.end()) { + return Status::OK(); + } + auto& slot_ref_map = slot_ref_map_it->second; + ColumnId cid = idx_to_cid[src_col_idx]; + if (slot_ref_map.find(cid) == slot_ref_map.end()) { + return Status::OK(); } + const VExpr* slot_ref_expr_addr = slot_ref_map.find(cid)->second; + _inverted_index_context->set_true_for_inverted_index_status(slot_ref_expr_addr, + idx_to_cid[cid]); + + VLOG_DEBUG << fmt::format( + "Evaluate ann range search for expr {}, src_col_idx {}, cid {}, row_bitmap " + "cardinality {}", + _root->debug_string(), src_col_idx, cid, row_bitmap.cardinality()); return Status::OK(); } diff --git a/be/src/vec/exprs/vexpr_context.h b/be/src/vec/exprs/vexpr_context.h index 144206fc430741..1331e0526f6883 100644 --- a/be/src/vec/exprs/vexpr_context.h +++ b/be/src/vec/exprs/vexpr_context.h @@ -302,6 +302,9 @@ class VExprContext { const std::vector>& cid_to_index_iterators, const std::vector& idx_to_cid, const std::vector>& column_iterators, + const std::unordered_map>& + common_expr_to_slotref_map, roaring::Roaring& row_bitmap, segment_v2::AnnIndexStats& ann_index_stats); private: diff --git a/be/src/vec/olap/vcollect_iterator.cpp b/be/src/vec/olap/vcollect_iterator.cpp index 6ad8202ad4ae21..0158cc2f9d86ce 100644 --- a/be/src/vec/olap/vcollect_iterator.cpp +++ b/be/src/vec/olap/vcollect_iterator.cpp @@ -262,6 +262,26 @@ Status VCollectIterator::_topn_next(Block* block) { // clear TEMP columns to avoid column align problem block->erase_tmp_columns(); auto clone_block = block->clone_empty(); + /* + select id, "${tR2}", + l2_distance_approximate + from ann_index_only_scan + where l2_distance_approximate < 10 + order by id + limit 20; + where id is the orderby key column. + */ + // Initialize virtual slot columns by schema (avoid runtime type checks): + // use _reader_context.vir_col_idx_to_type to construct real columns for those positions. + if (!_reader->_reader_context.vir_col_idx_to_type.empty()) { + const auto& idx_to_type = _reader->_reader_context.vir_col_idx_to_type; + for (const auto& kv : idx_to_type) { + size_t idx = kv.first; + if (idx < clone_block.columns()) { + clone_block.get_by_position(idx).column = kv.second->create_column(); + } + } + } MutableBlock mutable_block = vectorized::MutableBlock::build_mutable_block(&clone_block); if (!_reader->_reader_context.read_orderby_key_columns) { diff --git a/be/test/olap/vector_search/ann_range_search_test.cpp b/be/test/olap/vector_search/ann_range_search_test.cpp index d4839db58c0a30..715f368e3f1bb7 100644 --- a/be/test/olap/vector_search/ann_range_search_test.cpp +++ b/be/test/olap/vector_search/ann_range_search_test.cpp @@ -164,9 +164,12 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch) { })); segment_v2::AnnIndexStats stats; + std::unordered_map> + common_expr_to_slotref_map; ASSERT_TRUE(range_search_ctx ->evaluate_ann_range_search(cid_to_index_iterators, idx_to_cid, - column_iterators, row_bitmap, stats) + column_iterators, common_expr_to_slotref_map, + row_bitmap, stats) .ok()); doris::segment_v2::VirtualColumnIterator* virtual_column_iter = @@ -260,9 +263,12 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch2) { })); segment_v2::AnnIndexStats stats; + std::unordered_map> + common_expr_to_slotref_map; ASSERT_TRUE(range_search_ctx ->evaluate_ann_range_search(cid_to_index_iterators, idx_to_cid, - column_iterators, row_bitmap, stats) + column_iterators, common_expr_to_slotref_map, + row_bitmap, stats) .ok()); doris::segment_v2::VirtualColumnIterator* virtual_column_iter = @@ -738,9 +744,12 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch_DimensionMismatch) { roaring::Roaring row_bitmap; segment_v2::AnnIndexStats stats; + std::unordered_map> + common_expr_to_slotref_map; - auto st = range_search_ctx->evaluate_ann_range_search(cid_to_index_iterators, idx_to_cid, - column_iterators, row_bitmap, stats); + auto st = range_search_ctx->evaluate_ann_range_search( + cid_to_index_iterators, idx_to_cid, column_iterators, common_expr_to_slotref_map, + row_bitmap, stats); EXPECT_FALSE(st.ok()); EXPECT_TRUE(st.is()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java index 670af2e28a3ef3..b7b7bb4d6c8751 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java @@ -1403,24 +1403,11 @@ public void updateRequiredSlots(PlanTranslatorContext context, outputColumnUniqueIds.add(slot.getColumn().getUniqueId()); } } - for (SlotDescriptor virtualSlot : context.getTupleDesc(this.getTupleId()).getSlots()) { - Expr virtualColumn = virtualSlot.getVirtualColumn(); - if (virtualColumn == null) { - continue; - } - Set slotRefs = Sets.newHashSet(); - virtualColumn.collect(e -> e instanceof SlotRef, slotRefs); - Set virtualColumnInputSlotIds = slotRefs.stream() - .filter(s -> s instanceof SlotRef) - .map(s -> (SlotRef) s) - .map(SlotRef::getSlotId) - .collect(Collectors.toSet()); - for (SlotDescriptor slot : context.getTupleDesc(this.getTupleId()).getSlots()) { - if (virtualColumnInputSlotIds.contains(slot.getId()) && slot.getColumn() != null) { - outputColumnUniqueIds.add(slot.getColumn().getUniqueId()); - } - } - } + // Do not add input slots of virtual columns into outputColumnUniqueIds. + // Backend can decide whether the underlying source columns are truly needed + // (e.g., ANN distance index-only scan can produce the virtual distance without + // reading the source vector column). Keeping only the real projected slots here + // avoids forcing unnecessary reads in BE. } @Override diff --git a/regression-test/data/ann_index_p0/ann_index_only_scan.out b/regression-test/data/ann_index_p0/ann_index_only_scan.out new file mode 100644 index 00000000000000..cb1c0d98a599fc --- /dev/null +++ b/regression-test/data/ann_index_p0/ann_index_only_scan.out @@ -0,0 +1,19 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !q1 -- +0 ann_index_only_scan_q1 +5 ann_index_only_scan_q1 +6 ann_index_only_scan_q1 +2 ann_index_only_scan_q1 +9 ann_index_only_scan_q1 +8 ann_index_only_scan_q1 +4 ann_index_only_scan_q1 + +-- !q2 -- +0 ann_index_only_scan_q2 81.69191 +5 ann_index_only_scan_q2 90.8576 +6 ann_index_only_scan_q2 111.234 +2 ann_index_only_scan_q2 116.7573 +9 ann_index_only_scan_q2 122.1707 +8 ann_index_only_scan_q2 130.5337 +4 ann_index_only_scan_q2 136.0021 + diff --git a/regression-test/suites/ann_index_p0/ann_index_only_scan.groovy b/regression-test/suites/ann_index_p0/ann_index_only_scan.groovy new file mode 100644 index 00000000000000..756070f9f5fc93 --- /dev/null +++ b/regression-test/suites/ann_index_p0/ann_index_only_scan.groovy @@ -0,0 +1,450 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import groovy.json.JsonSlurper + + +def getProfileList = { + def dst = 'http://' + context.config.feHttpAddress + def conn = new URL(dst + "/rest/v1/query_profile").openConnection() + conn.setRequestMethod("GET") + def encoding = Base64.getEncoder().encodeToString((context.config.feHttpUser + ":" + + (context.config.feHttpPassword == null ? "" : context.config.feHttpPassword)).getBytes("UTF-8")) + conn.setRequestProperty("Authorization", "Basic ${encoding}") + return conn.getInputStream().getText() +} + +def getProfile = { id -> + def dst = 'http://' + context.config.feHttpAddress + def conn = new URL(dst + "/api/profile/text/?query_id=$id").openConnection() + conn.setRequestMethod("GET") + def encoding = Base64.getEncoder().encodeToString((context.config.feHttpUser + ":" + + (context.config.feHttpPassword == null ? "" : context.config.feHttpPassword)).getBytes("UTF-8")) + conn.setRequestProperty("Authorization", "Basic ${encoding}") + return conn.getInputStream().getText() +} + +suite("ann_index_only_scan") { + sql "drop table if exists ann_index_only_scan" + sql "unset variable all;" + sql "set profile_level=2;" + sql "set enable_profile=true;" + sql "set experimental_enable_virtual_slot_for_cse=true;" + // disable lazy materialization since it will break index-only scan. + sql "set experimental_topn_lazy_materialization_threshold=0;" + sql "set parallel_pipeline_task_num=1;" + sql "set enable_sql_cache=false;" + + sql """ + create table ann_index_only_scan ( + id int not null, + embedding array not null, + comment String not null, + value int null, + INDEX idx_comment(`comment`) USING INVERTED PROPERTIES("parser" = "english") COMMENT 'inverted index for comment', + INDEX ann_embedding(`embedding`) USING ANN PROPERTIES("index_type"="hnsw","metric_type"="l2_distance","dim"="8") + ) duplicate key (`id`) + distributed by hash(`id`) buckets 1 + properties("replication_num"="1"); + """ + + sql """ + INSERT INTO ann_index_only_scan (id, embedding, comment, value) VALUES + (0, [39.906116, 10.495334, 54.08394, 88.67262, 55.243687, 10.162686, 36.335983, 38.684258], "This example illustrates how subtle differences can influence perception. It's more about interpretation than right or wrong.", 100), + (1, [62.759315, 97.15586, 25.832521, 39.604908, 88.76715, 72.64085, 9.688437, 17.721428], "Thanks for all the comments, good and bad. They help us refine our test. Keep in mind that we're attempting to figure you out in 40 pairs of pictures. We did this so that lots of people could take it, just to introduce the idea.

A real test would have more like 200 pairs, which is what the YC founders took when we assessed their attributes in the first place.", 101), + (2, [15.447449, 59.7771, 65.54516, 12.973712, 99.685135, 72.080734, 85.71118, 99.35976], "At a glance, these might seem obvious, but there’s nuance in every choice. Don’t rush.", 102), + (3, [72.26747, 46.42257, 32.368374, 80.50209, 5.777631, 98.803314, 7.0915947, 68.62693], "We're testing how consistent your judgments are over a range of visual impressions. There's no single 'correct' answer.", 103), + (4, [22.098177, 74.10027, 63.634556, 4.710955, 12.405106, 79.39356, 63.014366, 68.67834], "Some pairs are meant to be tricky. Your intuition is part of what we're analyzing.", 104), + (5, [27.53003, 72.1106, 50.891026, 38.459953, 68.30715, 20.610682, 94.806274, 45.181377], "This data will help us identify patterns in how people perceive attributes such as trustworthiness or confidence.", 105), + (6, [77.73215, 64.42907, 71.50025, 43.85641, 94.42648, 50.04773, 65.12575, 68.58207], "Sometimes people see entirely different things in the same image. That's part of the exploration.", 106), + (7, [2.1537063, 82.667885, 16.171143, 71.126656, 5.335274, 40.286068, 11.943586, 3.69409], "Don't worry if you’re unsure. The ambiguity is intentional — that’s what makes this interesting.", 107), + (8, [54.435013, 56.800594, 59.335514, 55.829235, 85.46627, 33.388138, 11.076194, 20.480877], "Your reactions help us understand which features people subconsciously favor or avoid.", 108), + (9, [76.197945, 60.623528, 84.229805, 31.652937, 71.82595, 48.04684, 71.29212, 30.282396], "This task isn’t about right answers, but about consistency in your judgments over time.", 109); + """ + + // Fetch profile text by token with small retries for robustness + def getProfileWithToken = { token -> + String profileId = "" + int attempts = 0 + while (attempts < 10 && (profileId == null || profileId == "")) { + List profileData = new JsonSlurper().parseText(getProfileList()).data.rows + for (def profileItem in profileData) { + if (profileItem["Sql Statement"].toString().contains(token)) { + profileId = profileItem["Profile ID"].toString() + break + } + } + if (profileId == null || profileId == "") { + Thread.sleep(300) + } + attempts++ + } + assertTrue(profileId != null && profileId != "") + // ensure profile text is fully ready + Thread.sleep(800) + return getProfile(profileId).toString() + } + + def extractScanBytesValue = { String profileText -> + // Example line: "- ScanBytes: 80.00 B" + def lines = profileText.split("\n") + for (def line : lines) { + if (line.contains("ScanBytes:")) { + def m = (line =~ /ScanBytes:\s*([0-9]+(?:\.[0-9]+)?)\s*[A-Za-z]+/) + if (m.find()) { + return m.group(1) + } + } + } + return null + } + + // Helper to execute two query shapes (one plain, one with embedding) and return their ScanBytes values + def runAndGetScanBytesPair = { + def t1 = UUID.randomUUID().toString() + sql """ + select id, "${t1}", + l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) as dist + from ann_index_only_scan + order by dist + limit 7; + """ + def t2 = UUID.randomUUID().toString() + sql """ + select id, "${t2}", embedding, + l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) as dist + from ann_index_only_scan + order by dist + limit 7; + """ + def pA = getProfileWithToken(t1) + def pB = getProfileWithToken(t2) + def sA = extractScanBytesValue(pA) + def sB = extractScanBytesValue(pB) + assertTrue(sA != null && sB != null) + return [sA, sB] + } + + // enable index-only read path + sql "set enable_no_need_read_data_opt=true;" + def pair1 = runAndGetScanBytesPair() + logger.info("ScanBytes enabled: q1=${pair1[0]}, q2=${pair1[1]}") + // ScanBytes of q1 and q2 should not be same. since q2 reads embedding column, q1 will not read embedding column in t + assertTrue(pair1[0] != pair1[1]) + + // disable index-only read path, expect different ScanBytes + sql "set enable_no_need_read_data_opt=false;" + def pair2 = runAndGetScanBytesPair() + logger.info("ScanBytes disabled: q1=${pair2[0]}, q2=${pair2[1]}") + assertTrue(pair2[0] == pair2[1]) + + // 1) ANN range search: compare with/without selecting distance + sql "set enable_no_need_read_data_opt=true;" + sql "set experimental_enable_virtual_slot_for_cse=true;" + def tR1 = UUID.randomUUID().toString() + sql """ + select id, "${tR1}" from ann_index_only_scan + where l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) < 105.66439056396484 + order by id + limit 20; + """ + def tR2 = UUID.randomUUID().toString() + sql """ + select id, "${tR2}", + l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) as dist + from ann_index_only_scan + where l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) < 105.66439056396484 + order by id + limit 20; + """ + def pR1 = getProfileWithToken(tR1) + def pR2 = getProfileWithToken(tR2) + def sR1 = extractScanBytesValue(pR1) + def sR2 = extractScanBytesValue(pR2) + logger.info("ScanBytes range enabled: q1=${sR1}, q2=${sR2}") + assertTrue(sR1 == sR2) + + tR1 = UUID.randomUUID().toString() + tR2 = UUID.randomUUID().toString() + // No virtual slot. So result distance is not needed by any one, even if it is calculated by index. + // So we do not need to read embedding column. + sql """ + select id, "${tR1}" from ann_index_only_scan + where l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) > 105.66439056396484 + order by id + limit 20; + """ + // if condition is not lt_or_le, index will only return rowid without distance value + // so we still need to read embedding column. + sql """ + select id, "${tR2}", + l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) as dist + from ann_index_only_scan + where l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) > 105.66439056396484 + order by id + limit 20; + """ + pR1 = getProfileWithToken(tR1) + pR2 = getProfileWithToken(tR2) + sR1 = extractScanBytesValue(pR1) + sR2 = extractScanBytesValue(pR2) + logger.info("ScanBytes range enabled (neg): q1=${sR1}, q2=${sR2}") + assertTrue(sR1 != sR2) + + // 2) ANN with inverted index together: add comment MATCH_ANY filter + def tRI1 = UUID.randomUUID().toString() + sql """ + select id, "${tRI1}" from ann_index_only_scan + where l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) < 150.0 + and comment match_any 'people' + order by id + limit 20; + """ + def tRI2 = UUID.randomUUID().toString() + sql """ + select id, "${tRI2}", + l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) as dist + from ann_index_only_scan + where l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) < 150.0 + and comment match_any 'people' + order by id + limit 20; + """ + def pRI1 = getProfileWithToken(tRI1) + def pRI2 = getProfileWithToken(tRI2) + def sRI1 = extractScanBytesValue(pRI1) + def sRI2 = extractScanBytesValue(pRI2) + logger.info("ScanBytes range+inverted enabled: q1=${sRI1}, q2=${sRI2}") + assertTrue(sRI1 == sRI2) + // Negative: project non-index column to force base read + def tRIN1 = UUID.randomUUID().toString() + sql """ + select id, "${tRIN1}" from ann_index_only_scan + where l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) < 150.0 + and comment match_any 'people' + order by id + limit 20; + """ + def tRIN2 = UUID.randomUUID().toString() + sql """ + select id, "${tRIN2}", comment + from ann_index_only_scan + where l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) < 150.0 + and comment match_any 'people' + order by id + limit 20; + """ + def pRIN1 = getProfileWithToken(tRIN1) + def pRIN2 = getProfileWithToken(tRIN2) + def sRIN1 = extractScanBytesValue(pRIN1) + def sRIN2 = extractScanBytesValue(pRIN2) + logger.info("ScanBytes range+inverted neg: q1=${sRIN1}, q2=${sRIN2}") + assertTrue(sRIN1 != sRIN2) + + // 3) Range + TopN simultaneously + def tRT1 = UUID.randomUUID().toString() + sql """ + select id, "${tRT1}" from ann_index_only_scan + where l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) < 200.0 + order by l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) + limit 5; + """ + def tRT2 = UUID.randomUUID().toString() + sql """ + select id, "${tRT2}", + l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) as dist + from ann_index_only_scan + where l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) < 200.0 + order by dist + limit 5; + """ + def pRT1 = getProfileWithToken(tRT1) + def pRT2 = getProfileWithToken(tRT2) + def sRT1 = extractScanBytesValue(pRT1) + def sRT2 = extractScanBytesValue(pRT2) + logger.info("ScanBytes range+topn enabled: q1=${sRT1}, q2=${sRT2}") + assertTrue(sRT1 == sRT2) + + // 4) Ensure no index: same queries should not error on a table without ANN index + sql """ + DROP TABLE IF EXISTS ann_index_only_scan_no_ann; + """ + sql """ + CREATE TABLE ann_index_only_scan_no_ann ( + id int not null, + embedding array not null, + comment String not null, + value int null + ) duplicate key (`id`) + distributed by hash(`id`) buckets 1 + properties("replication_num"="1"); + """ + sql """ + INSERT INTO ann_index_only_scan_no_ann (id, embedding, comment, value) VALUES + (0, [39.906116, 10.495334, 54.08394, 88.67262, 55.243687, 10.162686, 36.335983, 38.684258], "A", 100), + (1, [62.759315, 97.15586, 25.832521, 39.604908, 88.76715, 72.64085, 9.688437, 17.721428], "B", 101), + (2, [15.447449, 59.7771, 65.54516, 12.973712, 99.685135, 72.080734, 85.71118, 99.35976], "C", 102); + """ + // Just execute; if there is no error, it's fine + sql """ + select id from ann_index_only_scan_no_ann + order by l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) limit 2; + """ + sql """ + select id from ann_index_only_scan_no_ann + where l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) < 999.0 order by id; + """ + + // 5) TopN + IndexFilter (inverted): index-only should apply when not projecting non-index columns + sql "set enable_no_need_read_data_opt=true;" + def tTI1 = UUID.randomUUID().toString() + sql """ + select id, "${tTI1}" + from ann_index_only_scan + where comment match_any 'people' + order by l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) + limit 5; + """ + def tTI2 = UUID.randomUUID().toString() + sql """ + select id, "${tTI2}", comment + from ann_index_only_scan + where comment match_any 'people' + order by l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) + limit 5; + """ + def pTI1 = getProfileWithToken(tTI1) + def pTI2 = getProfileWithToken(tTI2) + def sTI1 = extractScanBytesValue(pTI1) + def sTI2 = extractScanBytesValue(pTI2) + logger.info("ScanBytes topn+inverted: q1=${sTI1}, q2=${sTI2}") + assertTrue(sTI1 != sTI2) + + // 6) TopN + Range + IndexFilter + def tTRI1 = UUID.randomUUID().toString() + sql """ + select id, "${tTRI1}" + from ann_index_only_scan + where comment match_any 'people' + and l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) < 200.0 + order by l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) + limit 5; + """ + def tTRI2 = UUID.randomUUID().toString() + sql """ + select id, "${tTRI2}", comment + from ann_index_only_scan + where comment match_any 'people' + and l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) < 200.0 + order by l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) + limit 5; + """ + def pTRI1 = getProfileWithToken(tTRI1) + def pTRI2 = getProfileWithToken(tTRI2) + def sTRI1 = extractScanBytesValue(pTRI1) + def sTRI2 = extractScanBytesValue(pTRI2) + logger.info("ScanBytes topn+range+inverted: q1=${sTRI1}, q2=${sTRI2}") + assertTrue(sTRI1 != sTRI2) + + // 7) Range + proj + no-dist-from-index (gt/ ge): toggling the opt should have no effect + sql "set enable_no_need_read_data_opt=true;" + def tRN1 = UUID.randomUUID().toString() + sql """ + select id, "${tRN1}", + l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) as dist + from ann_index_only_scan + where l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) > 100.0 + order by id + limit 20; + """ + def pRN1 = getProfileWithToken(tRN1) + def sRN1 = extractScanBytesValue(pRN1) + sql "set enable_no_need_read_data_opt=false;" + def tRN2 = UUID.randomUUID().toString() + sql """ + select id, "${tRN2}", + l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) as dist + from ann_index_only_scan + where l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) > 100.0 + order by id + limit 20; + """ + def pRN2 = getProfileWithToken(tRN2) + def sRN2 = extractScanBytesValue(pRN2) + logger.info("ScanBytes range(gt)+proj opt-toggle: on=${sRN1}, off=${sRN2}") + assertTrue(sRN1 == sRN2) + + // 8) TopN + Range + CommonFilter (array_size on embedding): opt toggle should have no effect + sql "set enable_no_need_read_data_opt=true;" + def tCF1 = UUID.randomUUID().toString() + sql """ + select id, "${tCF1}" + from ann_index_only_scan + where array_size(embedding) > 5 + and l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) > 100.0 + order by l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) + limit 5; + """ + def pCF1 = getProfileWithToken(tCF1) + def sCF1 = extractScanBytesValue(pCF1) + sql "set enable_no_need_read_data_opt=false;" + def tCF2 = UUID.randomUUID().toString() + sql """ + select id, "${tCF2}" + from ann_index_only_scan + where array_size(embedding) > 5 + and l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) > 100.0 + order by l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) + limit 5; + """ + def pCF2 = getProfileWithToken(tCF2) + def sCF2 = extractScanBytesValue(pCF2) + logger.info("ScanBytes topn+range+common-filter opt-toggle: on=${sCF1}, off=${sCF2}") + assertTrue(sCF1 == sCF2) + + // 9) CSE: multiple uses of distance in predicates should still allow index-only + sql "set enable_no_need_read_data_opt=true;" + sql "set experimental_enable_virtual_slot_for_cse=true;" + def tCSE1 = UUID.randomUUID().toString() + sql """ + select id, "${tCSE1}" + from ann_index_only_scan + where abs(l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) + 10) > 10 + and l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) <= 150 + order by l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) + limit 5; + """ + def tCSE2 = UUID.randomUUID().toString() + sql """ + select id, "${tCSE2}", embedding + from ann_index_only_scan + where abs(l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) + 10) > 10 + and l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) <= 150 + order by l2_distance_approximate(embedding, [26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]) + limit 5; + """ + def pCSE1 = getProfileWithToken(tCSE1) + def pCSE2 = getProfileWithToken(tCSE2) + def sCSE1 = extractScanBytesValue(pCSE1) + def sCSE2 = extractScanBytesValue(pCSE2) + logger.info("ScanBytes CSE (no-embed vs embed): q1=${sCSE1}, q2=${sCSE2}") + // NOTE: currently, CSE with virtual slot still needs to read embedding column when not projecting it. + // Since we do not check if src column of some expr has been materializated. + // For example, althrough dist column has been calculated by l2_distance_approximate < 150 in the form of virtual slot, + // but when evaluating abs(dist + 10) > 10, we still need to read embedding column, eventhough dist will not be calculated again. + assertTrue(sCSE1 == sCSE2) +} \ No newline at end of file diff --git a/regression-test/suites/ann_index_p0/ann_index_only_scan_distance_expr.groovy b/regression-test/suites/ann_index_p0/ann_index_only_scan_distance_expr.groovy new file mode 100644 index 00000000000000..d87ce69a45d99a --- /dev/null +++ b/regression-test/suites/ann_index_p0/ann_index_only_scan_distance_expr.groovy @@ -0,0 +1,206 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import groovy.json.JsonSlurper + +// Focus: whether distance is projected or used in expressions/predicates and its impact on index-only scan. +// Strategy: Compare ScanBytes with and without projecting/using distance under different predicate directions. + +def getProfileList = { + def dst = 'http://' + context.config.feHttpAddress + def conn = new URL(dst + "/rest/v1/query_profile").openConnection() + conn.setRequestMethod("GET") + def encoding = Base64.getEncoder().encodeToString((context.config.feHttpUser + ":" + + (context.config.feHttpPassword == null ? "" : context.config.feHttpPassword)).getBytes("UTF-8")) + conn.setRequestProperty("Authorization", "Basic ${encoding}") + return conn.getInputStream().getText() +} + +def getProfile = { id -> + def dst = 'http://' + context.config.feHttpAddress + def conn = new URL(dst + "/api/profile/text/?query_id=$id").openConnection() + conn.setRequestMethod("GET") + def encoding = Base64.getEncoder().encodeToString((context.config.feHttpUser + ":" + + (context.config.feHttpPassword == null ? "" : context.config.feHttpPassword)).getBytes("UTF-8")) + conn.setRequestProperty("Authorization", "Basic ${encoding}") + return conn.getInputStream().getText() +} + +// Note: define getProfileWithToken inside suite to use suite-level assertTrue + +def extractScanBytesValue = { String profileText -> + def lines = profileText.split("\n") + for (def line : lines) { + if (line.contains("ScanBytes:")) { + def m = (line =~ /ScanBytes:\s*([0-9]+(?:\.[0-9]+)?)\s*[A-Za-z]+/) + if (m.find()) { + return m.group(1) + } + } + } + return null +} + +suite("ann_index_only_scan_distance_expr") { + def getProfileWithToken = { token -> + String profileId = "" + int attempts = 0 + while (attempts < 10 && (profileId == null || profileId == "")) { + List profileData = new JsonSlurper().parseText(getProfileList()).data.rows + for (def profileItem in profileData) { + if (profileItem["Sql Statement"].toString().contains(token)) { + profileId = profileItem["Profile ID"].toString() + break + } + } + if (profileId == null || profileId == "") { + Thread.sleep(300) + } + attempts++ + } + assertTrue(profileId != null && profileId != "") + Thread.sleep(800) + return getProfile(profileId).toString() + } + // session vars + sql "unset variable all;" + sql "set profile_level=2;" + sql "set enable_profile=true;" + sql "set experimental_topn_lazy_materialization_threshold=0;" + sql "set experimental_enable_virtual_slot_for_cse=true;" + sql "set enable_no_need_read_data_opt=true;" + sql "set parallel_pipeline_task_num=1;" // make execution more deterministic for test + + sql "drop table if exists ann_expr_l2" + sql """ + create table ann_expr_l2 ( + id int not null, + embedding array not null, + txt string not null, + index ann_embedding(`embedding`) using ann properties( + "index_type"="hnsw", + "metric_type"="l2_distance", + "dim"="8" + ) + ) duplicate key(id) + distributed by hash(id) buckets 1 + properties("replication_num"="1"); + """ + + sql """ + insert into ann_expr_l2 values + (0, [39.906116, 10.495334, 54.08394, 88.67262, 55.243687, 10.162686, 36.335983, 38.684258], 'A'), + (1, [62.759315, 97.15586, 25.832521, 39.604908, 88.76715, 72.64085, 9.688437, 17.721428], 'B'), + (2, [15.447449, 59.7771, 65.54516, 12.973712, 99.685135, 72.080734, 85.71118, 99.35976], 'C'), + (3, [72.26747, 46.42257, 32.368374, 80.50209, 5.777631, 98.803314, 7.0915947, 68.62693], 'D'), + (4, [22.098177, 74.10027, 63.634556, 4.710955, 12.405106, 79.39356, 63.014366, 68.67834], 'E'), + (5, [27.53003, 72.1106, 50.891026, 38.459953, 68.30715, 20.610682, 94.806274, 45.181377], 'F'), + (6, [77.73215, 64.42907, 71.50025, 43.85641, 94.42648, 50.04773, 65.12575, 68.58207], 'G'), + (7, [2.1537063, 82.667885, 16.171143, 71.126656, 5.335274, 40.286068, 11.943586, 3.69409], 'H'), + (8, [54.435013, 56.800594, 59.335514, 55.829235, 85.46627, 33.388138, 11.076194, 20.480877], 'I'), + (9, [76.197945, 60.623528, 84.229805, 31.652937, 71.82595, 48.04684, 71.29212, 30.282396], 'J'); + """ + + def v = "[26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]" + + // Case 1: Range < with distance used in projection arithmetic -> still index-only + def t1 = UUID.randomUUID().toString() + sql """ + select id, "${t1}" from ann_expr_l2 + where l2_distance_approximate(embedding, ${v}) < 170.0 + order by id limit 20; + """ + def t2 = UUID.randomUUID().toString() + sql """ + select id, "${t2}", (l2_distance_approximate(embedding, ${v}) * 2.0) as d2 + from ann_expr_l2 + where l2_distance_approximate(embedding, ${v}) < 170.0 + order by id limit 20; + """ + def p1 = getProfileWithToken(t1) + def p2 = getProfileWithToken(t2) + def s1 = extractScanBytesValue(p1) + def s2 = extractScanBytesValue(p2) + logger.info("Expr L2 < threshold: no-proj=${s1}, proj(d*2)=${s2}") + assertTrue(s1 == s2) + + // Case 2: Range > with distance used in projection arithmetic -> not index-only (needs base read) + def t3 = UUID.randomUUID().toString() + sql """ + select id, "${t3}" from ann_expr_l2 + where l2_distance_approximate(embedding, ${v}) > 120.0 + order by id limit 20; + """ + def t4 = UUID.randomUUID().toString() + sql """ + select id, "${t4}", (l2_distance_approximate(embedding, ${v}) + 1.0) as d2 + from ann_expr_l2 + where l2_distance_approximate(embedding, ${v}) > 120.0 + order by id limit 20; + """ + def p3 = getProfileWithToken(t3) + def p4 = getProfileWithToken(t4) + def s3 = extractScanBytesValue(p3) + def s4 = extractScanBytesValue(p4) + logger.info("Expr L2 > threshold: no-proj=${s3}, proj(d+1)=${s4}") + assertTrue(s3 != s4) + + // Case 3: Distance value reused in another predicate expression; still index-only for < + def t5 = UUID.randomUUID().toString() + sql """ + select id, "${t5}" from ann_expr_l2 + where l2_distance_approximate(embedding, ${v}) < 170.0 + and (l2_distance_approximate(embedding, ${v}) + 0.5) < 200.0 + order by id limit 20; + """ + def t6 = UUID.randomUUID().toString() + sql """ + select id, "${t6}", l2_distance_approximate(embedding, ${v}) as dist + from ann_expr_l2 + where l2_distance_approximate(embedding, ${v}) < 170.0 + and (l2_distance_approximate(embedding, ${v}) + 0.5) < 200.0 + order by id limit 20; + """ + def p5 = getProfileWithToken(t5) + def p6 = getProfileWithToken(t6) + def s5 = extractScanBytesValue(p5) + def s6 = extractScanBytesValue(p6) + logger.info("Expr L2 < threshold with extra predicate: no-proj=${s5}, with-dist=${s6}") + assertTrue(s5 == s6) + + // Case 4: TopN by distance with distance used in projection -> index-only + def t7 = UUID.randomUUID().toString() + sql """ + select id, "${t7}" + from ann_expr_l2 + order by l2_distance_approximate(embedding, ${v}) + limit 5; + """ + def t8 = UUID.randomUUID().toString() + sql """ + select id, "${t8}", (l2_distance_approximate(embedding, ${v}) / 2.0) as d2 + from ann_expr_l2 + order by l2_distance_approximate(embedding, ${v}) + limit 5; + """ + def p7 = getProfileWithToken(t7) + def p8 = getProfileWithToken(t8) + def s7 = extractScanBytesValue(p7) + def s8 = extractScanBytesValue(p8) + logger.info("TopN L2 asc: no-proj=${s7}, proj(d/2)=${s8}") + assertTrue(s7 == s8) +} diff --git a/regression-test/suites/ann_index_p0/ann_index_only_scan_metric_direction.groovy b/regression-test/suites/ann_index_p0/ann_index_only_scan_metric_direction.groovy new file mode 100644 index 00000000000000..02ebd673115a1f --- /dev/null +++ b/regression-test/suites/ann_index_p0/ann_index_only_scan_metric_direction.groovy @@ -0,0 +1,228 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import groovy.json.JsonSlurper + +// Focus: different metrics (l2 vs inner_product) and predicate directions (< vs >) +// Expectation: +// - l2_distance: index returns distance on < (or <=) range; not on > (or >=) +// - inner_product: index returns distance on > (or >=) range; not on < (or <=) +// We infer index-only read by comparing ScanBytes with and without selecting/using distance. + +def getProfileList = { + def dst = 'http://' + context.config.feHttpAddress + def conn = new URL(dst + "/rest/v1/query_profile").openConnection() + conn.setRequestMethod("GET") + def encoding = Base64.getEncoder().encodeToString((context.config.feHttpUser + ":" + + (context.config.feHttpPassword == null ? "" : context.config.feHttpPassword)).getBytes("UTF-8")) + conn.setRequestProperty("Authorization", "Basic ${encoding}") + return conn.getInputStream().getText() +} + +def getProfile = { id -> + def dst = 'http://' + context.config.feHttpAddress + def conn = new URL(dst + "/api/profile/text/?query_id=$id").openConnection() + conn.setRequestMethod("GET") + def encoding = Base64.getEncoder().encodeToString((context.config.feHttpUser + ":" + + (context.config.feHttpPassword == null ? "" : context.config.feHttpPassword)).getBytes("UTF-8")) + conn.setRequestProperty("Authorization", "Basic ${encoding}") + return conn.getInputStream().getText() +} + +// Note: define getProfileWithToken inside suite to use suite-level assertTrue + +def extractScanBytesValue = { String profileText -> + def lines = profileText.split("\n") + for (def line : lines) { + if (line.contains("ScanBytes:")) { + def m = (line =~ /ScanBytes:\s*([0-9]+(?:\.[0-9]+)?)\s*[A-Za-z]+/) + if (m.find()) { + return m.group(1) + } + } + } + return null +} + +suite("ann_index_only_scan_metric_direction") { + def getProfileWithToken = { token -> + String profileId = "" + int attempts = 0 + while (attempts < 10 && (profileId == null || profileId == "")) { + List profileData = new JsonSlurper().parseText(getProfileList()).data.rows + for (def profileItem in profileData) { + if (profileItem["Sql Statement"].toString().contains(token)) { + profileId = profileItem["Profile ID"].toString() + break + } + } + if (profileId == null || profileId == "") { + Thread.sleep(300) + } + attempts++ + } + assertTrue(profileId != null && profileId != "") + Thread.sleep(800) + return getProfile(profileId).toString() + } + // session vars + sql "unset variable all;" + sql "set profile_level=2;" + sql "set enable_profile=true;" + sql "set experimental_topn_lazy_materialization_threshold=0;" + sql "set experimental_enable_virtual_slot_for_cse=true;" + sql "set enable_no_need_read_data_opt=true;" + sql "set parallel_pipeline_task_num=1;" // make execution more deterministic for test + + // l2 table + sql "drop table if exists ann_md_l2" + sql """ + create table ann_md_l2 ( + id int not null, + embedding array not null, + comment string not null, + value int null, + index ann_embedding(`embedding`) using ann properties( + "index_type"="hnsw", + "metric_type"="l2_distance", + "dim"="8" + ) + ) duplicate key(id) + distributed by hash(id) buckets 1 + properties("replication_num"="1"); + """ + + // inner product table + sql "drop table if exists ann_md_ip" + sql """ + create table ann_md_ip ( + id int not null, + embedding array not null, + comment string not null, + value int null, + index ann_embedding(`embedding`) using ann properties( + "index_type"="hnsw", + "metric_type"="inner_product", + "dim"="8" + ) + ) duplicate key(id) + distributed by hash(id) buckets 1 + properties("replication_num"="1"); + """ + + def rows = """ + (0, [39.906116, 10.495334, 54.08394, 88.67262, 55.243687, 10.162686, 36.335983, 38.684258], "A", 100), + (1, [62.759315, 97.15586, 25.832521, 39.604908, 88.76715, 72.64085, 9.688437, 17.721428], "B", 101), + (2, [15.447449, 59.7771, 65.54516, 12.973712, 99.685135, 72.080734, 85.71118, 99.35976], "C", 102), + (3, [72.26747, 46.42257, 32.368374, 80.50209, 5.777631, 98.803314, 7.0915947, 68.62693], "D", 103), + (4, [22.098177, 74.10027, 63.634556, 4.710955, 12.405106, 79.39356, 63.014366, 68.67834], "E", 104), + (5, [27.53003, 72.1106, 50.891026, 38.459953, 68.30715, 20.610682, 94.806274, 45.181377], "F", 105), + (6, [77.73215, 64.42907, 71.50025, 43.85641, 94.42648, 50.04773, 65.12575, 68.58207], "G", 106), + (7, [2.1537063, 82.667885, 16.171143, 71.126656, 5.335274, 40.286068, 11.943586, 3.69409], "H", 107), + (8, [54.435013, 56.800594, 59.335514, 55.829235, 85.46627, 33.388138, 11.076194, 20.480877], "I", 108), + (9, [76.197945, 60.623528, 84.229805, 31.652937, 71.82595, 48.04684, 71.29212, 30.282396], "J", 109) + """ + sql "insert into ann_md_l2 values ${rows};" + sql "insert into ann_md_ip values ${rows};" + + // Common probe vector + def v = "[26.360261917114258,7.05784273147583,32.361351013183594,86.39714050292969,58.79527282714844,27.189321517944336,99.38946533203125,80.19270324707031]" + + // L2: < threshold -> expect index returns distance; projecting distance should NOT increase ScanBytes + def t1 = UUID.randomUUID().toString() + sql """ + select id, "${t1}" from ann_md_l2 + where l2_distance_approximate(embedding, ${v}) < 160.0 + order by id limit 20; + """ + def t2 = UUID.randomUUID().toString() + sql """ + select id, "${t2}", l2_distance_approximate(embedding, ${v}) as dist + from ann_md_l2 + where l2_distance_approximate(embedding, ${v}) < 160.0 + order by id limit 20; + """ + def p1 = getProfileWithToken(t1) + def p2 = getProfileWithToken(t2) + def s1 = extractScanBytesValue(p1) + def s2 = extractScanBytesValue(p2) + logger.info("L2 < threshold ScanBytes: no-proj=${s1}, with-dist=${s2}") + assertTrue(s1 == s2) + + // L2: > threshold -> index doesn't return distance; projecting distance SHOULD increase ScanBytes + def t3 = UUID.randomUUID().toString() + sql """ + select id, "${t3}" from ann_md_l2 + where l2_distance_approximate(embedding, ${v}) > 120.0 + order by id limit 20; + """ + def t4 = UUID.randomUUID().toString() + sql """ + select id, "${t4}", l2_distance_approximate(embedding, ${v}) as dist + from ann_md_l2 + where l2_distance_approximate(embedding, ${v}) > 120.0 + order by id limit 20; + """ + def p3 = getProfileWithToken(t3) + def p4 = getProfileWithToken(t4) + def s3 = extractScanBytesValue(p3) + def s4 = extractScanBytesValue(p4) + logger.info("L2 > threshold ScanBytes: no-proj=${s3}, with-dist=${s4}") + assertTrue(s3 != s4) + + // Inner Product: > threshold -> expect index returns distance + def t5 = UUID.randomUUID().toString() + sql """ + select id, "${t5}" from ann_md_ip + where inner_product_approximate(embedding, ${v}) > 1000.0 + order by id limit 20; + """ + def t6 = UUID.randomUUID().toString() + sql """ + select id, "${t6}", inner_product_approximate(embedding, ${v}) as score + from ann_md_ip + where inner_product_approximate(embedding, ${v}) > 1000.0 + order by id limit 20; + """ + def p5 = getProfileWithToken(t5) + def p6 = getProfileWithToken(t6) + def s5 = extractScanBytesValue(p5) + def s6 = extractScanBytesValue(p6) + logger.info("IP > threshold ScanBytes: no-proj=${s5}, with-score=${s6}") + assertTrue(s5 == s6) + + // Inner Product: < threshold -> expect index doesn't return distance; projecting distance increases ScanBytes + def t7 = UUID.randomUUID().toString() + sql """ + select id, "${t7}" from ann_md_ip + where inner_product_approximate(embedding, ${v}) < 16175.99 + order by id limit 20; + """ + def t8 = UUID.randomUUID().toString() + sql """ + select id, "${t8}", inner_product_approximate(embedding, ${v}) as score + from ann_md_ip + where inner_product_approximate(embedding, ${v}) < 16175.99 + order by id limit 20; + """ + def p7 = getProfileWithToken(t7) + def p8 = getProfileWithToken(t8) + def s7 = extractScanBytesValue(p7) + def s8 = extractScanBytesValue(p8) + logger.info("IP < threshold ScanBytes: no-proj=${s7}, with-score=${s8}") + assertTrue(s7 != s8) +}