From 7af9538a0b57a8b574e2f74a6f2cb8f5616339cd Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Thu, 10 Nov 2022 06:48:10 -0800 Subject: [PATCH] Speedup exhaustive_L2sqr_blas for AVX2, ARM NEON and AVX512 (#2568) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2568 Add a fused kernel for exhaustive_L2sqr_blas() call that combines a computation of dot product and the search for the nearest centroid. As a result, no temporary dot product values are written and read in RAM. Significantly speeds up the training of PQx[1] indices for low-dimensional PQ vectors ( 1, 2, 4, 8 ), and the effect is higher for higher values of [1]. AVX512 provides additional overloads for dimensionality of 12 and 16. The speedup is also beneficial for higher values of pq.cp.max_points_per_centroid (which is 256 by default). Speeds up IVFPQ training as well. AVX512 kernel is not enabled, but I've seen it speeding up the training TWICE versus AVX2 version. So, please feel free to use it by enabling AVX512 manually. Differential Revision: D41166766 fbshipit-source-id: 4db53e95397db6f5f90ca07258f24266cbd1ef9e --- faiss/CMakeLists.txt | 6 + faiss/utils/distances.cpp | 62 +++- faiss/utils/distances_fused/avx512.cpp | 320 ++++++++++++++++++ faiss/utils/distances_fused/avx512.h | 36 ++ .../utils/distances_fused/distances_fused.cpp | 35 ++ faiss/utils/distances_fused/distances_fused.h | 40 +++ faiss/utils/distances_fused/simdlib_based.cpp | 312 +++++++++++++++++ faiss/utils/distances_fused/simdlib_based.h | 32 ++ faiss/utils/simdlib_avx2.h | 51 ++- faiss/utils/simdlib_emulated.h | 87 +++++ faiss/utils/simdlib_neon.h | 70 ++++ 11 files changed, 1046 insertions(+), 5 deletions(-) create mode 100644 faiss/utils/distances_fused/avx512.cpp create mode 100644 faiss/utils/distances_fused/avx512.h create mode 100644 faiss/utils/distances_fused/distances_fused.cpp create mode 100644 faiss/utils/distances_fused/distances_fused.h create mode 100644 faiss/utils/distances_fused/simdlib_based.cpp create mode 100644 faiss/utils/distances_fused/simdlib_based.h diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt index fd3ddb30ec..f452e4a7bf 100644 --- a/faiss/CMakeLists.txt +++ b/faiss/CMakeLists.txt @@ -86,6 +86,9 @@ set(FAISS_SRC utils/quantize_lut.cpp utils/random.cpp utils/utils.cpp + utils/distances_fused/avx512.cpp + utils/distances_fused/distances_fused.cpp + utils/distances_fused/simdlib_based.cpp ) set(FAISS_HEADERS @@ -187,6 +190,9 @@ set(FAISS_HEADERS utils/simdlib_emulated.h utils/simdlib_neon.h utils/utils.h + utils/distances_fused/avx512.h + utils/distances_fused/distances_fused.h + utils/distances_fused/simdlib_based.h ) if(NOT WIN32) diff --git a/faiss/utils/distances.cpp b/faiss/utils/distances.cpp index 037f86b7af..80694c3bb6 100644 --- a/faiss/utils/distances.cpp +++ b/faiss/utils/distances.cpp @@ -26,6 +26,8 @@ #include #include +#include + #ifndef FINTEGER #define FINTEGER long #endif @@ -229,7 +231,7 @@ void exhaustive_inner_product_blas( // distance correction is an operator that can be applied to transform // the distances template -void exhaustive_L2sqr_blas( +void exhaustive_L2sqr_blas_default_impl( const float* x, const float* y, size_t d, @@ -311,10 +313,20 @@ void exhaustive_L2sqr_blas( } } +template +void exhaustive_L2sqr_blas( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + ResultHandler& res, + const float* y_norms = nullptr) { + exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res); +} + #ifdef __AVX2__ -// an override for AVX2 if only a single closest point is needed. -template <> -void exhaustive_L2sqr_blas>>( +void exhaustive_L2sqr_blas_cmax_avx2( const float* x, const float* y, size_t d, @@ -513,11 +525,53 @@ void exhaustive_L2sqr_blas>>( res.add_result(i, current_min_distance, current_min_index); } } + // Does nothing for SingleBestResultHandler, but + // keeping the call for the consistency. + res.end_multiple(); InterruptCallback::check(); } } #endif +// an override if only a single closest point is needed +template <> +void exhaustive_L2sqr_blas>>( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + SingleBestResultHandler>& res, + const float* y_norms) { +#if defined(__AVX2__) + // use a faster fused kernel if available + if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) { + // the kernel is available and it is complete, we're done. + return; + } + + // run the specialized AVX2 implementation + exhaustive_L2sqr_blas_cmax_avx2(x, y, d, nx, ny, res, y_norms); + +#elif defined(__aarch64__) + // use a faster fused kernel if available + if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) { + // the kernel is available and it is complete, we're done. + return; + } + + // run the default implementation + exhaustive_L2sqr_blas_default_impl< + SingleBestResultHandler>>( + x, y, d, nx, ny, res, y_norms); +#else + // run the default implementation + exhaustive_L2sqr_blas_default_impl< + SingleBestResultHandler>>( + x, y, d, nx, ny, res, y_norms); +#endif +} + template void knn_L2sqr_select( const float* x, diff --git a/faiss/utils/distances_fused/avx512.cpp b/faiss/utils/distances_fused/avx512.cpp new file mode 100644 index 0000000000..5c7906e79e --- /dev/null +++ b/faiss/utils/distances_fused/avx512.cpp @@ -0,0 +1,320 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#ifdef __AVX512__ + +#include + +namespace faiss { + +namespace { + +// It makes sense to like to overload certain cases because the further +// kernels are in need of AVX512 registers. So, let's tell compiler +// not to waste registers on a bit faster code, if needed. +template +float l2_sqr(const float* const x) { + // compiler should be smart enough to handle that + float output = x[0] * x[0]; + for (size_t i = 1; i < DIM; i++) { + output += x[i] * x[i]; + } + + return output; +} + +template <> +float l2_sqr<4>(const float* const x) { + __m128 v = _mm_loadu_ps(x); + __m128 v2 = _mm_mul_ps(v, v); + v2 = _mm_hadd_ps(v2, v2); + v2 = _mm_hadd_ps(v2, v2); + + return _mm_cvtss_f32(v2); +} + +template +float dot_product( + const float* const __restrict x, + const float* const __restrict y) { + // compiler should be smart enough to handle that + float output = x[0] * y[0]; + for (size_t i = 1; i < DIM; i++) { + output += x[i] * y[i]; + } + + return output; +} + +// The kernel for low dimensionality vectors. +// Finds the closest one from y for every given NX_POINTS_PER_LOOP points from x +// +// DIM is the dimensionality of the data +// NX_POINTS_PER_LOOP is the number of x points that get processed +// simultaneously. +// NY_POINTS_PER_LOOP is the number of y points that get processed +// simultaneously. +template +void kernel( + const float* const __restrict x, + const float* const __restrict y, + const float* const __restrict y_transposed, + size_t ny, + SingleBestResultHandler>& res, + const float* __restrict y_norms, + size_t i) { + const size_t ny_p = + (ny / (16 * NY_POINTS_PER_LOOP)) * (16 * NY_POINTS_PER_LOOP); + + // compute + const float* const __restrict xd_0 = x + i * DIM; + + // prefetch the next point + _mm_prefetch(xd_0 + DIM * sizeof(float), _MM_HINT_NTA); + + // load a single point from x + // load -2 * value + __m512 x_i[NX_POINTS_PER_LOOP][DIM]; + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + for (size_t dd = 0; dd < DIM; dd++) { + x_i[nx_k][dd] = _mm512_set1_ps(-2 * *(xd_0 + nx_k * DIM + dd)); + } + } + + // compute x_norm + float x_norm_i[NX_POINTS_PER_LOOP]; + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + x_norm_i[nx_k] = l2_sqr(xd_0 + nx_k * DIM); + } + + // distances and indices + __m512 min_distances_i[NX_POINTS_PER_LOOP]; + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + min_distances_i[nx_k] = + _mm512_set1_ps(res.dis_tab[i + nx_k] - x_norm_i[nx_k]); + } + + __m512i min_indices_i[NX_POINTS_PER_LOOP]; + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + min_indices_i[nx_k] = _mm512_set1_epi32(0); + } + + // + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + const __m512i indices_delta = _mm512_set1_epi32(16); + + // main loop + size_t j = 0; + for (; j < ny_p; j += NY_POINTS_PER_LOOP * 16) { + // compute dot products for NX_POINTS from x and NY_POINTS from y + // technically, we're multiplying -2x and y + __m512 dp_i[NX_POINTS_PER_LOOP][NY_POINTS_PER_LOOP]; + + // DIM 0 that uses MUL + for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) { + __m512 y_i = _mm512_loadu_ps(y_transposed + j + ny_k * 16 + ny * 0); + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + dp_i[nx_k][ny_k] = _mm512_mul_ps(x_i[nx_k][0], y_i); + } + } + + // other DIMs that use FMA + for (size_t dd = 1; dd < DIM; dd++) { + for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) { + __m512 y_i = + _mm512_loadu_ps(y_transposed + j + ny_k * 16 + ny * dd); + + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + dp_i[nx_k][ny_k] = _mm512_fmadd_ps( + x_i[nx_k][dd], y_i, dp_i[nx_k][ny_k]); + } + } + } + + // compute y^2 - 2 * (x,y) + for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) { + __m512 y_l2_sqr = _mm512_loadu_ps(y_norms + j + ny_k * 16); + + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + dp_i[nx_k][ny_k] = _mm512_add_ps(dp_i[nx_k][ny_k], y_l2_sqr); + } + } + + // do the comparisons and alter the min indices + for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) { + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + const __mmask16 comparison = _mm512_cmp_ps_mask( + min_distances_i[nx_k], dp_i[nx_k][ny_k], _CMP_LE_OS); + min_distances_i[nx_k] = _mm512_mask_blend_ps( + comparison, dp_i[nx_k][ny_k], min_distances_i[nx_k]); + min_indices_i[nx_k] = _mm512_castps_si512(_mm512_mask_blend_ps( + comparison, + _mm512_castsi512_ps(current_indices), + _mm512_castsi512_ps(min_indices_i[nx_k]))); + } + + current_indices = _mm512_add_epi32(current_indices, indices_delta); + } + } + + // dump values and find the minimum distance / minimum index + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + float min_distances_scalar[16]; + uint32_t min_indices_scalar[16]; + _mm512_storeu_ps(min_distances_scalar, min_distances_i[nx_k]); + _mm512_storeu_si512( + (__m512i*)(min_indices_scalar), min_indices_i[nx_k]); + + float current_min_distance = res.dis_tab[i + nx_k]; + uint32_t current_min_index = res.ids_tab[i + nx_k]; + + // This unusual comparison is needed to maintain the behavior + // of the original implementation: if two indices are + // represented with equal distance values, then + // the index with the min value is returned. + for (size_t jv = 0; jv < 16; jv++) { + // add missing x_norms[i] + float distance_candidate = + min_distances_scalar[jv] + x_norm_i[nx_k]; + + // negative values can occur for identical vectors + // due to roundoff errors. + if (distance_candidate < 0) + distance_candidate = 0; + + const int64_t index_candidate = min_indices_scalar[jv]; + + if (current_min_distance > distance_candidate) { + current_min_distance = distance_candidate; + current_min_index = index_candidate; + } else if ( + current_min_distance == distance_candidate && + current_min_index > index_candidate) { + current_min_index = index_candidate; + } + } + + // process leftovers + for (size_t j0 = j; j0 < ny; j0++) { + const float dp = + dot_product(x + (i + nx_k) * DIM, y + j0 * DIM); + float dis = x_norm_i[nx_k] + y_norms[j0] - 2 * dp; + // negative values can occur for identical vectors + // due to roundoff errors. + if (dis < 0) { + dis = 0; + } + + if (current_min_distance > dis) { + current_min_distance = dis; + current_min_index = j0; + } + } + + // done + res.add_result(i + nx_k, current_min_distance, current_min_index); + } +} + +template +void exhaustive_L2sqr_fused_cmax( + const float* const __restrict x, + const float* const __restrict y, + size_t nx, + size_t ny, + SingleBestResultHandler>& res, + const float* __restrict y_norms) { + // BLAS does not like empty matrices + if (nx == 0 || ny == 0) { + return; + } + + // compute norms for y + std::unique_ptr del2; + if (!y_norms) { + float* y_norms2 = new float[ny]; + del2.reset(y_norms2); + + for (size_t i = 0; i < ny; i++) { + y_norms2[i] = l2_sqr(y + i * DIM); + } + + y_norms = y_norms2; + } + + // initialize res + res.begin_multiple(0, nx); + + // transpose y + std::vector y_transposed(DIM * ny); + for (size_t j = 0; j < DIM; j++) { + for (size_t i = 0; i < ny; i++) { + y_transposed[j * ny + i] = y[j + i * DIM]; + } + } + + const size_t nx_p = (nx / NX_POINTS_PER_LOOP) * NX_POINTS_PER_LOOP; + // the main loop. +#pragma omp parallel for schedule(dynamic) + for (size_t i = 0; i < nx_p; i += NX_POINTS_PER_LOOP) { + kernel( + x, y, y_transposed.data(), ny, res, y_norms, i); + } + + for (size_t i = nx_p; i < nx; i++) { + kernel( + x, y, y_transposed.data(), ny, res, y_norms, i); + } + + // Does nothing for SingleBestResultHandler, but + // keeping the call for the consistency. + res.end_multiple(); + InterruptCallback::check(); +} + +} // namespace + +bool exhaustive_L2sqr_fused_cmax_AVX512( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + SingleBestResultHandler>& res, + const float* y_norms) { + // process only cases with certain dimensionalities + if (d == 1) { + exhaustive_L2sqr_fused_cmax<1, 4, 2>(x, y, nx, ny, res, y_norms); + return true; + } else if (d == 2) { + exhaustive_L2sqr_fused_cmax<2, 4, 2>(x, y, nx, ny, res, y_norms); + return true; + } else if (d == 4) { + exhaustive_L2sqr_fused_cmax<4, 4, 1>(x, y, nx, ny, res, y_norms); + return true; + } else if (d == 8) { + exhaustive_L2sqr_fused_cmax<8, 2, 1>(x, y, nx, ny, res, y_norms); + return true; + } else if (d == 12) { + exhaustive_L2sqr_fused_cmax<12, 1, 1>(x, y, nx, ny, res, y_norms); + return true; + } else if (d == 16) { + exhaustive_L2sqr_fused_cmax<16, 1, 1>(x, y, nx, ny, res, y_norms); + return true; + } + + return false; +} + +} // namespace faiss + +#endif diff --git a/faiss/utils/distances_fused/avx512.h b/faiss/utils/distances_fused/avx512.h new file mode 100644 index 0000000000..d730e3b61c --- /dev/null +++ b/faiss/utils/distances_fused/avx512.h @@ -0,0 +1,36 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// AVX512 might be not used, but this version provides ~2x speedup +// over AVX2 kernel, say, for training PQx10 or PQx12, and speeds up +// additional cases with larger dimensionalities. + +#pragma once + +#include +#include + +#include + +#ifdef __AVX512__ + +namespace faiss { + +// Returns true if the fused kernel is available and the data was processed. +// Returns false if the fused kernel is not available. +bool exhaustive_L2sqr_fused_cmax_AVX512( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + SingleBestResultHandler>& res, + const float* y_norms); + +} // namespace faiss + +#endif diff --git a/faiss/utils/distances_fused/distances_fused.cpp b/faiss/utils/distances_fused/distances_fused.cpp new file mode 100644 index 0000000000..a676e17144 --- /dev/null +++ b/faiss/utils/distances_fused/distances_fused.cpp @@ -0,0 +1,35 @@ +#include + +#include + +#include +#include + +namespace faiss { + +bool exhaustive_L2sqr_fused_cmax( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + SingleBestResultHandler>& res, + const float* y_norms) { + if (nx == 0 || ny == 0) { + // nothing to do + return true; + } + +#ifdef __AVX512__ + // avx512 kernel + return exhaustive_L2sqr_fused_cmax_AVX512(x, y, d, nx, ny, res, y_norms); +#elif defined(__AVX2__) || defined(__aarch64__) + // avx2 or arm neon kernel + return exhaustive_L2sqr_fused_cmax_simdlib(x, y, d, nx, ny, res, y_norms); +#else + // not supported, please use a general-purpose kernel + return false; +#endif +} + +} // namespace faiss diff --git a/faiss/utils/distances_fused/distances_fused.h b/faiss/utils/distances_fused/distances_fused.h new file mode 100644 index 0000000000..e6e35c209e --- /dev/null +++ b/faiss/utils/distances_fused/distances_fused.h @@ -0,0 +1,40 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// This file contains a fused kernel that combines distance computation +// and the search for the CLOSEST point. Currently, this is done for small +// dimensionality vectors when it is beneficial to avoid storing temporary +// dot product information in RAM. This is particularly effective +// when training PQx10 or PQx12 with the default parameters. +// +// InterruptCallback::check() is not used, because it is assumed that the +// kernel takes a little time because of a tiny dimensionality. +// +// Later on, similar optimization can be implemented for large size vectors, +// but a different kernel is needed. +// + +#pragma once + +#include + +#include + +namespace faiss { + +// Returns true if the fused kernel is available and the data was processed. +// Returns false if the fused kernel is not available. +bool exhaustive_L2sqr_fused_cmax( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + SingleBestResultHandler>& res, + const float* y_norms); + +} // namespace faiss diff --git a/faiss/utils/distances_fused/simdlib_based.cpp b/faiss/utils/distances_fused/simdlib_based.cpp new file mode 100644 index 0000000000..434f0dfb3c --- /dev/null +++ b/faiss/utils/distances_fused/simdlib_based.cpp @@ -0,0 +1,312 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#if defined(__AVX2__) || defined(__aarch64__) + +#include + +#if defined(__AVX2__) +#include +#endif + +namespace faiss { + +namespace { + +// It makes sense to like to overload certain cases because the further +// kernels are in need of registers. So, let's tell compiler +// not to waste registers on a bit faster code, if needed. +template +float l2_sqr(const float* const x) { + // compiler should be smart enough to handle that + float output = x[0] * x[0]; + for (size_t i = 1; i < DIM; i++) { + output += x[i] * x[i]; + } + + return output; +} + +template +float dot_product( + const float* const __restrict x, + const float* const __restrict y) { + // compiler should be smart enough to handle that + float output = x[0] * y[0]; + for (size_t i = 1; i < DIM; i++) { + output += x[i] * y[i]; + } + + return output; +} + +// The kernel for low dimensionality vectors. +// Finds the closest one from y for every given NX_POINTS_PER_LOOP points from x +// +// DIM is the dimensionality of the data +// NX_POINTS_PER_LOOP is the number of x points that get processed +// simultaneously. +// NY_POINTS_PER_LOOP is the number of y points that get processed +// simultaneously. +template +void kernel( + const float* const __restrict x, + const float* const __restrict y, + const float* const __restrict y_transposed, + const size_t ny, + SingleBestResultHandler>& res, + const float* __restrict y_norms, + const size_t i) { + const size_t ny_p = + (ny / (8 * NY_POINTS_PER_LOOP)) * (8 * NY_POINTS_PER_LOOP); + + // compute + const float* const __restrict xd_0 = x + i * DIM; + + // prefetch the next point +#if defined(__AVX2__) + _mm_prefetch(xd_0 + DIM * sizeof(float), _MM_HINT_NTA); +#endif + + // load a single point from x + // load -2 * value + simd8float32 x_i[NX_POINTS_PER_LOOP][DIM]; + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + for (size_t dd = 0; dd < DIM; dd++) { + x_i[nx_k][dd] = simd8float32(-2 * *(xd_0 + nx_k * DIM + dd)); + } + } + + // compute x_norm + float x_norm_i[NX_POINTS_PER_LOOP]; + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + x_norm_i[nx_k] = l2_sqr(xd_0 + nx_k * DIM); + } + + // distances and indices + simd8float32 min_distances_i[NX_POINTS_PER_LOOP]; + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + min_distances_i[nx_k] = + simd8float32(res.dis_tab[i + nx_k] - x_norm_i[nx_k]); + } + + simd8uint32 min_indices_i[NX_POINTS_PER_LOOP]; + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + min_indices_i[nx_k] = simd8uint32((uint32_t)0); + } + + // + simd8uint32 current_indices = simd8uint32(0, 1, 2, 3, 4, 5, 6, 7); + const simd8uint32 indices_delta = simd8uint32(8); + + // main loop + size_t j = 0; + for (; j < ny_p; j += NY_POINTS_PER_LOOP * 8) { + // compute dot products for NX_POINTS from x and NY_POINTS from y + // technically, we're multiplying -2x and y + simd8float32 dp_i[NX_POINTS_PER_LOOP][NY_POINTS_PER_LOOP]; + + // DIM 0 that uses MUL + for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) { + simd8float32 y_i = + simd8float32(y_transposed + j + ny_k * 8 + ny * 0); + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + dp_i[nx_k][ny_k] = x_i[nx_k][0] * y_i; + } + } + + // other DIMs that use FMA + for (size_t dd = 1; dd < DIM; dd++) { + for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) { + simd8float32 y_i = + simd8float32(y_transposed + j + ny_k * 8 + ny * dd); + + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + dp_i[nx_k][ny_k] = + fmadd(x_i[nx_k][dd], y_i, dp_i[nx_k][ny_k]); + } + } + } + + // compute y^2 + (-2x,y) + for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) { + simd8float32 y_l2_sqr = simd8float32(y_norms + j + ny_k * 8); + + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + dp_i[nx_k][ny_k] = dp_i[nx_k][ny_k] + y_l2_sqr; + } + } + + // do the comparisons and alter the min indices + for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) { + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + // cmpps + const simd8float32 comparison = + cmple(min_distances_i[nx_k], dp_i[nx_k][ny_k]); + min_distances_i[nx_k] = blend( + dp_i[nx_k][ny_k], min_distances_i[nx_k], comparison); + min_indices_i[nx_k] = + blend(current_indices, min_indices_i[nx_k], comparison); + } + + current_indices = current_indices + indices_delta; + } + } + + // dump values and find the minimum distance / minimum index + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + float min_distances_scalar[8]; + uint32_t min_indices_scalar[8]; + + min_distances_i[nx_k].storeu(min_distances_scalar); + min_indices_i[nx_k].storeu(min_indices_scalar); + + float current_min_distance = res.dis_tab[i + nx_k]; + uint32_t current_min_index = res.ids_tab[i + nx_k]; + + // This unusual comparison is needed to maintain the behavior + // of the original implementation: if two indices are + // represented with equal distance values, then + // the index with the min value is returned. + for (size_t jv = 0; jv < 8; jv++) { + // add missing x_norms[i] + float distance_candidate = + min_distances_scalar[jv] + x_norm_i[nx_k]; + + // negative values can occur for identical vectors + // due to roundoff errors. + if (distance_candidate < 0) { + distance_candidate = 0; + } + + const int64_t index_candidate = min_indices_scalar[jv]; + + if (current_min_distance > distance_candidate) { + current_min_distance = distance_candidate; + current_min_index = index_candidate; + } else if ( + current_min_distance == distance_candidate && + current_min_index > index_candidate) { + current_min_index = index_candidate; + } + } + + // process leftovers + for (size_t j0 = j; j0 < ny; j0++) { + const float dp = + dot_product(x + (i + nx_k) * DIM, y + j0 * DIM); + float dis = x_norm_i[nx_k] + y_norms[j0] - 2 * dp; + // negative values can occur for identical vectors + // due to roundoff errors. + if (dis < 0) { + dis = 0; + } + + if (current_min_distance > dis) { + current_min_distance = dis; + current_min_index = j0; + } + } + + // done + res.add_result(i + nx_k, current_min_distance, current_min_index); + } +} + +template +void exhaustive_L2sqr_fused_cmax( + const float* const __restrict x, + const float* const __restrict y, + size_t nx, + size_t ny, + SingleBestResultHandler>& res, + const float* __restrict y_norms) { + // BLAS does not like empty matrices + if (nx == 0 || ny == 0) { + return; + } + + // compute norms for y + std::unique_ptr del2; + if (!y_norms) { + float* y_norms2 = new float[ny]; + del2.reset(y_norms2); + + for (size_t i = 0; i < ny; i++) { + y_norms2[i] = l2_sqr(y + i * DIM); + } + + y_norms = y_norms2; + } + + // initialize res + res.begin_multiple(0, nx); + + // transpose y + std::vector y_transposed(DIM * ny); + for (size_t j = 0; j < DIM; j++) { + for (size_t i = 0; i < ny; i++) { + y_transposed[j * ny + i] = y[j + i * DIM]; + } + } + + const size_t nx_p = (nx / NX_POINTS_PER_LOOP) * NX_POINTS_PER_LOOP; + // the main loop. +#pragma omp parallel for schedule(dynamic) + for (size_t i = 0; i < nx_p; i += NX_POINTS_PER_LOOP) { + kernel( + x, y, y_transposed.data(), ny, res, y_norms, i); + } + + for (size_t i = nx_p; i < nx; i++) { + kernel( + x, y, y_transposed.data(), ny, res, y_norms, i); + } + + // Does nothing for SingleBestResultHandler, but + // keeping the call for the consistency. + res.end_multiple(); + InterruptCallback::check(); +} + +} // namespace + +bool exhaustive_L2sqr_fused_cmax_simdlib( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + SingleBestResultHandler>& res, + const float* y_norms) { + // Process only cases with certain dimensionalities. + // An acceptable dimensionality value is limited by the number of + // available registers. + if (d == 1) { + exhaustive_L2sqr_fused_cmax<1, 2, 2>(x, y, nx, ny, res, y_norms); + return true; + } else if (d == 2) { + exhaustive_L2sqr_fused_cmax<2, 2, 2>(x, y, nx, ny, res, y_norms); + return true; + } else if (d == 4) { + exhaustive_L2sqr_fused_cmax<4, 2, 1>(x, y, nx, ny, res, y_norms); + return true; + } else if (d == 8) { + exhaustive_L2sqr_fused_cmax<8, 1, 1>(x, y, nx, ny, res, y_norms); + return true; + } + + return false; +} + +} // namespace faiss + +#endif diff --git a/faiss/utils/distances_fused/simdlib_based.h b/faiss/utils/distances_fused/simdlib_based.h new file mode 100644 index 0000000000..b60da7b193 --- /dev/null +++ b/faiss/utils/distances_fused/simdlib_based.h @@ -0,0 +1,32 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +#if defined(__AVX2__) || defined(__aarch64__) + +namespace faiss { + +// Returns true if the fused kernel is available and the data was processed. +// Returns false if the fused kernel is not available. +bool exhaustive_L2sqr_fused_cmax_simdlib( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + SingleBestResultHandler>& res, + const float* y_norms); + +} // namespace faiss + +#endif diff --git a/faiss/utils/simdlib_avx2.h b/faiss/utils/simdlib_avx2.h index 6eee3fa0e1..b108c6bd8a 100644 --- a/faiss/utils/simdlib_avx2.h +++ b/faiss/utils/simdlib_avx2.h @@ -359,6 +359,25 @@ struct simd8uint32 : simd256bit { explicit simd8uint32(const uint8_t* x) : simd256bit((const void*)x) {} + explicit simd8uint32( + uint32_t u0, + uint32_t u1, + uint32_t u2, + uint32_t u3, + uint32_t u4, + uint32_t u5, + uint32_t u6, + uint32_t u7) + : simd256bit(_mm256_setr_epi32(u0, u1, u2, u3, u4, u5, u6, u7)) {} + + simd8uint32 operator+(simd8uint32 other) const { + return simd8uint32(_mm256_add_epi32(i, other.i)); + } + + simd8uint32 operator-(simd8uint32 other) const { + return simd8uint32(_mm256_sub_epi32(i, other.i)); + } + std::string elements_to_string(const char* fmt) const { uint32_t bytes[8]; storeu((void*)bytes); @@ -394,7 +413,18 @@ struct simd8float32 : simd256bit { explicit simd8float32(float x) : simd256bit(_mm256_set1_ps(x)) {} - explicit simd8float32(const float* x) : simd256bit(_mm256_load_ps(x)) {} + explicit simd8float32(const float* x) : simd256bit(_mm256_loadu_ps(x)) {} + + explicit simd8float32( + float f0, + float f1, + float f2, + float f3, + float f4, + float f5, + float f6, + float f7) + : simd256bit(_mm256_setr_ps(f0, f1, f2, f3, f4, f5, f6, f7)) {} simd8float32 operator*(simd8float32 other) const { return simd8float32(_mm256_mul_ps(f, other.f)); @@ -439,6 +469,25 @@ inline simd8float32 fmadd(simd8float32 a, simd8float32 b, simd8float32 c) { return simd8float32(_mm256_fmadd_ps(a.f, b.f, c.f)); } +inline simd8float32 cmple(simd8float32 a, simd8float32 b) { + return simd8float32(_mm256_cmp_ps(a.f, b.f, _CMP_LE_OS)); +} + +inline simd8float32 blend( + simd8float32 a, + simd8float32 b, + simd8float32 comparison) { + return simd8float32(_mm256_blendv_ps(a.f, b.f, comparison.f)); +} + +inline simd8uint32 blend( + simd8uint32 a, + simd8uint32 b, + simd8float32 comparison) { + return simd8uint32(_mm256_castps_si256(_mm256_blendv_ps( + _mm256_castsi256_ps(a.i), _mm256_castsi256_ps(b.i), comparison.f))); +} + namespace { // get even float32's of a and b, interleaved diff --git a/faiss/utils/simdlib_emulated.h b/faiss/utils/simdlib_emulated.h index 9267ef1b46..50b38eb814 100644 --- a/faiss/utils/simdlib_emulated.h +++ b/faiss/utils/simdlib_emulated.h @@ -440,6 +440,41 @@ struct simd8uint32 : simd256bit { explicit simd8uint32(const uint32_t* x) : simd256bit((const void*)x) {} + explicit simd8uint32( + uint32_t u0, + uint32_t u1, + uint32_t u2, + uint32_t u3, + uint32_t u4, + uint32_t u5, + uint32_t u6, + uint32_t u7) { + u32[0] = u0; + u32[1] = u1; + u32[2] = u2; + u32[3] = u3; + u32[4] = u4; + u32[5] = u5; + u32[6] = u6; + u32[7] = u7; + } + + simd8uint32 operator+(simd8uint32 other) const { + simd8uint32 result; + for (int i = 0; i < 8; i++) { + result.u32[i] = u32[i] + other.u32[i]; + } + return result; + } + + simd8uint32 operator-(simd8uint32 other) const { + simd8uint32 result; + for (int i = 0; i < 8; i++) { + result.u32[i] = u32[i] - other.u32[i]; + } + return result; + } + std::string elements_to_string(const char* fmt) const { char res[1000], *ptr = res; for (int i = 0; i < 8; i++) { @@ -484,6 +519,25 @@ struct simd8float32 : simd256bit { } } + explicit simd8float32( + float f0, + float f1, + float f2, + float f3, + float f4, + float f5, + float f6, + float f7) { + f32[0] = f0; + f32[1] = f1; + f32[2] = f2; + f32[3] = f3; + f32[4] = f4; + f32[5] = f5; + f32[6] = f6; + f32[7] = f7; + } + template static simd8float32 binary_func( const simd8float32& a, @@ -650,6 +704,39 @@ simd8float32 gethigh128(const simd8float32& a, const simd8float32& b) { return c; } +inline simd8float32 cmple(simd8float32 a, simd8float32 b) { + simd8float32 result; + for (size_t j = 0; j < 8; j++) { + result.u32[j] = (a.f32[j] <= b.f32[j]) ? 1 : 0; + } + + return result; +} + +inline simd8float32 blend( + simd8float32 a, + simd8float32 b, + simd8float32 comparison) { + simd8float32 result; + for (size_t j = 0; j < 8; j++) { + result.f32[j] = (comparison.u32[j] == 0) ? a.f32[j] : b.f32[j]; + } + + return result; +} + +inline simd8uint32 blend( + simd8uint32 a, + simd8uint32 b, + simd8float32 comparison) { + simd8uint32 result; + for (size_t j = 0; j < 8; j++) { + result.u32[j] = (comparison.u32[j] == 0) ? a.u32[j] : b.u32[j]; + } + + return result; +} + } // namespace } // namespace faiss diff --git a/faiss/utils/simdlib_neon.h b/faiss/utils/simdlib_neon.h index 737e948927..7c4c2e14d2 100644 --- a/faiss/utils/simdlib_neon.h +++ b/faiss/utils/simdlib_neon.h @@ -671,6 +671,32 @@ struct simd8uint32 { explicit simd8uint32(const uint8_t* x) : simd8uint32(simd32uint8(x)) {} + explicit simd8uint32( + uint32_t u0, + uint32_t u1, + uint32_t u2, + uint32_t u3, + uint32_t u4, + uint32_t u5, + uint32_t u6, + uint32_t u7) { + uint32_t alignas(16) temp[8] = {u0, u1, u2, u3, u4, u5, u6, u7}; + data.val[0] = vld1q_u32(temp); + data.val[1] = vld1q_u32(temp + 4); + } + + simd8uint32 operator+(simd8uint32 other) const { + return simd8uint32{uint32x4x2_t{ + vaddq_u32(data.val[0], other.data.val[0]), + vaddq_u32(data.val[1], other.data.val[1])}}; + } + + simd8uint32 operator-(simd8uint32 other) const { + return simd8uint32{uint32x4x2_t{ + vsubq_u32(data.val[0], other.data.val[0]), + vsubq_u32(data.val[1], other.data.val[1])}}; + } + void clear() { detail::simdlib::set1(data, &vdupq_n_u32, static_cast(0)); } @@ -734,6 +760,20 @@ struct simd8float32 { explicit simd8float32(const float* x) : data{vld1q_f32(x), vld1q_f32(x + 4)} {} + explicit simd8float32( + float f0, + float f1, + float f2, + float f3, + float f4, + float f5, + float f6, + float f7) { + float alignas(16) temp[8] = {f0, f1, f2, f3, f4, f5, f6, f7}; + data.val[0] = vld1q_f32(temp); + data.val[1] = vld1q_f32(temp + 4); + } + void clear() { detail::simdlib::set1(data, &vdupq_n_f32, 0.f); } @@ -806,6 +846,36 @@ inline simd8float32 fmadd( vfmaq_f32(c.data.val[1], a.data.val[1], b.data.val[1])}}; } +inline simd8float32 cmple(simd8float32 a, simd8float32 b) { + return simd8float32{float32x4x2_t{ + vcleq_f32(a.data.val[0], b.data.val[0]), + vcleq_f32(a.data.val[1], b.data.val[1])}}; +} + +inline simd8float32 blend( + simd8float32 a, + simd8float32 b, + simd8float32 comparison) { + return simd8float32{float32x4x2_t{ + vbslq_f32(comparison.data.val[0], a.data.val[0], b.data.val[0]), + vbslq_f32(comparison.data.val[1], a.data.val[1], b.data.val[1])}}; +} + +inline simd8uint32 blend( + simd8uint32 a, + simd8uint32 b, + simd8float32 comparison) { + return simd8float32{uint32x4x2_t{ + vbslq_u32( + vreinterpretq_u32_f32(comparison).data.val[0], + a.data.val[0], + b.data.val[0]), + vbslq_u32( + vreinterpretq_u32_f32(comparison).data.val[1], + a.data.val[1], + b.data.val[1])}}; +} + namespace { // get even float32's of a and b, interleaved