Skip to content

Commit

Permalink
Speedup exhaustive_L2sqr_blas for AVX2, ARM NEON and AVX512 (facebook…
Browse files Browse the repository at this point in the history
…research#2568)

Summary:
Pull Request resolved: facebookresearch#2568

Add a fused kernel for exhaustive_L2sqr_blas() call that combines a computation of dot product and the search for the nearest centroid. As a result, no temporary dot product values are written and read in RAM.

Speeds up the training of PQx[1] indices for dsub = 1, 2, 4, 8, and the effect is higher for higher values of [1]. AVX512 version provides additional overloads for dsub = 12, 16.

The speedup is also beneficial for higher values of pq.cp.max_points_per_centroid (which is 256 by default).

Speeds up IVFPQ training as well.

AVX512 kernel is not enabled, but I've seen it speeding up the training TWICE versus AVX2 version. So, please feel free to use it by enabling AVX512 manually.

Reviewed By: mdouze

Differential Revision: D41166766

fbshipit-source-id: 9ce681ef360daea11c3aa411fc19c415b6896b3c
  • Loading branch information
Alexandr Guzhva authored and facebook-github-bot committed Nov 14, 2022
1 parent ab13122 commit 484cfb2
Show file tree
Hide file tree
Showing 13 changed files with 1,201 additions and 5 deletions.
6 changes: 6 additions & 0 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ set(FAISS_SRC
utils/quantize_lut.cpp
utils/random.cpp
utils/utils.cpp
utils/distances_fused/avx512.cpp
utils/distances_fused/distances_fused.cpp
utils/distances_fused/simdlib_based.cpp
)

set(FAISS_HEADERS
Expand Down Expand Up @@ -187,6 +190,9 @@ set(FAISS_HEADERS
utils/simdlib_emulated.h
utils/simdlib_neon.h
utils/utils.h
utils/distances_fused/avx512.h
utils/distances_fused/distances_fused.h
utils/distances_fused/simdlib_based.h
)

if(NOT WIN32)
Expand Down
62 changes: 58 additions & 4 deletions faiss/utils/distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <faiss/impl/IDSelector.h>
#include <faiss/impl/ResultHandler.h>

#include <faiss/utils/distances_fused/distances_fused.h>

#ifndef FINTEGER
#define FINTEGER long
#endif
Expand Down Expand Up @@ -229,7 +231,7 @@ void exhaustive_inner_product_blas(
// distance correction is an operator that can be applied to transform
// the distances
template <class ResultHandler>
void exhaustive_L2sqr_blas(
void exhaustive_L2sqr_blas_default_impl(
const float* x,
const float* y,
size_t d,
Expand Down Expand Up @@ -311,10 +313,20 @@ void exhaustive_L2sqr_blas(
}
}

template <class ResultHandler>
void exhaustive_L2sqr_blas(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
ResultHandler& res,
const float* y_norms = nullptr) {
exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res);
}

#ifdef __AVX2__
// an override for AVX2 if only a single closest point is needed.
template <>
void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
void exhaustive_L2sqr_blas_cmax_avx2(
const float* x,
const float* y,
size_t d,
Expand Down Expand Up @@ -513,11 +525,53 @@ void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
res.add_result(i, current_min_distance, current_min_index);
}
}
// Does nothing for SingleBestResultHandler, but
// keeping the call for the consistency.
res.end_multiple();
InterruptCallback::check();
}
}
#endif

// an override if only a single closest point is needed
template <>
void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
SingleBestResultHandler<CMax<float, int64_t>>& res,
const float* y_norms) {
#if defined(__AVX2__)
// use a faster fused kernel if available
if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
// the kernel is available and it is complete, we're done.
return;
}

// run the specialized AVX2 implementation
exhaustive_L2sqr_blas_cmax_avx2(x, y, d, nx, ny, res, y_norms);

#elif defined(__aarch64__)
// use a faster fused kernel if available
if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
// the kernel is available and it is complete, we're done.
return;
}

// run the default implementation
exhaustive_L2sqr_blas_default_impl<
SingleBestResultHandler<CMax<float, int64_t>>>(
x, y, d, nx, ny, res, y_norms);
#else
// run the default implementation
exhaustive_L2sqr_blas_default_impl<
SingleBestResultHandler<CMax<float, int64_t>>>(
x, y, d, nx, ny, res, y_norms);
#endif
}

template <class ResultHandler>
void knn_L2sqr_select(
const float* x,
Expand Down
Loading

0 comments on commit 484cfb2

Please sign in to comment.