Skip to content

Commit

Permalink
feat: cuvs acceleration for gpu k-means (#2816)
Browse files Browse the repository at this point in the history
We currently have a pytorch-based k-means implementation for computing
IVF centroids. This PR accelerates it with cuVS.
This uses a tradeoff of faster iterations/less score improvement per
iteration.

By default, this is off, since it's primarily useful for very large
datasets where large centroid counts are applicable.

Benchmarking (classic k-means scoring):
- k=16384 clusters
- text2image-10M base set (10M image embeddings, 200 dimensions,
float32, cosine distance)

Results: Slightly better score @ ~1.5x faster. Speedup gets better with
more centroids.

<details>

<summary>Easy test script & outputs</summary>

```py
import numpy as np
from lance.cuvs.kmeans import KMeans as KMeansVS
from lance.torch.kmeans import KMeans
import lance
import time
# Note: This kind of approach performs quite poorly on random data (see https://arxiv.org/abs/2405.18680), so it's only worth testing on a real dataset
ds = lance.dataset("path/to/text2image-dataset") # can also use other medium~large datasets
data = np.stack(ds.to_table()["vector"].to_numpy())
max_iters_base = 10
max_iters_cuvs = 12 # iters using cuvs are much faster, but slightly less precise
metric = "cosine"

cuvs_start_time = time.time()
kmeans_cuvs = KMeansVS(
    CLUSTERS,
    metric=metric,
    max_iters=max_iters_cuvs,
    seed=0,
)
kmeans_cuvs.fit(data)
cuvs_end_time = time.time()

base_start_time = time.time()
kmeans = KMeans(
    CLUSTERS,
    metric=metric,
    max_iters=max_iters_base,
    seed=0,
)
kmeans.fit(data)
base_end_time = time.time()
print(f"score after {max_iters_cuvs} iters of kmeans_cuvs better than {max_iters_base} iters of kmeans by {kmeans.total_distance - kmeans_cuvs.total_distance}")
base_time = base_end_time-base_start_time
cuvs_time = cuvs_end_time-cuvs_start_time
print(f"time to run kmeans: {base_time}s. time to run kmeans_cuvs: {cuvs_time} (speedup: {base_time/cuvs_time}x)")
```
Output:
```
score after 12 iters of kmeans_cuvs better than 10 iters of kmeans by 5905.7138671875
time to run kmeans: 86.69116258621216s. time to run kmeans_cuvs: 56.66267776489258 (speedup: 1.5299517425899842x)
```
</details>

Additionally, a new "accelerator" choice has been added: "cuvs". This
requires one of the added optional dependencies (cuvs-py3X, X in
{9,10,11}). This can replace the two routines for which we already have
cuda acceleration: IVF model training (Lloyd's algorithm) and IVF
assignments. At sufficiently large centroid counts, this can
significantly accelerate these steps, resulting in better e2e time. See
below:


![results_static_20240906_132311_plot_dataset_sift1m_k_10](https://github.com/user-attachments/assets/579016ce-83c7-4fb0-a3da-a12609c76f02)
Although these plots are near-identical, the "cuvs" accelerated
variation took ~18.1s to build e2e, while the "cuda" accelerated
variation took ~24.4s.

This speedup persists on larger datasets, although I was mistaken in
that PQ assignments are a bigger bottleneck as the dataset gets larger
(thanks to some improvements I did not see), so this is not the
bottleneck step. The next step after this PR will be to accelerate PQ
with both cuda and cuvs.
  • Loading branch information
jacketsj authored Sep 23, 2024
1 parent 5606b17 commit ea78168
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 12 deletions.
14 changes: 14 additions & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,20 @@ tests = [
dev = ["ruff==0.4.1"]
benchmarks = ["pytest-benchmark"]
torch = ["torch"]
cuvs-py39 = [
"cuvs-cu12 @ https://pypi.nvidia.com/cuvs-cu12/cuvs_cu12-24.8.0-cp39-cp39-manylinux_2_28_x86_64.whl",
"pylibraft-cu12 @ https://pypi.nvidia.com/pylibraft-cu12/pylibraft_cu12-24.8.1-cp39-cp39-manylinux_2_28_x86_64.whl"
]

cuvs-py310 = [
"cuvs-cu12 @ https://pypi.nvidia.com/cuvs-cu12/cuvs_cu12-24.8.0-cp310-cp310-manylinux_2_28_x86_64.whl",
"pylibraft-cu12 @ https://pypi.nvidia.com/pylibraft-cu12/pylibraft_cu12-24.8.1-cp310-cp310-manylinux_2_28_x86_64.whl"
]

cuvs-py311 = [
"cuvs-cu12 @ https://pypi.nvidia.com/cuvs-cu12/cuvs_cu12-24.8.0-cp311-cp311-manylinux_2_28_x86_64.whl",
"pylibraft-cu12 @ https://pypi.nvidia.com/pylibraft-cu12/pylibraft_cu12-24.8.1-cp311-cp311-manylinux_2_28_x86_64.whl"
]
ray = ["ray[data]; python_version<'3.12'"]

[tool.ruff]
Expand Down
2 changes: 2 additions & 0 deletions python/python/lance/cuvs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors
143 changes: 143 additions & 0 deletions python/python/lance/cuvs/kmeans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors


import logging
import time
from typing import Literal, Optional, Tuple, Union

import pyarrow as pa

from lance.dependencies import cagra, raft_common, torch
from lance.dependencies import numpy as np
from lance.torch.kmeans import KMeans as KMeansTorch

__all__ = ["KMeans"]


class KMeans(KMeansTorch):
"""K-Means trains over vectors and divide into K clusters,
using cuVS as accelerator.
This implement is built on PyTorch+cuVS, supporting Nvidia GPU only.
Parameters
----------
k: int
The number of clusters
metric : str
Metric type, support "l2", "cosine" or "dot"
init: str
Initialization method. Only support "random" now.
max_iters: int
Max number of iterations to train the kmean model.
tolerance: float
Relative tolerance in regard to Frobenius norm of the difference in
the cluster centers of two consecutive iterations to declare convergence.
centroids : torch.Tensor, optional.
Provide existing centroids.
seed: int, optional
Random seed
device: str, optional
The device to run the PyTorch algorithms. Default we will pick
the most performant device on the host. See `lance.torch.preferred_device()`
For the cuVS implementation, it will be verified this is a cuda device.
"""

def __init__(
self,
k: int,
*,
metric: Literal["l2", "euclidean", "cosine", "dot"] = "l2",
init: Literal["random"] = "random",
max_iters: int = 50,
tolerance: float = 1e-4,
centroids: Optional[torch.Tensor] = None,
seed: Optional[int] = None,
device: Optional[str] = None,
itopk_size: int = 10,
):
if metric == "dot":
raise ValueError(
'Kmeans::__init__: metric == "dot" is incompatible' " with cuVS"
)
super().__init__(
k,
metric=metric,
init=init,
max_iters=max_iters,
tolerance=tolerance,
centroids=centroids,
seed=seed,
device=device,
)

if self.device.type != "cuda" or not torch.cuda.is_available():
raise ValueError("KMeans::__init__: cuda is not enabled/available")

self.itopk_size = itopk_size
self.time_rebuild = 0.0
self.time_search = 0.0

def fit(
self,
data: Union[
torch.utils.data.IterableDataset,
np.ndarray,
torch.Tensor,
pa.FixedSizeListArray,
],
) -> None:
self.time_rebuild = 0.0
self.time_search = 0.0
super().fit(data)
logging.info("Total search time: %s", self.time_search)
logging.info("Total rebuild time: %s", self.time_rebuild)

def rebuild_index(self):
rebuild_time_start = time.time()
cagra_metric = "sqeuclidean"
dim = self.centroids.shape[1]
graph_degree = max(dim // 4, 32)
nn_descent_degree = graph_degree * 2
index_params = cagra.IndexParams(
metric=cagra_metric,
intermediate_graph_degree=nn_descent_degree,
graph_degree=graph_degree,
build_algo="nn_descent",
compression=None,
)
self.index = cagra.build(index_params, self.centroids)
rebuild_time_end = time.time()
self.time_rebuild += rebuild_time_end - rebuild_time_start

self.y2 = None

def _transform(
self,
data: torch.Tensor,
y2: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.metric == "cosine":
data = torch.nn.functional.normalize(data)

search_time_start = time.time()
device = torch.device("cuda")
out_idx = raft_common.device_ndarray.empty((data.shape[0], 1), dtype="uint32")
out_dist = raft_common.device_ndarray.empty((data.shape[0], 1), dtype="float32")
search_params = cagra.SearchParams(itopk_size=self.itopk_size)
cagra.search(
search_params,
self.index,
data,
1,
neighbors=out_idx,
distances=out_dist,
)
ret = (
torch.as_tensor(out_idx, device=device).squeeze(dim=1).view(torch.int32),
torch.as_tensor(out_dist, device=device),
)
search_time_end = time.time()
self.time_search += search_time_end - search_time_start
return ret
10 changes: 10 additions & 0 deletions python/python/lance/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
_PANDAS_AVAILABLE = True
_POLARS_AVAILABLE = True
_TORCH_AVAILABLE = True
_CAGRA_AVAILABLE = True
_RAFT_COMMON_AVAILABLE = True
_HUGGING_FACE_AVAILABLE = True
_TENSORFLOW_AVAILABLE = True
_RAY_AVAILABLE = True
Expand All @@ -48,6 +50,8 @@ class _LazyModule(ModuleType):
"pandas": "pd.",
"polars": "pl.",
"torch": "torch.",
"cagra": "cagra.",
"common": "raft_common.",
"tensorflow": "tf.",
"ray": "ray.",
}
Expand Down Expand Up @@ -172,6 +176,8 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]:
pandas, _PANDAS_AVAILABLE = _lazy_import("pandas")
polars, _POLARS_AVAILABLE = _lazy_import("polars")
torch, _TORCH_AVAILABLE = _lazy_import("torch")
cagra, _CAGRA_AVAILABLE = _lazy_import("cuvs.neighbors.cagra")
raft_common, _RAFT_COMMON_AVAILABLE = _lazy_import("pylibraft.common")
datasets, _HUGGING_FACE_AVAILABLE = _lazy_import("datasets")
tensorflow, _TENSORFLOW_AVAILABLE = _lazy_import("tensorflow")
ray, _RAY_AVAILABLE = _lazy_import("ray")
Expand Down Expand Up @@ -238,6 +244,8 @@ def _check_for_ray(obj: Any, *, check_type: bool = True) -> bool:
"ray",
"tensorflow",
"torch",
"cagra",
"raft_common",
# lazy utilities
"_check_for_hugging_face",
"_check_for_numpy",
Expand All @@ -252,6 +260,8 @@ def _check_for_ray(obj: Any, *, check_type: bool = True) -> bool:
"_PANDAS_AVAILABLE",
"_POLARS_AVAILABLE",
"_TORCH_AVAILABLE",
"_CAGRA_AVAILABLE",
"_RAFT_COMMON_AVAILABLE",
"_HUGGING_FACE_AVAILABLE",
"_TENSORFLOW_AVAILABLE",
"_RAY_AVAILABLE",
Expand Down
15 changes: 12 additions & 3 deletions python/python/lance/torch/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
import pyarrow as pa
from tqdm import tqdm

from lance.dependencies import _check_for_numpy, _check_for_torch, torch
from lance.dependencies import (
_check_for_numpy,
_check_for_torch,
torch,
)
from lance.dependencies import numpy as np

from . import preferred_device
Expand Down Expand Up @@ -79,6 +83,8 @@ def __init__(
self.tolerance = tolerance
self.seed = seed

self.y2 = None

def __repr__(self):
return f"KMeans(k={self.k}, metric={self.metric}, device={self.device})"

Expand Down Expand Up @@ -220,14 +226,14 @@ def _fit_once(
)
counts_per_part = torch.zeros(self.centroids.shape[0], device=self.device)
ones = torch.ones(1024 * 16, device=self.device)
y2 = (self.centroids * self.centroids).sum(dim=1)
self.rebuild_index()
for idx, chunk in enumerate(data):
if idx % 50 == 0:
logging.info("Kmeans::train: epoch %s, chunk %s", epoch, idx)
chunk: torch.Tensor = chunk
dtype = chunk.dtype
chunk = chunk.to(self.device)
ids, dists = self._transform(chunk, y2=y2)
ids, dists = self._transform(chunk, y2=self.y2)

valid_mask = ids >= 0
if torch.any(~valid_mask):
Expand Down Expand Up @@ -263,6 +269,9 @@ def _fit_once(
)
return total_dist

def rebuild_index(self):
self.y2 = (self.centroids * self.centroids).sum(dim=1)

def _transform(
self,
data: torch.Tensor,
Expand Down
35 changes: 26 additions & 9 deletions python/python/lance/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ def train_ivf_centroids_on_accelerator(
) -> (np.ndarray, str):
"""Use accelerator (GPU or MPS) to train kmeans."""
if isinstance(accelerator, str) and (
not (CUDA_REGEX.match(accelerator) or accelerator == "mps")
not (
CUDA_REGEX.match(accelerator)
or accelerator == "mps"
or accelerator == "cuvs"
)
):
raise ValueError(
"Train ivf centroids on accelerator: "
Expand Down Expand Up @@ -181,14 +185,27 @@ def train_ivf_centroids_on_accelerator(
cache=True,
)

logging.info("Training IVF partitions using GPU(%s)", accelerator)
kmeans = KMeans(
k,
max_iters=max_iters,
metric=metric_type,
device=accelerator,
centroids=init_centroids,
)
if accelerator == "cuvs":
logging.info("Training IVF partitions using cuVS+GPU")
print("Training IVF partitions using cuVS+GPU")
from lance.cuvs.kmeans import KMeans as KMeansCuVS

kmeans = KMeansCuVS(
k,
max_iters=max_iters,
metric=metric_type,
device="cuda",
centroids=init_centroids,
)
else:
logging.info("Training IVF partitions using GPU(%s)", accelerator)
kmeans = KMeans(
k,
max_iters=max_iters,
metric=metric_type,
device=accelerator,
centroids=init_centroids,
)
kmeans.fit(ds)

centroids = kmeans.centroids.cpu().numpy()
Expand Down

0 comments on commit ea78168

Please sign in to comment.