From 6cc4523af999857e36a7a7811e08117a6c2a0883 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Tue, 25 Apr 2023 07:00:39 -0700 Subject: [PATCH] upgrade horizontal sum in distance_single_code for PQ/IVFPQ (#2830) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2830 17 cycles per AVX2 horizontal sum instead of 19 Reviewed By: mdouze Differential Revision: D45244153 fbshipit-source-id: 15accba2e8b4f12725dba41696c302e72f61c2db --- faiss/impl/code_distance/code_distance-avx2.h | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/faiss/impl/code_distance/code_distance-avx2.h b/faiss/impl/code_distance/code_distance-avx2.h index acf01cf541..3202025fb7 100644 --- a/faiss/impl/code_distance/code_distance-avx2.h +++ b/faiss/impl/code_distance/code_distance-avx2.h @@ -17,21 +17,19 @@ namespace { -// Computes a horizontal sum over an __m256 register -inline float horizontal_sum(const __m256 reg) { - const __m256 h0 = _mm256_hadd_ps(reg, reg); - const __m256 h1 = _mm256_hadd_ps(h0, h0); - - // extract high and low __m128 regs from __m256 - const __m128 h2 = _mm256_extractf128_ps(h1, 1); - const __m128 h3 = _mm256_castps256_ps128(h1); - - // get a final hsum into all 4 regs - const __m128 h4 = _mm_add_ss(h2, h3); +inline float horizontal_sum(const __m128 v) { + const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2)); + const __m128 v1 = _mm_add_ps(v, v0); + __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); + const __m128 v3 = _mm_add_ps(v1, v2); + return _mm_cvtss_f32(v3); +} - // extract f[0] from __m128 - const float hsum = _mm_cvtss_f32(h4); - return hsum; +// Computes a horizontal sum over an __m256 register +inline float horizontal_sum(const __m256 v) { + const __m128 v0 = + _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); + return horizontal_sum(v0); } } // namespace