Skip to content

Commit

Permalink
fix ACCESS VIOLATION error when searching using IDSelectorArray
Browse files Browse the repository at this point in the history
Summary:
Fixes facebookresearch#3156

Metamate says: "This diff fixes an ACCESS VIOLATION error that occurs when searching using IDSelectorArray. The code changes include adding a new parameter to the knn_inner_products_by_idx and knn_L2sqr_by_idx functions in the distances.cpp file, as well as modifying the test_search_params.py file to test the bounds of the IDSelectorArray."

Reviewed By: mdouze

Differential Revision: D53185461

fbshipit-source-id: c7ec4783f77455684c078bba3aace160078f6c27
  • Loading branch information
algoriddle authored and facebook-github-bot committed Jan 30, 2024
1 parent 67c6a19 commit 2817344
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 5 deletions.
17 changes: 12 additions & 5 deletions faiss/utils/distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstddef>
#include <cstdio>
#include <cstring>

Expand Down Expand Up @@ -670,7 +671,7 @@ void knn_inner_product(
}
if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
knn_inner_products_by_idx(
x, y, sela->ids, d, nx, sela->n, k, vals, ids, 0);
x, y, sela->ids, d, nx, ny, sela->n, k, vals, ids, 0);
return;
}

Expand Down Expand Up @@ -726,7 +727,7 @@ void knn_L2sqr(
sel = nullptr;
}
if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
knn_L2sqr_by_idx(x, y, sela->ids, d, nx, sela->n, k, vals, ids, 0);
knn_L2sqr_by_idx(x, y, sela->ids, d, nx, ny, sela->n, k, vals, ids, 0);
return;
}
if (k == 1) {
Expand Down Expand Up @@ -904,6 +905,7 @@ void knn_inner_products_by_idx(
size_t d,
size_t nx,
size_t ny,
size_t nsubset,
size_t k,
float* res_vals,
int64_t* res_ids,
Expand All @@ -921,9 +923,10 @@ void knn_inner_products_by_idx(
int64_t* __restrict idxi = res_ids + i * k;
minheap_heapify(k, simi, idxi);

for (j = 0; j < ny; j++) {
if (idsi[j] < 0)
for (j = 0; j < nsubset; j++) {
if (idsi[j] < 0 || idsi[j] >= ny) {
break;
}
float ip = fvec_inner_product(x_, y + d * idsi[j], d);

if (ip > simi[0]) {
Expand All @@ -941,6 +944,7 @@ void knn_L2sqr_by_idx(
size_t d,
size_t nx,
size_t ny,
size_t nsubset,
size_t k,
float* res_vals,
int64_t* res_ids,
Expand All @@ -955,7 +959,10 @@ void knn_L2sqr_by_idx(
float* __restrict simi = res_vals + i * k;
int64_t* __restrict idxi = res_ids + i * k;
maxheap_heapify(k, simi, idxi);
for (size_t j = 0; j < ny; j++) {
for (size_t j = 0; j < nsubset; j++) {
if (idsi[j] < 0 || idsi[j] >= ny) {
break;
}
float disij = fvec_L2sqr(x_, y + d * idsi[j], d);

if (disij < simi[0]) {
Expand Down
2 changes: 2 additions & 0 deletions faiss/utils/distances.h
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ void knn_inner_products_by_idx(
const int64_t* subset,
size_t d,
size_t nx,
size_t ny,
size_t nsubset,
size_t k,
float* vals,
Expand All @@ -398,6 +399,7 @@ void knn_L2sqr_by_idx(
const int64_t* subset,
size_t d,
size_t nx,
size_t ny,
size_t nsubset,
size_t k,
float* vals,
Expand Down
18 changes: 18 additions & 0 deletions tests/test_search_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,24 @@ def test_idmap(self):
np.testing.assert_array_equal(Iref, Inew)
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)

def test_bounds(self):
# https://github.com/facebookresearch/faiss/issues/3156
d = 64 # dimension
nb = 100000 # database size
xb = np.random.random((nb, d))
index_ip = faiss.IndexFlatIP(d)
index_ip.add(xb)
index_l2 = faiss.IndexFlatIP(d)
index_l2.add(xb)

out_of_bounds_id = nb + 15 # + 14 or lower will work fine
id_selector = faiss.IDSelectorArray([out_of_bounds_id])
search_params = faiss.SearchParameters(sel=id_selector)

# ignores out of bound, does not crash
distances, indices = index_ip.search(xb[:2], k=3, params=search_params)
distances, indices = index_l2.search(xb[:2], k=3, params=search_params)


class TestSearchParams(unittest.TestCase):

Expand Down

0 comments on commit 2817344

Please sign in to comment.