Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support transforming selected fragments in vector transform stage for ivf_pq index #2657

Merged
merged 9 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions python/python/lance/indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pyarrow as pa

from lance import LanceFragment
from lance.file import LanceFileReader, LanceFileWriter

from .lance import indices
Expand Down Expand Up @@ -379,6 +380,7 @@ def transform_vectors(
ivf: IvfModel,
pq: PqModel,
dest_uri: str,
fragments: Optional[list[LanceFragment]],
partition_ds_uri: Optional[str] = None,
):
"""
Expand All @@ -396,6 +398,9 @@ def transform_vectors(
dest_uri: str
The URI to save the transformed vectors to. The URI can be a local file
path or a cloud storage path.
fragments: list[LanceFragment]
The list of data fragments to use when computing the transformed vectors.
This is an optional parameter (the default uses all fragments).
partition_ds_uri: str
The URI of a precomputed partitions dataset. This allows the partition
transform to be skipped, using the precomputed value instead. This is
Expand All @@ -404,6 +409,13 @@ def transform_vectors(
dimension = self.dataset.schema.field(self.column[0]).type.list_size
num_subvectors = pq.num_subvectors
distance_type = ivf.distance_type
if fragments is None:
fragments = [f._fragment for f in self.dataset.get_fragments()]
elif len(fragments) == 0:
raise ValueError("fragments must be a non-empty list or None")
else:
fragments = [f._fragment for f in fragments]

indices.transform_vectors(
self.dataset._ds,
self.column[0],
Expand All @@ -413,6 +425,7 @@ def transform_vectors(
ivf.centroids,
pq.codebook,
dest_uri,
fragments,
)

def _determine_num_partitions(self, num_partitions: Optional[int], num_rows: int):
Expand Down
25 changes: 19 additions & 6 deletions python/python/tests/test_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from lance.file import LanceFileReader
from lance.indices import IndicesBuilder, IvfModel, PqModel

NUM_ROWS = 10000
NUM_ROWS_PER_FRAGMENT = 10000
DIMENSION = 128
NUM_SUBVECTORS = 8
NUM_PARTITIONS = 100
NUM_FRAGMENTS = 3
NUM_ROWS = NUM_ROWS_PER_FRAGMENT * NUM_FRAGMENTS
NUM_PARTITIONS = round(np.sqrt(NUM_ROWS))


@pytest.fixture(params=[np.float16, np.float32, np.float64], ids=["f16", "f32", "f64"])
Expand All @@ -19,7 +21,9 @@ def rand_dataset(tmpdir, request):
vectors.shape = -1
vectors = pa.FixedSizeListArray.from_arrays(vectors, DIMENSION)
table = pa.Table.from_arrays([vectors], names=["vectors"])
ds = lance.write_dataset(table, str(tmpdir / "dataset"))
uri = str(tmpdir / "dataset")

ds = lance.write_dataset(table, uri, max_rows_per_file=NUM_ROWS_PER_FRAGMENT)

return ds

Expand Down Expand Up @@ -115,12 +119,15 @@ def rand_pq(rand_dataset, rand_ivf):


def test_vector_transform(tmpdir, rand_dataset, rand_ivf, rand_pq):
fragments = list(rand_dataset.get_fragments())

builder = IndicesBuilder(rand_dataset, "vectors")
builder.transform_vectors(rand_ivf, rand_pq, str(tmpdir / "transformed"))
uri = str(tmpdir / "transformed")
builder.transform_vectors(rand_ivf, rand_pq, uri, fragments=fragments)

reader = LanceFileReader(str(tmpdir / "transformed"))
reader = LanceFileReader(uri)

assert reader.metadata().num_rows == 10000
assert reader.metadata().num_rows == (NUM_ROWS_PER_FRAGMENT * len(fragments))
data = next(reader.read_all(batch_size=10000).to_batches())

row_id = data.column("_rowid")
Expand All @@ -131,3 +138,9 @@ def test_vector_transform(tmpdir, rand_dataset, rand_ivf, rand_pq):

part_id = data.column("__ivf_part_id")
assert part_id.type == pa.uint32()

# test when fragments = None
builder.transform_vectors(rand_ivf, rand_pq, uri, fragments=None)
reader = LanceFileReader(uri)

assert reader.metadata().num_rows == (NUM_ROWS_PER_FRAGMENT * NUM_FRAGMENTS)
2 changes: 1 addition & 1 deletion python/src/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl FileFragment {
self.fragment.id()
}

fn metadata(&self) -> FragmentMetadata {
pub fn metadata(&self) -> FragmentMetadata {
FragmentMetadata::new(self.fragment.metadata().clone())
}

Expand Down
15 changes: 14 additions & 1 deletion python/src/indices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use lance_index::vector::{
use lance_linalg::distance::DistanceType;
use pyo3::{pyfunction, types::PyModule, wrap_pyfunction, PyObject, PyResult, Python};

use crate::fragment::FileFragment;
use crate::{dataset::Dataset, error::PythonErrorExt, file::object_store_from_uri_or_path, RT};

async fn do_train_ivf_model(
Expand Down Expand Up @@ -147,11 +148,14 @@ async fn do_transform_vectors(
ivf_centroids: FixedSizeListArray,
pq_model: ProductQuantizer,
dst_uri: &str,
fragments: Vec<FileFragment>,
) -> PyResult<()> {
let num_rows = dataset.ds.count_rows(None).await.infer_error()?;
let fragments = fragments.iter().map(|item| item.metadata().inner).collect();
let transform_input = dataset
.ds
.scan()
.with_fragments(fragments)
.project(&[column])
.infer_error()?
.with_row_id()
Expand Down Expand Up @@ -188,6 +192,7 @@ pub fn transform_vectors(
ivf_centroids: PyArrowType<ArrayData>,
pq_codebook: PyArrowType<ArrayData>,
dst_uri: &str,
fragments: Vec<FileFragment>,
) -> PyResult<()> {
let ivf_centroids = ivf_centroids.0;
let ivf_centroids = FixedSizeListArray::from(ivf_centroids);
Expand All @@ -203,7 +208,15 @@ pub fn transform_vectors(
);
RT.block_on(
Some(py),
do_transform_vectors(dataset, column, distance_type, ivf_centroids, pq, dst_uri),
do_transform_vectors(
dataset,
column,
distance_type,
ivf_centroids,
pq,
dst_uri,
fragments,
),
)?
}

Expand Down
Loading