Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
  • Loading branch information
BubbleCal committed Dec 18, 2024
1 parent 90b2080 commit 97b8d8a
Showing 1 changed file with 35 additions and 7 deletions.
42 changes: 35 additions & 7 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1604,6 +1604,7 @@ impl Scanner {
_ => unreachable!(),
};

// refine is always required for multivec
let mut knn_node = if q.refine_factor.is_some() || is_multivec {
let with_vector = self.dataset.schema().project(&[&q.column])?;
let knn_node_with_vector =
Expand Down Expand Up @@ -2010,14 +2011,18 @@ impl Scanner {
.map(|query_vec| {
let mut new_query = q.clone();
new_query.key = query_vec;
if num_queries > 1 {
new_query.k = q.k / 2;
}
new_query
});
let mut ann_nodes = Vec::with_capacity(new_queries.len());
let prefilter_source = self.prefilter_source(filter_plan).await?;
for query in new_queries {
ann_nodes.push(self.ann(&query, index, filter_plan).await?);
let ann_node = new_knn_exec(
self.dataset.clone(),
index,
&query,
prefilter_source.clone(),
)?;
ann_nodes.push(ann_node);
}
let ann_node = Arc::new(UnionExec::new(ann_nodes));
let ann_node = Arc::new(RepartitionExec::try_new(
Expand All @@ -2030,12 +2035,22 @@ impl Scanner {
expressions::col(ROW_ID, schema.as_ref())?,
ROW_ID.to_string(),
)];
let ann_node = Arc::new(AggregateExec::try_new(
// for now multivector is always with cosine distance so here convert the distance to `1 - distance`,
let ann_node: Arc<dyn ExecutionPlan> = Arc::new(AggregateExec::try_new(
AggregateMode::Single,
PhysicalGroupBy::new_single(group_expr),
vec![AggregateExprBuilder::new(
functions_aggregate::min_max::min_udaf(),
vec![expressions::col(DIST_COL, &schema)?],
functions_aggregate::sum::sum_udaf(),
vec![expressions::binary(
expressions::lit(1.0),
datafusion_expr::Operator::Minus,
expressions::cast(
expressions::col(DIST_COL, &schema)?,
&schema,
DataType::Float64,
)?,
&schema,
)?],
)
.schema(schema.clone())
.alias(DIST_COL)
Expand All @@ -2044,6 +2059,19 @@ impl Scanner {
ann_node,
schema,
)?);

let sort_expr = PhysicalSortExpr {
expr: expressions::col(DIST_COL, ann_node.schema().as_ref())?,
options: SortOptions {
descending: true,
nulls_first: false,
},
};
let ann_node = Arc::new(
SortExec::new(vec![sort_expr], ann_node)
.with_fetch(Some(q.k * q.refine_factor.unwrap_or(1) as usize)),
);

Ok(ann_node)
}

Expand Down

0 comments on commit 97b8d8a

Please sign in to comment.