Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support multivector type #3190

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1843,8 +1843,12 @@ def create_index(
if c not in self.schema.names:
raise KeyError(f"{c} not found in schema")
field = self.schema.field(c)
is_multivec = False
if pa.types.is_fixed_size_list(field.type):
dimension = field.type.list_size
elif pa.types.is_list(field.type):
dimension = field.type.value_type.list_size
is_multivec = True
elif (
isinstance(field.type, pa.FixedShapeTensorType)
and len(field.type.shape) == 1
Expand All @@ -1862,7 +1866,12 @@ def create_index(
f" ({num_sub_vectors})"
)

if not pa.types.is_floating(field.type.value_type):
if (
not is_multivec and not pa.types.is_floating(field.type.value_type)
) or (
is_multivec
and not pa.types.is_floating(field.type.value_type.value_type)
):
raise TypeError(
f"Vector column {c} must have floating value type, "
f"got {field.type.value_type}"
Expand Down Expand Up @@ -3043,14 +3052,19 @@ def nearest(
column_type = column_field.type
if hasattr(column_type, "storage_type"):
column_type = column_type.storage_type
if not pa.types.is_fixed_size_list(column_type):
if pa.types.is_fixed_size_list(column_type):
dim = column_type.list_size
elif pa.types.is_list(column_type) and pa.types.is_fixed_size_list(
column_type.value_type
):
dim = column_type.value_type.list_size
else:
raise TypeError(
f"Query column {column} must be a vector. Got {column_field.type}."
)
if len(q) != column_type.list_size:
if len(q) % dim != 0:
raise ValueError(
f"Query vector size {len(q)} does not match index column size"
f" {column_type.list_size}"
f"Query vector size {len(q)} does not match index column size" f" {dim}"
)

if k is not None and int(k) <= 0:
Expand Down
62 changes: 62 additions & 0 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,48 @@ def gen_str(n):
return tbl


def create_multivec_table(
nvec=1000, nvec_per_row=5, ndim=128, nans=0, nullify=False, dtype=np.float32
):
mat = np.random.randn(nvec, nvec_per_row, ndim)
if nans > 0:
nans_mat = np.empty((nans, ndim))
nans_mat[:] = np.nan
mat = np.concatenate((mat, nans_mat), axis=0)
mat = mat.astype(dtype)
price = np.random.rand(nvec + nans) * 100

def gen_str(n):
return "".join(random.choices(string.ascii_letters + string.digits, k=n))

meta = np.array([gen_str(100) for _ in range(nvec + nans)])

multi_vec_type = pa.list_(pa.list_(pa.float32(), ndim))
tbl = pa.Table.from_arrays(
[
pa.array((mat[i].tolist() for i in range(nvec)), type=multi_vec_type),
],
schema=pa.schema(
[
pa.field("vector", pa.list_(pa.list_(pa.float32(), ndim))),
]
),
)
tbl = (
tbl.append_column("price", pa.array(price))
.append_column("meta", pa.array(meta))
.append_column("id", pa.array(range(nvec + nans)))
)
if nullify:
idx = tbl.schema.get_field_index("vector")
vecs = tbl[idx].to_pylist()
nullified = [vec if i % 2 == 0 else None for i, vec in enumerate(vecs)]
field = tbl.schema.field(idx)
vecs = pa.array(nullified, field.type)
tbl = tbl.set_column(idx, field, vecs)
return tbl


@pytest.fixture()
def dataset(tmp_path):
tbl = create_table()
Expand All @@ -63,6 +105,21 @@ def indexed_dataset(tmp_path):
)


@pytest.fixture()
def multivec_dataset(tmp_path):
tbl = create_multivec_table()
yield lance.write_dataset(tbl, tmp_path)


@pytest.fixture()
def indexed_multivec_dataset(tmp_path):
tbl = create_multivec_table()
dataset = lance.write_dataset(tbl, tmp_path)
yield dataset.create_index(
"vector", index_type="IVF_PQ", num_partitions=4, num_sub_vectors=16
)


def run(ds, q=None, assert_func=None):
if q is None:
q = np.random.randn(128)
Expand Down Expand Up @@ -456,6 +513,11 @@ def test_create_ivf_hnsw_sq_index(dataset, tmp_path):
assert ann_ds.list_indices()[0]["fields"] == ["vector"]


def test_multivec_ann(indexed_multivec_dataset):
query = np.random.randn(5 * 128)
indexed_multivec_dataset.scanner(nearest={"column": "vector", "q": query, "k": 100})


def test_pre_populated_ivf_centroids(dataset, tmp_path: Path):
centroids = np.random.randn(5, 128).astype(np.float32) # IVF5
dataset_with_index = dataset.create_index(
Expand Down
53 changes: 35 additions & 18 deletions rust/lance-index/src/vector/flat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
//! Flat Vector Index.
//!

use arrow_array::{make_array, Array, ArrayRef, RecordBatch};
use std::sync::Arc;

use arrow::array::AsArray;
use arrow_array::{make_array, Array, ArrayRef, Float32Array, RecordBatch};
use arrow_schema::{DataType, Field as ArrowField};
use lance_arrow::*;
use lance_core::{Error, Result, ROW_ID};
use lance_linalg::distance::DistanceType;
use lance_linalg::distance::{multivec_distance, DistanceType};
use snafu::{location, Location};
use tracing::instrument;

Expand All @@ -32,30 +35,44 @@ pub async fn compute_distance(
// Ignore the distance calculated from inner vector index.
batch = batch.drop_column(DIST_COL)?;
}
let vectors = batch.column_by_name(column).ok_or_else(|| Error::Schema {
message: format!("column {} does not exist in dataset", column),
location: location!(),
})?;
let vectors = batch
.column_by_name(column)
.ok_or_else(|| Error::Schema {
message: format!("column {} does not exist in dataset", column),
location: location!(),
})?
.clone();

// A selection vector may have been applied to _rowid column, so we need to
// push that onto vectors if possible.
let vectors = as_fixed_size_list_array(vectors.as_ref()).clone();
let validity_buffer = if let Some(rowids) = batch.column_by_name(ROW_ID) {
rowids.nulls().map(|nulls| nulls.buffer().clone())
} else {
None
};

let vectors = vectors
.into_data()
.into_builder()
.null_bit_buffer(validity_buffer)
.build()
.map(make_array)?;
let vectors = as_fixed_size_list_array(vectors.as_ref()).clone();

tokio::task::spawn_blocking(move || {
let distances = dt.arrow_batch_func()(key.as_ref(), &vectors)? as ArrayRef;
// A selection vector may have been applied to _rowid column, so we need to
// push that onto vectors if possible.

let vectors = vectors
.into_data()
.into_builder()
.null_bit_buffer(validity_buffer)
.build()
.map(make_array)?;
let distances = match vectors.data_type() {
DataType::FixedSizeList(_, _) => {
let vectors = vectors.as_fixed_size_list();
dt.arrow_batch_func()(key.as_ref(), vectors)? as ArrayRef
}
DataType::List(_) => {
let vectors = vectors.as_list();
let dists = multivec_distance(key.as_ref(), vectors, dt)?;
Arc::new(Float32Array::from(dists))
}
_ => {
unreachable!()
}
};

batch
.try_with_column(distance_field(), distances)
Expand Down
2 changes: 2 additions & 0 deletions rust/lance-index/src/vector/flat/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ impl IvfSubIndex for FlatIndex {
dist: OrderedFloat(dist),
})
.sorted_unstable()
.unique_by(|r| r.id)
.take(k)
.map(
|OrderedNode {
Expand All @@ -105,6 +106,7 @@ impl IvfSubIndex for FlatIndex {
dist: OrderedFloat(dist_calc.distance(id as u32)),
})
.sorted_unstable()
.unique_by(|r| r.id)
.take(k)
.map(
|OrderedNode {
Expand Down
19 changes: 10 additions & 9 deletions rust/lance-index/src/vector/hnsw/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -706,15 +706,16 @@ impl IvfSubIndex for HNSW {
// if the queue is full, we just don't push it back, so ignore the error here
let _ = self.inner.visited_generator_queue.push(prefilter_generator);

let row_ids = UInt64Array::from_iter_values(results.iter().map(|x| storage.row_id(x.id)));
let distances = Arc::new(Float32Array::from_iter_values(
results.iter().map(|x| x.dist.0),
));

Ok(RecordBatch::try_new(
schema,
vec![distances, Arc::new(row_ids)],
)?)
// need to unique by row ids in case of searching multivector
let (row_ids, dists): (Vec<_>, Vec<_>) = results
.into_iter()
.map(|r| (storage.row_id(r.id), r.dist.0))
.unique_by(|r| r.0)
.unzip();
let row_ids = Arc::new(UInt64Array::from(row_ids));
let distances = Arc::new(Float32Array::from(dists));

Ok(RecordBatch::try_new(schema, vec![distances, row_ids])?)
}

/// Given a vector storage, containing all the data for the IVF partition, build the sub index.
Expand Down
9 changes: 6 additions & 3 deletions rust/lance-index/src/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ impl IvfTransformer {
vector_column: &str,
range: Option<Range<u32>>,
) -> Self {
let mut transforms: Vec<Arc<dyn Transformer>> = vec![];
let mut transforms: Vec<Arc<dyn Transformer>> =
vec![Arc::new(super::transform::Flatten::new(vector_column))];

let dt = if distance_type == DistanceType::Cosine {
transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
Expand Down Expand Up @@ -154,7 +155,8 @@ impl IvfTransformer {
range: Option<Range<u32>>,
with_pq_code: bool, // Pass true for v1 index format, otherwise false.
) -> Self {
let mut transforms: Vec<Arc<dyn Transformer>> = vec![];
let mut transforms: Vec<Arc<dyn Transformer>> =
vec![Arc::new(super::transform::Flatten::new(vector_column))];

let mt = if distance_type == MetricType::Cosine {
transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
Expand Down Expand Up @@ -206,7 +208,8 @@ impl IvfTransformer {
vector_column: &str,
range: Option<Range<u32>>,
) -> Self {
let mut transforms: Vec<Arc<dyn Transformer>> = vec![];
let mut transforms: Vec<Arc<dyn Transformer>> =
vec![Arc::new(super::transform::Flatten::new(vector_column))];

let mt = if metric_type == MetricType::Cosine {
transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
Expand Down
1 change: 1 addition & 0 deletions rust/lance-index/src/vector/ivf/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ impl Transformer for PartitionTransformer {
),
location: location!(),
})?;

let fsl = arr
.as_fixed_size_list_opt()
.ok_or_else(|| lance_core::Error::Index {
Expand Down
Loading
Loading