forked from Enet4/faiss
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
torch.distributed kmeans (facebookresearch#3876)
Summary: Pull Request resolved: facebookresearch#3876 Demo script for distributed kmeans. It provides a `DatasetAssign` object and shows how to run it with torch.distributed. Reviewed By: asadoughi, pankajsingh88 Differential Revision: D63013820 fbshipit-source-id: 22c959f3afdc04fd4aa8b9aeed309ea6290b1328
- Loading branch information
1 parent
8b0de63
commit b6d0237
Showing
9 changed files
with
217 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import numpy as np | ||
|
||
import torch | ||
import torch.distributed | ||
|
||
import faiss | ||
|
||
import faiss.contrib.torch_utils | ||
from faiss.contrib.torch import clustering | ||
from faiss.contrib import datasets | ||
|
||
|
||
class DatasetAssignDistributedGPU(clustering.DatasetAssign): | ||
""" | ||
There is one instance per worker, each worker has a dataset shard. | ||
The non-master workers do not run through the k-means function, so some | ||
code has run it to keep the workers in sync. | ||
""" | ||
|
||
def __init__(self, res, x, rank, nproc): | ||
clustering.DatasetAssign.__init__(self, x) | ||
self.res = res | ||
self.rank = rank | ||
self.nproc = nproc | ||
self.device = x.device | ||
|
||
n = len(x) | ||
sizes = torch.zeros(nproc, device=self.device, dtype=torch.int64) | ||
sizes[rank] = n | ||
torch.distributed.all_gather( | ||
[sizes[i:i + 1] for i in range(nproc)], sizes[rank:rank + 1]) | ||
self.sizes = sizes.cpu().numpy() | ||
|
||
# begin & end of each shard | ||
self.cs = np.zeros(nproc + 1, dtype='int64') | ||
self.cs[1:] = np.cumsum(self.sizes) | ||
|
||
def count(self): | ||
return int(self.sizes.sum()) | ||
|
||
def int_to_slaves(self, i): | ||
" broadcast an int to all workers " | ||
rank = self.rank | ||
tab = torch.zeros(1, device=self.device, dtype=torch.int64) | ||
if rank == 0: | ||
tab[0] = i | ||
else: | ||
assert i is None | ||
torch.distributed.broadcast(tab, 0) | ||
return tab.item() | ||
|
||
def get_subset(self, indices): | ||
rank = self.rank | ||
assert rank == 0 or indices is None | ||
|
||
len_indices = self.int_to_slaves(len(indices) if rank == 0 else None) | ||
|
||
if rank == 0: | ||
indices = torch.from_numpy(indices).to(self.device) | ||
else: | ||
indices = torch.zeros( | ||
len_indices, dtype=torch.int64, device=self.device) | ||
torch.distributed.broadcast(indices, 0) | ||
|
||
# select subset of indices | ||
|
||
i0, i1 = self.cs[rank], self.cs[rank + 1] | ||
|
||
mask = torch.logical_and(indices < i1, indices >= i0) | ||
output = torch.zeros( | ||
len_indices, self.x.shape[1], | ||
dtype=self.x.dtype, device=self.device) | ||
output[mask] = self.x[indices[mask] - i0] | ||
torch.distributed.reduce(output, 0) # sum | ||
if rank == 0: | ||
return output | ||
else: | ||
return None | ||
|
||
def perform_search(self, centroids): | ||
assert False, "shoudl not be called" | ||
|
||
def assign_to(self, centroids, weights=None): | ||
assert weights is None | ||
|
||
rank, nproc = self.rank, self.nproc | ||
assert rank == 0 or centroids is None | ||
nc = self.int_to_slaves(len(centroids) if rank == 0 else None) | ||
|
||
if rank != 0: | ||
centroids = torch.zeros( | ||
nc, self.x.shape[1], dtype=self.x.dtype, device=self.device) | ||
torch.distributed.broadcast(centroids, 0) | ||
|
||
# perform search | ||
D, I = faiss.knn_gpu( | ||
self.res, self.x, centroids, 1, device=self.device.index) | ||
|
||
I = I.ravel() | ||
D = D.ravel() | ||
|
||
sum_per_centroid = torch.zeros_like(centroids) | ||
if weights is None: | ||
sum_per_centroid.index_add_(0, I, self.x) | ||
else: | ||
sum_per_centroid.index_add_(0, I, self.x * weights[:, None]) | ||
|
||
torch.distributed.reduce(sum_per_centroid, 0) | ||
|
||
if rank == 0: | ||
# gather deos not support tensors of different sizes | ||
# should be implemented with point-to-point communication | ||
assert np.all(self.sizes == self.sizes[0]) | ||
device = self.device | ||
all_I = torch.zeros(self.count(), dtype=I.dtype, device=device) | ||
all_D = torch.zeros(self.count(), dtype=D.dtype, device=device) | ||
torch.distributed.gather( | ||
I, [all_I[self.cs[r]:self.cs[r + 1]] for r in range(nproc)], | ||
dst=0, | ||
) | ||
torch.distributed.gather( | ||
D, [all_D[self.cs[r]:self.cs[r + 1]] for r in range(nproc)], | ||
dst=0, | ||
) | ||
return all_I.cpu().numpy(), all_D, sum_per_centroid | ||
else: | ||
torch.distributed.gather(I, None, dst=0) | ||
torch.distributed.gather(D, None, dst=0) | ||
return None | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
torch.distributed.init_process_group( | ||
backend="nccl", | ||
) | ||
rank = torch.distributed.get_rank() | ||
nproc = torch.distributed.get_world_size() | ||
|
||
# current version does only support shards of the same size | ||
ds = datasets.SyntheticDataset(32, 10000, 0, 0, seed=1234 + rank) | ||
x = ds.get_train() | ||
|
||
device = torch.device(f"cuda:{rank}") | ||
|
||
torch.cuda.set_device(device) | ||
x = torch.from_numpy(x).to(device) | ||
res = faiss.StandardGpuResources() | ||
|
||
da = DatasetAssignDistributedGPU(res, x, rank, nproc) | ||
|
||
k = 1000 | ||
niter = 25 | ||
|
||
if rank == 0: | ||
print(f"sizes = {da.sizes}") | ||
centroids, iteration_stats = clustering.kmeans( | ||
k, da, niter=niter, return_stats=True) | ||
print("clusters:", centroids.cpu().numpy()) | ||
else: | ||
# make sure the iterations are aligned with master | ||
da.get_subset(None) | ||
|
||
for _ in range(niter): | ||
da.assign_to(None) | ||
|
||
torch.distributed.barrier() | ||
print("Done") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.