diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt index fd3ddb30ec..f452e4a7bf 100644 --- a/faiss/CMakeLists.txt +++ b/faiss/CMakeLists.txt @@ -86,6 +86,9 @@ set(FAISS_SRC utils/quantize_lut.cpp utils/random.cpp utils/utils.cpp + utils/distances_fused/avx512.cpp + utils/distances_fused/distances_fused.cpp + utils/distances_fused/simdlib_based.cpp ) set(FAISS_HEADERS @@ -187,6 +190,9 @@ set(FAISS_HEADERS utils/simdlib_emulated.h utils/simdlib_neon.h utils/utils.h + utils/distances_fused/avx512.h + utils/distances_fused/distances_fused.h + utils/distances_fused/simdlib_based.h ) if(NOT WIN32) diff --git a/faiss/utils/distances.cpp b/faiss/utils/distances.cpp index 037f86b7af..80694c3bb6 100644 --- a/faiss/utils/distances.cpp +++ b/faiss/utils/distances.cpp @@ -26,6 +26,8 @@ #include #include +#include + #ifndef FINTEGER #define FINTEGER long #endif @@ -229,7 +231,7 @@ void exhaustive_inner_product_blas( // distance correction is an operator that can be applied to transform // the distances template -void exhaustive_L2sqr_blas( +void exhaustive_L2sqr_blas_default_impl( const float* x, const float* y, size_t d, @@ -311,10 +313,20 @@ void exhaustive_L2sqr_blas( } } +template +void exhaustive_L2sqr_blas( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + ResultHandler& res, + const float* y_norms = nullptr) { + exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res); +} + #ifdef __AVX2__ -// an override for AVX2 if only a single closest point is needed. -template <> -void exhaustive_L2sqr_blas>>( +void exhaustive_L2sqr_blas_cmax_avx2( const float* x, const float* y, size_t d, @@ -513,11 +525,53 @@ void exhaustive_L2sqr_blas>>( res.add_result(i, current_min_distance, current_min_index); } } + // Does nothing for SingleBestResultHandler, but + // keeping the call for the consistency. + res.end_multiple(); InterruptCallback::check(); } } #endif +// an override if only a single closest point is needed +template <> +void exhaustive_L2sqr_blas>>( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + SingleBestResultHandler>& res, + const float* y_norms) { +#if defined(__AVX2__) + // use a faster fused kernel if available + if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) { + // the kernel is available and it is complete, we're done. + return; + } + + // run the specialized AVX2 implementation + exhaustive_L2sqr_blas_cmax_avx2(x, y, d, nx, ny, res, y_norms); + +#elif defined(__aarch64__) + // use a faster fused kernel if available + if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) { + // the kernel is available and it is complete, we're done. + return; + } + + // run the default implementation + exhaustive_L2sqr_blas_default_impl< + SingleBestResultHandler>>( + x, y, d, nx, ny, res, y_norms); +#else + // run the default implementation + exhaustive_L2sqr_blas_default_impl< + SingleBestResultHandler>>( + x, y, d, nx, ny, res, y_norms); +#endif +} + template void knn_L2sqr_select( const float* x, diff --git a/faiss/utils/distances_fused/avx512.cpp b/faiss/utils/distances_fused/avx512.cpp new file mode 100644 index 0000000000..5c7906e79e --- /dev/null +++ b/faiss/utils/distances_fused/avx512.cpp @@ -0,0 +1,320 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#ifdef __AVX512__ + +#include + +namespace faiss { + +namespace { + +// It makes sense to like to overload certain cases because the further +// kernels are in need of AVX512 registers. So, let's tell compiler +// not to waste registers on a bit faster code, if needed. +template +float l2_sqr(const float* const x) { + // compiler should be smart enough to handle that + float output = x[0] * x[0]; + for (size_t i = 1; i < DIM; i++) { + output += x[i] * x[i]; + } + + return output; +} + +template <> +float l2_sqr<4>(const float* const x) { + __m128 v = _mm_loadu_ps(x); + __m128 v2 = _mm_mul_ps(v, v); + v2 = _mm_hadd_ps(v2, v2); + v2 = _mm_hadd_ps(v2, v2); + + return _mm_cvtss_f32(v2); +} + +template +float dot_product( + const float* const __restrict x, + const float* const __restrict y) { + // compiler should be smart enough to handle that + float output = x[0] * y[0]; + for (size_t i = 1; i < DIM; i++) { + output += x[i] * y[i]; + } + + return output; +} + +// The kernel for low dimensionality vectors. +// Finds the closest one from y for every given NX_POINTS_PER_LOOP points from x +// +// DIM is the dimensionality of the data +// NX_POINTS_PER_LOOP is the number of x points that get processed +// simultaneously. +// NY_POINTS_PER_LOOP is the number of y points that get processed +// simultaneously. +template +void kernel( + const float* const __restrict x, + const float* const __restrict y, + const float* const __restrict y_transposed, + size_t ny, + SingleBestResultHandler>& res, + const float* __restrict y_norms, + size_t i) { + const size_t ny_p = + (ny / (16 * NY_POINTS_PER_LOOP)) * (16 * NY_POINTS_PER_LOOP); + + // compute + const float* const __restrict xd_0 = x + i * DIM; + + // prefetch the next point + _mm_prefetch(xd_0 + DIM * sizeof(float), _MM_HINT_NTA); + + // load a single point from x + // load -2 * value + __m512 x_i[NX_POINTS_PER_LOOP][DIM]; + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + for (size_t dd = 0; dd < DIM; dd++) { + x_i[nx_k][dd] = _mm512_set1_ps(-2 * *(xd_0 + nx_k * DIM + dd)); + } + } + + // compute x_norm + float x_norm_i[NX_POINTS_PER_LOOP]; + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + x_norm_i[nx_k] = l2_sqr(xd_0 + nx_k * DIM); + } + + // distances and indices + __m512 min_distances_i[NX_POINTS_PER_LOOP]; + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + min_distances_i[nx_k] = + _mm512_set1_ps(res.dis_tab[i + nx_k] - x_norm_i[nx_k]); + } + + __m512i min_indices_i[NX_POINTS_PER_LOOP]; + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + min_indices_i[nx_k] = _mm512_set1_epi32(0); + } + + // + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + const __m512i indices_delta = _mm512_set1_epi32(16); + + // main loop + size_t j = 0; + for (; j < ny_p; j += NY_POINTS_PER_LOOP * 16) { + // compute dot products for NX_POINTS from x and NY_POINTS from y + // technically, we're multiplying -2x and y + __m512 dp_i[NX_POINTS_PER_LOOP][NY_POINTS_PER_LOOP]; + + // DIM 0 that uses MUL + for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) { + __m512 y_i = _mm512_loadu_ps(y_transposed + j + ny_k * 16 + ny * 0); + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + dp_i[nx_k][ny_k] = _mm512_mul_ps(x_i[nx_k][0], y_i); + } + } + + // other DIMs that use FMA + for (size_t dd = 1; dd < DIM; dd++) { + for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) { + __m512 y_i = + _mm512_loadu_ps(y_transposed + j + ny_k * 16 + ny * dd); + + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + dp_i[nx_k][ny_k] = _mm512_fmadd_ps( + x_i[nx_k][dd], y_i, dp_i[nx_k][ny_k]); + } + } + } + + // compute y^2 - 2 * (x,y) + for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) { + __m512 y_l2_sqr = _mm512_loadu_ps(y_norms + j + ny_k * 16); + + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + dp_i[nx_k][ny_k] = _mm512_add_ps(dp_i[nx_k][ny_k], y_l2_sqr); + } + } + + // do the comparisons and alter the min indices + for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) { + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + const __mmask16 comparison = _mm512_cmp_ps_mask( + min_distances_i[nx_k], dp_i[nx_k][ny_k], _CMP_LE_OS); + min_distances_i[nx_k] = _mm512_mask_blend_ps( + comparison, dp_i[nx_k][ny_k], min_distances_i[nx_k]); + min_indices_i[nx_k] = _mm512_castps_si512(_mm512_mask_blend_ps( + comparison, + _mm512_castsi512_ps(current_indices), + _mm512_castsi512_ps(min_indices_i[nx_k]))); + } + + current_indices = _mm512_add_epi32(current_indices, indices_delta); + } + } + + // dump values and find the minimum distance / minimum index + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + float min_distances_scalar[16]; + uint32_t min_indices_scalar[16]; + _mm512_storeu_ps(min_distances_scalar, min_distances_i[nx_k]); + _mm512_storeu_si512( + (__m512i*)(min_indices_scalar), min_indices_i[nx_k]); + + float current_min_distance = res.dis_tab[i + nx_k]; + uint32_t current_min_index = res.ids_tab[i + nx_k]; + + // This unusual comparison is needed to maintain the behavior + // of the original implementation: if two indices are + // represented with equal distance values, then + // the index with the min value is returned. + for (size_t jv = 0; jv < 16; jv++) { + // add missing x_norms[i] + float distance_candidate = + min_distances_scalar[jv] + x_norm_i[nx_k]; + + // negative values can occur for identical vectors + // due to roundoff errors. + if (distance_candidate < 0) + distance_candidate = 0; + + const int64_t index_candidate = min_indices_scalar[jv]; + + if (current_min_distance > distance_candidate) { + current_min_distance = distance_candidate; + current_min_index = index_candidate; + } else if ( + current_min_distance == distance_candidate && + current_min_index > index_candidate) { + current_min_index = index_candidate; + } + } + + // process leftovers + for (size_t j0 = j; j0 < ny; j0++) { + const float dp = + dot_product(x + (i + nx_k) * DIM, y + j0 * DIM); + float dis = x_norm_i[nx_k] + y_norms[j0] - 2 * dp; + // negative values can occur for identical vectors + // due to roundoff errors. + if (dis < 0) { + dis = 0; + } + + if (current_min_distance > dis) { + current_min_distance = dis; + current_min_index = j0; + } + } + + // done + res.add_result(i + nx_k, current_min_distance, current_min_index); + } +} + +template +void exhaustive_L2sqr_fused_cmax( + const float* const __restrict x, + const float* const __restrict y, + size_t nx, + size_t ny, + SingleBestResultHandler>& res, + const float* __restrict y_norms) { + // BLAS does not like empty matrices + if (nx == 0 || ny == 0) { + return; + } + + // compute norms for y + std::unique_ptr del2; + if (!y_norms) { + float* y_norms2 = new float[ny]; + del2.reset(y_norms2); + + for (size_t i = 0; i < ny; i++) { + y_norms2[i] = l2_sqr(y + i * DIM); + } + + y_norms = y_norms2; + } + + // initialize res + res.begin_multiple(0, nx); + + // transpose y + std::vector y_transposed(DIM * ny); + for (size_t j = 0; j < DIM; j++) { + for (size_t i = 0; i < ny; i++) { + y_transposed[j * ny + i] = y[j + i * DIM]; + } + } + + const size_t nx_p = (nx / NX_POINTS_PER_LOOP) * NX_POINTS_PER_LOOP; + // the main loop. +#pragma omp parallel for schedule(dynamic) + for (size_t i = 0; i < nx_p; i += NX_POINTS_PER_LOOP) { + kernel( + x, y, y_transposed.data(), ny, res, y_norms, i); + } + + for (size_t i = nx_p; i < nx; i++) { + kernel( + x, y, y_transposed.data(), ny, res, y_norms, i); + } + + // Does nothing for SingleBestResultHandler, but + // keeping the call for the consistency. + res.end_multiple(); + InterruptCallback::check(); +} + +} // namespace + +bool exhaustive_L2sqr_fused_cmax_AVX512( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + SingleBestResultHandler>& res, + const float* y_norms) { + // process only cases with certain dimensionalities + if (d == 1) { + exhaustive_L2sqr_fused_cmax<1, 4, 2>(x, y, nx, ny, res, y_norms); + return true; + } else if (d == 2) { + exhaustive_L2sqr_fused_cmax<2, 4, 2>(x, y, nx, ny, res, y_norms); + return true; + } else if (d == 4) { + exhaustive_L2sqr_fused_cmax<4, 4, 1>(x, y, nx, ny, res, y_norms); + return true; + } else if (d == 8) { + exhaustive_L2sqr_fused_cmax<8, 2, 1>(x, y, nx, ny, res, y_norms); + return true; + } else if (d == 12) { + exhaustive_L2sqr_fused_cmax<12, 1, 1>(x, y, nx, ny, res, y_norms); + return true; + } else if (d == 16) { + exhaustive_L2sqr_fused_cmax<16, 1, 1>(x, y, nx, ny, res, y_norms); + return true; + } + + return false; +} + +} // namespace faiss + +#endif diff --git a/faiss/utils/distances_fused/avx512.h b/faiss/utils/distances_fused/avx512.h new file mode 100644 index 0000000000..d730e3b61c --- /dev/null +++ b/faiss/utils/distances_fused/avx512.h @@ -0,0 +1,36 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// AVX512 might be not used, but this version provides ~2x speedup +// over AVX2 kernel, say, for training PQx10 or PQx12, and speeds up +// additional cases with larger dimensionalities. + +#pragma once + +#include +#include + +#include + +#ifdef __AVX512__ + +namespace faiss { + +// Returns true if the fused kernel is available and the data was processed. +// Returns false if the fused kernel is not available. +bool exhaustive_L2sqr_fused_cmax_AVX512( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + SingleBestResultHandler>& res, + const float* y_norms); + +} // namespace faiss + +#endif diff --git a/faiss/utils/distances_fused/distances_fused.cpp b/faiss/utils/distances_fused/distances_fused.cpp new file mode 100644 index 0000000000..a676e17144 --- /dev/null +++ b/faiss/utils/distances_fused/distances_fused.cpp @@ -0,0 +1,35 @@ +#include + +#include + +#include +#include + +namespace faiss { + +bool exhaustive_L2sqr_fused_cmax( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + SingleBestResultHandler>& res, + const float* y_norms) { + if (nx == 0 || ny == 0) { + // nothing to do + return true; + } + +#ifdef __AVX512__ + // avx512 kernel + return exhaustive_L2sqr_fused_cmax_AVX512(x, y, d, nx, ny, res, y_norms); +#elif defined(__AVX2__) || defined(__aarch64__) + // avx2 or arm neon kernel + return exhaustive_L2sqr_fused_cmax_simdlib(x, y, d, nx, ny, res, y_norms); +#else + // not supported, please use a general-purpose kernel + return false; +#endif +} + +} // namespace faiss diff --git a/faiss/utils/distances_fused/distances_fused.h b/faiss/utils/distances_fused/distances_fused.h new file mode 100644 index 0000000000..e6e35c209e --- /dev/null +++ b/faiss/utils/distances_fused/distances_fused.h @@ -0,0 +1,40 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// This file contains a fused kernel that combines distance computation +// and the search for the CLOSEST point. Currently, this is done for small +// dimensionality vectors when it is beneficial to avoid storing temporary +// dot product information in RAM. This is particularly effective +// when training PQx10 or PQx12 with the default parameters. +// +// InterruptCallback::check() is not used, because it is assumed that the +// kernel takes a little time because of a tiny dimensionality. +// +// Later on, similar optimization can be implemented for large size vectors, +// but a different kernel is needed. +// + +#pragma once + +#include + +#include + +namespace faiss { + +// Returns true if the fused kernel is available and the data was processed. +// Returns false if the fused kernel is not available. +bool exhaustive_L2sqr_fused_cmax( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + SingleBestResultHandler>& res, + const float* y_norms); + +} // namespace faiss diff --git a/faiss/utils/distances_fused/simdlib_based.cpp b/faiss/utils/distances_fused/simdlib_based.cpp new file mode 100644 index 0000000000..434f0dfb3c --- /dev/null +++ b/faiss/utils/distances_fused/simdlib_based.cpp @@ -0,0 +1,312 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#if defined(__AVX2__) || defined(__aarch64__) + +#include + +#if defined(__AVX2__) +#include +#endif + +namespace faiss { + +namespace { + +// It makes sense to like to overload certain cases because the further +// kernels are in need of registers. So, let's tell compiler +// not to waste registers on a bit faster code, if needed. +template +float l2_sqr(const float* const x) { + // compiler should be smart enough to handle that + float output = x[0] * x[0]; + for (size_t i = 1; i < DIM; i++) { + output += x[i] * x[i]; + } + + return output; +} + +template +float dot_product( + const float* const __restrict x, + const float* const __restrict y) { + // compiler should be smart enough to handle that + float output = x[0] * y[0]; + for (size_t i = 1; i < DIM; i++) { + output += x[i] * y[i]; + } + + return output; +} + +// The kernel for low dimensionality vectors. +// Finds the closest one from y for every given NX_POINTS_PER_LOOP points from x +// +// DIM is the dimensionality of the data +// NX_POINTS_PER_LOOP is the number of x points that get processed +// simultaneously. +// NY_POINTS_PER_LOOP is the number of y points that get processed +// simultaneously. +template +void kernel( + const float* const __restrict x, + const float* const __restrict y, + const float* const __restrict y_transposed, + const size_t ny, + SingleBestResultHandler>& res, + const float* __restrict y_norms, + const size_t i) { + const size_t ny_p = + (ny / (8 * NY_POINTS_PER_LOOP)) * (8 * NY_POINTS_PER_LOOP); + + // compute + const float* const __restrict xd_0 = x + i * DIM; + + // prefetch the next point +#if defined(__AVX2__) + _mm_prefetch(xd_0 + DIM * sizeof(float), _MM_HINT_NTA); +#endif + + // load a single point from x + // load -2 * value + simd8float32 x_i[NX_POINTS_PER_LOOP][DIM]; + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + for (size_t dd = 0; dd < DIM; dd++) { + x_i[nx_k][dd] = simd8float32(-2 * *(xd_0 + nx_k * DIM + dd)); + } + } + + // compute x_norm + float x_norm_i[NX_POINTS_PER_LOOP]; + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + x_norm_i[nx_k] = l2_sqr(xd_0 + nx_k * DIM); + } + + // distances and indices + simd8float32 min_distances_i[NX_POINTS_PER_LOOP]; + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + min_distances_i[nx_k] = + simd8float32(res.dis_tab[i + nx_k] - x_norm_i[nx_k]); + } + + simd8uint32 min_indices_i[NX_POINTS_PER_LOOP]; + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + min_indices_i[nx_k] = simd8uint32((uint32_t)0); + } + + // + simd8uint32 current_indices = simd8uint32(0, 1, 2, 3, 4, 5, 6, 7); + const simd8uint32 indices_delta = simd8uint32(8); + + // main loop + size_t j = 0; + for (; j < ny_p; j += NY_POINTS_PER_LOOP * 8) { + // compute dot products for NX_POINTS from x and NY_POINTS from y + // technically, we're multiplying -2x and y + simd8float32 dp_i[NX_POINTS_PER_LOOP][NY_POINTS_PER_LOOP]; + + // DIM 0 that uses MUL + for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) { + simd8float32 y_i = + simd8float32(y_transposed + j + ny_k * 8 + ny * 0); + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + dp_i[nx_k][ny_k] = x_i[nx_k][0] * y_i; + } + } + + // other DIMs that use FMA + for (size_t dd = 1; dd < DIM; dd++) { + for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) { + simd8float32 y_i = + simd8float32(y_transposed + j + ny_k * 8 + ny * dd); + + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + dp_i[nx_k][ny_k] = + fmadd(x_i[nx_k][dd], y_i, dp_i[nx_k][ny_k]); + } + } + } + + // compute y^2 + (-2x,y) + for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) { + simd8float32 y_l2_sqr = simd8float32(y_norms + j + ny_k * 8); + + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + dp_i[nx_k][ny_k] = dp_i[nx_k][ny_k] + y_l2_sqr; + } + } + + // do the comparisons and alter the min indices + for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) { + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + // cmpps + const simd8float32 comparison = + cmple(min_distances_i[nx_k], dp_i[nx_k][ny_k]); + min_distances_i[nx_k] = blend( + dp_i[nx_k][ny_k], min_distances_i[nx_k], comparison); + min_indices_i[nx_k] = + blend(current_indices, min_indices_i[nx_k], comparison); + } + + current_indices = current_indices + indices_delta; + } + } + + // dump values and find the minimum distance / minimum index + for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) { + float min_distances_scalar[8]; + uint32_t min_indices_scalar[8]; + + min_distances_i[nx_k].storeu(min_distances_scalar); + min_indices_i[nx_k].storeu(min_indices_scalar); + + float current_min_distance = res.dis_tab[i + nx_k]; + uint32_t current_min_index = res.ids_tab[i + nx_k]; + + // This unusual comparison is needed to maintain the behavior + // of the original implementation: if two indices are + // represented with equal distance values, then + // the index with the min value is returned. + for (size_t jv = 0; jv < 8; jv++) { + // add missing x_norms[i] + float distance_candidate = + min_distances_scalar[jv] + x_norm_i[nx_k]; + + // negative values can occur for identical vectors + // due to roundoff errors. + if (distance_candidate < 0) { + distance_candidate = 0; + } + + const int64_t index_candidate = min_indices_scalar[jv]; + + if (current_min_distance > distance_candidate) { + current_min_distance = distance_candidate; + current_min_index = index_candidate; + } else if ( + current_min_distance == distance_candidate && + current_min_index > index_candidate) { + current_min_index = index_candidate; + } + } + + // process leftovers + for (size_t j0 = j; j0 < ny; j0++) { + const float dp = + dot_product(x + (i + nx_k) * DIM, y + j0 * DIM); + float dis = x_norm_i[nx_k] + y_norms[j0] - 2 * dp; + // negative values can occur for identical vectors + // due to roundoff errors. + if (dis < 0) { + dis = 0; + } + + if (current_min_distance > dis) { + current_min_distance = dis; + current_min_index = j0; + } + } + + // done + res.add_result(i + nx_k, current_min_distance, current_min_index); + } +} + +template +void exhaustive_L2sqr_fused_cmax( + const float* const __restrict x, + const float* const __restrict y, + size_t nx, + size_t ny, + SingleBestResultHandler>& res, + const float* __restrict y_norms) { + // BLAS does not like empty matrices + if (nx == 0 || ny == 0) { + return; + } + + // compute norms for y + std::unique_ptr del2; + if (!y_norms) { + float* y_norms2 = new float[ny]; + del2.reset(y_norms2); + + for (size_t i = 0; i < ny; i++) { + y_norms2[i] = l2_sqr(y + i * DIM); + } + + y_norms = y_norms2; + } + + // initialize res + res.begin_multiple(0, nx); + + // transpose y + std::vector y_transposed(DIM * ny); + for (size_t j = 0; j < DIM; j++) { + for (size_t i = 0; i < ny; i++) { + y_transposed[j * ny + i] = y[j + i * DIM]; + } + } + + const size_t nx_p = (nx / NX_POINTS_PER_LOOP) * NX_POINTS_PER_LOOP; + // the main loop. +#pragma omp parallel for schedule(dynamic) + for (size_t i = 0; i < nx_p; i += NX_POINTS_PER_LOOP) { + kernel( + x, y, y_transposed.data(), ny, res, y_norms, i); + } + + for (size_t i = nx_p; i < nx; i++) { + kernel( + x, y, y_transposed.data(), ny, res, y_norms, i); + } + + // Does nothing for SingleBestResultHandler, but + // keeping the call for the consistency. + res.end_multiple(); + InterruptCallback::check(); +} + +} // namespace + +bool exhaustive_L2sqr_fused_cmax_simdlib( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + SingleBestResultHandler>& res, + const float* y_norms) { + // Process only cases with certain dimensionalities. + // An acceptable dimensionality value is limited by the number of + // available registers. + if (d == 1) { + exhaustive_L2sqr_fused_cmax<1, 2, 2>(x, y, nx, ny, res, y_norms); + return true; + } else if (d == 2) { + exhaustive_L2sqr_fused_cmax<2, 2, 2>(x, y, nx, ny, res, y_norms); + return true; + } else if (d == 4) { + exhaustive_L2sqr_fused_cmax<4, 2, 1>(x, y, nx, ny, res, y_norms); + return true; + } else if (d == 8) { + exhaustive_L2sqr_fused_cmax<8, 1, 1>(x, y, nx, ny, res, y_norms); + return true; + } + + return false; +} + +} // namespace faiss + +#endif diff --git a/faiss/utils/distances_fused/simdlib_based.h b/faiss/utils/distances_fused/simdlib_based.h new file mode 100644 index 0000000000..b60da7b193 --- /dev/null +++ b/faiss/utils/distances_fused/simdlib_based.h @@ -0,0 +1,32 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +#if defined(__AVX2__) || defined(__aarch64__) + +namespace faiss { + +// Returns true if the fused kernel is available and the data was processed. +// Returns false if the fused kernel is not available. +bool exhaustive_L2sqr_fused_cmax_simdlib( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + SingleBestResultHandler>& res, + const float* y_norms); + +} // namespace faiss + +#endif diff --git a/faiss/utils/simdlib_avx2.h b/faiss/utils/simdlib_avx2.h index 6eee3fa0e1..b108c6bd8a 100644 --- a/faiss/utils/simdlib_avx2.h +++ b/faiss/utils/simdlib_avx2.h @@ -359,6 +359,25 @@ struct simd8uint32 : simd256bit { explicit simd8uint32(const uint8_t* x) : simd256bit((const void*)x) {} + explicit simd8uint32( + uint32_t u0, + uint32_t u1, + uint32_t u2, + uint32_t u3, + uint32_t u4, + uint32_t u5, + uint32_t u6, + uint32_t u7) + : simd256bit(_mm256_setr_epi32(u0, u1, u2, u3, u4, u5, u6, u7)) {} + + simd8uint32 operator+(simd8uint32 other) const { + return simd8uint32(_mm256_add_epi32(i, other.i)); + } + + simd8uint32 operator-(simd8uint32 other) const { + return simd8uint32(_mm256_sub_epi32(i, other.i)); + } + std::string elements_to_string(const char* fmt) const { uint32_t bytes[8]; storeu((void*)bytes); @@ -394,7 +413,18 @@ struct simd8float32 : simd256bit { explicit simd8float32(float x) : simd256bit(_mm256_set1_ps(x)) {} - explicit simd8float32(const float* x) : simd256bit(_mm256_load_ps(x)) {} + explicit simd8float32(const float* x) : simd256bit(_mm256_loadu_ps(x)) {} + + explicit simd8float32( + float f0, + float f1, + float f2, + float f3, + float f4, + float f5, + float f6, + float f7) + : simd256bit(_mm256_setr_ps(f0, f1, f2, f3, f4, f5, f6, f7)) {} simd8float32 operator*(simd8float32 other) const { return simd8float32(_mm256_mul_ps(f, other.f)); @@ -439,6 +469,25 @@ inline simd8float32 fmadd(simd8float32 a, simd8float32 b, simd8float32 c) { return simd8float32(_mm256_fmadd_ps(a.f, b.f, c.f)); } +inline simd8float32 cmple(simd8float32 a, simd8float32 b) { + return simd8float32(_mm256_cmp_ps(a.f, b.f, _CMP_LE_OS)); +} + +inline simd8float32 blend( + simd8float32 a, + simd8float32 b, + simd8float32 comparison) { + return simd8float32(_mm256_blendv_ps(a.f, b.f, comparison.f)); +} + +inline simd8uint32 blend( + simd8uint32 a, + simd8uint32 b, + simd8float32 comparison) { + return simd8uint32(_mm256_castps_si256(_mm256_blendv_ps( + _mm256_castsi256_ps(a.i), _mm256_castsi256_ps(b.i), comparison.f))); +} + namespace { // get even float32's of a and b, interleaved diff --git a/faiss/utils/simdlib_emulated.h b/faiss/utils/simdlib_emulated.h index 9267ef1b46..50b38eb814 100644 --- a/faiss/utils/simdlib_emulated.h +++ b/faiss/utils/simdlib_emulated.h @@ -440,6 +440,41 @@ struct simd8uint32 : simd256bit { explicit simd8uint32(const uint32_t* x) : simd256bit((const void*)x) {} + explicit simd8uint32( + uint32_t u0, + uint32_t u1, + uint32_t u2, + uint32_t u3, + uint32_t u4, + uint32_t u5, + uint32_t u6, + uint32_t u7) { + u32[0] = u0; + u32[1] = u1; + u32[2] = u2; + u32[3] = u3; + u32[4] = u4; + u32[5] = u5; + u32[6] = u6; + u32[7] = u7; + } + + simd8uint32 operator+(simd8uint32 other) const { + simd8uint32 result; + for (int i = 0; i < 8; i++) { + result.u32[i] = u32[i] + other.u32[i]; + } + return result; + } + + simd8uint32 operator-(simd8uint32 other) const { + simd8uint32 result; + for (int i = 0; i < 8; i++) { + result.u32[i] = u32[i] - other.u32[i]; + } + return result; + } + std::string elements_to_string(const char* fmt) const { char res[1000], *ptr = res; for (int i = 0; i < 8; i++) { @@ -484,6 +519,25 @@ struct simd8float32 : simd256bit { } } + explicit simd8float32( + float f0, + float f1, + float f2, + float f3, + float f4, + float f5, + float f6, + float f7) { + f32[0] = f0; + f32[1] = f1; + f32[2] = f2; + f32[3] = f3; + f32[4] = f4; + f32[5] = f5; + f32[6] = f6; + f32[7] = f7; + } + template static simd8float32 binary_func( const simd8float32& a, @@ -650,6 +704,39 @@ simd8float32 gethigh128(const simd8float32& a, const simd8float32& b) { return c; } +inline simd8float32 cmple(simd8float32 a, simd8float32 b) { + simd8float32 result; + for (size_t j = 0; j < 8; j++) { + result.u32[j] = (a.f32[j] <= b.f32[j]) ? 1 : 0; + } + + return result; +} + +inline simd8float32 blend( + simd8float32 a, + simd8float32 b, + simd8float32 comparison) { + simd8float32 result; + for (size_t j = 0; j < 8; j++) { + result.f32[j] = (comparison.u32[j] == 0) ? a.f32[j] : b.f32[j]; + } + + return result; +} + +inline simd8uint32 blend( + simd8uint32 a, + simd8uint32 b, + simd8float32 comparison) { + simd8uint32 result; + for (size_t j = 0; j < 8; j++) { + result.u32[j] = (comparison.u32[j] == 0) ? a.u32[j] : b.u32[j]; + } + + return result; +} + } // namespace } // namespace faiss diff --git a/faiss/utils/simdlib_neon.h b/faiss/utils/simdlib_neon.h index 737e948927..77884a05da 100644 --- a/faiss/utils/simdlib_neon.h +++ b/faiss/utils/simdlib_neon.h @@ -671,6 +671,32 @@ struct simd8uint32 { explicit simd8uint32(const uint8_t* x) : simd8uint32(simd32uint8(x)) {} + explicit simd8uint32( + uint32_t u0, + uint32_t u1, + uint32_t u2, + uint32_t u3, + uint32_t u4, + uint32_t u5, + uint32_t u6, + uint32_t u7) { + uint32_t alignas(16) temp[8] = {u0, u1, u2, u3, u4, u5, u6, u7}; + data.val[0] = vld1q_u32(temp); + data.val[1] = vld1q_u32(temp + 4); + } + + simd8uint32 operator+(simd8uint32 other) const { + return simd8uint32{uint32x4x2_t{ + vaddq_u32(data.val[0], other.data.val[0]), + vaddq_u32(data.val[1], other.data.val[1])}}; + } + + simd8uint32 operator-(simd8uint32 other) const { + return simd8uint32{uint32x4x2_t{ + vsubq_u32(data.val[0], other.data.val[0]), + vsubq_u32(data.val[1], other.data.val[1])}}; + } + void clear() { detail::simdlib::set1(data, &vdupq_n_u32, static_cast(0)); } @@ -734,6 +760,20 @@ struct simd8float32 { explicit simd8float32(const float* x) : data{vld1q_f32(x), vld1q_f32(x + 4)} {} + explicit simd8float32( + float f0, + float f1, + float f2, + float f3, + float f4, + float f5, + float f6, + float f7) { + float alignas(16) temp[8] = {f0, f1, f2, f3, f4, f5, f6, f7}; + data.val[0] = vld1q_f32(temp); + data.val[1] = vld1q_f32(temp + 4); + } + void clear() { detail::simdlib::set1(data, &vdupq_n_f32, 0.f); } @@ -806,6 +846,36 @@ inline simd8float32 fmadd( vfmaq_f32(c.data.val[1], a.data.val[1], b.data.val[1])}}; } +inline simd8float32 cmple(simd8float32 a, simd8float32 b) { + return simd8float32{float32x4x2_t{ + vcleq_f32(a.data.val[0], b.data.val[0]), + vcleq_f32(a.data.val[1], b.data.val[1])}}; +} + +inline simd8float32 blend( + simd8float32 a, + simd8float32 b, + simd8float32 comparison) { + return simd8float32{float32x4x2_t{ + vbslq_f32(comparison.data.val[0], a.data.val[0], b.data.val[0]), + vbslq_f32(comparison.data.val[1], a.data.val[1], b.data.val[1])}}; +} + +inline simd8uint32 blend( + simd8uint32 a, + simd8uint32 b, + simd8float32 comparison) { + return simd8float32{uint32x4x2_t{ + vbslq_u32( + vreinterpretq_u32_f32(comparison.data).val[0], + a.data.val[0], + b.data.val[0]), + vbslq_u32( + vreinterpretq_u32_f32(comparison.data).val[1], + a.data.val[1], + b.data.val[1])}}; +} + namespace { // get even float32's of a and b, interleaved