Skip to content

Commit

Permalink
feat: handle nulls when creating indices with cuda (#2910)
Browse files Browse the repository at this point in the history
  • Loading branch information
westonpace authored Sep 19, 2024
1 parent ef39518 commit 3f935d5
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 29 deletions.
84 changes: 73 additions & 11 deletions python/python/lance/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import gc
import logging
import math
import random
import warnings
from abc import ABC, abstractmethod
Expand All @@ -15,6 +16,7 @@
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, TypeVar, Union

import pyarrow as pa
import pyarrow.compute as pc

import lance
from lance.dependencies import numpy as np
Expand Down Expand Up @@ -105,12 +107,63 @@ def _efficient_sample(
del tbl


def _filtered_efficient_sample(
dataset: lance.LanceDataset,
n: int,
columns: Optional[Union[List[str], Dict[str, str]]],
batch_size: int,
target_takes: int,
filter: str,
) -> Generator[pa.RecordBatch, None, None]:
total_records = len(dataset)
shard_size = math.ceil(n / target_takes)
num_shards = math.ceil(total_records / shard_size)

shards = list(range(num_shards))
random.shuffle(shards)

tables = []
remaining_rows = n
remaining_in_batch = min(batch_size, n)
for shard in shards:
start = shard * shard_size
end = min(start + shard_size, total_records)
table = dataset.to_table(
columns=columns,
offset=start,
limit=(end - start),
batch_size=shard_size,
)
if len(columns) == 1 and filter.lower() == f"{columns[0]} is not null":
table = pc.drop_null(table)
elif filter is not None:
raise NotImplementedError(f"Can't yet run filter <{filter}> in-memory")
if table.num_rows > 0:
tables.append(table)
remaining_rows -= table.num_rows
remaining_in_batch = remaining_in_batch - table.num_rows
if remaining_in_batch <= 0:
combined = pa.concat_tables(tables).combine_chunks()
batch = combined.slice(0, batch_size).to_batches()[0]
yield batch
remaining_in_batch = min(batch_size, remaining_rows)
if len(combined) > batch_size:
leftover = combined.slice(batch_size)
tables = [leftover]
remaining_in_batch -= len(leftover)
else:
tables = []
if remaining_rows <= 0:
break


def maybe_sample(
dataset: Union[str, Path, lance.LanceDataset],
n: int,
columns: Union[list[str], dict[str, str], str],
batch_size: int = 10240,
max_takes: int = 2048,
filt: Optional[str] = None,
) -> Generator[pa.RecordBatch, None, None]:
"""Sample n records from the dataset.
Expand All @@ -129,6 +182,10 @@ def maybe_sample(
This is employed to minimize the number of random reads necessary for sampling.
A sufficiently large value can provide an effective random sample without
the need for excessive random reads.
filter : str, optional
The filter to apply to the dataset, by default None. If a filter is provided,
then we will first load all row ids in memory and then batch through the ids
in random order until enough matches have been found.
Returns
-------
Expand All @@ -143,18 +200,23 @@ def maybe_sample(

if n >= len(dataset):
# Dont have enough data in the dataset. Just do a full scan
yield from dataset.to_batches(columns=columns, batch_size=batch_size)
yield from dataset.to_batches(
columns=columns, batch_size=batch_size, filter=filt
)
elif filt is not None:
yield from _filtered_efficient_sample(
dataset, n, columns, batch_size, max_takes, filt
)
elif n > max_takes:
yield from _efficient_sample(dataset, n, columns, batch_size, max_takes)
else:
if n > max_takes:
yield from _efficient_sample(dataset, n, columns, batch_size, max_takes)
else:
choices = np.random.choice(len(dataset), n, replace=False)
idx = 0
while idx < len(choices):
end = min(idx + batch_size, len(choices))
tbl = dataset.take(choices[idx:end], columns=columns).combine_chunks()
yield tbl.to_batches()[0]
idx += batch_size
choices = np.random.choice(len(dataset), n, replace=False)
idx = 0
while idx < len(choices):
end = min(idx + batch_size, len(choices))
tbl = dataset.take(choices[idx:end], columns=columns).combine_chunks()
yield tbl.to_batches()[0]
idx += batch_size


T = TypeVar("T")
Expand Down
4 changes: 1 addition & 3 deletions python/python/lance/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,6 @@ def __init__(
warnings.warn("rank and world_size are deprecated", DeprecationWarning)
self.sampler: Optional[Sampler] = sampler

if filter is not None and self.samples > 0 or self.samples is None:
raise ValueError("`filter` is not supported with `samples`")

# Dataset with huggingface metadata
if (
dataset.schema.metadata is not None
Expand Down Expand Up @@ -284,6 +281,7 @@ def __iter__(self):
n=self.samples,
columns=self.columns,
batch_size=self.batch_size,
filt=self.filter,
)
else:
raw_stream = sampler(
Expand Down
22 changes: 18 additions & 4 deletions python/python/lance/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,30 @@ def train_ivf_centroids_on_accelerator(

k = int(k)

logging.info("Randomly select %s centroids from %s", k, dataset)
samples = dataset.sample(k, [column], sorted=True).combine_chunks()
fsl = samples.to_batches()[0][column]
init_centroids = torch.from_numpy(np.stack(fsl.to_numpy(zero_copy_only=False)))
if dataset.schema.field(column).nullable:
filt = f"{column} is not null"
else:
filt = None

logging.info("Randomly select %s centroids from %s (filt=%s)", k, dataset, filt)

ds = TorchDataset(
dataset,
batch_size=k,
columns=[column],
samples=sample_size,
filter=filt,
)

init_centroids = next(iter(ds))
logging.info("Done sampling: centroids shape: %s", init_centroids.shape)

ds = TorchDataset(
dataset,
batch_size=20480,
columns=[column],
samples=sample_size,
filter=filt,
cache=True,
)

Expand Down Expand Up @@ -233,6 +246,7 @@ def compute_partitions(
batch_size=batch_size,
with_row_id=True,
columns=[column],
filter=f"{column} is not null",
)
loader = torch.utils.data.DataLoader(
torch_ds,
Expand Down
64 changes: 64 additions & 0 deletions python/python/tests/test_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@ def rand_dataset(tmpdir, request):
return ds


@pytest.fixture
def mostly_null_dataset(tmpdir, request):
vectors = np.random.randn(NUM_ROWS, DIMENSION).astype(np.float32)
vectors.shape = -1
vectors = pa.FixedSizeListArray.from_arrays(vectors, DIMENSION)
vectors = vectors.to_pylist()
vectors = [vec if i % 10 == 0 else None for i, vec in enumerate(vectors)]
vectors = pa.array(vectors, pa.list_(pa.float32(), DIMENSION))
table = pa.Table.from_arrays([vectors], names=["vectors"])

uri = str(tmpdir / "nulls_dataset")
ds = lance.write_dataset(table, uri, max_rows_per_file=NUM_ROWS_PER_FRAGMENT)
return ds


def test_ivf_centroids(tmpdir, rand_dataset):
ivf = IndicesBuilder(rand_dataset, "vectors").train_ivf(sample_rate=16)

Expand All @@ -44,6 +59,13 @@ def test_ivf_centroids(tmpdir, rand_dataset):
assert ivf.centroids == reloaded.centroids


def test_ivf_centroids_mostly_null(mostly_null_dataset):
ivf = IndicesBuilder(mostly_null_dataset, "vectors").train_ivf(sample_rate=16)

assert ivf.distance_type == "l2"
assert len(ivf.centroids) == NUM_PARTITIONS


@pytest.mark.cuda
def test_ivf_centroids_cuda(rand_dataset):
ivf = IndicesBuilder(rand_dataset, "vectors").train_ivf(
Expand All @@ -54,6 +76,16 @@ def test_ivf_centroids_cuda(rand_dataset):
assert len(ivf.centroids) == NUM_PARTITIONS


@pytest.mark.cuda
def test_ivf_centroids_mostly_null_cuda(mostly_null_dataset):
ivf = IndicesBuilder(mostly_null_dataset, "vectors").train_ivf(
sample_rate=16, accelerator="cuda"
)

assert ivf.distance_type == "l2"
assert len(ivf.centroids) == NUM_PARTITIONS


def test_ivf_centroids_distance_type(tmpdir, rand_dataset):
def check(distance_type):
ivf = IndicesBuilder(rand_dataset, "vectors").train_ivf(
Expand Down Expand Up @@ -95,6 +127,16 @@ def test_gen_pq(tmpdir, rand_dataset, rand_ivf):
assert pq.codebook == reloaded.codebook


def test_gen_pq_mostly_null(mostly_null_dataset):
centroids = np.random.rand(DIMENSION * 100).astype(np.float32)
centroids = pa.FixedSizeListArray.from_arrays(centroids, DIMENSION)
ivf = IvfModel(centroids, "l2")

pq = IndicesBuilder(mostly_null_dataset, "vectors").train_pq(ivf, sample_rate=2)
assert pq.dimension == DIMENSION
assert pq.num_subvectors == NUM_SUBVECTORS


@pytest.mark.cuda
def test_assign_partitions(rand_dataset, rand_ivf):
builder = IndicesBuilder(rand_dataset, "vectors")
Expand All @@ -113,6 +155,28 @@ def test_assign_partitions(rand_dataset, rand_ivf):
assert len(found_row_ids) == rand_dataset.count_rows()


@pytest.mark.cuda
def test_assign_partitions_mostly_null(mostly_null_dataset):
centroids = np.random.rand(DIMENSION * 100).astype(np.float32)
centroids = pa.FixedSizeListArray.from_arrays(centroids, DIMENSION)
ivf = IvfModel(centroids, "l2")

builder = IndicesBuilder(mostly_null_dataset, "vectors")

partitions_uri = builder.assign_ivf_partitions(ivf, accelerator="cuda")

partitions = lance.dataset(partitions_uri)
found_row_ids = set()
for batch in partitions.to_batches():
row_ids = batch["row_id"]
for row_id in row_ids:
found_row_ids.add(row_id)
part_ids = batch["partition"]
for part_id in part_ids:
assert part_id.as_py() < 100
assert len(found_row_ids) == (mostly_null_dataset.count_rows() / 10)


@pytest.fixture
def rand_pq(rand_dataset, rand_ivf):
dtype = rand_dataset.schema.field("vectors").type.value_type.to_pandas_dtype()
Expand Down
14 changes: 11 additions & 3 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from lance.vector import vec_to_table # noqa: E402


def create_table(nvec=1000, ndim=128, nans=0):
def create_table(nvec=1000, ndim=128, nans=0, nullify=False):
mat = np.random.randn(nvec, ndim)
if nans > 0:
nans_mat = np.empty((nans, ndim))
Expand All @@ -37,6 +37,13 @@ def gen_str(n):
.append_column("meta", pa.array(meta))
.append_column("id", pa.array(range(nvec + nans)))
)
if nullify:
idx = tbl.schema.get_field_index("vector")
vecs = tbl[idx].to_pylist()
nullified = [vec if i % 2 == 0 else None for i, vec in enumerate(vecs)]
field = tbl.schema.field(idx)
vecs = pa.array(nullified, field.type)
tbl = tbl.set_column(idx, field, vecs)
return tbl


Expand Down Expand Up @@ -191,8 +198,9 @@ def test_index_with_pq_codebook(tmp_path):


@pytest.mark.cuda
def test_create_index_using_cuda(tmp_path):
tbl = create_table()
@pytest.mark.parametrize("nullify", [False, True])
def test_create_index_using_cuda(tmp_path, nullify):
tbl = create_table(nullify=nullify)
dataset = lance.write_dataset(tbl, tmp_path)
dataset = dataset.create_index(
"vector",
Expand Down
18 changes: 10 additions & 8 deletions python/python/tests/torch_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,16 @@ def check(dataset):
)
)

# sampling fails
with pytest.raises(ValueError):
LanceDataset(
ds,
batch_size=10,
filter="ids >= 300",
samples=100,
columns=["ids"],
# sampling with filter
with pytest.raises(NotImplementedError):
check(
LanceDataset(
ds,
batch_size=10,
filter="ids >= 300",
samples=100,
columns=["ids"],
)
)


Expand Down

0 comments on commit 3f935d5

Please sign in to comment.