From d7ccdc4d8eeb5ca7e41aeaca4653f49f1e9a0447 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Fri, 7 Apr 2023 12:15:49 -0700 Subject: [PATCH] Faster versions of fvec_op_ny_Dx for AVX2 (#2811) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2811 Use transpose_AxB kernels to speedup computations. Differential Revision: D44726814 fbshipit-source-id: 566b1179e4deea53cce24f45a5106e7b52b23641 --- faiss/utils/distances_simd.cpp | 392 +++++++++++++++++++++++++++++---- 1 file changed, 346 insertions(+), 46 deletions(-) diff --git a/faiss/utils/distances_simd.cpp b/faiss/utils/distances_simd.cpp index 39195d62a5..99562fc73c 100644 --- a/faiss/utils/distances_simd.cpp +++ b/faiss/utils/distances_simd.cpp @@ -247,6 +247,32 @@ static inline __m128 masked_read(int d, const float* x) { namespace { +/// helper function +inline float horizonal_sum(const __m128 v) { + __m128 accu = _mm_hadd_ps(v, v); + accu = _mm_hadd_ps(accu, accu); + return _mm_cvtss_f32(accu); +} + +#ifdef __AVX2__ +/// helper function for AVX2 +inline float horizonal_sum(const __m256 v) { + // horizontal sum + const __m256 h0 = _mm256_hadd_ps(v, v); + const __m256 h1 = _mm256_hadd_ps(h0, h0); + + // extract high and low __m128 regs from __m256 + const __m128 h2 = _mm256_extractf128_ps(h1, 1); + const __m128 h3 = _mm256_castps256_ps128(h1); + + // get a final hsum into all 4 regs + const __m128 h4 = _mm_add_ss(h2, h3); + + // extract f[0] from __m128 + return _mm_cvtss_f32(h4); +} +#endif + /// Function that does a component-wise operation between x and y /// to compute L2 distances. ElementOp can then be used in the fvec_op_ny /// functions below @@ -260,6 +286,13 @@ struct ElementOpL2 { __m128 tmp = _mm_sub_ps(x, y); return _mm_mul_ps(tmp, tmp); } + +#ifdef __AVX2__ + static __m256 op(__m256 x, __m256 y) { + __m256 tmp = _mm256_sub_ps(x, y); + return _mm256_mul_ps(tmp, tmp); + } +#endif }; /// Function that does a component-wise operation between x and y @@ -272,6 +305,12 @@ struct ElementOpIP { static __m128 op(__m128 x, __m128 y) { return _mm_mul_ps(x, y); } + +#ifdef __AVX2__ + static __m256 op(__m256 x, __m256 y) { + return _mm256_mul_ps(x, y); + } +#endif }; template @@ -314,6 +353,131 @@ void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) { } } +#ifdef __AVX2__ + +template <> +void fvec_op_ny_D2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D2-vectors per loop. + _mm_prefetch(y, _MM_HINT_T0); + _mm_prefetch(y + 16, _MM_HINT_T0); + + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + + for (i = 0; i < ny8 * 8; i += 8) { + _mm_prefetch(y + 32, _MM_HINT_T0); + + // load 8x2 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + + transpose_8x2( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + v0, + v1); + + // compute distances + __m256 distances = _mm256_mul_ps(m0, v0); + distances = _mm256_fmadd_ps(m1, v1, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 16; + } + } + + if (i < ny) { + // process leftovers + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float distance = x0 * y[0] + x1 * y[1]; + y += 2; + dis[i] = distance; + } + } +} + +template <> +void fvec_op_ny_D2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D2-vectors per loop. + _mm_prefetch(y, _MM_HINT_T0); + _mm_prefetch(y + 16, _MM_HINT_T0); + + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + + for (i = 0; i < ny8 * 8; i += 8) { + _mm_prefetch(y + 32, _MM_HINT_T0); + + // load 8x2 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + + transpose_8x2( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + v0, + v1); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 16; + } + } + + if (i < ny) { + // process leftovers + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float sub0 = x0 - y[0]; + float sub1 = x1 - y[1]; + float distance = sub0 * sub0 + sub1 * sub1; + + y += 2; + dis[i] = distance; + } + } +} + +#endif + template void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) { __m128 x0 = _mm_loadu_ps(x); @@ -321,17 +485,12 @@ void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) { for (size_t i = 0; i < ny; i++) { __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); y += 4; - accu = _mm_hadd_ps(accu, accu); - accu = _mm_hadd_ps(accu, accu); - dis[i] = _mm_cvtss_f32(accu); + dis[i] = horizonal_sum(accu); } } #ifdef __AVX2__ -// Specialized versions for AVX2 for any CPUs that support gather/scatter. -// Todo: implement fvec_op_ny_Dxxx in the same way. - template <> void fvec_op_ny_D4( float* dis, @@ -343,16 +502,9 @@ void fvec_op_ny_D4( if (ny8 > 0) { // process 8 D4-vectors per loop. - _mm_prefetch(y, _MM_HINT_NTA); - _mm_prefetch(y + 16, _MM_HINT_NTA); - - // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0]) const __m256 m0 = _mm256_set1_ps(x[0]); - // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1]) const __m256 m1 = _mm256_set1_ps(x[1]); - // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2]) const __m256 m2 = _mm256_set1_ps(x[2]); - // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3]) const __m256 m3 = _mm256_set1_ps(x[3]); for (i = 0; i < ny8 * 8; i += 8) { @@ -395,9 +547,7 @@ void fvec_op_ny_D4( for (; i < ny; i++) { __m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y)); y += 4; - accu = _mm_hadd_ps(accu, accu); - accu = _mm_hadd_ps(accu, accu); - dis[i] = _mm_cvtss_f32(accu); + dis[i] = horizonal_sum(accu); } } } @@ -413,16 +563,9 @@ void fvec_op_ny_D4( if (ny8 > 0) { // process 8 D4-vectors per loop. - _mm_prefetch(y, _MM_HINT_NTA); - _mm_prefetch(y + 16, _MM_HINT_NTA); - - // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0]) const __m256 m0 = _mm256_set1_ps(x[0]); - // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1]) const __m256 m1 = _mm256_set1_ps(x[1]); - // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2]) const __m256 m2 = _mm256_set1_ps(x[2]); - // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3]) const __m256 m3 = _mm256_set1_ps(x[3]); for (i = 0; i < ny8 * 8; i += 8) { @@ -471,9 +614,7 @@ void fvec_op_ny_D4( for (; i < ny; i++) { __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); y += 4; - accu = _mm_hadd_ps(accu, accu); - accu = _mm_hadd_ps(accu, accu); - dis[i] = _mm_cvtss_f32(accu); + dis[i] = horizonal_sum(accu); } } } @@ -496,6 +637,182 @@ void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) { } } +#ifdef __AVX2__ + +template <> +void fvec_op_ny_D8( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D8-vectors per loop. + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + const __m256 m4 = _mm256_set1_ps(x[4]); + const __m256 m5 = _mm256_set1_ps(x[5]); + const __m256 m6 = _mm256_set1_ps(x[6]); + const __m256 m7 = _mm256_set1_ps(x[7]); + + for (i = 0; i < ny8 * 8; i += 8) { + // load 8x8 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + __m256 v4; + __m256 v5; + __m256 v6; + __m256 v7; + + transpose_8x8( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + _mm256_loadu_ps(y + 4 * 8), + _mm256_loadu_ps(y + 5 * 8), + _mm256_loadu_ps(y + 6 * 8), + _mm256_loadu_ps(y + 7 * 8), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute distances + __m256 distances = _mm256_mul_ps(m0, v0); + distances = _mm256_fmadd_ps(m1, v1, distances); + distances = _mm256_fmadd_ps(m2, v2, distances); + distances = _mm256_fmadd_ps(m3, v3, distances); + distances = _mm256_fmadd_ps(m4, v4, distances); + distances = _mm256_fmadd_ps(m5, v5, distances); + distances = _mm256_fmadd_ps(m6, v6, distances); + distances = _mm256_fmadd_ps(m7, v7, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 64; + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y)); + y += 8; + dis[i] = horizonal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D8( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D8-vectors per loop. + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + const __m256 m4 = _mm256_set1_ps(x[4]); + const __m256 m5 = _mm256_set1_ps(x[5]); + const __m256 m6 = _mm256_set1_ps(x[6]); + const __m256 m7 = _mm256_set1_ps(x[7]); + + for (i = 0; i < ny8 * 8; i += 8) { + // load 8x8 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + __m256 v4; + __m256 v5; + __m256 v6; + __m256 v7; + + transpose_8x8( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + _mm256_loadu_ps(y + 4 * 8), + _mm256_loadu_ps(y + 5 * 8), + _mm256_loadu_ps(y + 6 * 8), + _mm256_loadu_ps(y + 7 * 8), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + const __m256 d2 = _mm256_sub_ps(m2, v2); + const __m256 d3 = _mm256_sub_ps(m3, v3); + const __m256 d4 = _mm256_sub_ps(m4, v4); + const __m256 d5 = _mm256_sub_ps(m5, v5); + const __m256 d6 = _mm256_sub_ps(m6, v6); + const __m256 d7 = _mm256_sub_ps(m7, v7); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + distances = _mm256_fmadd_ps(d2, d2, distances); + distances = _mm256_fmadd_ps(d3, d3, distances); + distances = _mm256_fmadd_ps(d4, d4, distances); + distances = _mm256_fmadd_ps(d5, d5, distances); + distances = _mm256_fmadd_ps(d6, d6, distances); + distances = _mm256_fmadd_ps(d7, d7, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 64; + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y)); + y += 8; + dis[i] = horizonal_sum(accu); + } + } +} + +#endif + template void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) { __m128 x0 = _mm_loadu_ps(x); @@ -892,10 +1209,7 @@ size_t fvec_L2sqr_ny_nearest_D4( for (; i < ny; i++) { __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); y += 4; - accu = _mm_hadd_ps(accu, accu); - accu = _mm_hadd_ps(accu, accu); - - const auto distance = _mm_cvtss_f32(accu); + const float distance = horizonal_sum(accu); if (current_min_distance > distance) { current_min_distance = distance; @@ -1031,23 +1345,9 @@ size_t fvec_L2sqr_ny_nearest_D8( __m256 x0 = _mm256_loadu_ps(x); for (; i < ny; i++) { - __m256 sub = _mm256_sub_ps(x0, _mm256_loadu_ps(y)); - __m256 accu = _mm256_mul_ps(sub, sub); + __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y)); y += 8; - - // horitontal sum - const __m256 h0 = _mm256_hadd_ps(accu, accu); - const __m256 h1 = _mm256_hadd_ps(h0, h0); - - // extract high and low __m128 regs from __m256 - const __m128 h2 = _mm256_extractf128_ps(h1, 1); - const __m128 h3 = _mm256_castps256_ps128(h1); - - // get a final hsum into all 4 regs - const __m128 h4 = _mm_add_ss(h2, h3); - - // extract f[0] from __m128 - const float distance = _mm_cvtss_f32(h4); + const float distance = horizonal_sum(accu); if (current_min_distance > distance) { current_min_distance = distance;