Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relax IVFFlatDedup test #3077

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions contrib/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,23 +226,41 @@ def compute_PR_for(q):
# Functions that compare search results with a reference result.
# They are intended for use in tests

def check_ref_knn_with_draws(Dref, Iref, Dnew, Inew):
""" test that knn search results are identical, raise if not """
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
def _cluster_tables_with_tolerance(tab1, tab2, thr):
""" for two tables, cluster them by merging values closer than thr.
Returns the cluster ids for each table element """
tab = np.hstack([tab1, tab2])
tab.sort()
n = len(tab)
diffs = np.ones(n)
diffs[1:] = tab[1:] - tab[:-1]
unique_vals = tab[diffs > thr]
idx1 = np.searchsorted(unique_vals, tab1, side='right') - 1
idx2 = np.searchsorted(unique_vals, tab2, side='right') - 1
return idx1, idx2


def check_ref_knn_with_draws(Dref, Iref, Dnew, Inew, rtol=1e-5):
""" test that knn search results are identical, with possible ties.
Raise if not. """
np.testing.assert_allclose(Dref, Dnew, rtol=rtol)
# here we have to be careful because of draws
testcase = unittest.TestCase() # because it makes nice error messages
for i in range(len(Iref)):
if np.all(Iref[i] == Inew[i]): # easy case
continue
# we can deduce nothing about the latest line
skip_dis = Dref[i, -1]
for dis in np.unique(Dref):
if dis == skip_dis:

# otherwise collect elements per distance
r = rtol * Dref[i].max()

DrefC, DnewC = _cluster_tables_with_tolerance(Dref[i], Dnew[i], r)

for dis in np.unique(DrefC):
if dis == DrefC[-1]:
continue
mask = Dref[i, :] == dis
mask = DrefC == dis
testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask]))


def check_ref_range_results(Lref, Dref, Iref,
Lnew, Dnew, Inew):
""" compare range search results wrt. a reference result,
Expand Down
23 changes: 4 additions & 19 deletions tests/test_index_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from common_faiss_tests import get_dataset_2
from faiss.contrib.datasets import SyntheticDataset
from faiss.contrib.inspect_tools import make_LinearTransform_matrix

from faiss.contrib.evaluation import check_ref_knn_with_draws

class TestRemoveFastScan(unittest.TestCase):
def do_test(self, ntotal, removed):
Expand Down Expand Up @@ -430,12 +430,6 @@ def test_mmappedIO_pretrans(self):

class TestIVFFlatDedup(unittest.TestCase):

def normalize_res(self, D, I):
dmax = D[-1]
res = [(d, i) for d, i in zip(D, I) if d < dmax]
res.sort()
return res

def test_dedup(self):
d = 10
nb = 1000
Expand Down Expand Up @@ -471,10 +465,7 @@ def test_dedup(self):
Dref, Iref = index_ref.search(xq, 20)
Dnew, Inew = index_new.search(xq, 20)

for i in range(nq):
ref = self.normalize_res(Dref[i], Iref[i])
new = self.normalize_res(Dnew[i], Inew[i])
assert ref == new
check_ref_knn_with_draws(Dref, Iref, Dnew, Inew)

# test I/O
fd, tmpfile = tempfile.mkstemp()
Expand All @@ -487,10 +478,7 @@ def test_dedup(self):
os.unlink(tmpfile)
Dst, Ist = index_st.search(xq, 20)

for i in range(nq):
new = self.normalize_res(Dnew[i], Inew[i])
st = self.normalize_res(Dst[i], Ist[i])
assert st == new
check_ref_knn_with_draws(Dnew, Inew, Dst, Ist)

# test remove
toremove = np.hstack((np.arange(3, 1000, 5), np.arange(850, 950)))
Expand All @@ -501,10 +489,7 @@ def test_dedup(self):
Dref, Iref = index_ref.search(xq, 20)
Dnew, Inew = index_new.search(xq, 20)

for i in range(nq):
ref = self.normalize_res(Dref[i], Iref[i])
new = self.normalize_res(Dnew[i], Inew[i])
assert ref == new
check_ref_knn_with_draws(Dref, Iref, Dnew, Inew)


class TestSerialize(unittest.TestCase):
Expand Down