-
Notifications
You must be signed in to change notification settings - Fork 4
/
kmeans.py
95 lines (76 loc) · 3.7 KB
/
kmeans.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import faiss
from faiss import IndexFlatL2
import numpy as np
import argparse
from collections import defaultdict
from tqdm import tqdm
import scipy.sparse as sp
import pickle
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dstore-size', type=int, default=103225485)
parser.add_argument('--dstore', type=str, help='path to store')
parser.add_argument('--num-clusters', type=int, help='typically, around 1/100 of the datastore size. See also Figures 8 and 9 in the paper: https://arxiv.org/pdf/2201.12431.pdf')
parser.add_argument('--sample', type=int, help='The number of data samples to use as the clustering data. If possible -- use the entire datastore. If not, use as large sample as memory allows.')
parser.add_argument('--batch-size', type=int, default=500000)
parser.add_argument('--dim', type=int, default=1024)
parser.add_argument('--save', type=str)
args = parser.parse_args()
keys = np.memmap(args.dstore + '_keys.npy',
dtype=np.float16, mode='r', shape=(args.dstore_size, args.dim))
rs = np.random.RandomState(1)
if args.sample > args.dstore_size:
print('Taking all data for training')
to_cluster = keys[:]
else:
to_cluster = np.zeros((args.sample, args.dim), dtype=np.float16)
idx = rs.choice(np.arange(args.dstore_size), size=args.sample, replace=False)
to_cluster[:] = keys[idx]
to_cluster = to_cluster.astype(np.float32)
print('start cluster')
niter = 20
verbose = True
kmeans = faiss.Kmeans(args.dim, args.num_clusters, niter=niter, verbose=verbose, gpu=True, seed=1)
kmeans.train(to_cluster)
# centroids_filename = f'{args.save}_s{args.sample}_k{args.num_clusters}_centroids.npy'
# np.save(centroids_filename, kmeans.centroids)
# print(f'Saved centroids to {centroids_filename}')
# Finished training the k-means clustering,
# Now we assign each data point to its closest centroid
print('to add:', args.dstore_size)
print('Creating index and adding centroids')
index = IndexFlatL2(args.dim)
index.add(kmeans.centroids)
print('Index created, moving index to GPU')
co = faiss.GpuClonerOptions()
co.useFloat16 = True
index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, index, co)
print('Moved index to GPU')
start = 0
# dists = []
centroid_ids = []
print('Starting to add tokens')
while start < args.dstore_size:
end = min(args.dstore_size, start + args.batch_size)
to_search = keys[start:end].copy()
d, key_i = index.search(to_search.astype(np.float32), 1)
# dists.append(d.squeeze())
centroid_ids.append(key_i.squeeze())
start += args.batch_size
if (start % 1000000) == 0:
print('Assigned %d tokens so far' % start)
centroid_ids = np.concatenate(centroid_ids)
# centroid_ids_filename = f'{args.centroids}_centroid_ids.npy'
# np.save(centroid_ids_filename, centroid_ids)
print('Processing the mapping of cluster->members')
parent_cluster = centroid_ids
cluster_to_members = defaultdict(set)
for key_i, cluster in tqdm(enumerate(parent_cluster), total=args.dstore_size):
cluster_to_members[cluster.item()].add(key_i)
row_ind = [k for k, v in cluster_to_members.items() for _ in range(len(v))]
col_ind = [i for ids in cluster_to_members.values() for i in ids]
members_sp = sp.csr_matrix(([1]*len(row_ind), (row_ind, col_ind)))
members_filename = f'{args.save}_s{args.sample}_k{args.num_clusters}_members.pkl'
with open(members_filename, 'wb') as f:
pickle.dump(members_sp, f)
print(f'Done, found {len(cluster_to_members)} clusters, written to {members_filename}')