Skip to content

Commit

Permalink
feat: support transforming selected fragments in vector transform sta…
Browse files Browse the repository at this point in the history
…ge for ivf_pq index (#2657)

The `transform_vectors()` api transforms the ivf-pq indices to get an
output table with the row_id, pq_code, and partition_id. This PR
modifies this API so that this transformation is done for selected
fragments that the user specifies as a parameter. If `fragments=None`,
the default will use all fragments of the dataset.
  • Loading branch information
raunaks13 authored Aug 1, 2024
1 parent 2c04d4e commit 7f956b5
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 8 deletions.
13 changes: 13 additions & 0 deletions python/python/lance/indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pyarrow as pa

from lance import LanceFragment
from lance.file import LanceFileReader, LanceFileWriter

from .lance import indices
Expand Down Expand Up @@ -379,6 +380,7 @@ def transform_vectors(
ivf: IvfModel,
pq: PqModel,
dest_uri: str,
fragments: Optional[list[LanceFragment]],
partition_ds_uri: Optional[str] = None,
):
"""
Expand All @@ -396,6 +398,9 @@ def transform_vectors(
dest_uri: str
The URI to save the transformed vectors to. The URI can be a local file
path or a cloud storage path.
fragments: list[LanceFragment]
The list of data fragments to use when computing the transformed vectors.
This is an optional parameter (the default uses all fragments).
partition_ds_uri: str
The URI of a precomputed partitions dataset. This allows the partition
transform to be skipped, using the precomputed value instead. This is
Expand All @@ -404,6 +409,13 @@ def transform_vectors(
dimension = self.dataset.schema.field(self.column[0]).type.list_size
num_subvectors = pq.num_subvectors
distance_type = ivf.distance_type
if fragments is None:
fragments = [f._fragment for f in self.dataset.get_fragments()]
elif len(fragments) == 0:
raise ValueError("fragments must be a non-empty list or None")
else:
fragments = [f._fragment for f in fragments]

indices.transform_vectors(
self.dataset._ds,
self.column[0],
Expand All @@ -413,6 +425,7 @@ def transform_vectors(
ivf.centroids,
pq.codebook,
dest_uri,
fragments,
)

def _determine_num_partitions(self, num_partitions: Optional[int], num_rows: int):
Expand Down
25 changes: 19 additions & 6 deletions python/python/tests/test_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from lance.file import LanceFileReader
from lance.indices import IndicesBuilder, IvfModel, PqModel

NUM_ROWS = 10000
NUM_ROWS_PER_FRAGMENT = 10000
DIMENSION = 128
NUM_SUBVECTORS = 8
NUM_PARTITIONS = 100
NUM_FRAGMENTS = 3
NUM_ROWS = NUM_ROWS_PER_FRAGMENT * NUM_FRAGMENTS
NUM_PARTITIONS = round(np.sqrt(NUM_ROWS))


@pytest.fixture(params=[np.float16, np.float32, np.float64], ids=["f16", "f32", "f64"])
Expand All @@ -19,7 +21,9 @@ def rand_dataset(tmpdir, request):
vectors.shape = -1
vectors = pa.FixedSizeListArray.from_arrays(vectors, DIMENSION)
table = pa.Table.from_arrays([vectors], names=["vectors"])
ds = lance.write_dataset(table, str(tmpdir / "dataset"))
uri = str(tmpdir / "dataset")

ds = lance.write_dataset(table, uri, max_rows_per_file=NUM_ROWS_PER_FRAGMENT)

return ds

Expand Down Expand Up @@ -115,12 +119,15 @@ def rand_pq(rand_dataset, rand_ivf):


def test_vector_transform(tmpdir, rand_dataset, rand_ivf, rand_pq):
fragments = list(rand_dataset.get_fragments())

builder = IndicesBuilder(rand_dataset, "vectors")
builder.transform_vectors(rand_ivf, rand_pq, str(tmpdir / "transformed"))
uri = str(tmpdir / "transformed")
builder.transform_vectors(rand_ivf, rand_pq, uri, fragments=fragments)

reader = LanceFileReader(str(tmpdir / "transformed"))
reader = LanceFileReader(uri)

assert reader.metadata().num_rows == 10000
assert reader.metadata().num_rows == (NUM_ROWS_PER_FRAGMENT * len(fragments))
data = next(reader.read_all(batch_size=10000).to_batches())

row_id = data.column("_rowid")
Expand All @@ -131,3 +138,9 @@ def test_vector_transform(tmpdir, rand_dataset, rand_ivf, rand_pq):

part_id = data.column("__ivf_part_id")
assert part_id.type == pa.uint32()

# test when fragments = None
builder.transform_vectors(rand_ivf, rand_pq, uri, fragments=None)
reader = LanceFileReader(uri)

assert reader.metadata().num_rows == (NUM_ROWS_PER_FRAGMENT * NUM_FRAGMENTS)
2 changes: 1 addition & 1 deletion python/src/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl FileFragment {
self.fragment.id()
}

fn metadata(&self) -> FragmentMetadata {
pub fn metadata(&self) -> FragmentMetadata {
FragmentMetadata::new(self.fragment.metadata().clone())
}

Expand Down
15 changes: 14 additions & 1 deletion python/src/indices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use lance_index::vector::{
use lance_linalg::distance::DistanceType;
use pyo3::{pyfunction, types::PyModule, wrap_pyfunction, PyObject, PyResult, Python};

use crate::fragment::FileFragment;
use crate::{dataset::Dataset, error::PythonErrorExt, file::object_store_from_uri_or_path, RT};

async fn do_train_ivf_model(
Expand Down Expand Up @@ -147,11 +148,14 @@ async fn do_transform_vectors(
ivf_centroids: FixedSizeListArray,
pq_model: ProductQuantizer,
dst_uri: &str,
fragments: Vec<FileFragment>,
) -> PyResult<()> {
let num_rows = dataset.ds.count_rows(None).await.infer_error()?;
let fragments = fragments.iter().map(|item| item.metadata().inner).collect();
let transform_input = dataset
.ds
.scan()
.with_fragments(fragments)
.project(&[column])
.infer_error()?
.with_row_id()
Expand Down Expand Up @@ -188,6 +192,7 @@ pub fn transform_vectors(
ivf_centroids: PyArrowType<ArrayData>,
pq_codebook: PyArrowType<ArrayData>,
dst_uri: &str,
fragments: Vec<FileFragment>,
) -> PyResult<()> {
let ivf_centroids = ivf_centroids.0;
let ivf_centroids = FixedSizeListArray::from(ivf_centroids);
Expand All @@ -203,7 +208,15 @@ pub fn transform_vectors(
);
RT.block_on(
Some(py),
do_transform_vectors(dataset, column, distance_type, ivf_centroids, pq, dst_uri),
do_transform_vectors(
dataset,
column,
distance_type,
ivf_centroids,
pq,
dst_uri,
fragments,
),
)?
}

Expand Down

0 comments on commit 7f956b5

Please sign in to comment.