Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.distributed kmeans #3876

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 46 additions & 17 deletions contrib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,12 @@ def assign_to(self, centroids, weights=None):

I = I.ravel()
D = D.ravel()
n = len(self.x)
nc, d = centroids.shape
sum_per_centroid = np.zeros((nc, d), dtype='float32')
if weights is None:
weights = np.ones(n, dtype='float32')
nc = len(centroids)
m = scipy.sparse.csc_matrix(
(weights, I, np.arange(n + 1)),
shape=(nc, n))
sum_per_centroid = m * self.x
np.add.at(sum_per_centroid, I, self.x)
else:
np.add.at(sum_per_centroid, I, weights[:, np.newaxis] * self.x)

return I, D, sum_per_centroid

Expand All @@ -185,7 +183,8 @@ def perform_search(self, centroids):

def sparse_assign_to_dense(xq, xb, xq_norms=None, xb_norms=None):
""" assignment function for xq is sparse, xb is dense
uses a matrix multiplication. The squared norms can be provided if available.
uses a matrix multiplication. The squared norms can be provided if
available.
"""
nq = xq.shape[0]
nb = xb.shape[0]
Expand Down Expand Up @@ -272,6 +271,7 @@ def assign_to(self, centroids, weights=None):
if weights is None:
weights = np.ones(n, dtype='float32')
nc = len(centroids)

m = scipy.sparse.csc_matrix(
(weights, I, np.arange(n + 1)),
shape=(nc, n))
Expand All @@ -285,25 +285,40 @@ def imbalance_factor(k, assign):
return faiss.imbalance_factor(len(assign), k, faiss.swig_ptr(assign))


def check_if_torch(x):
if x.__class__ == np.ndarray:
return False
import torch
if isinstance(x, torch.Tensor):
return True
raise NotImplementedError(f"Unknown tensor type {type(x)}")


def reassign_centroids(hassign, centroids, rs=None):
""" reassign centroids when some of them collapse """
if rs is None:
rs = np.random
k, d = centroids.shape
nsplit = 0
is_torch = check_if_torch(centroids)

empty_cents = np.where(hassign == 0)[0]

if empty_cents.size == 0:
if len(empty_cents) == 0:
return 0

fac = np.ones(d)
if is_torch:
import torch
fac = torch.ones_like(centroids[0])
else:
fac = np.ones_like(centroids[0])
fac[::2] += 1 / 1024.
fac[1::2] -= 1 / 1024.

# this is a single pass unless there are more than k/2
# empty centroids
while empty_cents.size > 0:
# choose which centroids to split
while len(empty_cents) > 0:
# choose which centroids to split (numpy)
probas = hassign.astype('float') - 1
probas[probas < 0] = 0
probas /= probas.sum()
Expand All @@ -327,13 +342,17 @@ def reassign_centroids(hassign, centroids, rs=None):
return nsplit



def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,
return_stats=False):
"""Pure python kmeans implementation. Follows the Faiss C++ version
quite closely, but takes a DatasetAssign instead of a training data
matrix. Also redo is not implemented. """
n, d = data.count(), data.dim()
matrix. Also redo is not implemented.

For the torch implementation, the centroids are tensors (possibly on GPU),
but the indices remain numpy on CPU.
"""
n, d = data.count(), data.dim()
log = print if verbose else print_nop

log(("Clustering %d points in %dD to %d clusters, " +
Expand All @@ -345,6 +364,7 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,
# initialization
perm = rs.choice(n, size=k, replace=False)
centroids = data.get_subset(perm)
is_torch = check_if_torch(centroids)

iteration_stats = []

Expand All @@ -362,12 +382,17 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,
t_search_tot += time.time() - t0s;

err = D.sum()
if is_torch:
err = err.item()
obj.append(err)

hassign = np.bincount(assign, minlength=k)

fac = hassign.reshape(-1, 1).astype('float32')
fac[fac == 0] = 1 # quiet warning
fac[fac == 0] = 1 # quiet warning
if is_torch:
import torch
fac = torch.from_numpy(fac).to(sums.device)

centroids = sums / fac

Expand All @@ -377,7 +402,7 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,
"obj": err,
"time": (time.time() - t0),
"time_search": t_search_tot,
"imbalance_factor": imbalance_factor (k, assign),
"imbalance_factor": imbalance_factor(k, assign),
"nsplit": nsplit
}

Expand All @@ -391,7 +416,11 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,

if checkpoint is not None:
log('storing centroids in', checkpoint)
np.save(checkpoint, centroids)
if is_torch:
import torch
torch.save(centroids, checkpoint)
else:
np.save(checkpoint, centroids)

if return_stats:
return centroids, iteration_stats
Expand Down
6 changes: 6 additions & 0 deletions contrib/torch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# The Torch contrib

This contrib directory contains a few Pytorch routines that
are useful for similarity search. They do not necessarily depend on Faiss.

The code is designed to work with CPU and GPU tensors.
Empty file added contrib/torch/__init__.py
Empty file.
59 changes: 59 additions & 0 deletions contrib/torch/clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
This contrib module contains Pytorch code for k-means clustering
"""
import faiss
import faiss.contrib.torch_utils
import torch

# the kmeans can produce both torch and numpy centroids
from faiss.contrib.clustering import kmeans

class DatasetAssign:
"""Wrapper for a tensor that offers a function to assign the vectors
to centroids. All other implementations offer the same interface"""

def __init__(self, x):
self.x = x

def count(self):
return self.x.shape[0]

def dim(self):
return self.x.shape[1]

def get_subset(self, indices):
return self.x[indices]

def perform_search(self, centroids):
return faiss.knn(self.x, centroids, 1)

def assign_to(self, centroids, weights=None):
D, I = self.perform_search(centroids)

I = I.ravel()
D = D.ravel()
nc, d = centroids.shape

sum_per_centroid = torch.zeros_like(centroids)
if weights is None:
sum_per_centroid.index_add_(0, I, self.x)
else:
sum_per_centroid.index_add_(0, I, self.x * weights[:, None])

# the indices are still in numpy.
return I.cpu().numpy(), D, sum_per_centroid


class DatasetAssignGPU(DatasetAssign):

def __init__(self, res, x):
DatasetAssign.__init__(self, x)
self.res = res

def perform_search(self, centroids):
return faiss.knn_gpu(self.res, self.x, centroids, 1)
53 changes: 53 additions & 0 deletions contrib/torch/quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
This contrib module contains Pytorch code for quantization.
"""

import numpy as np
import torch
import faiss

from faiss.contrib import torch_utils


class Quantizer:

def __init__(self, d, code_size):
self.d = d
self.code_size = code_size

def train(self, x):
pass

def encode(self, x):
pass

def decode(self, x):
pass


class VectorQuantizer(Quantizer):

def __init__(self, d, k):
code_size = int(torch.ceil(torch.log2(k) / 8))
Quantizer.__init__(d, code_size)
self.k = k

def train(self, x):
pass


class ProductQuantizer(Quantizer):

def __init__(self, d, M, nbits):
code_size = int(torch.ceil(M * nbits / 8))
Quantizer.__init__(d, code_size)
self.M = M
self.nbits = nbits

def train(self, x):
pass
62 changes: 62 additions & 0 deletions contrib/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,18 @@
import sys
import numpy as np

##################################################################
# Equivalent of swig_ptr for Torch tensors
##################################################################

def swig_ptr_from_UInt8Tensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
assert x.dtype == torch.uint8
return faiss.cast_integer_to_uint8_ptr(
x.untyped_storage().data_ptr() + x.storage_offset())


def swig_ptr_from_HalfTensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
Expand All @@ -43,27 +48,34 @@ def swig_ptr_from_HalfTensor(x):
return faiss.cast_integer_to_void_ptr(
x.untyped_storage().data_ptr() + x.storage_offset() * 2)


def swig_ptr_from_FloatTensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
assert x.dtype == torch.float32
return faiss.cast_integer_to_float_ptr(
x.untyped_storage().data_ptr() + x.storage_offset() * 4)


def swig_ptr_from_IntTensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
assert x.dtype == torch.int32, 'dtype=%s' % x.dtype
return faiss.cast_integer_to_int_ptr(
x.untyped_storage().data_ptr() + x.storage_offset() * 4)


def swig_ptr_from_IndicesTensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
assert x.dtype == torch.int64, 'dtype=%s' % x.dtype
return faiss.cast_integer_to_idx_t_ptr(
x.untyped_storage().data_ptr() + x.storage_offset() * 8)

##################################################################
# utilities
##################################################################

@contextlib.contextmanager
def using_stream(res, pytorch_stream=None):
""" Creates a scoping object to make Faiss GPU use the same stream
Expand Down Expand Up @@ -107,6 +119,10 @@ def torch_replace_method(the_class, name, replacement,
setattr(the_class, name + '_numpy', orig_method)
setattr(the_class, name, replacement)

##################################################################
# Setup wrappers
##################################################################

def handle_torch_Index(the_class):
def torch_replacement_add(self, x):
if type(x) is np.ndarray:
Expand Down Expand Up @@ -493,6 +509,52 @@ def torch_replacement_sa_decode(self, codes, x=None):
handle_torch_Index(the_class)


# allows torch tensor usage with knn
def torch_replacement_knn(xq, xb, k, metric=faiss.METRIC_L2, metric_arg=0):
if type(xb) is np.ndarray:
# Forward to faiss __init__.py base method
return faiss.knn_numpy(xq, xb, k, metric=metric, metric_arg=metric_arg)

nb, d = xb.size()
assert xb.is_contiguous()
assert xb.dtype == torch.float32
assert not xb.is_cuda, "use knn_gpu for GPU tensors"

nq, d2 = xq.size()
assert d2 == d
assert xq.is_contiguous()
assert xq.dtype == torch.float32
assert not xq.is_cuda, "use knn_gpu for GPU tensors"

D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)
I = torch.empty(nq, k, device=xb.device, dtype=torch.int64)
I_ptr = swig_ptr_from_IndicesTensor(I)
D_ptr = swig_ptr_from_FloatTensor(D)
xb_ptr = swig_ptr_from_FloatTensor(xb)
xq_ptr = swig_ptr_from_FloatTensor(xq)

if metric == faiss.METRIC_L2:
faiss.knn_L2sqr(
xq_ptr, xb_ptr,
d, nq, nb, k, D_ptr, I_ptr
)
elif metric == faiss.METRIC_INNER_PRODUCT:
faiss.knn_inner_product(
xq_ptr, xb_ptr,
d, nq, nb, k, D_ptr, I_ptr
)
else:
faiss.knn_extra_metrics(
xq_ptr, xb_ptr,
d, nq, nb, metric, metric_arg, k, D_ptr, I_ptr
)

return D, I


torch_replace_method(faiss_module, 'knn', torch_replacement_knn, True, True)


# allows torch tensor usage with bfKnn
def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRIC_L2, device=-1, use_raft=False):
if type(xb) is np.ndarray:
Expand Down
Loading
Loading