From 1f2af8ed193ac0cae500357ec99a30eabaf174d7 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Tue, 23 May 2023 16:05:40 -0700 Subject: [PATCH] AVX2 version of faiss::HNSW::MinimaxHeap::pop_min() (#2874) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2874 Differential Revision: D46125506 fbshipit-source-id: c485b0a15521754d53ced0595d7acd78d4aef517 --- faiss/impl/HNSW.cpp | 92 +++++++++++++++++++++++++++++++++++ tests/CMakeLists.txt | 1 + tests/test_hnsw.cpp | 111 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 204 insertions(+) create mode 100644 tests/test_hnsw.cpp diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index cd7f9a0d91..cbcd29c347 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -16,6 +16,13 @@ #include #include +#include + +#ifdef __AVX2__ +#include +#include +#endif + namespace faiss { /************************************************************** @@ -1010,6 +1017,90 @@ void HNSW::MinimaxHeap::clear() { nvalid = k = 0; } +#ifdef __AVX2__ +int HNSW::MinimaxHeap::pop_min(float* vmin_out) { + assert(k > 0); + static_assert( + std::is_same::value, + "This code expects storage_idx_t to be int32_t"); + + int32_t min_idx = -1; + float min_dis = std::numeric_limits::max(); + + size_t iii = 0; + + __m256i min_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + __m256 min_distances = _mm256_set1_ps(std::numeric_limits::max()); + __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + __m256i offset = _mm256_set1_epi32(8); + + // The baseline version is available in non-AVX2 branch. + + // The following loop tracks the rightmost index with the min distance. + // -1 index values are ignored. + const int k8 = (k / 8) * 8; + for (; iii < k8; iii += 8) { + __m256i indices = + _mm256_loadu_si256((const __m256i*)(ids.data() + iii)); + __m256 distances = _mm256_loadu_ps(dis.data() + iii); + + // This mask filters out -1 values among indices. + __m256i m1mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), indices); + + __m256i dmask = _mm256_castps_si256( + _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS)); + __m256 finalmask = _mm256_castsi256_ps(_mm256_or_si256(m1mask, dmask)); + + const __m256i min_indices_new = _mm256_castps_si256(_mm256_blendv_ps( + _mm256_castsi256_ps(current_indices), + _mm256_castsi256_ps(min_indices), + finalmask)); + + const __m256 min_distances_new = + _mm256_blendv_ps(distances, min_distances, finalmask); + + min_indices = min_indices_new; + min_distances = min_distances_new; + + current_indices = _mm256_add_epi32(current_indices, offset); + } + + // Vectorizing is doable, but is not practical + int32_t vidx8[8]; + float vdis8[8]; + _mm256_storeu_ps(vdis8, min_distances); + _mm256_storeu_si256((__m256i*)vidx8, min_indices); + + for (size_t j = 0; j < 8; j++) { + if (min_dis > vdis8[j] || (min_dis == vdis8[j] && min_idx < vidx8[j])) { + min_idx = vidx8[j]; + min_dis = vdis8[j]; + } + } + + // process last values. Vectorizing is doable, but is not practical + for (; iii < k; iii++) { + if (ids[iii] != -1 && dis[iii] <= min_dis) { + min_dis = dis[iii]; + min_idx = iii; + } + } + + if (min_idx == -1) { + return -1; + } + + if (vmin_out) + *vmin_out = min_dis; + int ret = ids[min_idx]; + ids[min_idx] = -1; + --nvalid; + return ret; +} + +#else + +// baseline non-vectorized version int HNSW::MinimaxHeap::pop_min(float* vmin_out) { assert(k > 0); // returns min. This is an O(n) operation @@ -1039,6 +1130,7 @@ int HNSW::MinimaxHeap::pop_min(float* vmin_out) { return ret; } +#endif int HNSW::MinimaxHeap::count_below(float thresh) { int n_below = 0; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index ecf45cde50..d5b6084432 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -28,6 +28,7 @@ set(FAISS_TEST_SRC test_distances_simd.cpp test_heap.cpp test_code_distance.cpp + test_hnsw.cpp ) add_executable(faiss_test ${FAISS_TEST_SRC}) diff --git a/tests/test_hnsw.cpp b/tests/test_hnsw.cpp new file mode 100644 index 0000000000..242199db4f --- /dev/null +++ b/tests/test_hnsw.cpp @@ -0,0 +1,111 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include +#include + +#include + +int reference_pop_min(faiss::HNSW::MinimaxHeap& heap, float* vmin_out) { + assert(heap.k > 0); + // returns min. This is an O(n) operation + int i = heap.k - 1; + while (i >= 0) { + if (heap.ids[i] != -1) + break; + i--; + } + if (i == -1) + return -1; + int imin = i; + float vmin = heap.dis[i]; + i--; + while (i >= 0) { + if (heap.ids[i] != -1 && heap.dis[i] < vmin) { + vmin = heap.dis[i]; + imin = i; + } + i--; + } + if (vmin_out) + *vmin_out = vmin; + int ret = heap.ids[imin]; + heap.ids[imin] = -1; + --heap.nvalid; + + return ret; +} + +void test_popmin(int heap_size, int amount_to_put) { + // create a heap + faiss::HNSW::MinimaxHeap mm_heap(heap_size); + + using storage_idx_t = faiss::HNSW::storage_idx_t; + + std::default_random_engine rng(123 + heap_size * amount_to_put); + std::uniform_int_distribution u(0, 65536); + std::uniform_real_distribution uf(0, 1); + + // generate random unique indices + std::unordered_set indices; + while (indices.size() < amount_to_put) { + const storage_idx_t index = u(rng); + indices.insert(index); + } + + // put ones into the heap + for (const auto index : indices) { + mm_heap.push(index, uf(rng)); + } + + // clone the heap + faiss::HNSW::MinimaxHeap cloned_mm_heap = mm_heap; + + // takes ones out one by one + while (mm_heap.size() > 0) { + // compare heaps + ASSERT_EQ(mm_heap.n, cloned_mm_heap.n); + ASSERT_EQ(mm_heap.k, cloned_mm_heap.k); + ASSERT_EQ(mm_heap.nvalid, cloned_mm_heap.nvalid); + ASSERT_EQ(mm_heap.ids, cloned_mm_heap.ids); + ASSERT_EQ(mm_heap.dis, cloned_mm_heap.dis); + + // use the reference pop_min for the cloned heap + float cloned_vmin_dis = std::numeric_limits::quiet_NaN(); + storage_idx_t cloned_vmin_idx = + reference_pop_min(cloned_mm_heap, &cloned_vmin_dis); + + float vmin_dis = std::numeric_limits::quiet_NaN(); + storage_idx_t vmin_idx = mm_heap.pop_min(&vmin_dis); + + // compare returns + ASSERT_EQ(vmin_dis, cloned_vmin_dis); + ASSERT_EQ(vmin_idx, cloned_vmin_idx); + } + + // compare heaps again + ASSERT_EQ(mm_heap.n, cloned_mm_heap.n); + ASSERT_EQ(mm_heap.k, cloned_mm_heap.k); + ASSERT_EQ(mm_heap.nvalid, cloned_mm_heap.nvalid); + ASSERT_EQ(mm_heap.ids, cloned_mm_heap.ids); + ASSERT_EQ(mm_heap.dis, cloned_mm_heap.dis); +} + +TEST(HNSW, Test_popmin) { + std::vector sizes = {1, 2, 3, 4, 5, 7, 9, 11, 16, 27, 32}; + for (const size_t size : sizes) { + for (size_t amount = size; amount > 0; amount /= 2) { + test_popmin(size, amount); + } + } +}