-
Notifications
You must be signed in to change notification settings - Fork 3
/
faiss_utils.py
108 lines (91 loc) · 3.11 KB
/
faiss_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import numpy as np
import faiss
import torch
def swig_ptr_from_FloatTensor(x):
assert x.is_contiguous()
assert x.dtype == torch.float32
return faiss.cast_integer_to_float_ptr(
x.storage().data_ptr() + x.storage_offset() * 4)
def swig_ptr_from_LongTensor(x):
assert x.is_contiguous()
assert x.dtype == torch.int64, 'dtype=%s' % x.dtype
return faiss.cast_integer_to_long_ptr(
x.storage().data_ptr() + x.storage_offset() * 8)
def search_index_pytorch(index, x, k, D=None, I=None):
"""call the search function of an index with pytorch tensor I/O (CPU
and GPU supported)"""
assert x.is_contiguous()
n, d = x.size()
assert d == index.d
if D is None:
D = torch.empty((n, k), dtype=torch.float32, device=x.device)
else:
assert D.size() == (n, k)
if I is None:
I = torch.empty((n, k), dtype=torch.int64, device=x.device)
else:
assert I.size() == (n, k)
torch.cuda.synchronize()
xptr = swig_ptr_from_FloatTensor(x)
Iptr = swig_ptr_from_LongTensor(I)
Dptr = swig_ptr_from_FloatTensor(D)
index.search_c(n, xptr,
k, Dptr, Iptr)
torch.cuda.synchronize()
return D, I
def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None,
metric=faiss.METRIC_L2):
assert xb.device == xq.device
nq, d = xq.size()
if xq.is_contiguous():
xq_row_major = True
elif xq.t().is_contiguous():
xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-)
xq_row_major = False
else:
raise TypeError('matrix should be row or column-major')
xq_ptr = swig_ptr_from_FloatTensor(xq)
nb, d2 = xb.size()
assert d2 == d
if xb.is_contiguous():
xb_row_major = True
elif xb.t().is_contiguous():
xb = xb.t()
xb_row_major = False
else:
raise TypeError('matrix should be row or column-major')
xb_ptr = swig_ptr_from_FloatTensor(xb)
if D is None:
D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)
else:
assert D.shape == (nq, k)
assert D.device == xb.device
if I is None:
I = torch.empty(nq, k, device=xb.device, dtype=torch.int64)
else:
assert I.shape == (nq, k)
assert I.device == xb.device
D_ptr = swig_ptr_from_FloatTensor(D)
I_ptr = swig_ptr_from_LongTensor(I)
faiss.bruteForceKnn(res, metric,
xb_ptr, xb_row_major, nb,
xq_ptr, xq_row_major, nq,
d, k, D_ptr, I_ptr)
return D, I
def index_init_gpu(ngpus, feat_dim):
flat_config = []
for i in range(ngpus):
cfg = faiss.GpuIndexFlatConfig()
cfg.useFloat16 = False
cfg.device = i
flat_config.append(cfg)
res = [faiss.StandardGpuResources() for i in range(ngpus)]
indexes = [faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus)]
index = faiss.IndexShards(feat_dim)
for sub_index in indexes:
index.add_shard(sub_index)
index.reset()
return index
def index_init_cpu(feat_dim):
return faiss.IndexFlatL2(feat_dim)