Skip to content

Commit

Permalink
AVX2 version of faiss::HNSW::MinimaxHeap::pop_min() (facebookresearch…
Browse files Browse the repository at this point in the history
…#2874)

Summary: Pull Request resolved: facebookresearch#2874

Differential Revision: D46125506

fbshipit-source-id: c485b0a15521754d53ced0595d7acd78d4aef517
  • Loading branch information
Alexandr Guzhva authored and facebook-github-bot committed May 23, 2023
1 parent 1668a6a commit 1f2af8e
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 0 deletions.
92 changes: 92 additions & 0 deletions faiss/impl/HNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
#include <faiss/impl/IDSelector.h>
#include <faiss/utils/prefetch.h>

#include <faiss/impl/platform_macros.h>

#ifdef __AVX2__
#include <immintrin.h>
#include <type_traits>
#endif

namespace faiss {

/**************************************************************
Expand Down Expand Up @@ -1010,6 +1017,90 @@ void HNSW::MinimaxHeap::clear() {
nvalid = k = 0;
}

#ifdef __AVX2__
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
assert(k > 0);
static_assert(
std::is_same<storage_idx_t, int32_t>::value,
"This code expects storage_idx_t to be int32_t");

int32_t min_idx = -1;
float min_dis = std::numeric_limits<float>::max();

size_t iii = 0;

__m256i min_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
__m256 min_distances = _mm256_set1_ps(std::numeric_limits<float>::max());
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
__m256i offset = _mm256_set1_epi32(8);

// The baseline version is available in non-AVX2 branch.

// The following loop tracks the rightmost index with the min distance.
// -1 index values are ignored.
const int k8 = (k / 8) * 8;
for (; iii < k8; iii += 8) {
__m256i indices =
_mm256_loadu_si256((const __m256i*)(ids.data() + iii));
__m256 distances = _mm256_loadu_ps(dis.data() + iii);

// This mask filters out -1 values among indices.
__m256i m1mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), indices);

__m256i dmask = _mm256_castps_si256(
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS));
__m256 finalmask = _mm256_castsi256_ps(_mm256_or_si256(m1mask, dmask));

const __m256i min_indices_new = _mm256_castps_si256(_mm256_blendv_ps(
_mm256_castsi256_ps(current_indices),
_mm256_castsi256_ps(min_indices),
finalmask));

const __m256 min_distances_new =
_mm256_blendv_ps(distances, min_distances, finalmask);

min_indices = min_indices_new;
min_distances = min_distances_new;

current_indices = _mm256_add_epi32(current_indices, offset);
}

// Vectorizing is doable, but is not practical
int32_t vidx8[8];
float vdis8[8];
_mm256_storeu_ps(vdis8, min_distances);
_mm256_storeu_si256((__m256i*)vidx8, min_indices);

for (size_t j = 0; j < 8; j++) {
if (min_dis > vdis8[j] || (min_dis == vdis8[j] && min_idx < vidx8[j])) {
min_idx = vidx8[j];
min_dis = vdis8[j];
}
}

// process last values. Vectorizing is doable, but is not practical
for (; iii < k; iii++) {
if (ids[iii] != -1 && dis[iii] <= min_dis) {
min_dis = dis[iii];
min_idx = iii;
}
}

if (min_idx == -1) {
return -1;
}

if (vmin_out)
*vmin_out = min_dis;
int ret = ids[min_idx];
ids[min_idx] = -1;
--nvalid;
return ret;
}

#else

// baseline non-vectorized version
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
assert(k > 0);
// returns min. This is an O(n) operation
Expand Down Expand Up @@ -1039,6 +1130,7 @@ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {

return ret;
}
#endif

int HNSW::MinimaxHeap::count_below(float thresh) {
int n_below = 0;
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ set(FAISS_TEST_SRC
test_distances_simd.cpp
test_heap.cpp
test_code_distance.cpp
test_hnsw.cpp
)

add_executable(faiss_test ${FAISS_TEST_SRC})
Expand Down
111 changes: 111 additions & 0 deletions tests/test_hnsw.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <gtest/gtest.h>

#include <cstddef>
#include <cstdint>
#include <limits>
#include <random>
#include <unordered_set>
#include <vector>

#include <faiss/impl/HNSW.h>

int reference_pop_min(faiss::HNSW::MinimaxHeap& heap, float* vmin_out) {
assert(heap.k > 0);
// returns min. This is an O(n) operation
int i = heap.k - 1;
while (i >= 0) {
if (heap.ids[i] != -1)
break;
i--;
}
if (i == -1)
return -1;
int imin = i;
float vmin = heap.dis[i];
i--;
while (i >= 0) {
if (heap.ids[i] != -1 && heap.dis[i] < vmin) {
vmin = heap.dis[i];
imin = i;
}
i--;
}
if (vmin_out)
*vmin_out = vmin;
int ret = heap.ids[imin];
heap.ids[imin] = -1;
--heap.nvalid;

return ret;
}

void test_popmin(int heap_size, int amount_to_put) {
// create a heap
faiss::HNSW::MinimaxHeap mm_heap(heap_size);

using storage_idx_t = faiss::HNSW::storage_idx_t;

std::default_random_engine rng(123 + heap_size * amount_to_put);
std::uniform_int_distribution<storage_idx_t> u(0, 65536);
std::uniform_real_distribution<float> uf(0, 1);

// generate random unique indices
std::unordered_set<storage_idx_t> indices;
while (indices.size() < amount_to_put) {
const storage_idx_t index = u(rng);
indices.insert(index);
}

// put ones into the heap
for (const auto index : indices) {
mm_heap.push(index, uf(rng));
}

// clone the heap
faiss::HNSW::MinimaxHeap cloned_mm_heap = mm_heap;

// takes ones out one by one
while (mm_heap.size() > 0) {
// compare heaps
ASSERT_EQ(mm_heap.n, cloned_mm_heap.n);
ASSERT_EQ(mm_heap.k, cloned_mm_heap.k);
ASSERT_EQ(mm_heap.nvalid, cloned_mm_heap.nvalid);
ASSERT_EQ(mm_heap.ids, cloned_mm_heap.ids);
ASSERT_EQ(mm_heap.dis, cloned_mm_heap.dis);

// use the reference pop_min for the cloned heap
float cloned_vmin_dis = std::numeric_limits<float>::quiet_NaN();
storage_idx_t cloned_vmin_idx =
reference_pop_min(cloned_mm_heap, &cloned_vmin_dis);

float vmin_dis = std::numeric_limits<float>::quiet_NaN();
storage_idx_t vmin_idx = mm_heap.pop_min(&vmin_dis);

// compare returns
ASSERT_EQ(vmin_dis, cloned_vmin_dis);
ASSERT_EQ(vmin_idx, cloned_vmin_idx);
}

// compare heaps again
ASSERT_EQ(mm_heap.n, cloned_mm_heap.n);
ASSERT_EQ(mm_heap.k, cloned_mm_heap.k);
ASSERT_EQ(mm_heap.nvalid, cloned_mm_heap.nvalid);
ASSERT_EQ(mm_heap.ids, cloned_mm_heap.ids);
ASSERT_EQ(mm_heap.dis, cloned_mm_heap.dis);
}

TEST(HNSW, Test_popmin) {
std::vector<size_t> sizes = {1, 2, 3, 4, 5, 7, 9, 11, 16, 27, 32};
for (const size_t size : sizes) {
for (size_t amount = size; amount > 0; amount /= 2) {
test_popmin(size, amount);
}
}
}

0 comments on commit 1f2af8e

Please sign in to comment.