diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index 436a1bfcc9..c3693fd906 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -470,105 +470,6 @@ void search_neighbors_to_add( vt.advance(); } -/************************************************************** - * Searching subroutines - **************************************************************/ - -/// greedily update a nearest vector at a given level -HNSWStats greedy_update_nearest( - const HNSW& hnsw, - DistanceComputer& qdis, - int level, - storage_idx_t& nearest, - float& d_nearest) { - // selects a version - const bool reference_version = false; - - HNSWStats stats; - - for (;;) { - storage_idx_t prev_nearest = nearest; - - size_t begin, end; - hnsw.neighbor_range(nearest, level, &begin, &end); - - size_t ndis = 0; - - // select a version, based on a flag - if (reference_version) { - // a reference version - for (size_t i = begin; i < end; i++) { - storage_idx_t v = hnsw.neighbors[i]; - if (v < 0) - break; - ndis += 1; - float dis = qdis(v); - if (dis < d_nearest) { - nearest = v; - d_nearest = dis; - } - } - } else { - // a faster version - - // the following version processes 4 neighbors at a time - auto update_with_candidate = [&](const storage_idx_t idx, - const float dis) { - if (dis < d_nearest) { - nearest = idx; - d_nearest = dis; - } - }; - - int n_buffered = 0; - storage_idx_t buffered_ids[4]; - - for (size_t j = begin; j < end; j++) { - storage_idx_t v = hnsw.neighbors[j]; - if (v < 0) - break; - ndis += 1; - - buffered_ids[n_buffered] = v; - n_buffered += 1; - - if (n_buffered == 4) { - float dis[4]; - qdis.distances_batch_4( - buffered_ids[0], - buffered_ids[1], - buffered_ids[2], - buffered_ids[3], - dis[0], - dis[1], - dis[2], - dis[3]); - - for (size_t id4 = 0; id4 < 4; id4++) { - update_with_candidate(buffered_ids[id4], dis[id4]); - } - - n_buffered = 0; - } - } - - // process leftovers - for (size_t icnt = 0; icnt < n_buffered; icnt++) { - float dis = qdis(buffered_ids[icnt]); - update_with_candidate(buffered_ids[icnt], dis); - } - } - - // update stats - stats.ndis += ndis; - stats.nhops += 1; - - if (nearest == prev_nearest) { - return stats; - } - } -} - } // namespace /// Finds neighbors and builds links with them, starting from an entry @@ -671,12 +572,10 @@ void HNSW::add_with_locks( * Searching **************************************************************/ -namespace { using MinimaxHeap = HNSW::MinimaxHeap; using Node = HNSW::Node; using C = HNSW::C; /** Do a BFS on the candidates list */ - int search_from_candidates( const HNSW& hnsw, DistanceComputer& qdis, @@ -685,11 +584,8 @@ int search_from_candidates( VisitedTable& vt, HNSWStats& stats, int level, - int nres_in = 0, - const SearchParametersHNSW* params = nullptr) { - // selects a version - const bool reference_version = false; - + int nres_in, + const SearchParametersHNSW* params) { int nres = nres_in; int ndis = 0; @@ -734,97 +630,70 @@ int search_from_candidates( size_t begin, end; hnsw.neighbor_range(v0, level, &begin, &end); - // select a version, based on a flag - if (reference_version) { - // a reference version - for (size_t j = begin; j < end; j++) { - int v1 = hnsw.neighbors[j]; - if (v1 < 0) - break; - if (vt.get(v1)) { - continue; - } - vt.set(v1); - ndis++; - float d = qdis(v1); - if (!sel || sel->is_member(v1)) { - if (d < threshold) { - if (res.add_result(d, v1)) { - threshold = res.threshold; - nres += 1; - } - } - } - - candidates.push(v1, d); - } - } else { - // a faster version - - // the following version processes 4 neighbors at a time - size_t jmax = begin; - for (size_t j = begin; j < end; j++) { - int v1 = hnsw.neighbors[j]; - if (v1 < 0) - break; + // a faster version: reference version in unit test test_hnsw.cpp + // the following version processes 4 neighbors at a time + size_t jmax = begin; + for (size_t j = begin; j < end; j++) { + int v1 = hnsw.neighbors[j]; + if (v1 < 0) + break; - prefetch_L2(vt.visited.data() + v1); - jmax += 1; - } + prefetch_L2(vt.visited.data() + v1); + jmax += 1; + } - int counter = 0; - size_t saved_j[4]; + int counter = 0; + size_t saved_j[4]; - threshold = res.threshold; + threshold = res.threshold; - auto add_to_heap = [&](const size_t idx, const float dis) { - if (!sel || sel->is_member(idx)) { - if (dis < threshold) { - if (res.add_result(dis, idx)) { - threshold = res.threshold; - nres += 1; - } + auto add_to_heap = [&](const size_t idx, const float dis) { + if (!sel || sel->is_member(idx)) { + if (dis < threshold) { + if (res.add_result(dis, idx)) { + threshold = res.threshold; + nres += 1; } } - candidates.push(idx, dis); - }; - - for (size_t j = begin; j < jmax; j++) { - int v1 = hnsw.neighbors[j]; - - bool vget = vt.get(v1); - vt.set(v1); - saved_j[counter] = v1; - counter += vget ? 0 : 1; - - if (counter == 4) { - float dis[4]; - qdis.distances_batch_4( - saved_j[0], - saved_j[1], - saved_j[2], - saved_j[3], - dis[0], - dis[1], - dis[2], - dis[3]); - - for (size_t id4 = 0; id4 < 4; id4++) { - add_to_heap(saved_j[id4], dis[id4]); - } + } + candidates.push(idx, dis); + }; + + for (size_t j = begin; j < jmax; j++) { + int v1 = hnsw.neighbors[j]; + + bool vget = vt.get(v1); + vt.set(v1); + saved_j[counter] = v1; + counter += vget ? 0 : 1; + + if (counter == 4) { + float dis[4]; + qdis.distances_batch_4( + saved_j[0], + saved_j[1], + saved_j[2], + saved_j[3], + dis[0], + dis[1], + dis[2], + dis[3]); + + for (size_t id4 = 0; id4 < 4; id4++) { + add_to_heap(saved_j[id4], dis[id4]); + } - ndis += 4; + ndis += 4; - counter = 0; - } + counter = 0; } + } - for (size_t icnt = 0; icnt < counter; icnt++) { - float dis = qdis(saved_j[icnt]); - add_to_heap(saved_j[icnt], dis); + for (size_t icnt = 0; icnt < counter; icnt++) { + float dis = qdis(saved_j[icnt]); + add_to_heap(saved_j[icnt], dis); - ndis += 1; - } + ndis += 1; } nstep++; @@ -852,9 +721,6 @@ std::priority_queue search_from_candidate_unbounded( int ef, VisitedTable* vt, HNSWStats& stats) { - // selects a version - const bool reference_version = false; - int ndis = 0; std::priority_queue top_candidates; std::priority_queue, std::greater> candidates; @@ -878,112 +744,162 @@ std::priority_queue search_from_candidate_unbounded( size_t begin, end; hnsw.neighbor_range(v0, 0, &begin, &end); - if (reference_version) { - // reference version - for (size_t j = begin; j < end; ++j) { - int v1 = hnsw.neighbors[j]; - - if (v1 < 0) { - break; - } - if (vt->get(v1)) { - continue; - } + // a faster version: reference version in unit test test_hnsw.cpp + // the following version processes 4 neighbors at a time + size_t jmax = begin; + for (size_t j = begin; j < end; j++) { + int v1 = hnsw.neighbors[j]; + if (v1 < 0) + break; - vt->set(v1); + prefetch_L2(vt->visited.data() + v1); + jmax += 1; + } - float d1 = qdis(v1); - ++ndis; + int counter = 0; + size_t saved_j[4]; - if (top_candidates.top().first > d1 || - top_candidates.size() < ef) { - candidates.emplace(d1, v1); - top_candidates.emplace(d1, v1); + auto add_to_heap = [&](const size_t idx, const float dis) { + if (top_candidates.top().first > dis || + top_candidates.size() < ef) { + candidates.emplace(dis, idx); + top_candidates.emplace(dis, idx); - if (top_candidates.size() > ef) { - top_candidates.pop(); - } + if (top_candidates.size() > ef) { + top_candidates.pop(); } } - } else { - // a faster version + }; + + for (size_t j = begin; j < jmax; j++) { + int v1 = hnsw.neighbors[j]; + + bool vget = vt->get(v1); + vt->set(v1); + saved_j[counter] = v1; + counter += vget ? 0 : 1; + + if (counter == 4) { + float dis[4]; + qdis.distances_batch_4( + saved_j[0], + saved_j[1], + saved_j[2], + saved_j[3], + dis[0], + dis[1], + dis[2], + dis[3]); + + for (size_t id4 = 0; id4 < 4; id4++) { + add_to_heap(saved_j[id4], dis[id4]); + } - // the following version processes 4 neighbors at a time - size_t jmax = begin; - for (size_t j = begin; j < end; j++) { - int v1 = hnsw.neighbors[j]; - if (v1 < 0) - break; + ndis += 4; - prefetch_L2(vt->visited.data() + v1); - jmax += 1; + counter = 0; } + } - int counter = 0; - size_t saved_j[4]; + for (size_t icnt = 0; icnt < counter; icnt++) { + float dis = qdis(saved_j[icnt]); + add_to_heap(saved_j[icnt], dis); - auto add_to_heap = [&](const size_t idx, const float dis) { - if (top_candidates.top().first > dis || - top_candidates.size() < ef) { - candidates.emplace(dis, idx); - top_candidates.emplace(dis, idx); + ndis += 1; + } - if (top_candidates.size() > ef) { - top_candidates.pop(); - } - } - }; + stats.nhops += 1; + } + + ++stats.n1; + if (candidates.size() == 0) { + ++stats.n2; + } + stats.ndis += ndis; - for (size_t j = begin; j < jmax; j++) { - int v1 = hnsw.neighbors[j]; + return top_candidates; +} - bool vget = vt->get(v1); - vt->set(v1); - saved_j[counter] = v1; - counter += vget ? 0 : 1; +/// greedily update a nearest vector at a given level +HNSWStats greedy_update_nearest( + const HNSW& hnsw, + DistanceComputer& qdis, + int level, + storage_idx_t& nearest, + float& d_nearest) { + HNSWStats stats; - if (counter == 4) { - float dis[4]; - qdis.distances_batch_4( - saved_j[0], - saved_j[1], - saved_j[2], - saved_j[3], - dis[0], - dis[1], - dis[2], - dis[3]); + for (;;) { + storage_idx_t prev_nearest = nearest; - for (size_t id4 = 0; id4 < 4; id4++) { - add_to_heap(saved_j[id4], dis[id4]); - } + size_t begin, end; + hnsw.neighbor_range(nearest, level, &begin, &end); - ndis += 4; + size_t ndis = 0; - counter = 0; - } + // a faster version: reference version in unit test test_hnsw.cpp + // the following version processes 4 neighbors at a time + auto update_with_candidate = [&](const storage_idx_t idx, + const float dis) { + if (dis < d_nearest) { + nearest = idx; + d_nearest = dis; } + }; - for (size_t icnt = 0; icnt < counter; icnt++) { - float dis = qdis(saved_j[icnt]); - add_to_heap(saved_j[icnt], dis); + int n_buffered = 0; + storage_idx_t buffered_ids[4]; - ndis += 1; + for (size_t j = begin; j < end; j++) { + storage_idx_t v = hnsw.neighbors[j]; + if (v < 0) + break; + ndis += 1; + + buffered_ids[n_buffered] = v; + n_buffered += 1; + + if (n_buffered == 4) { + float dis[4]; + qdis.distances_batch_4( + buffered_ids[0], + buffered_ids[1], + buffered_ids[2], + buffered_ids[3], + dis[0], + dis[1], + dis[2], + dis[3]); + + for (size_t id4 = 0; id4 < 4; id4++) { + update_with_candidate(buffered_ids[id4], dis[id4]); + } + + n_buffered = 0; } } + // process leftovers + for (size_t icnt = 0; icnt < n_buffered; icnt++) { + float dis = qdis(buffered_ids[icnt]); + update_with_candidate(buffered_ids[icnt], dis); + } + + // update stats + stats.ndis += ndis; stats.nhops += 1; - } - ++stats.n1; - if (candidates.size() == 0) { - ++stats.n2; + if (nearest == prev_nearest) { + return stats; + } } - stats.ndis += ndis; - - return top_candidates; } +namespace { +using MinimaxHeap = HNSW::MinimaxHeap; +using Node = HNSW::Node; +using C = HNSW::C; + // just used as a lower bound for the minmaxheap, but it is set for heap search int extract_k_from_ResultHandler(ResultHandler& res) { using RH = HeapBlockResultHandler; @@ -993,7 +909,7 @@ int extract_k_from_ResultHandler(ResultHandler& res) { return 1; } -} // anonymous namespace +} // namespace HNSWStats HNSW::search( DistanceComputer& qdis, diff --git a/faiss/impl/HNSW.h b/faiss/impl/HNSW.h index d2c974f384..dbe75d0b6e 100644 --- a/faiss/impl/HNSW.h +++ b/faiss/impl/HNSW.h @@ -255,4 +255,30 @@ struct HNSWStats { // global var that collects them all FAISS_API extern HNSWStats hnsw_stats; +int search_from_candidates( + const HNSW& hnsw, + DistanceComputer& qdis, + ResultHandler& res, + HNSW::MinimaxHeap& candidates, + VisitedTable& vt, + HNSWStats& stats, + int level, + int nres_in = 0, + const SearchParametersHNSW* params = nullptr); + +HNSWStats greedy_update_nearest( + const HNSW& hnsw, + DistanceComputer& qdis, + int level, + HNSW::storage_idx_t& nearest, + float& d_nearest); + +std::priority_queue search_from_candidate_unbounded( + const HNSW& hnsw, + const HNSW::Node& node, + DistanceComputer& qdis, + int ef, + VisitedTable* vt, + HNSWStats& stats); + } // namespace faiss diff --git a/tests/test_hnsw.cpp b/tests/test_hnsw.cpp index 9d90cf25e8..dbe1945017 100644 --- a/tests/test_hnsw.cpp +++ b/tests/test_hnsw.cpp @@ -8,13 +8,16 @@ #include #include -#include #include #include #include #include +#include #include +#include +#include +#include int reference_pop_min(faiss::HNSW::MinimaxHeap& heap, float* vmin_out) { assert(heap.k > 0); @@ -190,3 +193,353 @@ TEST(HNSW, Test_popmin_infinite_distances) { } } } + +class HNSWTest : public testing::Test { + protected: + HNSWTest() { + xb = std::make_unique>(d * nb); + xb->reserve(d * nb); + faiss::float_rand(xb->data(), d * nb, 12345); + index = std::make_unique(d, M); + index->add(nb, xb->data()); + xq = std::unique_ptr>( + new std::vector(d * nq)); + xq->reserve(d * nq); + faiss::float_rand(xq->data(), d * nq, 12345); + dis = std::unique_ptr( + index->storage->get_distance_computer()); + dis->set_query(xq->data() + 0 * index->d); + } + + const int d = 64; + const int nb = 2000; + const int M = 4; + const int nq = 10; + const int k = 10; + std::unique_ptr> xb; + std::unique_ptr> xq; + std::unique_ptr dis; + std::unique_ptr index; +}; + +/** Do a BFS on the candidates list */ +int reference_search_from_candidates( + const faiss::HNSW& hnsw, + faiss::DistanceComputer& qdis, + faiss::ResultHandler& res, + faiss::HNSW::MinimaxHeap& candidates, + faiss::VisitedTable& vt, + faiss::HNSWStats& stats, + int level, + int nres_in, + const faiss::SearchParametersHNSW* params) { + int nres = nres_in; + int ndis = 0; + + // can be overridden by search params + bool do_dis_check = params ? params->check_relative_distance + : hnsw.check_relative_distance; + int efSearch = params ? params->efSearch : hnsw.efSearch; + const faiss::IDSelector* sel = params ? params->sel : nullptr; + + faiss::HNSW::C::T threshold = res.threshold; + for (int i = 0; i < candidates.size(); i++) { + faiss::idx_t v1 = candidates.ids[i]; + float d = candidates.dis[i]; + FAISS_ASSERT(v1 >= 0); + if (!sel || sel->is_member(v1)) { + if (d < threshold) { + if (res.add_result(d, v1)) { + threshold = res.threshold; + } + } + } + vt.set(v1); + } + + int nstep = 0; + + while (candidates.size() > 0) { + float d0 = 0; + int v0 = candidates.pop_min(&d0); + + if (do_dis_check) { + // tricky stopping condition: there are more that ef + // distances that are processed already that are smaller + // than d0 + + int n_dis_below = candidates.count_below(d0); + if (n_dis_below >= efSearch) { + break; + } + } + + size_t begin, end; + hnsw.neighbor_range(v0, level, &begin, &end); + + // a reference version + for (size_t j = begin; j < end; j++) { + int v1 = hnsw.neighbors[j]; + if (v1 < 0) + break; + if (vt.get(v1)) { + continue; + } + vt.set(v1); + ndis++; + float d = qdis(v1); + if (!sel || sel->is_member(v1)) { + if (d < threshold) { + if (res.add_result(d, v1)) { + threshold = res.threshold; + nres += 1; + } + } + } + + candidates.push(v1, d); + } + + nstep++; + if (!do_dis_check && nstep > efSearch) { + break; + } + } + + if (level == 0) { + stats.n1++; + if (candidates.size() == 0) { + stats.n2++; + } + stats.ndis += ndis; + stats.nhops += nstep; + } + + return nres; +} + +faiss::HNSWStats reference_greedy_update_nearest( + const faiss::HNSW& hnsw, + faiss::DistanceComputer& qdis, + int level, + faiss::HNSW::storage_idx_t& nearest, + float& d_nearest) { + faiss::HNSWStats stats; + + for (;;) { + faiss::HNSW::storage_idx_t prev_nearest = nearest; + + size_t begin, end; + hnsw.neighbor_range(nearest, level, &begin, &end); + + size_t ndis = 0; + + for (size_t i = begin; i < end; i++) { + faiss::HNSW::storage_idx_t v = hnsw.neighbors[i]; + if (v < 0) + break; + ndis += 1; + float dis = qdis(v); + if (dis < d_nearest) { + nearest = v; + d_nearest = dis; + } + } + // update stats + stats.ndis += ndis; + stats.nhops += 1; + + if (nearest == prev_nearest) { + return stats; + } + } +} + +std::priority_queue reference_search_from_candidate_unbounded( + const faiss::HNSW& hnsw, + const faiss::HNSW::Node& node, + faiss::DistanceComputer& qdis, + int ef, + faiss::VisitedTable* vt, + faiss::HNSWStats& stats) { + int ndis = 0; + std::priority_queue top_candidates; + std::priority_queue< + faiss::HNSW::Node, + std::vector, + std::greater> + candidates; + + top_candidates.push(node); + candidates.push(node); + + vt->set(node.second); + + while (!candidates.empty()) { + float d0; + faiss::HNSW::storage_idx_t v0; + std::tie(d0, v0) = candidates.top(); + + if (d0 > top_candidates.top().first) { + break; + } + + candidates.pop(); + + size_t begin, end; + hnsw.neighbor_range(v0, 0, &begin, &end); + + for (size_t j = begin; j < end; ++j) { + int v1 = hnsw.neighbors[j]; + + if (v1 < 0) { + break; + } + if (vt->get(v1)) { + continue; + } + + vt->set(v1); + + float d1 = qdis(v1); + ++ndis; + + if (top_candidates.top().first > d1 || top_candidates.size() < ef) { + candidates.emplace(d1, v1); + top_candidates.emplace(d1, v1); + + if (top_candidates.size() > ef) { + top_candidates.pop(); + } + } + } + + stats.nhops += 1; + } + + ++stats.n1; + if (candidates.size() == 0) { + ++stats.n2; + } + stats.ndis += ndis; + + return top_candidates; +} + +TEST_F(HNSWTest, TEST_search_from_candidate_unbounded) { + omp_set_num_threads(1); + auto nearest = index->hnsw.entry_point; + float d_nearest = (*dis)(nearest); + auto node = faiss::HNSW::Node(d_nearest, nearest); + faiss::VisitedTable vt(index->ntotal); + faiss::HNSWStats stats; + + // actual version + auto top_candidates = faiss::search_from_candidate_unbounded( + index->hnsw, node, *dis, k, &vt, stats); + + auto reference_nearest = index->hnsw.entry_point; + float reference_d_nearest = (*dis)(nearest); + auto reference_node = + faiss::HNSW::Node(reference_d_nearest, reference_nearest); + faiss::VisitedTable reference_vt(index->ntotal); + faiss::HNSWStats reference_stats; + + // reference version + auto reference_top_candidates = reference_search_from_candidate_unbounded( + index->hnsw, + reference_node, + *dis, + k, + &reference_vt, + reference_stats); + EXPECT_EQ(stats.ndis, reference_stats.ndis); + EXPECT_EQ(stats.nhops, reference_stats.nhops); + EXPECT_EQ(stats.n1, reference_stats.n1); + EXPECT_EQ(stats.n2, reference_stats.n2); + while (!top_candidates.empty() && !reference_top_candidates.empty()) { + EXPECT_EQ(top_candidates.top(), reference_top_candidates.top()); + top_candidates.pop(); + reference_top_candidates.pop(); + } + EXPECT_TRUE(top_candidates.empty() && reference_top_candidates.empty()); +} + +TEST_F(HNSWTest, TEST_greedy_update_nearest) { + omp_set_num_threads(1); + + auto nearest = index->hnsw.entry_point; + float d_nearest = (*dis)(nearest); + auto reference_nearest = index->hnsw.entry_point; + float reference_d_nearest = (*dis)(reference_nearest); + + // actual version + auto stats = faiss::greedy_update_nearest( + index->hnsw, *dis, 0, nearest, d_nearest); + + // reference version + auto reference_stats = reference_greedy_update_nearest( + index->hnsw, *dis, 0, reference_nearest, reference_d_nearest); + EXPECT_EQ(stats.ndis, reference_stats.ndis); + EXPECT_EQ(stats.nhops, reference_stats.nhops); + EXPECT_EQ(stats.n1, reference_stats.n1); + EXPECT_EQ(stats.n2, reference_stats.n2); + EXPECT_EQ(d_nearest, reference_d_nearest); + EXPECT_EQ(nearest, reference_nearest); +} + +TEST_F(HNSWTest, TEST_search_from_candidates) { + omp_set_num_threads(1); + + std::vector I(k * nq); + std::vector D(k * nq); + std::vector reference_I(k * nq); + std::vector reference_D(k * nq); + using RH = faiss::HeapBlockResultHandler; + + faiss::VisitedTable vt(index->ntotal); + faiss::VisitedTable reference_vt(index->ntotal); + int num_candidates = 10; + faiss::HNSW::MinimaxHeap candidates(num_candidates); + faiss::HNSW::MinimaxHeap reference_candidates(num_candidates); + + for (int i = 0; i < num_candidates; i++) { + vt.set(i); + reference_vt.set(i); + candidates.push(i, (*dis)(i)); + reference_candidates.push(i, (*dis)(i)); + } + + faiss::HNSWStats stats; + RH bres(nq, D.data(), I.data(), k); + faiss::HeapBlockResultHandler::SingleResultHandler res( + bres); + + res.begin(0); + faiss::search_from_candidates( + index->hnsw, *dis, res, candidates, vt, stats, 0, 0, nullptr); + res.end(); + + faiss::HNSWStats reference_stats; + RH reference_bres(nq, reference_D.data(), reference_I.data(), k); + faiss::HeapBlockResultHandler::SingleResultHandler + reference_res(reference_bres); + reference_res.begin(0); + reference_search_from_candidates( + index->hnsw, + *dis, + reference_res, + reference_candidates, + reference_vt, + reference_stats, + 0, + 0, + nullptr); + reference_res.end(); + EXPECT_EQ(reference_D, D); + EXPECT_EQ(reference_I, I); + EXPECT_EQ(reference_stats.ndis, stats.ndis); + EXPECT_EQ(reference_stats.nhops, stats.nhops); + EXPECT_EQ(reference_stats.n1, stats.n1); + EXPECT_EQ(reference_stats.n2, stats.n2); +}