Skip to content

Commit

Permalink
Relax IVFFlatDedup test
Browse files Browse the repository at this point in the history
  • Loading branch information
mdouze authored and facebook-github-bot committed Sep 28, 2023
1 parent cf90435 commit 6a6e094
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 20 deletions.
3 changes: 2 additions & 1 deletion contrib/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ def compute_PR_for(q):
# They are intended for use in tests

def check_ref_knn_with_draws(Dref, Iref, Dnew, Inew):
""" test that knn search results are identical, raise if not """
""" test that knn search results are identical, with possible ties.
Raise if not. """
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
# here we have to be careful because of draws
testcase = unittest.TestCase() # because it makes nice error messages
Expand Down
23 changes: 4 additions & 19 deletions tests/test_index_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from common_faiss_tests import get_dataset_2
from faiss.contrib.datasets import SyntheticDataset
from faiss.contrib.inspect_tools import make_LinearTransform_matrix

from faiss.contrib.evaluation import check_ref_knn_with_draws

class TestRemoveFastScan(unittest.TestCase):
def do_test(self, ntotal, removed):
Expand Down Expand Up @@ -430,12 +430,6 @@ def test_mmappedIO_pretrans(self):

class TestIVFFlatDedup(unittest.TestCase):

def normalize_res(self, D, I):
dmax = D[-1]
res = [(d, i) for d, i in zip(D, I) if d < dmax]
res.sort()
return res

def test_dedup(self):
d = 10
nb = 1000
Expand Down Expand Up @@ -471,10 +465,7 @@ def test_dedup(self):
Dref, Iref = index_ref.search(xq, 20)
Dnew, Inew = index_new.search(xq, 20)

for i in range(nq):
ref = self.normalize_res(Dref[i], Iref[i])
new = self.normalize_res(Dnew[i], Inew[i])
assert ref == new
check_ref_knn_with_draws(Dref, Iref, Dnew, Inew)

# test I/O
fd, tmpfile = tempfile.mkstemp()
Expand All @@ -487,10 +478,7 @@ def test_dedup(self):
os.unlink(tmpfile)
Dst, Ist = index_st.search(xq, 20)

for i in range(nq):
new = self.normalize_res(Dnew[i], Inew[i])
st = self.normalize_res(Dst[i], Ist[i])
assert st == new
check_ref_knn_with_draws(Dnew, Inew, Dst, Ist)

# test remove
toremove = np.hstack((np.arange(3, 1000, 5), np.arange(850, 950)))
Expand All @@ -501,10 +489,7 @@ def test_dedup(self):
Dref, Iref = index_ref.search(xq, 20)
Dnew, Inew = index_new.search(xq, 20)

for i in range(nq):
ref = self.normalize_res(Dref[i], Iref[i])
new = self.normalize_res(Dnew[i], Inew[i])
assert ref == new
check_ref_knn_with_draws(Dref, Iref, Dnew, Inew)


class TestSerialize(unittest.TestCase):
Expand Down

0 comments on commit 6a6e094

Please sign in to comment.