diff --git a/python/python/lance/sampler.py b/python/python/lance/sampler.py index fb3f670c86..2028609d1e 100644 --- a/python/python/lance/sampler.py +++ b/python/python/lance/sampler.py @@ -6,6 +6,7 @@ import gc import logging +import math import random import warnings from abc import ABC, abstractmethod @@ -15,6 +16,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, TypeVar, Union import pyarrow as pa +import pyarrow.compute as pc import lance from lance.dependencies import numpy as np @@ -105,12 +107,63 @@ def _efficient_sample( del tbl +def _filtered_efficient_sample( + dataset: lance.LanceDataset, + n: int, + columns: Optional[Union[List[str], Dict[str, str]]], + batch_size: int, + target_takes: int, + filter: str, +) -> Generator[pa.RecordBatch, None, None]: + total_records = len(dataset) + shard_size = math.ceil(n / target_takes) + num_shards = math.ceil(total_records / shard_size) + + shards = list(range(num_shards)) + random.shuffle(shards) + + tables = [] + remaining_rows = n + remaining_in_batch = min(batch_size, n) + for shard in shards: + start = shard * shard_size + end = min(start + shard_size, total_records) + table = dataset.to_table( + columns=columns, + offset=start, + limit=(end - start), + batch_size=shard_size, + ) + if len(columns) == 1 and filter.lower() == f"{columns[0]} is not null": + table = pc.drop_null(table) + elif filter is not None: + raise NotImplementedError(f"Can't yet run filter <{filter}> in-memory") + if table.num_rows > 0: + tables.append(table) + remaining_rows -= table.num_rows + remaining_in_batch = remaining_in_batch - table.num_rows + if remaining_in_batch <= 0: + combined = pa.concat_tables(tables).combine_chunks() + batch = combined.slice(0, batch_size).to_batches()[0] + yield batch + remaining_in_batch = min(batch_size, remaining_rows) + if len(combined) > batch_size: + leftover = combined.slice(batch_size) + tables = [leftover] + remaining_in_batch -= len(leftover) + else: + tables = [] + if remaining_rows <= 0: + break + + def maybe_sample( dataset: Union[str, Path, lance.LanceDataset], n: int, columns: Union[list[str], dict[str, str], str], batch_size: int = 10240, max_takes: int = 2048, + filt: Optional[str] = None, ) -> Generator[pa.RecordBatch, None, None]: """Sample n records from the dataset. @@ -129,6 +182,10 @@ def maybe_sample( This is employed to minimize the number of random reads necessary for sampling. A sufficiently large value can provide an effective random sample without the need for excessive random reads. + filter : str, optional + The filter to apply to the dataset, by default None. If a filter is provided, + then we will first load all row ids in memory and then batch through the ids + in random order until enough matches have been found. Returns ------- @@ -143,18 +200,23 @@ def maybe_sample( if n >= len(dataset): # Dont have enough data in the dataset. Just do a full scan - yield from dataset.to_batches(columns=columns, batch_size=batch_size) + yield from dataset.to_batches( + columns=columns, batch_size=batch_size, filter=filt + ) + elif filt is not None: + yield from _filtered_efficient_sample( + dataset, n, columns, batch_size, max_takes, filt + ) + elif n > max_takes: + yield from _efficient_sample(dataset, n, columns, batch_size, max_takes) else: - if n > max_takes: - yield from _efficient_sample(dataset, n, columns, batch_size, max_takes) - else: - choices = np.random.choice(len(dataset), n, replace=False) - idx = 0 - while idx < len(choices): - end = min(idx + batch_size, len(choices)) - tbl = dataset.take(choices[idx:end], columns=columns).combine_chunks() - yield tbl.to_batches()[0] - idx += batch_size + choices = np.random.choice(len(dataset), n, replace=False) + idx = 0 + while idx < len(choices): + end = min(idx + batch_size, len(choices)) + tbl = dataset.take(choices[idx:end], columns=columns).combine_chunks() + yield tbl.to_batches()[0] + idx += batch_size T = TypeVar("T") diff --git a/python/python/lance/torch/data.py b/python/python/lance/torch/data.py index 30cb58c86f..3f066543dd 100644 --- a/python/python/lance/torch/data.py +++ b/python/python/lance/torch/data.py @@ -233,9 +233,6 @@ def __init__( warnings.warn("rank and world_size are deprecated", DeprecationWarning) self.sampler: Optional[Sampler] = sampler - if filter is not None and self.samples > 0 or self.samples is None: - raise ValueError("`filter` is not supported with `samples`") - # Dataset with huggingface metadata if ( dataset.schema.metadata is not None @@ -284,6 +281,7 @@ def __iter__(self): n=self.samples, columns=self.columns, batch_size=self.batch_size, + filt=self.filter, ) else: raw_stream = sampler( diff --git a/python/python/lance/vector.py b/python/python/lance/vector.py index 48daa44148..e30911e1b9 100644 --- a/python/python/lance/vector.py +++ b/python/python/lance/vector.py @@ -154,10 +154,22 @@ def train_ivf_centroids_on_accelerator( k = int(k) - logging.info("Randomly select %s centroids from %s", k, dataset) - samples = dataset.sample(k, [column], sorted=True).combine_chunks() - fsl = samples.to_batches()[0][column] - init_centroids = torch.from_numpy(np.stack(fsl.to_numpy(zero_copy_only=False))) + if dataset.schema.field(column).nullable: + filt = f"{column} is not null" + else: + filt = None + + logging.info("Randomly select %s centroids from %s (filt=%s)", k, dataset, filt) + + ds = TorchDataset( + dataset, + batch_size=k, + columns=[column], + samples=sample_size, + filter=filt, + ) + + init_centroids = next(iter(ds)) logging.info("Done sampling: centroids shape: %s", init_centroids.shape) ds = TorchDataset( @@ -165,6 +177,7 @@ def train_ivf_centroids_on_accelerator( batch_size=20480, columns=[column], samples=sample_size, + filter=filt, cache=True, ) @@ -233,6 +246,7 @@ def compute_partitions( batch_size=batch_size, with_row_id=True, columns=[column], + filter=f"{column} is not null", ) loader = torch.utils.data.DataLoader( torch_ds, diff --git a/python/python/tests/test_indices.py b/python/python/tests/test_indices.py index dc909a8925..015994c483 100644 --- a/python/python/tests/test_indices.py +++ b/python/python/tests/test_indices.py @@ -32,6 +32,21 @@ def rand_dataset(tmpdir, request): return ds +@pytest.fixture +def mostly_null_dataset(tmpdir, request): + vectors = np.random.randn(NUM_ROWS, DIMENSION).astype(np.float32) + vectors.shape = -1 + vectors = pa.FixedSizeListArray.from_arrays(vectors, DIMENSION) + vectors = vectors.to_pylist() + vectors = [vec if i % 10 == 0 else None for i, vec in enumerate(vectors)] + vectors = pa.array(vectors, pa.list_(pa.float32(), DIMENSION)) + table = pa.Table.from_arrays([vectors], names=["vectors"]) + + uri = str(tmpdir / "nulls_dataset") + ds = lance.write_dataset(table, uri, max_rows_per_file=NUM_ROWS_PER_FRAGMENT) + return ds + + def test_ivf_centroids(tmpdir, rand_dataset): ivf = IndicesBuilder(rand_dataset, "vectors").train_ivf(sample_rate=16) @@ -44,6 +59,13 @@ def test_ivf_centroids(tmpdir, rand_dataset): assert ivf.centroids == reloaded.centroids +def test_ivf_centroids_mostly_null(mostly_null_dataset): + ivf = IndicesBuilder(mostly_null_dataset, "vectors").train_ivf(sample_rate=16) + + assert ivf.distance_type == "l2" + assert len(ivf.centroids) == NUM_PARTITIONS + + @pytest.mark.cuda def test_ivf_centroids_cuda(rand_dataset): ivf = IndicesBuilder(rand_dataset, "vectors").train_ivf( @@ -54,6 +76,16 @@ def test_ivf_centroids_cuda(rand_dataset): assert len(ivf.centroids) == NUM_PARTITIONS +@pytest.mark.cuda +def test_ivf_centroids_mostly_null_cuda(mostly_null_dataset): + ivf = IndicesBuilder(mostly_null_dataset, "vectors").train_ivf( + sample_rate=16, accelerator="cuda" + ) + + assert ivf.distance_type == "l2" + assert len(ivf.centroids) == NUM_PARTITIONS + + def test_ivf_centroids_distance_type(tmpdir, rand_dataset): def check(distance_type): ivf = IndicesBuilder(rand_dataset, "vectors").train_ivf( @@ -95,6 +127,16 @@ def test_gen_pq(tmpdir, rand_dataset, rand_ivf): assert pq.codebook == reloaded.codebook +def test_gen_pq_mostly_null(mostly_null_dataset): + centroids = np.random.rand(DIMENSION * 100).astype(np.float32) + centroids = pa.FixedSizeListArray.from_arrays(centroids, DIMENSION) + ivf = IvfModel(centroids, "l2") + + pq = IndicesBuilder(mostly_null_dataset, "vectors").train_pq(ivf, sample_rate=2) + assert pq.dimension == DIMENSION + assert pq.num_subvectors == NUM_SUBVECTORS + + @pytest.mark.cuda def test_assign_partitions(rand_dataset, rand_ivf): builder = IndicesBuilder(rand_dataset, "vectors") @@ -113,6 +155,28 @@ def test_assign_partitions(rand_dataset, rand_ivf): assert len(found_row_ids) == rand_dataset.count_rows() +@pytest.mark.cuda +def test_assign_partitions_mostly_null(mostly_null_dataset): + centroids = np.random.rand(DIMENSION * 100).astype(np.float32) + centroids = pa.FixedSizeListArray.from_arrays(centroids, DIMENSION) + ivf = IvfModel(centroids, "l2") + + builder = IndicesBuilder(mostly_null_dataset, "vectors") + + partitions_uri = builder.assign_ivf_partitions(ivf, accelerator="cuda") + + partitions = lance.dataset(partitions_uri) + found_row_ids = set() + for batch in partitions.to_batches(): + row_ids = batch["row_id"] + for row_id in row_ids: + found_row_ids.add(row_id) + part_ids = batch["partition"] + for part_id in part_ids: + assert part_id.as_py() < 100 + assert len(found_row_ids) == (mostly_null_dataset.count_rows() / 10) + + @pytest.fixture def rand_pq(rand_dataset, rand_ivf): dtype = rand_dataset.schema.field("vectors").type.value_type.to_pandas_dtype() diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index 6fc54ee690..7673a64ee6 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -19,7 +19,7 @@ from lance.vector import vec_to_table # noqa: E402 -def create_table(nvec=1000, ndim=128, nans=0): +def create_table(nvec=1000, ndim=128, nans=0, nullify=False): mat = np.random.randn(nvec, ndim) if nans > 0: nans_mat = np.empty((nans, ndim)) @@ -37,6 +37,13 @@ def gen_str(n): .append_column("meta", pa.array(meta)) .append_column("id", pa.array(range(nvec + nans))) ) + if nullify: + idx = tbl.schema.get_field_index("vector") + vecs = tbl[idx].to_pylist() + nullified = [vec if i % 2 == 0 else None for i, vec in enumerate(vecs)] + field = tbl.schema.field(idx) + vecs = pa.array(nullified, field.type) + tbl = tbl.set_column(idx, field, vecs) return tbl @@ -191,8 +198,9 @@ def test_index_with_pq_codebook(tmp_path): @pytest.mark.cuda -def test_create_index_using_cuda(tmp_path): - tbl = create_table() +@pytest.mark.parametrize("nullify", [False, True]) +def test_create_index_using_cuda(tmp_path, nullify): + tbl = create_table(nullify=nullify) dataset = lance.write_dataset(tbl, tmp_path) dataset = dataset.create_index( "vector", diff --git a/python/python/tests/torch_tests/test_data.py b/python/python/tests/torch_tests/test_data.py index 66a214001c..bb76b30771 100644 --- a/python/python/tests/torch_tests/test_data.py +++ b/python/python/tests/torch_tests/test_data.py @@ -134,14 +134,16 @@ def check(dataset): ) ) - # sampling fails - with pytest.raises(ValueError): - LanceDataset( - ds, - batch_size=10, - filter="ids >= 300", - samples=100, - columns=["ids"], + # sampling with filter + with pytest.raises(NotImplementedError): + check( + LanceDataset( + ds, + batch_size=10, + filter="ids >= 300", + samples=100, + columns=["ids"], + ) )