Skip to content

Commit

Permalink
refactor: flat search to use datafusion top k (#2535)
Browse files Browse the repository at this point in the history
Closes #1324
  • Loading branch information
eddyxu authored Jun 27, 2024
1 parent b8a92eb commit f8c5f4d
Show file tree
Hide file tree
Showing 6 changed files with 252 additions and 377 deletions.
2 changes: 1 addition & 1 deletion python/python/tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def check_index(has_knn_combined):
results = dataset.to_table(nearest=query).column("vector")
assert has_target(query["q"], results)
plan = dataset.scanner(nearest=query).explain_plan()
assert ("KNNFlat" in plan) == has_knn_combined
assert ("KNNVectorDistance" in plan) == has_knn_combined

# Original state is 2 indexed fragments of size 150. This should not require
# a combined scan
Expand Down
4 changes: 2 additions & 2 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,8 +507,8 @@ def check_index(has_knn_combined, delete_has_happened):
for query in sample_queries:
results = dataset.to_table(nearest=query).column("vector")
assert has_target(query["q"], results)
plan = dataset.scanner(nearest=query).explain_plan()
assert ("KNNFlat" in plan) == has_knn_combined
plan = dataset.scanner(nearest=query).explain_plan(verbose=True)
assert ("KNNVectorDistance" in plan) == has_knn_combined
for query in sample_delete_queries:
results = dataset.to_table(nearest=query).column("vector")
assert delete_has_happened != has_target(query["q"], results)
Expand Down
133 changes: 19 additions & 114 deletions rust/lance-index/src/vector/flat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,15 @@
//! Flat Vector Index.
//!
use std::sync::Arc;

use arrow_array::{
cast::AsArray, make_array, Array, ArrayRef, FixedSizeListArray, RecordBatch, StructArray,
};
use arrow_ord::sort::sort_to_indices;
use arrow_schema::{DataType, Field as ArrowField, SchemaRef, SortOptions};
use arrow_select::{concat::concat, take::take};
use futures::{
future,
stream::{repeat_with, StreamExt, TryStreamExt},
};
use arrow_array::{make_array, Array, ArrayRef, RecordBatch};
use arrow_schema::{DataType, Field as ArrowField};
use lance_arrow::*;
use lance_core::{Error, Result, ROW_ID};
use lance_io::stream::RecordBatchStream;
use lance_linalg::distance::DistanceType;
use snafu::{location, Location};
use tracing::instrument;

use super::{Query, DIST_COL};
use super::DIST_COL;

pub mod index;
pub mod storage;
Expand All @@ -33,57 +22,20 @@ fn distance_field() -> ArrowField {
}

#[instrument(level = "debug", skip_all)]
pub async fn flat_search(
stream: impl RecordBatchStream + 'static,
query: &Query,
) -> Result<RecordBatch> {
let input_schema = stream.schema();
let batches = stream
.try_filter(|batch| future::ready(batch.num_rows() > 0))
.zip(repeat_with(|| query.metric_type))
.map(|(batch, mt)| async move { flat_search_batch(query, mt, batch?).await })
.buffer_unordered(16)
.try_collect::<Vec<_>>()
.await?;

if batches.is_empty() {
if input_schema.column_with_name(DIST_COL).is_none() {
let schema_with_distance = input_schema.try_with_column(distance_field())?;
return Ok(RecordBatch::new_empty(schema_with_distance.into()));
} else {
return Ok(RecordBatch::new_empty(input_schema));
}
}

// TODO: waiting to do this until the end adds quite a bit of latency. We should
// do this in a streaming fashion. See also: https://github.com/lancedb/lance/issues/1324
let batch = concat_batches(&batches[0].schema(), &batches)?;
let distances = batch.column_by_name(DIST_COL).unwrap();
let indices = sort_to_indices(distances, None, Some(query.k))?;

let struct_arr = StructArray::from(batch);
let selected_arr = take(&struct_arr, &indices, None)?;
Ok(selected_arr.as_struct().into())
}

#[instrument(level = "debug", skip(query, batch))]
async fn flat_search_batch(
query: &Query,
mt: DistanceType,
pub async fn compute_distance(
key: ArrayRef,
dt: DistanceType,
column: &str,
mut batch: RecordBatch,
) -> Result<RecordBatch> {
let key = query.key.clone();
let k = query.k;
if batch.column_by_name(DIST_COL).is_some() {
// Ignore the distance calculated from inner vector index.
batch = batch.drop_column(DIST_COL)?;
}
let vectors = batch
.column_by_name(&query.column)
.ok_or_else(|| Error::Schema {
message: format!("column {} does not exist in dataset", query.column),
location: location!(),
})?;
let vectors = batch.column_by_name(column).ok_or_else(|| Error::Schema {
message: format!("column {} does not exist in dataset", column),
location: location!(),
})?;

// A selection vector may have been applied to _rowid column, so we need to
// push that onto vectors if possible.
Expand All @@ -103,62 +55,15 @@ async fn flat_search_batch(
let vectors = as_fixed_size_list_array(vectors.as_ref()).clone();

tokio::task::spawn_blocking(move || {
let distances = mt.arrow_batch_func()(key.as_ref(), &vectors)? as ArrayRef;

// We don't want any nulls in result, so limit to k or the number of valid values.
let k = std::cmp::min(k, distances.len() - distances.null_count());

let sort_options = SortOptions {
nulls_first: false,
..Default::default()
};
let indices = sort_to_indices(&distances, Some(sort_options), Some(k))?;

let batch_with_distance = batch.try_with_column(distance_field(), distances)?;
let struct_arr = StructArray::from(batch_with_distance);
let selected_arr = take(&struct_arr, &indices, None)?;
Ok::<RecordBatch, Error>(selected_arr.as_struct().into())
let distances = dt.arrow_batch_func()(key.as_ref(), &vectors)? as ArrayRef;

batch
.try_with_column(distance_field(), distances)
.map_err(|e| Error::Execution {
message: format!("Failed to adding distance column: {}", e),
location: location!(),
})
})
.await
.unwrap()
}

fn concat_batches<'a>(
schema: &SchemaRef,
input_batches: impl IntoIterator<Item = &'a RecordBatch>,
) -> Result<RecordBatch> {
let batches: Vec<&RecordBatch> = input_batches.into_iter().collect();
if batches.is_empty() {
return Ok(RecordBatch::new_empty(schema.clone()));
}
let field_num = schema.fields().len();
let mut arrays = Vec::with_capacity(field_num);
for i in 0..field_num {
let in_arrays = &batches
.iter()
.map(|batch| batch.column(i).as_ref())
.collect::<Vec<_>>();
let data_type = in_arrays[0].data_type();
let array = match data_type {
DataType::FixedSizeList(f, size) if f.data_type().is_floating() => {
concat_fsl(in_arrays, *size)?
}
_ => concat(in_arrays)?,
};
arrays.push(array);
}
Ok(RecordBatch::try_new(schema.clone(), arrays)?)
}

/// Optimized version of concatenating fixed-size list. Upstream does not yet
/// pre-size FSL arrays correctly. We can remove this once they do.
fn concat_fsl(arrays: &[&dyn Array], size: i32) -> Result<ArrayRef> {
let values_arrays: Vec<_> = arrays
.iter()
.map(|&arr| as_fixed_size_list_array(arr).values().as_ref())
.collect();
let values = concat(&values_arrays)?;
Ok(Arc::new(FixedSizeListArray::try_new_from_values(
values, size,
)?))
}
Loading

0 comments on commit f8c5f4d

Please sign in to comment.