diff --git a/python/python/lance/indices.py b/python/python/lance/indices.py index 5772533c3b..4113dc09cb 100644 --- a/python/python/lance/indices.py +++ b/python/python/lance/indices.py @@ -7,6 +7,7 @@ import numpy as np import pyarrow as pa +from lance import LanceFragment from lance.file import LanceFileReader, LanceFileWriter from .lance import indices @@ -379,6 +380,7 @@ def transform_vectors( ivf: IvfModel, pq: PqModel, dest_uri: str, + fragments: Optional[list[LanceFragment]], partition_ds_uri: Optional[str] = None, ): """ @@ -396,6 +398,9 @@ def transform_vectors( dest_uri: str The URI to save the transformed vectors to. The URI can be a local file path or a cloud storage path. + fragments: list[LanceFragment] + The list of data fragments to use when computing the transformed vectors. + This is an optional parameter (the default uses all fragments). partition_ds_uri: str The URI of a precomputed partitions dataset. This allows the partition transform to be skipped, using the precomputed value instead. This is @@ -404,6 +409,13 @@ def transform_vectors( dimension = self.dataset.schema.field(self.column[0]).type.list_size num_subvectors = pq.num_subvectors distance_type = ivf.distance_type + if fragments is None: + fragments = [f._fragment for f in self.dataset.get_fragments()] + elif len(fragments) == 0: + raise ValueError("fragments must be a non-empty list or None") + else: + fragments = [f._fragment for f in fragments] + indices.transform_vectors( self.dataset._ds, self.column[0], @@ -413,6 +425,7 @@ def transform_vectors( ivf.centroids, pq.codebook, dest_uri, + fragments, ) def _determine_num_partitions(self, num_partitions: Optional[int], num_rows: int): diff --git a/python/python/tests/test_indices.py b/python/python/tests/test_indices.py index 8264605d11..8c244caf34 100644 --- a/python/python/tests/test_indices.py +++ b/python/python/tests/test_indices.py @@ -7,10 +7,12 @@ from lance.file import LanceFileReader from lance.indices import IndicesBuilder, IvfModel, PqModel -NUM_ROWS = 10000 +NUM_ROWS_PER_FRAGMENT = 10000 DIMENSION = 128 NUM_SUBVECTORS = 8 -NUM_PARTITIONS = 100 +NUM_FRAGMENTS = 3 +NUM_ROWS = NUM_ROWS_PER_FRAGMENT * NUM_FRAGMENTS +NUM_PARTITIONS = round(np.sqrt(NUM_ROWS)) @pytest.fixture(params=[np.float16, np.float32, np.float64], ids=["f16", "f32", "f64"]) @@ -19,7 +21,9 @@ def rand_dataset(tmpdir, request): vectors.shape = -1 vectors = pa.FixedSizeListArray.from_arrays(vectors, DIMENSION) table = pa.Table.from_arrays([vectors], names=["vectors"]) - ds = lance.write_dataset(table, str(tmpdir / "dataset")) + uri = str(tmpdir / "dataset") + + ds = lance.write_dataset(table, uri, max_rows_per_file=NUM_ROWS_PER_FRAGMENT) return ds @@ -115,12 +119,15 @@ def rand_pq(rand_dataset, rand_ivf): def test_vector_transform(tmpdir, rand_dataset, rand_ivf, rand_pq): + fragments = list(rand_dataset.get_fragments()) + builder = IndicesBuilder(rand_dataset, "vectors") - builder.transform_vectors(rand_ivf, rand_pq, str(tmpdir / "transformed")) + uri = str(tmpdir / "transformed") + builder.transform_vectors(rand_ivf, rand_pq, uri, fragments=fragments) - reader = LanceFileReader(str(tmpdir / "transformed")) + reader = LanceFileReader(uri) - assert reader.metadata().num_rows == 10000 + assert reader.metadata().num_rows == (NUM_ROWS_PER_FRAGMENT * len(fragments)) data = next(reader.read_all(batch_size=10000).to_batches()) row_id = data.column("_rowid") @@ -131,3 +138,9 @@ def test_vector_transform(tmpdir, rand_dataset, rand_ivf, rand_pq): part_id = data.column("__ivf_part_id") assert part_id.type == pa.uint32() + + # test when fragments = None + builder.transform_vectors(rand_ivf, rand_pq, uri, fragments=None) + reader = LanceFileReader(uri) + + assert reader.metadata().num_rows == (NUM_ROWS_PER_FRAGMENT * NUM_FRAGMENTS) diff --git a/python/src/fragment.rs b/python/src/fragment.rs index 4f5c55e5f1..1278724d64 100644 --- a/python/src/fragment.rs +++ b/python/src/fragment.rs @@ -124,7 +124,7 @@ impl FileFragment { self.fragment.id() } - fn metadata(&self) -> FragmentMetadata { + pub fn metadata(&self) -> FragmentMetadata { FragmentMetadata::new(self.fragment.metadata().clone()) } diff --git a/python/src/indices.rs b/python/src/indices.rs index 8dfbeae1a1..53cdcec21e 100644 --- a/python/src/indices.rs +++ b/python/src/indices.rs @@ -12,6 +12,7 @@ use lance_index::vector::{ use lance_linalg::distance::DistanceType; use pyo3::{pyfunction, types::PyModule, wrap_pyfunction, PyObject, PyResult, Python}; +use crate::fragment::FileFragment; use crate::{dataset::Dataset, error::PythonErrorExt, file::object_store_from_uri_or_path, RT}; async fn do_train_ivf_model( @@ -147,11 +148,14 @@ async fn do_transform_vectors( ivf_centroids: FixedSizeListArray, pq_model: ProductQuantizer, dst_uri: &str, + fragments: Vec, ) -> PyResult<()> { let num_rows = dataset.ds.count_rows(None).await.infer_error()?; + let fragments = fragments.iter().map(|item| item.metadata().inner).collect(); let transform_input = dataset .ds .scan() + .with_fragments(fragments) .project(&[column]) .infer_error()? .with_row_id() @@ -188,6 +192,7 @@ pub fn transform_vectors( ivf_centroids: PyArrowType, pq_codebook: PyArrowType, dst_uri: &str, + fragments: Vec, ) -> PyResult<()> { let ivf_centroids = ivf_centroids.0; let ivf_centroids = FixedSizeListArray::from(ivf_centroids); @@ -203,7 +208,15 @@ pub fn transform_vectors( ); RT.block_on( Some(py), - do_transform_vectors(dataset, column, distance_type, ivf_centroids, pq, dst_uri), + do_transform_vectors( + dataset, + column, + distance_type, + ivf_centroids, + pq, + dst_uri, + fragments, + ), )? }