-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
demo_qinco.py
77 lines (58 loc) · 2.3 KB
/
demo_qinco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
This demonstrates how to reproduce the QINCo paper results using the Faiss
QINCo implementation. The code loads the reference model because training
is not implemented in Faiss.
Prepare the data with
cd /tmp
# get the reference qinco code
git clone https://github.com/facebookresearch/Qinco.git
# get the data
wget https://dl.fbaipublicfiles.com/QINCo/datasets/bigann/bigann1M.bvecs
# get the model
wget https://dl.fbaipublicfiles.com/QINCo/models/bigann_8x8_L2.pt
"""
import numpy as np
from faiss.contrib.vecs_io import bvecs_mmap
import sys
import time
import torch
import faiss
# make sure pickle deserialization will work
sys.path.append("/tmp/Qinco")
import model_qinco
with torch.no_grad():
qinco = torch.load("/tmp/bigann_8x8_L2.pt", weights_only=False)
qinco.eval()
# print(qinco)
if True:
torch.set_num_threads(1)
faiss.omp_set_num_threads(1)
x_base = bvecs_mmap("/tmp/bigann1M.bvecs")[:1000].astype('float32')
x_scaled = torch.from_numpy(x_base) / qinco.db_scale
t0 = time.time()
codes, _ = qinco.encode(x_scaled)
x_decoded_scaled = qinco.decode(codes)
print(f"Pytorch encode {time.time() - t0:.3f} s")
# multi-thread: 1.13s, single-thread: 7.744
x_decoded = x_decoded_scaled.numpy() * qinco.db_scale
err = ((x_decoded - x_base) ** 2).sum(1).mean()
print("MSE=", err) # = 14211.956, near the L=2 result in Fig 4 of the paper
qinco2 = faiss.QINCo(qinco)
t0 = time.time()
codes2 = qinco2.encode(faiss.Tensor2D(x_scaled))
x_decoded2 = qinco2.decode(codes2).numpy() * qinco.db_scale
print(f"Faiss encode {time.time() - t0:.3f} s")
# multi-thread: 3.2s, single thread: 7.019
# these tests don't work because there are outlier encodings
# np.testing.assert_array_equal(codes.numpy(), codes2.numpy())
# np.testing.assert_allclose(x_decoded, x_decoded2)
ndiff = (codes.numpy() != codes2.numpy()).sum() / codes.numel()
assert ndiff < 0.01
ndiff = (((x_decoded - x_decoded2) ** 2).sum(1) > 1e-5).sum()
assert ndiff / len(x_base) < 0.01
err = ((x_decoded2 - x_base) ** 2).sum(1).mean()
print("MSE=", err) # = 14213.551